mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
eda tests passing
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
[94mcharswap[0m (textattack.augmentation.CharSwapAugmenter)
|
||||
[94meda[0m (textattack.augmentation.EasyDataAugmenter)
|
||||
[94membedding[0m (textattack.augmentation.EmbeddingAugmenter)
|
||||
[94mwordnet[0m (textattack.augmentation.WordNetAugmenter)
|
||||
|
||||
@@ -1,2 +1,7 @@
|
||||
from .augmenter import Augmenter
|
||||
from .recipes import WordNetAugmenter, EmbeddingAugmenter, CharSwapAugmenter, EasyDataAugmenter
|
||||
from .recipes import (
|
||||
WordNetAugmenter,
|
||||
EmbeddingAugmenter,
|
||||
CharSwapAugmenter,
|
||||
EasyDataAugmenter,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ DEFAULT_CONSTRAINTS = [
|
||||
]
|
||||
|
||||
|
||||
|
||||
class EasyDataAugmenter(Augmenter):
|
||||
"""
|
||||
|
||||
@@ -36,10 +35,14 @@ class EasyDataAugmenter(Augmenter):
|
||||
self.transformations_per_example = n_aug
|
||||
n_aug_each = max(n_aug // 4, 1)
|
||||
|
||||
self.synonym_replacement = WordNetAugmenter(transformations_per_example=n_aug_each)
|
||||
self.synonym_replacement = WordNetAugmenter(
|
||||
transformations_per_example=n_aug_each
|
||||
)
|
||||
self.random_deletion = DeletionAugmenter(transformations_per_example=n_aug_each)
|
||||
self.random_swap = SwapAugmenter(transformations_per_example=n_aug_each)
|
||||
self.random_insertion = SynonymInsertionAugmenter(transformations_per_example=n_aug_each)
|
||||
self.random_insertion = SynonymInsertionAugmenter(
|
||||
transformations_per_example=n_aug_each
|
||||
)
|
||||
|
||||
def _set_words_to_swap(self, num):
|
||||
self.synonym_replacement.num_words_to_swap = num
|
||||
@@ -49,9 +52,9 @@ class EasyDataAugmenter(Augmenter):
|
||||
|
||||
def augment(self, text):
|
||||
attacked_text = textattack.shared.AttackedText(text)
|
||||
num_words_to_swap = max(1, int(self.alpha*len(attacked_text.words)))
|
||||
num_words_to_swap = max(1, int(self.alpha * len(attacked_text.words)))
|
||||
self._set_words_to_swap(num_words_to_swap)
|
||||
|
||||
|
||||
augmented_text = [attacked_text.printable_text()]
|
||||
augmented_text += self.synonym_replacement.augment(text)
|
||||
augmented_text += self.random_deletion.augment(text)
|
||||
@@ -59,17 +62,21 @@ class EasyDataAugmenter(Augmenter):
|
||||
augmented_text += self.random_insertion.augment(text)
|
||||
|
||||
random.shuffle(augmented_text)
|
||||
return augmented_text[:self.transformations_per_example]
|
||||
return augmented_text[: self.transformations_per_example]
|
||||
|
||||
|
||||
class SwapAugmenter(Augmenter):
|
||||
def __init__(self, **kwargs):
|
||||
from textattack.transformations import RandomSwap
|
||||
|
||||
transformation = RandomSwap()
|
||||
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)
|
||||
|
||||
|
||||
class SynonymInsertionAugmenter(Augmenter):
|
||||
def __init__(self, **kwargs):
|
||||
from textattack.transformations import RandomSynonymInsertion
|
||||
|
||||
transformation = RandomSynonymInsertion()
|
||||
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)
|
||||
|
||||
@@ -83,9 +90,11 @@ class WordNetAugmenter(Augmenter):
|
||||
transformation = WordSwapWordNet()
|
||||
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)
|
||||
|
||||
|
||||
class DeletionAugmenter(Augmenter):
|
||||
def __init__(self, **kwargs):
|
||||
from textattack.transformations import WordDeletion
|
||||
|
||||
transformation = WordDeletion()
|
||||
super().__init__(transformation, constraints=DEFAULT_CONSTRAINTS, **kwargs)
|
||||
|
||||
|
||||
@@ -57,10 +57,9 @@ class AugmentCommand(TextAttackCommand):
|
||||
f"Read {len(rows)} rows from {args.csv}. Found columns {row_keys}."
|
||||
)
|
||||
|
||||
if args.recipe == 'eda':
|
||||
if args.recipe == "eda":
|
||||
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
|
||||
alpha=args.alpha,
|
||||
n_aug=args.transformations_per_example,
|
||||
alpha=args.alpha, n_aug=args.transformations_per_example,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -130,7 +129,7 @@ class AugmentCommand(TextAttackCommand):
|
||||
"--a",
|
||||
help="fraction of words to modify (EasyDataAugmenter)",
|
||||
type=float,
|
||||
default=.1,
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -209,13 +209,15 @@ class AttackedText:
|
||||
raise ValueError(f"Cannot assign word at index {i}")
|
||||
words[i] = new_word
|
||||
return self.generate_new_attacked_text(words)
|
||||
|
||||
|
||||
def replace_word_at_index(self, index, new_word):
|
||||
""" This code returns a new AttackedText object where the word at
|
||||
``index`` is replaced with a new word.
|
||||
"""
|
||||
if not isinstance(new_word, str):
|
||||
raise TypeError(f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}")
|
||||
raise TypeError(
|
||||
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
|
||||
)
|
||||
return self.replace_words_at_indices([index], [new_word])
|
||||
|
||||
def delete_word_at_index(self, index):
|
||||
|
||||
@@ -4,10 +4,12 @@ from nltk.corpus import wordnet
|
||||
|
||||
from textattack.transformations import Transformation
|
||||
|
||||
|
||||
class RandomSynonymInsertion(Transformation):
|
||||
"""
|
||||
Transformation that inserts synonyms of words that are already in the sequence.
|
||||
"""
|
||||
|
||||
def _get_synonyms(self, word):
|
||||
synonyms = set()
|
||||
for syn in wordnet.synsets(word):
|
||||
@@ -20,13 +22,20 @@ class RandomSynonymInsertion(Transformation):
|
||||
transformed_texts = []
|
||||
for idx in indices_to_modify:
|
||||
synonyms = []
|
||||
# try to find a word with synonyms, and deal with edge case where there aren't any
|
||||
for attempt in range(7):
|
||||
synonyms = self._get_synonyms(random.choice(current_text.words))
|
||||
if synonyms: break
|
||||
if synonyms:
|
||||
break
|
||||
elif attempt == 6:
|
||||
return [current_text]
|
||||
random_synonym = random.choice(synonyms)
|
||||
transformed_texts.append(current_text.insert_text_after_word_index(idx, random_synonym))
|
||||
transformed_texts.append(
|
||||
current_text.insert_text_after_word_index(idx, random_synonym)
|
||||
)
|
||||
return transformed_texts
|
||||
|
||||
|
||||
def check_if_one_word(word):
|
||||
for c in word:
|
||||
if not c.isalpha():
|
||||
|
||||
@@ -2,19 +2,20 @@ import random
|
||||
|
||||
from textattack.transformations import Transformation
|
||||
|
||||
|
||||
class RandomSwap(Transformation):
|
||||
"""
|
||||
Transformation that swaps the order of words in a sequnce.
|
||||
"""
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
transformed_texts = []
|
||||
words = current_text.words
|
||||
for idx in indices_to_modify:
|
||||
word = words[idx]
|
||||
swap_idx = random.choice(list(set(range(len(words))) - {idx}))
|
||||
swapped_text = current_text.replace_word_at_index(idx, words[swap_idx]).replace_word_at_index(swap_idx, word)
|
||||
swapped_text = current_text.replace_word_at_index(
|
||||
idx, words[swap_idx]
|
||||
).replace_word_at_index(swap_idx, word)
|
||||
transformed_texts.append(swapped_text)
|
||||
return transformed_texts
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user