1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update docs & tests

This commit is contained in:
Jack Morris
2020-06-20 18:14:14 -04:00
parent e4a4f54202
commit 98bf92e1ce
8 changed files with 132 additions and 10 deletions

View File

@@ -1,5 +1,5 @@
language: python
python: '3.8'
python: '3.8.2'
cache: pip
before_install:
- python --version

View File

@@ -20,3 +20,4 @@ terminaltables
tqdm
visdom
wandb
flair

View File

@@ -15,6 +15,7 @@ Attack(
(include_unknown_words): True
)
(1): PartOfSpeech(
(tagger_type): nltk
(tagset): universal
(allow_verb_noun_swap): True
)

View File

@@ -0,0 +1,66 @@
/.*/Attack(
(search_method): GreedyWordSwapWIR(
(wir_method): unk
)
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
)
(constraints):
(0): WordEmbeddingDistance(
(embedding_type): paragramcf
(min_cos_sim): 0.8
(cased): False
(include_unknown_words): True
)
(1): PartOfSpeech(
(tagger_type): flair
(tagset): universal
(allow_verb_noun_swap): True
)
(2): RepeatModification
(3): StopwordModification
(is_black_box): True
)
--------------------------------------------- Result 1 ---------------------------------------------
1-->[FAILED]
this is a film well worth seeing , talking and singing heads and all .
--------------------------------------------- Result 2 ---------------------------------------------
1-->0
what really surprises about wisegirls is its low-key quality and genuine tenderness .
what really dumbfounded about wisegirls is its low-vital quality and veritable sensibility .
--------------------------------------------- Result 3 ---------------------------------------------
1-->[FAILED]
( wendigo is ) why we go to the cinema : to be fed through the eye , the heart , the mind .
--------------------------------------------- Result 4 ---------------------------------------------
1-->[FAILED]
one of the greatest family-oriented , fantasy-adventure movies ever .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 1 |
| Number of failed attacks: | 3 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 75.0% |
| Attack success rate: | 25.0% |
| Average perturbed word %: | 30.77% |
| Average num. words per input: | 13.5 |
| Avg num queries: | 45.0 |
+-------------------------------+--------+

View File

@@ -17,6 +17,7 @@ Attack(
(include_unknown_words): True
)
(2): PartOfSpeech(
(tagger_type): nltk
(tagset): universal
(allow_verb_noun_swap): True
)

View File

@@ -97,7 +97,20 @@ attack_test_params = [
"tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt",
),
#
# fmt: off
# test: run_attack untargeted classification on BERT MR using word embedding transformation and greedy-word-WIR search
# using Flair's part-of-speech tagger as constraint.
#
(
"run_attack_flair_pos_tagger",
(
"textattack attack --model bert-base-uncased-mr --search greedy-word-wir --transformation word-swap-embedding "
"--constraints repeat stopword embedding:min_cos_sim=0.8 part-of-speech:tagger_type=\\'flair\\' "
"--num-examples 4 --num-examples-offset 10"
),
"tests/sample_outputs/run_attack_flair_pos_tagger.txt",
),
# fmt: on
#
]

View File

@@ -5,7 +5,7 @@ from helpers import run_command_and_get_result
list_test_params = [
(
"list_augmentation_recipes",
"textattack list augmentation_recipes",
"textattack list augmentation-recipes",
"tests/sample_outputs/list_augmentation_recipes.txt",
)
]

View File

@@ -1,3 +1,5 @@
from flair.data import Sentence
from flair.models import SequenceTagger
import lru
import nltk
@@ -10,13 +12,24 @@ class PartOfSpeech(Constraint):
""" Constraints word swaps to only swap words with the same part of speech.
Uses the NLTK universal part-of-speech tagger by default.
An implementation of `<https://arxiv.org/abs/1907.11932>`_
adapted from `<https://github.com/jind11/TextFooler>`_.
adapted from `<https://github.com/jind11/TextFooler>`_.
POS tagger from Flair `<https://github.com/flairNLP/flair>` also available
"""
def __init__(self, tagset="universal", allow_verb_noun_swap=True):
def __init__(
self, tagger_type="nltk", tagset="universal", allow_verb_noun_swap=True
):
self.tagger_type = tagger_type
self.tagset = tagset
self.allow_verb_noun_swap = allow_verb_noun_swap
self._pos_tag_cache = lru.LRU(2 ** 14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
else:
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
def _can_replace_pos(self, pos_a, pos_b):
return (pos_a == pos_b) or (
@@ -27,11 +40,24 @@ class PartOfSpeech(Constraint):
context_words = before_ctx + [word] + after_ctx
context_key = " ".join(context_words)
if context_key in self._pos_tag_cache:
pos_list = self._pos_tag_cache[context_key]
word_list, pos_list = self._pos_tag_cache[context_key]
else:
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
self._pos_tag_cache[context_key] = pos_list
return pos_list
if self.tagger_type == "nltk":
word_list, pos_list = zip(
*nltk.pos_tag(context_words, tagset=self.tagset)
)
if self.tagger_type == "flair":
word_list, pos_list = zip_flair_result(
self._flair_pos_tagger.predict(context_key)[0]
)
self._pos_tag_cache[context_key] = (word_list, pos_list)
# idx of `word` in `context_words`
idx = len(before_ctx)
assert word_list[idx] == word, "POS list not matched with original word list."
return pos_list[idx]
def _check_constraint(self, transformed_text, current_text, original_text=None):
try:
@@ -45,7 +71,7 @@ class PartOfSpeech(Constraint):
current_word = current_text.words[i]
transformed_word = transformed_text.words[i]
before_ctx = current_text.words[max(i - 4, 0) : i]
after_ctx = current_text.words[i + 1 : min(i + 5, len(current_text.words))]
after_ctx = current_text.words[i + 1 : min(i + 4, len(current_text.words))]
cur_pos = self._get_pos(before_ctx, current_word, after_ctx)
replace_pos = self._get_pos(before_ctx, transformed_word, after_ctx)
if not self._can_replace_pos(cur_pos, replace_pos):
@@ -57,4 +83,18 @@ class PartOfSpeech(Constraint):
return transformation_consists_of_word_swaps(transformation)
def extra_repr_keys(self):
return ["tagset", "allow_verb_noun_swap"]
return ["tagger_type", "tagset", "allow_verb_noun_swap"]
def zip_flair_result(pred):
if not isinstance(pred, Sentence):
raise TypeError(f"Result from Flair POS tagger must be a `Sentence` object.")
tokens = pred.tokens
word_list = []
pos_list = []
for token in tokens:
word_list.append(token.text)
pos_list.append(token.annotation_layers["pos"][0]._value)
return word_list, pos_list