1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

speed up and fixes

This commit is contained in:
uvafan
2020-05-17 20:33:01 -04:00
parent b94684971e
commit edf01eb4c8
8 changed files with 9 additions and 9 deletions

View File

@@ -60,4 +60,4 @@ def Alzantot2018(model):
#
search_method = GeneticAlgorithm(pop_size=60, max_iters=20)
return Attack(goal_function, constraint, transformation, search_method)
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -13,7 +13,7 @@ class ModificationConstraint(Constraint):
def __call__(self, x, transformation):
""" Returns the word indices in x which are able to be modified """
if not self.check_compatibility(transformation):
return True
return set(range(len(x.words)))
return self._get_modifiable_indices(x)
def _get_modifiable_indices(x):

View File

@@ -23,7 +23,7 @@ class BeamSearch(SearchMethod):
potential_next_beam = []
for text in beam:
transformations = self.get_transformations(
text, original_text=original_tokenized_text)
text, original_text=initial_result.tokenized_text)
for next_text in transformations:
potential_next_beam.append(next_text)
if len(potential_next_beam) == 0:

View File

@@ -135,7 +135,7 @@ class GeneticAlgorithm(SearchMethod):
def __call__(self, initial_result):
self.original_tokenized_text = initial_result.tokenized_text
self.correct_output = intial_result.output
self.correct_output = initial_result.output
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
pop = self._generate_population(neighbors_len)
cur_score = initial_result.score

View File

@@ -76,7 +76,7 @@ class Attack:
"""
if not self.transformation:
raise RuntimeError('Cannot call `get_transformations` without a transformation.')
transformations = np.array(self.transformation(text,
modification_constraints=self.modification_constraints, **kwargs))
if apply_constraints:
@@ -138,7 +138,7 @@ class Attack:
if final_result.succeeded:
return SuccessfulAttackResult(initial_result, final_result)
else:
return FailedAttackResult(intial_result, final_result)
return FailedAttackResult(initial_result, final_result)
def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,
attack_n=False, attack_skippable_examples=False):

View File

@@ -162,7 +162,7 @@ class TokenizedText:
"""
final_sentence = ''
text = self.text
new_attack_attrs = deepcopy(self.attack_attrs)
new_attack_attrs = dict()
new_attack_attrs['stopword_indices'] = set()
new_attack_attrs['modified_indices'] = set()
new_attack_attrs['newly_modified_indices'] = set()

View File

@@ -14,8 +14,7 @@ class Transformation:
else:
indices_to_modify = set(indices_to_modify)
for constraint in modification_constraints:
if constraint.check_compatibility(self):
indices_to_modify = indices_to_modify & constraint(tokenized_text, self)
indices_to_modify = indices_to_modify & constraint(tokenized_text, self)
transformed_texts = self._get_transformations(tokenized_text, indices_to_modify)
for text in transformed_texts:
text.attack_attrs['last_transformation'] = self

View File

@@ -67,6 +67,7 @@ class WordSwapGradientBased(Transformation):
# grad differences between all flips and original word (eq. 1 from paper)
vocab_size = lookup_table.size(0)
diffs = torch.zeros(len(indices_to_replace), vocab_size)
indices_to_replace = list(indices_to_replace)
for j, word_idx in enumerate(indices_to_replace):
# Get the grad w.r.t the one-hot index of the word.
b_grads = emb_grad[word_idx].view(1,-1).mm(lookup_table_transpose).squeeze()