1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

examples for tensorflow and sklearn

This commit is contained in:
Jack Morris
2020-07-31 15:35:03 -04:00
parent d9ad73b8ff
commit b0f473685d
11 changed files with 1198 additions and 112 deletions

View File

@@ -165,15 +165,9 @@ class Attack:
original_text: The original ``AttackedText`` from which the attack started.
"""
# Remove any occurences of current_text in transformed_texts
original_num_texts = len(transformed_texts)
transformed_texts = [
t for t in transformed_texts if t.text != current_text.text
]
if len(transformed_texts) < original_num_texts:
# If this happened, warn the user
utils.logger.warn(
"Warning: transformation returned text with no changes. Skipping."
)
# Populate cache with transformed_texts
uncached_texts = []
for transformed_text in transformed_texts:
@@ -239,25 +233,25 @@ class Attack:
i = indices.popleft()
try:
text_input, ground_truth_output = dataset[i]
try:
# get label names from dataset, if possible
label_names = dataset.label_names
except AttributeError:
label_names = None
attacked_text = AttackedText(
text_input, attack_attrs={"label_names": label_names}
)
goal_function_result, _ = self.goal_function.init_attack_example(
attacked_text, ground_truth_output
)
yield goal_function_result
except IndexError:
utils.logger.warn(
f"Dataset has {len(dataset)} samples but tried to access index {i}. Ending attack early."
)
break
try:
# get label names from dataset, if possible
label_names = dataset.label_names
except AttributeError:
label_names = None
attacked_text = AttackedText(
text_input, attack_attrs={"label_names": label_names}
)
goal_function_result, _ = self.goal_function.init_attack_example(
attacked_text, ground_truth_output
)
yield goal_function_result
def attack_dataset(self, dataset, indices=None):
"""Runs an attack on the given dataset and outputs the results to the
console and the output file.