1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #298 from QData/new-augmenter

New augmenter
This commit is contained in:
Jack Morris
2020-10-18 22:24:06 -04:00
committed by GitHub
4 changed files with 102 additions and 31 deletions

View File

@@ -9,10 +9,63 @@ def test_imports():
def test_embedding_augmenter():
from textattack.augmentation import EmbeddingAugmenter
augmenter = EmbeddingAugmenter(transformations_per_example=64)
augmenter = EmbeddingAugmenter(
pct_words_to_swap=0.01, transformations_per_example=64
)
s = "There is nothing either good or bad, but thinking makes it so."
augmented_text_list = augmenter.augment(s)
augmented_s = (
"There is nothing either good or unfavourable, but thinking makes it so."
)
assert augmented_s in augmented_text_list
def test_checklist_augmenter():
from textattack.augmentation import CheckListAugmenter
augmenter = CheckListAugmenter(
pct_words_to_swap=0.01, transformations_per_example=64
)
s = "I'll be happy to assist you."
augmented_text_list = augmenter.augment(s)
augmented_s = "I will be happy to assist you."
assert augmented_s in augmented_text_list
s = "I will be happy to assist you."
augmented_text_list = augmenter.augment(s)
augmented_s = "I'll be happy to assist you."
assert augmented_s in augmented_text_list
def test_charwap_augmenter():
from textattack.augmentation import CharSwapAugmenter
augmenter = CharSwapAugmenter(
pct_words_to_swap=0.01, transformations_per_example=64
)
s = "To be or not to be"
augmented_text_list = augmenter.augment(s)
augmented_s = "T be or not to be"
assert augmented_s in augmented_text_list
def test_easydata_augmenter():
from textattack.augmentation import EasyDataAugmenter
augmenter = EasyDataAugmenter(
pct_words_to_swap=0.01, transformations_per_example=64
)
s = "Hakuna Montana"
augmented_text_list = augmenter.augment(s)
augmented_s = "Montana Hakuna"
assert augmented_s in augmented_text_list
def test_wordnet_augmenter():
from textattack.augmentation import WordNetAugmenter
augmenter = WordNetAugmenter(pct_words_to_swap=0.01, transformations_per_example=64)
s = "The Dragon warrior is a panda"
augmented_text_list = augmenter.augment(s)
augmented_s = "The firedrake warrior is a panda"
assert augmented_s in augmented_text_list

View File

@@ -5,4 +5,5 @@ from .recipes import (
CharSwapAugmenter,
EasyDataAugmenter,
CheckListAugmenter,
DeletionAugmenter,
)

View File

@@ -73,28 +73,32 @@ class Augmenter:
int(self.pct_words_to_swap * len(attacked_text.words)), 1
)
for _ in range(self.transformations_per_example):
index_order = list(range(len(attacked_text.words)))
random.shuffle(index_order)
current_text = attacked_text
words_swapped = 0
for i in index_order:
words_swapped = len(current_text.attack_attrs["modified_indices"])
while words_swapped < num_words_to_swap:
transformed_texts = self.transformation(
current_text, self.pre_transformation_constraints, [i]
current_text, self.pre_transformation_constraints
)
# Get rid of transformations we already have
transformed_texts = [
t for t in transformed_texts if t not in all_transformed_texts
]
# Filter out transformations that don't match the constraints.
transformed_texts = self._filter_transformations(
transformed_texts, current_text, original_text
)
# if there's no more transformed texts after filter, terminate
if not len(transformed_texts):
continue
current_text = random.choice(transformed_texts)
words_swapped += 1
if words_swapped == num_words_to_swap:
break
current_text = random.choice(transformed_texts)
# update words_swapped based on modified indices
words_swapped = len(current_text.attack_attrs["modified_indices"])
all_transformed_texts.add(current_text)
return sorted([at.printable_text() for at in all_transformed_texts])

View File

@@ -40,7 +40,7 @@ class AugmentCommand(TextAttackCommand):
while True:
print(
'\nEnter a sentence to augment, "q" to quit, "c" to change arguments:\n'
'\nEnter a sentence to augment, "q" to quit, "c" to view/change arguments:\n'
)
text = input()
@@ -48,29 +48,42 @@ class AugmentCommand(TextAttackCommand):
break
elif text == "c":
print("\nChanging augmenter arguments...\n")
recipe = input(
"\tAugmentation recipe name ('r' to see available recipes): "
)
if recipe == "r":
print("\n\twordnet, embedding, charswap, eda\n")
args.recipe = input("\tAugmentation recipe name: ")
else:
args.recipe = recipe
args.pct_words_to_swap = float(
input("\tPercentage of words to swap (0.0 ~ 1.0): ")
)
args.transformations_per_example = int(
input("\tTransformations per input example: ")
print(
f"\nCurrent Arguments:\n\n\t augmentation recipe: {args.recipe}, "
f"\n\t pct_words_to_swap: {args.pct_words_to_swap}, "
f"\n\t transformations_per_example: {args.transformations_per_example}\n"
)
print("\nGenerating new augmenter...\n")
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
pct_words_to_swap=args.pct_words_to_swap,
transformations_per_example=args.transformations_per_example,
change = input(
"Enter 'c' again to change arguments, any other keys to opt out\n"
)
print("--------------------------------------------------------")
if change == "c":
print("\nChanging augmenter arguments...\n")
recipe = input(
"\tAugmentation recipe name ('r' to see available recipes): "
)
if recipe == "r":
print("\n\twordnet, embedding, charswap, eda, checklist\n")
args.recipe = input("\tAugmentation recipe name: ")
else:
args.recipe = recipe
args.pct_words_to_swap = float(
input("\tPercentage of words to swap (0.0 ~ 1.0): ")
)
args.transformations_per_example = int(
input("\tTransformations per input example: ")
)
print("\nGenerating new augmenter...\n")
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
pct_words_to_swap=args.pct_words_to_swap,
transformations_per_example=args.transformations_per_example,
)
print(
"--------------------------------------------------------"
)
continue
elif not text: