mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
@@ -9,10 +9,63 @@ def test_imports():
|
||||
def test_embedding_augmenter():
|
||||
from textattack.augmentation import EmbeddingAugmenter
|
||||
|
||||
augmenter = EmbeddingAugmenter(transformations_per_example=64)
|
||||
augmenter = EmbeddingAugmenter(
|
||||
pct_words_to_swap=0.01, transformations_per_example=64
|
||||
)
|
||||
s = "There is nothing either good or bad, but thinking makes it so."
|
||||
augmented_text_list = augmenter.augment(s)
|
||||
augmented_s = (
|
||||
"There is nothing either good or unfavourable, but thinking makes it so."
|
||||
)
|
||||
assert augmented_s in augmented_text_list
|
||||
|
||||
|
||||
def test_checklist_augmenter():
|
||||
from textattack.augmentation import CheckListAugmenter
|
||||
|
||||
augmenter = CheckListAugmenter(
|
||||
pct_words_to_swap=0.01, transformations_per_example=64
|
||||
)
|
||||
s = "I'll be happy to assist you."
|
||||
augmented_text_list = augmenter.augment(s)
|
||||
augmented_s = "I will be happy to assist you."
|
||||
assert augmented_s in augmented_text_list
|
||||
|
||||
s = "I will be happy to assist you."
|
||||
augmented_text_list = augmenter.augment(s)
|
||||
augmented_s = "I'll be happy to assist you."
|
||||
assert augmented_s in augmented_text_list
|
||||
|
||||
|
||||
def test_charwap_augmenter():
|
||||
from textattack.augmentation import CharSwapAugmenter
|
||||
|
||||
augmenter = CharSwapAugmenter(
|
||||
pct_words_to_swap=0.01, transformations_per_example=64
|
||||
)
|
||||
s = "To be or not to be"
|
||||
augmented_text_list = augmenter.augment(s)
|
||||
augmented_s = "T be or not to be"
|
||||
assert augmented_s in augmented_text_list
|
||||
|
||||
|
||||
def test_easydata_augmenter():
|
||||
from textattack.augmentation import EasyDataAugmenter
|
||||
|
||||
augmenter = EasyDataAugmenter(
|
||||
pct_words_to_swap=0.01, transformations_per_example=64
|
||||
)
|
||||
s = "Hakuna Montana"
|
||||
augmented_text_list = augmenter.augment(s)
|
||||
augmented_s = "Montana Hakuna"
|
||||
assert augmented_s in augmented_text_list
|
||||
|
||||
|
||||
def test_wordnet_augmenter():
|
||||
from textattack.augmentation import WordNetAugmenter
|
||||
|
||||
augmenter = WordNetAugmenter(pct_words_to_swap=0.01, transformations_per_example=64)
|
||||
s = "The Dragon warrior is a panda"
|
||||
augmented_text_list = augmenter.augment(s)
|
||||
augmented_s = "The firedrake warrior is a panda"
|
||||
assert augmented_s in augmented_text_list
|
||||
|
||||
@@ -5,4 +5,5 @@ from .recipes import (
|
||||
CharSwapAugmenter,
|
||||
EasyDataAugmenter,
|
||||
CheckListAugmenter,
|
||||
DeletionAugmenter,
|
||||
)
|
||||
|
||||
@@ -73,28 +73,32 @@ class Augmenter:
|
||||
int(self.pct_words_to_swap * len(attacked_text.words)), 1
|
||||
)
|
||||
for _ in range(self.transformations_per_example):
|
||||
index_order = list(range(len(attacked_text.words)))
|
||||
random.shuffle(index_order)
|
||||
current_text = attacked_text
|
||||
words_swapped = 0
|
||||
for i in index_order:
|
||||
words_swapped = len(current_text.attack_attrs["modified_indices"])
|
||||
|
||||
while words_swapped < num_words_to_swap:
|
||||
transformed_texts = self.transformation(
|
||||
current_text, self.pre_transformation_constraints, [i]
|
||||
current_text, self.pre_transformation_constraints
|
||||
)
|
||||
|
||||
# Get rid of transformations we already have
|
||||
transformed_texts = [
|
||||
t for t in transformed_texts if t not in all_transformed_texts
|
||||
]
|
||||
|
||||
# Filter out transformations that don't match the constraints.
|
||||
transformed_texts = self._filter_transformations(
|
||||
transformed_texts, current_text, original_text
|
||||
)
|
||||
|
||||
# if there's no more transformed texts after filter, terminate
|
||||
if not len(transformed_texts):
|
||||
continue
|
||||
current_text = random.choice(transformed_texts)
|
||||
words_swapped += 1
|
||||
if words_swapped == num_words_to_swap:
|
||||
break
|
||||
|
||||
current_text = random.choice(transformed_texts)
|
||||
|
||||
# update words_swapped based on modified indices
|
||||
words_swapped = len(current_text.attack_attrs["modified_indices"])
|
||||
all_transformed_texts.add(current_text)
|
||||
return sorted([at.printable_text() for at in all_transformed_texts])
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class AugmentCommand(TextAttackCommand):
|
||||
|
||||
while True:
|
||||
print(
|
||||
'\nEnter a sentence to augment, "q" to quit, "c" to change arguments:\n'
|
||||
'\nEnter a sentence to augment, "q" to quit, "c" to view/change arguments:\n'
|
||||
)
|
||||
text = input()
|
||||
|
||||
@@ -48,29 +48,42 @@ class AugmentCommand(TextAttackCommand):
|
||||
break
|
||||
|
||||
elif text == "c":
|
||||
print("\nChanging augmenter arguments...\n")
|
||||
recipe = input(
|
||||
"\tAugmentation recipe name ('r' to see available recipes): "
|
||||
)
|
||||
if recipe == "r":
|
||||
print("\n\twordnet, embedding, charswap, eda\n")
|
||||
args.recipe = input("\tAugmentation recipe name: ")
|
||||
else:
|
||||
args.recipe = recipe
|
||||
|
||||
args.pct_words_to_swap = float(
|
||||
input("\tPercentage of words to swap (0.0 ~ 1.0): ")
|
||||
)
|
||||
args.transformations_per_example = int(
|
||||
input("\tTransformations per input example: ")
|
||||
print(
|
||||
f"\nCurrent Arguments:\n\n\t augmentation recipe: {args.recipe}, "
|
||||
f"\n\t pct_words_to_swap: {args.pct_words_to_swap}, "
|
||||
f"\n\t transformations_per_example: {args.transformations_per_example}\n"
|
||||
)
|
||||
|
||||
print("\nGenerating new augmenter...\n")
|
||||
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
|
||||
pct_words_to_swap=args.pct_words_to_swap,
|
||||
transformations_per_example=args.transformations_per_example,
|
||||
change = input(
|
||||
"Enter 'c' again to change arguments, any other keys to opt out\n"
|
||||
)
|
||||
print("--------------------------------------------------------")
|
||||
if change == "c":
|
||||
print("\nChanging augmenter arguments...\n")
|
||||
recipe = input(
|
||||
"\tAugmentation recipe name ('r' to see available recipes): "
|
||||
)
|
||||
if recipe == "r":
|
||||
print("\n\twordnet, embedding, charswap, eda, checklist\n")
|
||||
args.recipe = input("\tAugmentation recipe name: ")
|
||||
else:
|
||||
args.recipe = recipe
|
||||
|
||||
args.pct_words_to_swap = float(
|
||||
input("\tPercentage of words to swap (0.0 ~ 1.0): ")
|
||||
)
|
||||
args.transformations_per_example = int(
|
||||
input("\tTransformations per input example: ")
|
||||
)
|
||||
|
||||
print("\nGenerating new augmenter...\n")
|
||||
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
|
||||
pct_words_to_swap=args.pct_words_to_swap,
|
||||
transformations_per_example=args.transformations_per_example,
|
||||
)
|
||||
print(
|
||||
"--------------------------------------------------------"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
elif not text:
|
||||
|
||||
Reference in New Issue
Block a user