1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #329 from a1noack/new_goal_fcns

turn unknown goal function error into a logger warning
This commit is contained in:
Jack Morris
2020-11-09 10:38:57 -05:00
committed by GitHub
2 changed files with 11 additions and 10 deletions

View File

@@ -20,12 +20,12 @@ from . import logger
# A list of goal functions and the corresponding available models.
MODELS_BY_GOAL_FUNCTIONS = {
(TargetedClassification, UntargetedClassification, InputReduction): [
r"^textattack.models.lstm_for_classification.*",
r"^textattack.models.helpers.lstm_for_classification.*",
r"^textattack.models.helpers.word_cnn_for_classification.*",
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
],
(NonOverlappingOutput, MinimizeBleu,): [
r"^textattack.models.translation.*",
r"^textattack.models.summarization.*",
r"^textattack.models.helpers.t5_for_text_to_text.*",
],
}
@@ -51,7 +51,8 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
try:
matching_model_globs = MODELS_BY_GOAL_FUNCTION[goal_function_class]
except KeyError:
raise ValueError(f"No entry found for goal function {goal_function_class}.")
matching_model_globs = []
logger.warn(f"No entry found for goal function {goal_function_class}.")
# Get options for this goal function.
# model_module = model_class.__module__
model_module_path = ".".join((model_class.__module__, model_class.__name__))
@@ -61,23 +62,23 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
logger.info(
f"Goal function {goal_function_class} compatible with model {model_class.__name__}."
)
return True
return
# If we got here, the model does not match the intended goal function.
for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items():
for glob in globs:
if re.match(glob, model_module_path):
raise ValueError(
logger.warn(
f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}."
" Found match with other goal functions: {goal_functions}."
f" Found match with other goal functions: {goal_functions}."
)
# If it matches another goal function, throw an error.
return
# If it matches another goal function, warn user.
# Otherwise, this is an unknown modelperhaps user-provided, or we forgot to
# update the corresponding dictionary. Warn user and return.
logger.warn(
f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}."
)
return True
def validate_model_gradient_word_swap_compatibility(model):