mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update docs & tests
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
language: python
|
||||
python: '3.8'
|
||||
python: '3.8.2'
|
||||
cache: pip
|
||||
before_install:
|
||||
- python --version
|
||||
|
||||
@@ -20,3 +20,4 @@ terminaltables
|
||||
tqdm
|
||||
visdom
|
||||
wandb
|
||||
flair
|
||||
@@ -15,6 +15,7 @@ Attack(
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(1): PartOfSpeech(
|
||||
(tagger_type): nltk
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
)
|
||||
|
||||
66
tests/sample_outputs/run_attack_flair_pos_tagger.txt
Normal file
66
tests/sample_outputs/run_attack_flair_pos_tagger.txt
Normal file
@@ -0,0 +1,66 @@
|
||||
/.*/Attack(
|
||||
(search_method): GreedyWordSwapWIR(
|
||||
(wir_method): unk
|
||||
)
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 15
|
||||
(embedding_type): paragramcf
|
||||
)
|
||||
(constraints):
|
||||
(0): WordEmbeddingDistance(
|
||||
(embedding_type): paragramcf
|
||||
(min_cos_sim): 0.8
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(1): PartOfSpeech(
|
||||
(tagger_type): flair
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
)
|
||||
(2): RepeatModification
|
||||
(3): StopwordModification
|
||||
(is_black_box): True
|
||||
)
|
||||
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92m1[0m-->[91m[FAILED][0m
|
||||
|
||||
this is a film well worth seeing , talking and singing heads and all .
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[92m1[0m-->[91m0[0m
|
||||
|
||||
what really [92msurprises[0m about wisegirls is its low-[92mkey[0m quality and [92mgenuine[0m [92mtenderness[0m .
|
||||
|
||||
what really [91mdumbfounded[0m about wisegirls is its low-[91mvital[0m quality and [91mveritable[0m [91msensibility[0m .
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[92m1[0m-->[91m[FAILED][0m
|
||||
|
||||
( wendigo is ) why we go to the cinema : to be fed through the eye , the heart , the mind .
|
||||
|
||||
|
||||
--------------------------------------------- Result 4 ---------------------------------------------
|
||||
[92m1[0m-->[91m[FAILED][0m
|
||||
|
||||
one of the greatest family-oriented , fantasy-adventure movies ever .
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 1 |
|
||||
| Number of failed attacks: | 3 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 75.0% |
|
||||
| Attack success rate: | 25.0% |
|
||||
| Average perturbed word %: | 30.77% |
|
||||
| Average num. words per input: | 13.5 |
|
||||
| Avg num queries: | 45.0 |
|
||||
+-------------------------------+--------+
|
||||
@@ -17,6 +17,7 @@ Attack(
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(2): PartOfSpeech(
|
||||
(tagger_type): nltk
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
)
|
||||
|
||||
@@ -97,7 +97,20 @@ attack_test_params = [
|
||||
"tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt",
|
||||
),
|
||||
#
|
||||
# fmt: off
|
||||
# test: run_attack untargeted classification on BERT MR using word embedding transformation and greedy-word-WIR search
|
||||
# using Flair's part-of-speech tagger as constraint.
|
||||
#
|
||||
(
|
||||
"run_attack_flair_pos_tagger",
|
||||
(
|
||||
"textattack attack --model bert-base-uncased-mr --search greedy-word-wir --transformation word-swap-embedding "
|
||||
"--constraints repeat stopword embedding:min_cos_sim=0.8 part-of-speech:tagger_type=\\'flair\\' "
|
||||
"--num-examples 4 --num-examples-offset 10"
|
||||
),
|
||||
"tests/sample_outputs/run_attack_flair_pos_tagger.txt",
|
||||
),
|
||||
# fmt: on
|
||||
#
|
||||
]
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from helpers import run_command_and_get_result
|
||||
list_test_params = [
|
||||
(
|
||||
"list_augmentation_recipes",
|
||||
"textattack list augmentation_recipes",
|
||||
"textattack list augmentation-recipes",
|
||||
"tests/sample_outputs/list_augmentation_recipes.txt",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
import lru
|
||||
import nltk
|
||||
|
||||
@@ -10,13 +12,24 @@ class PartOfSpeech(Constraint):
|
||||
""" Constraints word swaps to only swap words with the same part of speech.
|
||||
Uses the NLTK universal part-of-speech tagger by default.
|
||||
An implementation of `<https://arxiv.org/abs/1907.11932>`_
|
||||
adapted from `<https://github.com/jind11/TextFooler>`_.
|
||||
adapted from `<https://github.com/jind11/TextFooler>`_.
|
||||
|
||||
POS tagger from Flair `<https://github.com/flairNLP/flair>` also available
|
||||
"""
|
||||
|
||||
def __init__(self, tagset="universal", allow_verb_noun_swap=True):
|
||||
def __init__(
|
||||
self, tagger_type="nltk", tagset="universal", allow_verb_noun_swap=True
|
||||
):
|
||||
self.tagger_type = tagger_type
|
||||
self.tagset = tagset
|
||||
self.allow_verb_noun_swap = allow_verb_noun_swap
|
||||
|
||||
self._pos_tag_cache = lru.LRU(2 ** 14)
|
||||
if tagger_type == "flair":
|
||||
if tagset == "universal":
|
||||
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
|
||||
else:
|
||||
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
|
||||
def _can_replace_pos(self, pos_a, pos_b):
|
||||
return (pos_a == pos_b) or (
|
||||
@@ -27,11 +40,24 @@ class PartOfSpeech(Constraint):
|
||||
context_words = before_ctx + [word] + after_ctx
|
||||
context_key = " ".join(context_words)
|
||||
if context_key in self._pos_tag_cache:
|
||||
pos_list = self._pos_tag_cache[context_key]
|
||||
word_list, pos_list = self._pos_tag_cache[context_key]
|
||||
else:
|
||||
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
|
||||
self._pos_tag_cache[context_key] = pos_list
|
||||
return pos_list
|
||||
if self.tagger_type == "nltk":
|
||||
word_list, pos_list = zip(
|
||||
*nltk.pos_tag(context_words, tagset=self.tagset)
|
||||
)
|
||||
|
||||
if self.tagger_type == "flair":
|
||||
word_list, pos_list = zip_flair_result(
|
||||
self._flair_pos_tagger.predict(context_key)[0]
|
||||
)
|
||||
|
||||
self._pos_tag_cache[context_key] = (word_list, pos_list)
|
||||
|
||||
# idx of `word` in `context_words`
|
||||
idx = len(before_ctx)
|
||||
assert word_list[idx] == word, "POS list not matched with original word list."
|
||||
return pos_list[idx]
|
||||
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
try:
|
||||
@@ -45,7 +71,7 @@ class PartOfSpeech(Constraint):
|
||||
current_word = current_text.words[i]
|
||||
transformed_word = transformed_text.words[i]
|
||||
before_ctx = current_text.words[max(i - 4, 0) : i]
|
||||
after_ctx = current_text.words[i + 1 : min(i + 5, len(current_text.words))]
|
||||
after_ctx = current_text.words[i + 1 : min(i + 4, len(current_text.words))]
|
||||
cur_pos = self._get_pos(before_ctx, current_word, after_ctx)
|
||||
replace_pos = self._get_pos(before_ctx, transformed_word, after_ctx)
|
||||
if not self._can_replace_pos(cur_pos, replace_pos):
|
||||
@@ -57,4 +83,18 @@ class PartOfSpeech(Constraint):
|
||||
return transformation_consists_of_word_swaps(transformation)
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["tagset", "allow_verb_noun_swap"]
|
||||
return ["tagger_type", "tagset", "allow_verb_noun_swap"]
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
if not isinstance(pred, Sentence):
|
||||
raise TypeError(f"Result from Flair POS tagger must be a `Sentence` object.")
|
||||
|
||||
tokens = pred.tokens
|
||||
word_list = []
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
Reference in New Issue
Block a user