mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update flair version & centralize POS tagging
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import flair
|
||||
from flair.data import Sentence
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
|
||||
from .utils import words_from_text
|
||||
from .utils import device, words_from_text
|
||||
|
||||
flair.device = device
|
||||
|
||||
|
||||
class AttackedText:
|
||||
@@ -40,9 +44,10 @@ class AttackedText:
|
||||
raise TypeError(
|
||||
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
|
||||
)
|
||||
# Find words in input lazily.
|
||||
# Process input lazily.
|
||||
self._words = None
|
||||
self._words_per_input = None
|
||||
self._pos_tags = None
|
||||
# Format text inputs.
|
||||
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
|
||||
if attack_attrs is None:
|
||||
@@ -113,6 +118,34 @@ class AttackedText:
|
||||
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
|
||||
return self.text[text_idx_start:text_idx_end]
|
||||
|
||||
def pos_of_word_index(self, desired_word_idx):
|
||||
"""Returns the part-of-speech of the word at index `word_idx`.
|
||||
|
||||
Uses FLAIR part-of-speech tagger.
|
||||
"""
|
||||
if not self._pos_tags:
|
||||
sentence = Sentence(self.text)
|
||||
textattack.shared.utils.flair_tag(sentence)
|
||||
self._pos_tags = sentence
|
||||
flair_word_list, flair_pos_list = textattack.shared.utils.zip_flair_result(
|
||||
self._pos_tags
|
||||
)
|
||||
|
||||
for word_idx, word in enumerate(self.words):
|
||||
assert (
|
||||
word in flair_word_list
|
||||
), "word absent in flair returned part-of-speech tags"
|
||||
word_idx_in_flair_tags = flair_word_list.index(word)
|
||||
if word_idx == desired_word_idx:
|
||||
return flair_pos_list[word_idx_in_flair_tags]
|
||||
else:
|
||||
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
|
||||
flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :]
|
||||
|
||||
raise ValueError(
|
||||
f"Did not find word from index {desired_word_idx} in flair POS tag"
|
||||
)
|
||||
|
||||
def _text_index_of_word_index(self, i):
|
||||
"""Returns the index of word ``i`` in self.text."""
|
||||
pre_words = self.words[: i + 1]
|
||||
|
||||
@@ -163,3 +163,34 @@ def color_text(text, color=None, method=None):
|
||||
return color + text + ANSI_ESCAPE_CODES.STOP
|
||||
elif method == "file":
|
||||
return "[[" + text + "]]"
|
||||
|
||||
|
||||
_flair_pos_tagger = None
|
||||
|
||||
|
||||
def flair_tag(sentence):
|
||||
"""Tags a `Sentence` object using `flair` part-of-speech tagger."""
|
||||
global _flair_pos_tagger
|
||||
if not _flair_pos_tagger:
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
_flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
_flair_pos_tagger.predict(sentence)
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
"""Takes a sentence tagging from `flair` and returns two lists, of words
|
||||
and their corresponding parts-of-speech."""
|
||||
from flair.data import Sentence
|
||||
|
||||
if not isinstance(pred, Sentence):
|
||||
raise TypeError("Result from Flair POS tagger must be a `Sentence` object.")
|
||||
|
||||
tokens = pred.tokens
|
||||
word_list = []
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
Reference in New Issue
Block a user