1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

TT->AttackedText, remove ids, give TT keys, batch tokenize

This commit is contained in:
Jack Morris
2020-06-16 19:46:20 -04:00
parent c155c0a390
commit 9bf213f7fd
23 changed files with 156 additions and 108 deletions

View File

@@ -10,7 +10,7 @@ from textattack.attack_results import (
SkippedAttackResult,
SuccessfulAttackResult,
)
from textattack.shared import TokenizedText, utils
from textattack.shared import AttackedText, utils
class Attack:
@@ -80,8 +80,8 @@ class Attack:
through the applicable constraints.
Args:
current_text: The current ``TokenizedText`` on which to perform the transformations.
original_text: The original ``TokenizedText`` from which the attack started.
current_text: The current ``AttackedText`` on which to perform the transformations.
original_text: The original ``AttackedText`` from which the attack started.
apply_constraints: Whether or not to apply post-transformation constraints.
Returns:
@@ -111,9 +111,9 @@ class Attack:
Filters a list of potential transformaed texts based on ``self.constraints``\.
Args:
transformed_texts: A list of candidate transformed ``TokenizedText``\s to filter.
current_text: The current ``TokenizedText`` on which the transformation was applied.
original_text: The original ``TokenizedText`` from which the attack started.
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
current_text: The current ``AttackedText`` on which the transformation was applied.
original_text: The original ``AttackedText`` from which the attack started.
"""
filtered_texts = transformed_texts[:]
for C in self.constraints:
@@ -138,9 +138,9 @@ class Attack:
Checks cache first.
Args:
transformed_texts: A list of candidate transformed ``TokenizedText``\s to filter.
current_text: The current ``TokenizedText`` on which the transformation was applied.
original_text: The original ``TokenizedText`` from which the attack started.
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
current_text: The current ``AttackedText`` on which the transformation was applied.
original_text: The original ``AttackedText`` from which the attack started.
"""
# Populate cache with transformed_texts
uncached_texts = []
@@ -163,7 +163,7 @@ class Attack:
def attack_one(self, initial_result):
"""
Calls the ``SearchMethod`` to perturb the ``TokenizedText`` stored in
Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
``initial_result``.
Args:
@@ -217,7 +217,7 @@ class Attack:
yield
for text, ground_truth_output in dataset:
tokenized_text = TokenizedText(text, self.goal_function.tokenizer)
tokenized_text = AttackedText(text, self.goal_function.tokenizer)
self.goal_function.num_queries = 0
goal_function_result, _ = self.goal_function.get_result(
tokenized_text, ground_truth_output