1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

eda tests passing

This commit is contained in:
jakegrigsby
2020-06-29 19:02:29 -04:00
parent 20234b5203
commit c88798073b
7 changed files with 45 additions and 19 deletions

View File

@@ -1,3 +1,4 @@
charswap (textattack.augmentation.CharSwapAugmenter)
eda (textattack.augmentation.EasyDataAugmenter)
embedding (textattack.augmentation.EmbeddingAugmenter)
wordnet (textattack.augmentation.WordNetAugmenter)

View File

@@ -1,2 +1,7 @@
from .augmenter import Augmenter
from .recipes import WordNetAugmenter, EmbeddingAugmenter, CharSwapAugmenter, EasyDataAugmenter
from .recipes import (
WordNetAugmenter,
EmbeddingAugmenter,
CharSwapAugmenter,
EasyDataAugmenter,
)

View File

@@ -10,7 +10,6 @@ DEFAULT_CONSTRAINTS = [
]
class EasyDataAugmenter(Augmenter):
"""
@@ -36,10 +35,14 @@ class EasyDataAugmenter(Augmenter):
self.transformations_per_example = n_aug
n_aug_each = max(n_aug // 4, 1)
self.synonym_replacement = WordNetAugmenter(transformations_per_example=n_aug_each)
self.synonym_replacement = WordNetAugmenter(
transformations_per_example=n_aug_each
)
self.random_deletion = DeletionAugmenter(transformations_per_example=n_aug_each)
self.random_swap = SwapAugmenter(transformations_per_example=n_aug_each)
self.random_insertion = SynonymInsertionAugmenter(transformations_per_example=n_aug_each)
self.random_insertion = SynonymInsertionAugmenter(
transformations_per_example=n_aug_each
)
def _set_words_to_swap(self, num):
self.synonym_replacement.num_words_to_swap = num
@@ -49,9 +52,9 @@ class EasyDataAugmenter(Augmenter):
def augment(self, text):
attacked_text = textattack.shared.AttackedText(text)
num_words_to_swap = max(1, int(self.alpha*len(attacked_text.words)))
num_words_to_swap = max(1, int(self.alpha * len(attacked_text.words)))
self._set_words_to_swap(num_words_to_swap)
augmented_text = [attacked_text.printable_text()]
augmented_text += self.synonym_replacement.augment(text)
augmented_text += self.random_deletion.augment(text)
@@ -59,17 +62,21 @@ class EasyDataAugmenter(Augmenter):
augmented_text += self.random_insertion.augment(text)
random.shuffle(augmented_text)
return augmented_text[:self.transformations_per_example]
return augmented_text[: self.transformations_per_example]
class SwapAugmenter(Augmenter):
def __init__(self, **kwargs):
from textattack.transformations import RandomSwap
transformation = RandomSwap()
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)
class SynonymInsertionAugmenter(Augmenter):
def __init__(self, **kwargs):
from textattack.transformations import RandomSynonymInsertion
transformation = RandomSynonymInsertion()
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)
@@ -83,9 +90,11 @@ class WordNetAugmenter(Augmenter):
transformation = WordSwapWordNet()
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)
class DeletionAugmenter(Augmenter):
def __init__(self, **kwargs):
from textattack.transformations import WordDeletion
transformation = WordDeletion()
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)

View File

@@ -57,10 +57,9 @@ class AugmentCommand(TextAttackCommand):
f"Read {len(rows)} rows from {args.csv}. Found columns {row_keys}."
)
if args.recipe == 'eda':
if args.recipe == "eda":
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
alpha=args.alpha,
n_aug=args.transformations_per_example,
alpha=args.alpha, n_aug=args.transformations_per_example,
)
else:
@@ -130,7 +129,7 @@ class AugmentCommand(TextAttackCommand):
"--a",
help="fraction of words to modify (EasyDataAugmenter)",
type=float,
default=.1,
default=0.1,
)
parser.add_argument(

View File

@@ -209,13 +209,15 @@ class AttackedText:
raise ValueError(f"Cannot assign word at index {i}")
words[i] = new_word
return self.generate_new_attacked_text(words)
def replace_word_at_index(self, index, new_word):
""" This code returns a new AttackedText object where the word at
``index`` is replaced with a new word.
"""
if not isinstance(new_word, str):
raise TypeError(f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}")
raise TypeError(
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
)
return self.replace_words_at_indices([index], [new_word])
def delete_word_at_index(self, index):

View File

@@ -4,10 +4,12 @@ from nltk.corpus import wordnet
from textattack.transformations import Transformation
class RandomSynonymInsertion(Transformation):
"""
Transformation that inserts synonyms of words that are already in the sequence.
"""
def _get_synonyms(self, word):
synonyms = set()
for syn in wordnet.synsets(word):
@@ -20,13 +22,20 @@ class RandomSynonymInsertion(Transformation):
transformed_texts = []
for idx in indices_to_modify:
synonyms = []
# try to find a word with synonyms, and deal with edge case where there aren't any
for attempt in range(7):
synonyms = self._get_synonyms(random.choice(current_text.words))
if synonyms: break
if synonyms:
break
elif attempt == 6:
return [current_text]
random_synonym = random.choice(synonyms)
transformed_texts.append(current_text.insert_text_after_word_index(idx, random_synonym))
transformed_texts.append(
current_text.insert_text_after_word_index(idx, random_synonym)
)
return transformed_texts
def check_if_one_word(word):
for c in word:
if not c.isalpha():

View File

@@ -2,19 +2,20 @@ import random
from textattack.transformations import Transformation
class RandomSwap(Transformation):
"""
Transformation that swaps the order of words in a sequnce.
"""
def _get_transformations(self, current_text, indices_to_modify):
transformed_texts = []
words = current_text.words
for idx in indices_to_modify:
word = words[idx]
swap_idx = random.choice(list(set(range(len(words))) - {idx}))
swapped_text = current_text.replace_word_at_index(idx, words[swap_idx]).replace_word_at_index(swap_idx, word)
swapped_text = current_text.replace_word_at_index(
idx, words[swap_idx]
).replace_word_at_index(swap_idx, word)
transformed_texts.append(swapped_text)
return transformed_texts