mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add validator and do each step in 1 backwards pass
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import collections
|
||||
import re
|
||||
|
||||
import textattack
|
||||
from textattack.goal_functions import *
|
||||
|
||||
from .utils import get_logger
|
||||
@@ -59,4 +60,16 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
|
||||
# Otherwise, this is an unknown model–perhaps user-provided, or we forgot to
|
||||
# update the corresponding dictionary. Warn user and return.
|
||||
logger.warn(f'Unknown if model {model} compatible with goal function {goal_function}.')
|
||||
return True
|
||||
return True
|
||||
|
||||
def validate_model_gradient_word_swap_compatibility(model):
|
||||
"""
|
||||
Determines if `model` is task-compatible with `GradientBasedWordSwap`.
|
||||
|
||||
We can only take the gradient with respect to an individual word if the
|
||||
model uses a word-based tokenizer.
|
||||
"""
|
||||
if isinstance(model, textattack.models.helpers.LSTMForClassification):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f'Cannot perform GradientBasedWordSwap on model {model}.')
|
||||
Reference in New Issue
Block a user