1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Flair ner to AttackedText

This commit is contained in:
Hanyu Liu
2020-09-09 19:37:55 -04:00
parent 9d0a668210
commit 7062355a3a
5 changed files with 61 additions and 62 deletions

View File

@@ -48,6 +48,7 @@ class AttackedText:
self._words = None
self._words_per_input = None
self._pos_tags = None
self._ner_tags = None
# Format text inputs.
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
if attack_attrs is None:
@@ -146,6 +147,31 @@ class AttackedText:
f"Did not find word from index {desired_word_idx} in flair POS tag"
)
def ner_of_word_index(self, desired_word_idx):
"""Returns the ner tag of the word at index `word_idx`.
Uses FLAIR ner tagger.
"""
if not self._ner_tags:
sentence = Sentence(self.text)
textattack.shared.utils.flair_tag(sentence, "ner")
self._ner_tags = sentence
flair_word_list, flair_ner_list = textattack.shared.utils.zip_flair_result(
self._ner_tags, "ner"
)
for word_idx, word in enumerate(flair_word_list):
word_idx_in_flair_tags = flair_word_list.index(word)
if word_idx == desired_word_idx:
return flair_ner_list[word_idx_in_flair_tags]
else:
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
flair_ner_list = flair_ner_list[word_idx_in_flair_tags + 1 :]
raise ValueError(
f"Did not find word from index {desired_word_idx} in flair POS tag"
)
def _text_index_of_word_index(self, i):
"""Returns the index of word ``i`` in self.text."""
pre_words = self.words[: i + 1]
@@ -231,6 +257,7 @@ class AttackedText:
raise TypeError(
f"convert_from_original_idxs got invalid idxs type {type(idxs)}"
)
# update length of original_index_map, when the length of AttackedText increased
while len(self.attack_attrs["original_index_map"]) < len(idxs):
missing_index = self.attack_attrs["original_index_map"][-1] + 1

View File

@@ -168,17 +168,17 @@ def color_text(text, color=None, method=None):
_flair_pos_tagger = None
def flair_tag(sentence):
def flair_tag(sentence, tag_type="pos-fast"):
"""Tags a `Sentence` object using `flair` part-of-speech tagger."""
global _flair_pos_tagger
if not _flair_pos_tagger:
from flair.models import SequenceTagger
_flair_pos_tagger = SequenceTagger.load("pos-fast")
_flair_pos_tagger = SequenceTagger.load(tag_type)
_flair_pos_tagger.predict(sentence)
def zip_flair_result(pred):
def zip_flair_result(pred, tag_type="pos-fast"):
"""Takes a sentence tagging from `flair` and returns two lists, of words
and their corresponding parts-of-speech."""
from flair.data import Sentence
@@ -191,6 +191,9 @@ def zip_flair_result(pred):
pos_list = []
for token in tokens:
word_list.append(token.text)
pos_list.append(token.annotation_layers["pos"][0]._value)
if tag_type == "pos-fast":
pos_list.append(token.annotation_layers["pos"][0]._value)
elif tag_type == "ner":
pos_list.append(token.get_tag("ner"))
return word_list, pos_list

View File

@@ -55,23 +55,13 @@ class WordSwapChangeLocation(Transformation):
def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
# TODO: move ner recognition to AttackedText
# really want to silent this line:
tagger = SequenceTagger.load("ner")
sentence = Sentence(current_text.text)
tagger.predict(sentence)
location_idx = []
# pre-screen for actual locations, using flair
# summarize location idx into a list (location_idx)
for token in sentence:
tag = token.get_tag("ner")
if (
"LOC" in tag.value
and tag.score > self.confidence_score
and (token.idx - 1) in indices_to_modify
):
location_idx.append(token.idx - 1)
for i in indices_to_modify:
word_to_replace = current_text.words[i]
tag = current_text.ner_of_word_index(i)
if "LOC" in tag.value and tag.score > self.confidence_score:
location_idx.append(i)
# Combine location idx and words to a list ([0] is idx, [1] is location name)
# For example, [1,2] to [ [1,2] , ["New York"] ]
@@ -98,6 +88,7 @@ class WordSwapChangeLocation(Transformation):
text = text.replace_word_at_index(idx[0], r)
transformed_texts.append(text)
print(transformed_texts)
return transformed_texts
def _get_new_location(self, word):

View File

@@ -3,10 +3,10 @@ from flair.models import SequenceTagger
import numpy as np
from textattack.shared.data import PERSON_NAMES
from textattack.transformations import Transformation
from textattack.transformations import WordSwap
class WordSwapChangeName(Transformation):
class WordSwapChangeName(WordSwap):
def __init__(
self, n=3, first_only=False, last_only=False, confidence_score=0.7, **kwargs
):
@@ -24,54 +24,31 @@ class WordSwapChangeName(Transformation):
self.confidence_score = confidence_score
def _get_transformations(self, current_text, indices_to_modify):
# really want to silent this line:
tagger = SequenceTagger.load("ner")
# TODO: move ner recognition to AttackedText
sentence = Sentence(current_text.text)
tagger.predict(sentence)
fir_name_idx = []
last_name_idx = []
# use flair to screen for actual names and eliminate false-positives
for token in sentence:
tag = token.get_tag("ner")
if tag.value == "B-PER" and tag.score > self.confidence_score:
fir_name_idx.append(token.idx - 1)
elif tag.value == "E-PER" and tag.score > self.confidence_score:
last_name_idx.append(token.idx - 1)
words = current_text.words
transformed_texts = []
for i in indices_to_modify:
word_to_replace = words[i]
word_to_replace = current_text.words[i]
word_to_replace_ner = current_text.ner_of_word_index(i)
replacement_words = self._get_replacement_words(
word_to_replace, word_to_replace_ner
)
for r in replacement_words:
transformed_texts.append(current_text.replace_word_at_index(i, r))
# search for first name replacement
if i in fir_name_idx and not self.last_only:
replacement_words = self._get_firstname(word_to_replace)
transformed_texts_idx = []
for r in replacement_words:
if r == word_to_replace:
continue
transformed_texts_idx.append(
current_text.replace_word_at_index(i, r)
)
transformed_texts.extend(transformed_texts_idx)
# print(transformed_texts)
# search for last name replacement
elif i in last_name_idx and not self.first_only:
replacement_words = self._get_lastname(word_to_replace)
transformed_texts_idx = []
for r in replacement_words:
if r == word_to_replace:
continue
transformed_texts_idx.append(
current_text.replace_word_at_index(i, r)
)
transformed_texts.extend(transformed_texts_idx)
return transformed_texts
def _get_replacement_words(self, word, word_part_of_speech):
replacement_words = []
tag = word_part_of_speech
if tag.value == "B-PER" and tag.score > self.confidence_score:
replacement_words = self._get_firstname(word)
elif tag.value == "E-PER" and tag.score > self.confidence_score:
replacement_words = self._get_lastname(word)
return replacement_words
def _get_lastname(self, word):
"""Return a list of random last names."""
return np.random.choice(PERSON_NAMES["last"], self.n)

View File

@@ -19,6 +19,7 @@ class WordSwapInflections(WordSwap):
self._flair_to_lemminflect_pos_map = {"NN": "NOUN", "VB": "VERB", "JJ": "ADJ"}
def _get_replacement_words(self, word, word_part_of_speech):
if word_part_of_speech not in self._flair_to_lemminflect_pos_map:
# Only nouns, verbs, and adjectives have proper inflections.
return []