mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #477 from cogeid/Fix-RandomSwap-and-RandomSynonymInsertion-bug
Fix-RandomSwap-and-RandomSynonymInsertion-bug pr 368 to pass pytest
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Augmenter Class
|
||||
===================
|
||||
|
||||
"""
|
||||
import random
|
||||
|
||||
@@ -103,7 +102,10 @@ class Augmenter:
|
||||
current_text = random.choice(transformed_texts)
|
||||
|
||||
# update words_swapped based on modified indices
|
||||
words_swapped = len(current_text.attack_attrs["modified_indices"])
|
||||
words_swapped = max(
|
||||
len(current_text.attack_attrs["modified_indices"]),
|
||||
words_swapped + 1,
|
||||
)
|
||||
all_transformed_texts.add(current_text)
|
||||
return sorted([at.printable_text() for at in all_transformed_texts])
|
||||
|
||||
@@ -113,7 +115,6 @@ class Augmenter:
|
||||
|
||||
Args:
|
||||
text_list (list(string)): a list of strings for data augmentation
|
||||
|
||||
Returns a list(string) of augmented texts.
|
||||
"""
|
||||
if show_progress:
|
||||
|
||||
@@ -67,6 +67,7 @@ class EasyDataAugmenter(Augmenter):
|
||||
augmented_text += self.random_deletion.augment(text)
|
||||
augmented_text += self.random_swap.augment(text)
|
||||
augmented_text += self.random_insertion.augment(text)
|
||||
augmented_text = list(set(augmented_text))
|
||||
random.shuffle(augmented_text)
|
||||
return augmented_text[: self.transformations_per_example]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user