1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update flair version & centralize POS tagging

This commit is contained in:
Jack Morris
2020-08-31 20:54:03 -04:00
parent c9931daf43
commit e47e34ebe1
6 changed files with 88 additions and 89 deletions

View File

@@ -1,12 +1,16 @@
from collections import OrderedDict
import math
import flair
from flair.data import Sentence
import numpy as np
import torch
import textattack
from .utils import words_from_text
from .utils import device, words_from_text
flair.device = device
class AttackedText:
@@ -40,9 +44,10 @@ class AttackedText:
raise TypeError(
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
)
# Find words in input lazily.
# Process input lazily.
self._words = None
self._words_per_input = None
self._pos_tags = None
# Format text inputs.
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
if attack_attrs is None:
@@ -113,6 +118,34 @@ class AttackedText:
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
return self.text[text_idx_start:text_idx_end]
def pos_of_word_index(self, desired_word_idx):
"""Returns the part-of-speech of the word at index `word_idx`.
Uses FLAIR part-of-speech tagger.
"""
if not self._pos_tags:
sentence = Sentence(self.text)
textattack.shared.utils.flair_tag(sentence)
self._pos_tags = sentence
flair_word_list, flair_pos_list = textattack.shared.utils.zip_flair_result(
self._pos_tags
)
for word_idx, word in enumerate(self.words):
assert (
word in flair_word_list
), "word absent in flair returned part-of-speech tags"
word_idx_in_flair_tags = flair_word_list.index(word)
if word_idx == desired_word_idx:
return flair_pos_list[word_idx_in_flair_tags]
else:
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :]
raise ValueError(
f"Did not find word from index {desired_word_idx} in flair POS tag"
)
def _text_index_of_word_index(self, i):
"""Returns the index of word ``i`` in self.text."""
pre_words = self.words[: i + 1]

View File

@@ -163,3 +163,34 @@ def color_text(text, color=None, method=None):
return color + text + ANSI_ESCAPE_CODES.STOP
elif method == "file":
return "[[" + text + "]]"
_flair_pos_tagger = None
def flair_tag(sentence):
"""Tags a `Sentence` object using `flair` part-of-speech tagger."""
global _flair_pos_tagger
if not _flair_pos_tagger:
from flair.models import SequenceTagger
_flair_pos_tagger = SequenceTagger.load("pos-fast")
_flair_pos_tagger.predict(sentence)
def zip_flair_result(pred):
"""Takes a sentence tagging from `flair` and returns two lists, of words
and their corresponding parts-of-speech."""
from flair.data import Sentence
if not isinstance(pred, Sentence):
raise TypeError("Result from Flair POS tagger must be a `Sentence` object.")
tokens = pred.tokens
word_list = []
pos_list = []
for token in tokens:
word_list.append(token.text)
pos_list.append(token.annotation_layers["pos"][0]._value)
return word_list, pos_list