1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

support custom models; waiting on datasets integration

This commit is contained in:
Jack Morris
2020-06-05 17:35:52 -04:00
parent 20f3769963
commit b6dd920c08
16 changed files with 351 additions and 176 deletions

View File

@@ -27,16 +27,15 @@ class Attack:
def __init__(self, goal_function=None, constraints=[], transformation=None, search_method=None):
""" Initialize an attack object. Attacks can be run multiple times. """
self.search_method = search_method
self.goal_function = goal_function
if not self.goal_function:
raise NameError('Cannot instantiate attack without self.goal_function for predictions')
if not hasattr(self, 'tokenizer'):
if hasattr(self.goal_function.model, 'tokenizer'):
self.tokenizer = self.goal_function.model.tokenizer
else:
raise NameError('Cannot instantiate attack without tokenizer')
self.search_method = search_method
if not self.search_method:
raise NameError('Cannot instantiate attack without search method')
self.transformation = transformation
if not self.transformation:
raise NameError('Cannot instantiate attack without transformation')
self.is_black_box = getattr(transformation, 'is_black_box', True)
if not self.search_method.check_transformation_compatibility(self.transformation):
@@ -172,7 +171,7 @@ class Attack:
yield
for text, ground_truth_output in dataset:
tokenized_text = TokenizedText(text, self.tokenizer)
tokenized_text = TokenizedText(text, self.goal_function.tokenizer)
goal_function_result = self.goal_function.get_result(tokenized_text, ground_truth_output)
# We can skip examples for which the goal is already succeeded,
# unless `attack_skippable_examples` is True.