mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Flair ner to AttackedText
This commit is contained in:
@@ -168,17 +168,17 @@ def color_text(text, color=None, method=None):
|
||||
_flair_pos_tagger = None
|
||||
|
||||
|
||||
def flair_tag(sentence):
|
||||
def flair_tag(sentence, tag_type="pos-fast"):
|
||||
"""Tags a `Sentence` object using `flair` part-of-speech tagger."""
|
||||
global _flair_pos_tagger
|
||||
if not _flair_pos_tagger:
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
_flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
_flair_pos_tagger = SequenceTagger.load(tag_type)
|
||||
_flair_pos_tagger.predict(sentence)
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
def zip_flair_result(pred, tag_type="pos-fast"):
|
||||
"""Takes a sentence tagging from `flair` and returns two lists, of words
|
||||
and their corresponding parts-of-speech."""
|
||||
from flair.data import Sentence
|
||||
@@ -191,6 +191,9 @@ def zip_flair_result(pred):
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
if tag_type == "pos-fast":
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
elif tag_type == "ner":
|
||||
pos_list.append(token.get_tag("ner"))
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
Reference in New Issue
Block a user