mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
speed up and fixes
This commit is contained in:
@@ -60,4 +60,4 @@ def Alzantot2018(model):
|
||||
#
|
||||
search_method = GeneticAlgorithm(pop_size=60, max_iters=20)
|
||||
|
||||
return Attack(goal_function, constraint, transformation, search_method)
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
|
||||
@@ -13,7 +13,7 @@ class ModificationConstraint(Constraint):
|
||||
def __call__(self, x, transformation):
|
||||
""" Returns the word indices in x which are able to be modified """
|
||||
if not self.check_compatibility(transformation):
|
||||
return True
|
||||
return set(range(len(x.words)))
|
||||
return self._get_modifiable_indices(x)
|
||||
|
||||
def _get_modifiable_indices(x):
|
||||
|
||||
@@ -23,7 +23,7 @@ class BeamSearch(SearchMethod):
|
||||
potential_next_beam = []
|
||||
for text in beam:
|
||||
transformations = self.get_transformations(
|
||||
text, original_text=original_tokenized_text)
|
||||
text, original_text=initial_result.tokenized_text)
|
||||
for next_text in transformations:
|
||||
potential_next_beam.append(next_text)
|
||||
if len(potential_next_beam) == 0:
|
||||
|
||||
@@ -135,7 +135,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
|
||||
def __call__(self, initial_result):
|
||||
self.original_tokenized_text = initial_result.tokenized_text
|
||||
self.correct_output = intial_result.output
|
||||
self.correct_output = initial_result.output
|
||||
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
||||
pop = self._generate_population(neighbors_len)
|
||||
cur_score = initial_result.score
|
||||
|
||||
@@ -76,7 +76,7 @@ class Attack:
|
||||
"""
|
||||
if not self.transformation:
|
||||
raise RuntimeError('Cannot call `get_transformations` without a transformation.')
|
||||
|
||||
|
||||
transformations = np.array(self.transformation(text,
|
||||
modification_constraints=self.modification_constraints, **kwargs))
|
||||
if apply_constraints:
|
||||
@@ -138,7 +138,7 @@ class Attack:
|
||||
if final_result.succeeded:
|
||||
return SuccessfulAttackResult(initial_result, final_result)
|
||||
else:
|
||||
return FailedAttackResult(intial_result, final_result)
|
||||
return FailedAttackResult(initial_result, final_result)
|
||||
|
||||
def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,
|
||||
attack_n=False, attack_skippable_examples=False):
|
||||
|
||||
@@ -162,7 +162,7 @@ class TokenizedText:
|
||||
"""
|
||||
final_sentence = ''
|
||||
text = self.text
|
||||
new_attack_attrs = deepcopy(self.attack_attrs)
|
||||
new_attack_attrs = dict()
|
||||
new_attack_attrs['stopword_indices'] = set()
|
||||
new_attack_attrs['modified_indices'] = set()
|
||||
new_attack_attrs['newly_modified_indices'] = set()
|
||||
|
||||
@@ -14,8 +14,7 @@ class Transformation:
|
||||
else:
|
||||
indices_to_modify = set(indices_to_modify)
|
||||
for constraint in modification_constraints:
|
||||
if constraint.check_compatibility(self):
|
||||
indices_to_modify = indices_to_modify & constraint(tokenized_text, self)
|
||||
indices_to_modify = indices_to_modify & constraint(tokenized_text, self)
|
||||
transformed_texts = self._get_transformations(tokenized_text, indices_to_modify)
|
||||
for text in transformed_texts:
|
||||
text.attack_attrs['last_transformation'] = self
|
||||
|
||||
@@ -67,6 +67,7 @@ class WordSwapGradientBased(Transformation):
|
||||
# grad differences between all flips and original word (eq. 1 from paper)
|
||||
vocab_size = lookup_table.size(0)
|
||||
diffs = torch.zeros(len(indices_to_replace), vocab_size)
|
||||
indices_to_replace = list(indices_to_replace)
|
||||
for j, word_idx in enumerate(indices_to_replace):
|
||||
# Get the grad w.r.t the one-hot index of the word.
|
||||
b_grads = emb_grad[word_idx].view(1,-1).mm(lookup_table_transpose).squeeze()
|
||||
|
||||
Reference in New Issue
Block a user