mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Flair ner to AttackedText
This commit is contained in:
@@ -48,6 +48,7 @@ class AttackedText:
|
||||
self._words = None
|
||||
self._words_per_input = None
|
||||
self._pos_tags = None
|
||||
self._ner_tags = None
|
||||
# Format text inputs.
|
||||
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
|
||||
if attack_attrs is None:
|
||||
@@ -146,6 +147,31 @@ class AttackedText:
|
||||
f"Did not find word from index {desired_word_idx} in flair POS tag"
|
||||
)
|
||||
|
||||
def ner_of_word_index(self, desired_word_idx):
|
||||
"""Returns the ner tag of the word at index `word_idx`.
|
||||
|
||||
Uses FLAIR ner tagger.
|
||||
"""
|
||||
if not self._ner_tags:
|
||||
sentence = Sentence(self.text)
|
||||
textattack.shared.utils.flair_tag(sentence, "ner")
|
||||
self._ner_tags = sentence
|
||||
flair_word_list, flair_ner_list = textattack.shared.utils.zip_flair_result(
|
||||
self._ner_tags, "ner"
|
||||
)
|
||||
|
||||
for word_idx, word in enumerate(flair_word_list):
|
||||
word_idx_in_flair_tags = flair_word_list.index(word)
|
||||
if word_idx == desired_word_idx:
|
||||
return flair_ner_list[word_idx_in_flair_tags]
|
||||
else:
|
||||
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
|
||||
flair_ner_list = flair_ner_list[word_idx_in_flair_tags + 1 :]
|
||||
|
||||
raise ValueError(
|
||||
f"Did not find word from index {desired_word_idx} in flair POS tag"
|
||||
)
|
||||
|
||||
def _text_index_of_word_index(self, i):
|
||||
"""Returns the index of word ``i`` in self.text."""
|
||||
pre_words = self.words[: i + 1]
|
||||
@@ -231,6 +257,7 @@ class AttackedText:
|
||||
raise TypeError(
|
||||
f"convert_from_original_idxs got invalid idxs type {type(idxs)}"
|
||||
)
|
||||
|
||||
# update length of original_index_map, when the length of AttackedText increased
|
||||
while len(self.attack_attrs["original_index_map"]) < len(idxs):
|
||||
missing_index = self.attack_attrs["original_index_map"][-1] + 1
|
||||
|
||||
@@ -168,17 +168,17 @@ def color_text(text, color=None, method=None):
|
||||
_flair_pos_tagger = None
|
||||
|
||||
|
||||
def flair_tag(sentence):
|
||||
def flair_tag(sentence, tag_type="pos-fast"):
|
||||
"""Tags a `Sentence` object using `flair` part-of-speech tagger."""
|
||||
global _flair_pos_tagger
|
||||
if not _flair_pos_tagger:
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
_flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
_flair_pos_tagger = SequenceTagger.load(tag_type)
|
||||
_flair_pos_tagger.predict(sentence)
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
def zip_flair_result(pred, tag_type="pos-fast"):
|
||||
"""Takes a sentence tagging from `flair` and returns two lists, of words
|
||||
and their corresponding parts-of-speech."""
|
||||
from flair.data import Sentence
|
||||
@@ -191,6 +191,9 @@ def zip_flair_result(pred):
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
if tag_type == "pos-fast":
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
elif tag_type == "ner":
|
||||
pos_list.append(token.get_tag("ner"))
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
@@ -55,23 +55,13 @@ class WordSwapChangeLocation(Transformation):
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
words = current_text.words
|
||||
# TODO: move ner recognition to AttackedText
|
||||
# really want to silent this line:
|
||||
tagger = SequenceTagger.load("ner")
|
||||
sentence = Sentence(current_text.text)
|
||||
tagger.predict(sentence)
|
||||
location_idx = []
|
||||
|
||||
# pre-screen for actual locations, using flair
|
||||
# summarize location idx into a list (location_idx)
|
||||
for token in sentence:
|
||||
tag = token.get_tag("ner")
|
||||
if (
|
||||
"LOC" in tag.value
|
||||
and tag.score > self.confidence_score
|
||||
and (token.idx - 1) in indices_to_modify
|
||||
):
|
||||
location_idx.append(token.idx - 1)
|
||||
for i in indices_to_modify:
|
||||
word_to_replace = current_text.words[i]
|
||||
tag = current_text.ner_of_word_index(i)
|
||||
if "LOC" in tag.value and tag.score > self.confidence_score:
|
||||
location_idx.append(i)
|
||||
|
||||
# Combine location idx and words to a list ([0] is idx, [1] is location name)
|
||||
# For example, [1,2] to [ [1,2] , ["New York"] ]
|
||||
@@ -98,6 +88,7 @@ class WordSwapChangeLocation(Transformation):
|
||||
text = text.replace_word_at_index(idx[0], r)
|
||||
|
||||
transformed_texts.append(text)
|
||||
print(transformed_texts)
|
||||
return transformed_texts
|
||||
|
||||
def _get_new_location(self, word):
|
||||
|
||||
@@ -3,10 +3,10 @@ from flair.models import SequenceTagger
|
||||
import numpy as np
|
||||
|
||||
from textattack.shared.data import PERSON_NAMES
|
||||
from textattack.transformations import Transformation
|
||||
from textattack.transformations import WordSwap
|
||||
|
||||
|
||||
class WordSwapChangeName(Transformation):
|
||||
class WordSwapChangeName(WordSwap):
|
||||
def __init__(
|
||||
self, n=3, first_only=False, last_only=False, confidence_score=0.7, **kwargs
|
||||
):
|
||||
@@ -24,54 +24,31 @@ class WordSwapChangeName(Transformation):
|
||||
self.confidence_score = confidence_score
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
# really want to silent this line:
|
||||
tagger = SequenceTagger.load("ner")
|
||||
|
||||
# TODO: move ner recognition to AttackedText
|
||||
sentence = Sentence(current_text.text)
|
||||
tagger.predict(sentence)
|
||||
fir_name_idx = []
|
||||
last_name_idx = []
|
||||
|
||||
# use flair to screen for actual names and eliminate false-positives
|
||||
for token in sentence:
|
||||
tag = token.get_tag("ner")
|
||||
if tag.value == "B-PER" and tag.score > self.confidence_score:
|
||||
fir_name_idx.append(token.idx - 1)
|
||||
elif tag.value == "E-PER" and tag.score > self.confidence_score:
|
||||
last_name_idx.append(token.idx - 1)
|
||||
|
||||
words = current_text.words
|
||||
transformed_texts = []
|
||||
|
||||
for i in indices_to_modify:
|
||||
word_to_replace = words[i]
|
||||
word_to_replace = current_text.words[i]
|
||||
word_to_replace_ner = current_text.ner_of_word_index(i)
|
||||
replacement_words = self._get_replacement_words(
|
||||
word_to_replace, word_to_replace_ner
|
||||
)
|
||||
for r in replacement_words:
|
||||
transformed_texts.append(current_text.replace_word_at_index(i, r))
|
||||
|
||||
# search for first name replacement
|
||||
if i in fir_name_idx and not self.last_only:
|
||||
replacement_words = self._get_firstname(word_to_replace)
|
||||
transformed_texts_idx = []
|
||||
for r in replacement_words:
|
||||
if r == word_to_replace:
|
||||
continue
|
||||
transformed_texts_idx.append(
|
||||
current_text.replace_word_at_index(i, r)
|
||||
)
|
||||
transformed_texts.extend(transformed_texts_idx)
|
||||
# print(transformed_texts)
|
||||
|
||||
# search for last name replacement
|
||||
elif i in last_name_idx and not self.first_only:
|
||||
replacement_words = self._get_lastname(word_to_replace)
|
||||
transformed_texts_idx = []
|
||||
for r in replacement_words:
|
||||
if r == word_to_replace:
|
||||
continue
|
||||
transformed_texts_idx.append(
|
||||
current_text.replace_word_at_index(i, r)
|
||||
)
|
||||
transformed_texts.extend(transformed_texts_idx)
|
||||
return transformed_texts
|
||||
|
||||
def _get_replacement_words(self, word, word_part_of_speech):
|
||||
|
||||
replacement_words = []
|
||||
tag = word_part_of_speech
|
||||
if tag.value == "B-PER" and tag.score > self.confidence_score:
|
||||
replacement_words = self._get_firstname(word)
|
||||
elif tag.value == "E-PER" and tag.score > self.confidence_score:
|
||||
replacement_words = self._get_lastname(word)
|
||||
return replacement_words
|
||||
|
||||
def _get_lastname(self, word):
|
||||
"""Return a list of random last names."""
|
||||
return np.random.choice(PERSON_NAMES["last"], self.n)
|
||||
|
||||
@@ -19,6 +19,7 @@ class WordSwapInflections(WordSwap):
|
||||
self._flair_to_lemminflect_pos_map = {"NN": "NOUN", "VB": "VERB", "JJ": "ADJ"}
|
||||
|
||||
def _get_replacement_words(self, word, word_part_of_speech):
|
||||
|
||||
if word_part_of_speech not in self._flair_to_lemminflect_pos_map:
|
||||
# Only nouns, verbs, and adjectives have proper inflections.
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user