diff --git a/tests/sample_outputs/list_augmentation_recipes.txt b/tests/sample_outputs/list_augmentation_recipes.txt index 92952f00..5565662c 100644 --- a/tests/sample_outputs/list_augmentation_recipes.txt +++ b/tests/sample_outputs/list_augmentation_recipes.txt @@ -1,3 +1,4 @@ charswap (textattack.augmentation.CharSwapAugmenter) +eda (textattack.augmentation.EasyDataAugmenter) embedding (textattack.augmentation.EmbeddingAugmenter) wordnet (textattack.augmentation.WordNetAugmenter) diff --git a/textattack/augmentation/__init__.py b/textattack/augmentation/__init__.py index 17aea77c..cd2dd856 100644 --- a/textattack/augmentation/__init__.py +++ b/textattack/augmentation/__init__.py @@ -1,2 +1,7 @@ from .augmenter import Augmenter -from .recipes import WordNetAugmenter, EmbeddingAugmenter, CharSwapAugmenter, EasyDataAugmenter +from .recipes import ( + WordNetAugmenter, + EmbeddingAugmenter, + CharSwapAugmenter, + EasyDataAugmenter, +) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index d0906bec..748f0b98 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -10,7 +10,6 @@ DEFAULT_CONSTRAINTS = [ ] - class EasyDataAugmenter(Augmenter): """ @@ -36,10 +35,14 @@ class EasyDataAugmenter(Augmenter): self.transformations_per_example = n_aug n_aug_each = max(n_aug // 4, 1) - self.synonym_replacement = WordNetAugmenter(transformations_per_example=n_aug_each) + self.synonym_replacement = WordNetAugmenter( + transformations_per_example=n_aug_each + ) self.random_deletion = DeletionAugmenter(transformations_per_example=n_aug_each) self.random_swap = SwapAugmenter(transformations_per_example=n_aug_each) - self.random_insertion = SynonymInsertionAugmenter(transformations_per_example=n_aug_each) + self.random_insertion = SynonymInsertionAugmenter( + transformations_per_example=n_aug_each + ) def _set_words_to_swap(self, num): self.synonym_replacement.num_words_to_swap = num @@ -49,9 +52,9 @@ class EasyDataAugmenter(Augmenter): def augment(self, text): attacked_text = textattack.shared.AttackedText(text) - num_words_to_swap = max(1, int(self.alpha*len(attacked_text.words))) + num_words_to_swap = max(1, int(self.alpha * len(attacked_text.words))) self._set_words_to_swap(num_words_to_swap) - + augmented_text = [attacked_text.printable_text()] augmented_text += self.synonym_replacement.augment(text) augmented_text += self.random_deletion.augment(text) @@ -59,17 +62,21 @@ class EasyDataAugmenter(Augmenter): augmented_text += self.random_insertion.augment(text) random.shuffle(augmented_text) - return augmented_text[:self.transformations_per_example] + return augmented_text[: self.transformations_per_example] + class SwapAugmenter(Augmenter): def __init__(self, **kwargs): from textattack.transformations import RandomSwap + transformation = RandomSwap() super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs) + class SynonymInsertionAugmenter(Augmenter): def __init__(self, **kwargs): from textattack.transformations import RandomSynonymInsertion + transformation = RandomSynonymInsertion() super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs) @@ -83,9 +90,11 @@ class WordNetAugmenter(Augmenter): transformation = WordSwapWordNet() super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs) + class DeletionAugmenter(Augmenter): def __init__(self, **kwargs): from textattack.transformations import WordDeletion + transformation = WordDeletion() super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs) diff --git a/textattack/commands/augment.py b/textattack/commands/augment.py index 4db4d16e..3c37004b 100644 --- a/textattack/commands/augment.py +++ b/textattack/commands/augment.py @@ -57,10 +57,9 @@ class AugmentCommand(TextAttackCommand): f"Read {len(rows)} rows from {args.csv}. Found columns {row_keys}." ) - if args.recipe == 'eda': + if args.recipe == "eda": augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( - alpha=args.alpha, - n_aug=args.transformations_per_example, + alpha=args.alpha, n_aug=args.transformations_per_example, ) else: @@ -130,7 +129,7 @@ class AugmentCommand(TextAttackCommand): "--a", help="fraction of words to modify (EasyDataAugmenter)", type=float, - default=.1, + default=0.1, ) parser.add_argument( diff --git a/textattack/shared/attacked_text.py b/textattack/shared/attacked_text.py index 6e65d4e9..1fb26479 100644 --- a/textattack/shared/attacked_text.py +++ b/textattack/shared/attacked_text.py @@ -209,13 +209,15 @@ class AttackedText: raise ValueError(f"Cannot assign word at index {i}") words[i] = new_word return self.generate_new_attacked_text(words) - + def replace_word_at_index(self, index, new_word): """ This code returns a new AttackedText object where the word at ``index`` is replaced with a new word. """ if not isinstance(new_word, str): - raise TypeError(f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}") + raise TypeError( + f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}" + ) return self.replace_words_at_indices([index], [new_word]) def delete_word_at_index(self, index): diff --git a/textattack/transformations/random_synonym_insertion.py b/textattack/transformations/random_synonym_insertion.py index 071ccb0f..9a5ef008 100644 --- a/textattack/transformations/random_synonym_insertion.py +++ b/textattack/transformations/random_synonym_insertion.py @@ -4,10 +4,12 @@ from nltk.corpus import wordnet from textattack.transformations import Transformation + class RandomSynonymInsertion(Transformation): """ Transformation that inserts synonyms of words that are already in the sequence. """ + def _get_synonyms(self, word): synonyms = set() for syn in wordnet.synsets(word): @@ -20,13 +22,20 @@ class RandomSynonymInsertion(Transformation): transformed_texts = [] for idx in indices_to_modify: synonyms = [] + # try to find a word with synonyms, and deal with edge case where there aren't any for attempt in range(7): synonyms = self._get_synonyms(random.choice(current_text.words)) - if synonyms: break + if synonyms: + break + elif attempt == 6: + return [current_text] random_synonym = random.choice(synonyms) - transformed_texts.append(current_text.insert_text_after_word_index(idx, random_synonym)) + transformed_texts.append( + current_text.insert_text_after_word_index(idx, random_synonym) + ) return transformed_texts + def check_if_one_word(word): for c in word: if not c.isalpha(): diff --git a/textattack/transformations/word_swap_random_word.py b/textattack/transformations/word_swap_random_word.py index a3375aa6..719bd2d3 100644 --- a/textattack/transformations/word_swap_random_word.py +++ b/textattack/transformations/word_swap_random_word.py @@ -2,19 +2,20 @@ import random from textattack.transformations import Transformation + class RandomSwap(Transformation): """ Transformation that swaps the order of words in a sequnce. """ + def _get_transformations(self, current_text, indices_to_modify): transformed_texts = [] words = current_text.words for idx in indices_to_modify: word = words[idx] swap_idx = random.choice(list(set(range(len(words))) - {idx})) - swapped_text = current_text.replace_word_at_index(idx, words[swap_idx]).replace_word_at_index(swap_idx, word) + swapped_text = current_text.replace_word_at_index( + idx, words[swap_idx] + ).replace_word_at_index(swap_idx, word) transformed_texts.append(swapped_text) return transformed_texts - - -