mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
support custom models; waiting on datasets integration
This commit is contained in:
@@ -27,16 +27,15 @@ class Attack:
|
||||
|
||||
def __init__(self, goal_function=None, constraints=[], transformation=None, search_method=None):
|
||||
""" Initialize an attack object. Attacks can be run multiple times. """
|
||||
self.search_method = search_method
|
||||
self.goal_function = goal_function
|
||||
if not self.goal_function:
|
||||
raise NameError('Cannot instantiate attack without self.goal_function for predictions')
|
||||
if not hasattr(self, 'tokenizer'):
|
||||
if hasattr(self.goal_function.model, 'tokenizer'):
|
||||
self.tokenizer = self.goal_function.model.tokenizer
|
||||
else:
|
||||
raise NameError('Cannot instantiate attack without tokenizer')
|
||||
self.search_method = search_method
|
||||
if not self.search_method:
|
||||
raise NameError('Cannot instantiate attack without search method')
|
||||
self.transformation = transformation
|
||||
if not self.transformation:
|
||||
raise NameError('Cannot instantiate attack without transformation')
|
||||
self.is_black_box = getattr(transformation, 'is_black_box', True)
|
||||
|
||||
if not self.search_method.check_transformation_compatibility(self.transformation):
|
||||
@@ -172,7 +171,7 @@ class Attack:
|
||||
yield
|
||||
|
||||
for text, ground_truth_output in dataset:
|
||||
tokenized_text = TokenizedText(text, self.tokenizer)
|
||||
tokenized_text = TokenizedText(text, self.goal_function.tokenizer)
|
||||
goal_function_result = self.goal_function.get_result(tokenized_text, ground_truth_output)
|
||||
# We can skip examples for which the goal is already succeeded,
|
||||
# unless `attack_skippable_examples` is True.
|
||||
|
||||
Reference in New Issue
Block a user