mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update compatibility check with current modules
This commit is contained in:
@@ -36,7 +36,7 @@ class GoalFunction(ABC):
|
||||
model_cache_size=2 ** 20,
|
||||
):
|
||||
validators.validate_model_goal_function_compatibility(
|
||||
self.__class__, model_wrapper.__class__
|
||||
self.__class__, model_wrapper.model.__class__
|
||||
)
|
||||
self.model = model_wrapper
|
||||
self.maximizable = maximizable
|
||||
|
||||
@@ -20,12 +20,12 @@ from . import logger
|
||||
# A list of goal functions and the corresponding available models.
|
||||
MODELS_BY_GOAL_FUNCTIONS = {
|
||||
(TargetedClassification, UntargetedClassification, InputReduction): [
|
||||
r"^textattack.models.lstm_for_classification.*",
|
||||
r"^textattack.models.helpers.lstm_for_classification.*",
|
||||
r"^textattack.models.helpers.word_cnn_for_classification.*",
|
||||
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
|
||||
],
|
||||
(NonOverlappingOutput, MinimizeBleu,): [
|
||||
r"^textattack.models.translation.*",
|
||||
r"^textattack.models.summarization.*",
|
||||
r"^textattack.models.helpers.t5_for_text_to_text.*",
|
||||
],
|
||||
}
|
||||
|
||||
@@ -62,15 +62,16 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
|
||||
logger.info(
|
||||
f"Goal function {goal_function_class} compatible with model {model_class.__name__}."
|
||||
)
|
||||
return True
|
||||
return
|
||||
# If we got here, the model does not match the intended goal function.
|
||||
for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items():
|
||||
for glob in globs:
|
||||
if re.match(glob, model_module_path):
|
||||
logger.warn(
|
||||
f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}."
|
||||
" Found match with other goal functions: {goal_functions}."
|
||||
f" Found match with other goal functions: {goal_functions}."
|
||||
)
|
||||
return
|
||||
# If it matches another goal function, warn user.
|
||||
|
||||
# Otherwise, this is an unknown model–perhaps user-provided, or we forgot to
|
||||
@@ -78,7 +79,6 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
|
||||
logger.warn(
|
||||
f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}."
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def validate_model_gradient_word_swap_compatibility(model):
|
||||
|
||||
Reference in New Issue
Block a user