1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add validator and do each step in 1 backwards pass

This commit is contained in:
Jack Morris
2020-05-05 14:28:48 -04:00
parent 1a5ff59190
commit bd57fa642f
6 changed files with 60 additions and 36 deletions

View File

@@ -1,6 +1,7 @@
import collections
import re
import textattack
from textattack.goal_functions import *
from .utils import get_logger
@@ -59,4 +60,16 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
# Otherwise, this is an unknown modelperhaps user-provided, or we forgot to
# update the corresponding dictionary. Warn user and return.
logger.warn(f'Unknown if model {model} compatible with goal function {goal_function}.')
return True
return True
def validate_model_gradient_word_swap_compatibility(model):
"""
Determines if `model` is task-compatible with `GradientBasedWordSwap`.
We can only take the gradient with respect to an individual word if the
model uses a word-based tokenizer.
"""
if isinstance(model, textattack.models.helpers.LSTMForClassification):
return True
else:
raise ValueError(f'Cannot perform GradientBasedWordSwap on model {model}.')