1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

flair workaround

This commit is contained in:
Jack Morris
2020-09-01 12:33:17 -04:00
parent 02c9d76cad
commit 4d2585e608
2 changed files with 36 additions and 2 deletions

View File

@@ -12,6 +12,34 @@ from textattack.shared.validators import transformation_consists_of_word_swaps
flair.device = textattack.shared.utils.device
def load_flair_upos_fast():
"""Loads flair 'upos-fast' SequenceTagger.
This is a temporary workaround for flair v0.6. Will be fixed when
flair pushes the bug fix.
"""
import pathlib
import warnings
from flair import file_utils
import torch
hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
upos_path = "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"])
model_path = file_utils.cached_path(upos_path, cache_dir=pathlib.Path("models"))
model_file = SequenceTagger._fetch_model(model_path)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups
# see https://github.com/zalandoresearch/flair/issues/351
f = file_utils.load_big_file(str(model_file))
state = torch.load(f, map_location="cpu")
model = SequenceTagger._init_model_with_state_dict(state)
model.eval()
model.to(textattack.shared.utils.device)
return model
class PartOfSpeech(Constraint):
"""Constraints word swaps to only swap words with the same part of speech.
Uses the NLTK universal part-of-speech tagger by default. An implementation
@@ -43,7 +71,7 @@ class PartOfSpeech(Constraint):
self._pos_tag_cache = lru.LRU(2 ** 14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
self._flair_pos_tagger = load_flair_upos_fast()
else:
self._flair_pos_tagger = SequenceTagger.load("pos-fast")

View File

@@ -7,6 +7,8 @@ To decide which two tokens to compare, it greedily chooses the most
similar token from one text and matches it to a token in the second
text.
"""
import warnings
import bert_score
from textattack.constraints import Constraint
@@ -55,7 +57,11 @@ class BERTScore(Constraint):
`reference_text` is lower than minimum BERT Score."""
cand = transformed_text.text
ref = reference_text.text
result = self._bert_scorer.score([cand], [ref])
with warnings.catch_warnings():
# Catch the many warnings that the huggingface API throws when using
# BERT scorer.
warnings.filterwarnings("ignore")
result = self._bert_scorer.score([cand], [ref])
score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item()
if score >= self.min_bert_score:
return True