1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Flair ner to AttackedText

This commit is contained in:
Hanyu Liu
2020-09-09 19:37:55 -04:00
parent 9d0a668210
commit 7062355a3a
5 changed files with 61 additions and 62 deletions

View File

@@ -168,17 +168,17 @@ def color_text(text, color=None, method=None):
_flair_pos_tagger = None
def flair_tag(sentence):
def flair_tag(sentence, tag_type="pos-fast"):
"""Tags a `Sentence` object using `flair` part-of-speech tagger."""
global _flair_pos_tagger
if not _flair_pos_tagger:
from flair.models import SequenceTagger
_flair_pos_tagger = SequenceTagger.load("pos-fast")
_flair_pos_tagger = SequenceTagger.load(tag_type)
_flair_pos_tagger.predict(sentence)
def zip_flair_result(pred):
def zip_flair_result(pred, tag_type="pos-fast"):
"""Takes a sentence tagging from `flair` and returns two lists, of words
and their corresponding parts-of-speech."""
from flair.data import Sentence
@@ -191,6 +191,9 @@ def zip_flair_result(pred):
pos_list = []
for token in tokens:
word_list.append(token.text)
pos_list.append(token.annotation_layers["pos"][0]._value)
if tag_type == "pos-fast":
pos_list.append(token.annotation_layers["pos"][0]._value)
elif tag_type == "ner":
pos_list.append(token.get_tag("ner"))
return word_list, pos_list