mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
flair workaround
This commit is contained in:
@@ -12,6 +12,34 @@ from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
flair.device = textattack.shared.utils.device
|
||||
|
||||
|
||||
def load_flair_upos_fast():
|
||||
"""Loads flair 'upos-fast' SequenceTagger.
|
||||
|
||||
This is a temporary workaround for flair v0.6. Will be fixed when
|
||||
flair pushes the bug fix.
|
||||
"""
|
||||
import pathlib
|
||||
import warnings
|
||||
|
||||
from flair import file_utils
|
||||
import torch
|
||||
|
||||
hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
|
||||
upos_path = "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"])
|
||||
model_path = file_utils.cached_path(upos_path, cache_dir=pathlib.Path("models"))
|
||||
model_file = SequenceTagger._fetch_model(model_path)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
# load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups
|
||||
# see https://github.com/zalandoresearch/flair/issues/351
|
||||
f = file_utils.load_big_file(str(model_file))
|
||||
state = torch.load(f, map_location="cpu")
|
||||
model = SequenceTagger._init_model_with_state_dict(state)
|
||||
model.eval()
|
||||
model.to(textattack.shared.utils.device)
|
||||
return model
|
||||
|
||||
|
||||
class PartOfSpeech(Constraint):
|
||||
"""Constraints word swaps to only swap words with the same part of speech.
|
||||
Uses the NLTK universal part-of-speech tagger by default. An implementation
|
||||
@@ -43,7 +71,7 @@ class PartOfSpeech(Constraint):
|
||||
self._pos_tag_cache = lru.LRU(2 ** 14)
|
||||
if tagger_type == "flair":
|
||||
if tagset == "universal":
|
||||
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
|
||||
self._flair_pos_tagger = load_flair_upos_fast()
|
||||
else:
|
||||
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ To decide which two tokens to compare, it greedily chooses the most
|
||||
similar token from one text and matches it to a token in the second
|
||||
text.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import bert_score
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
@@ -55,7 +57,11 @@ class BERTScore(Constraint):
|
||||
`reference_text` is lower than minimum BERT Score."""
|
||||
cand = transformed_text.text
|
||||
ref = reference_text.text
|
||||
result = self._bert_scorer.score([cand], [ref])
|
||||
with warnings.catch_warnings():
|
||||
# Catch the many warnings that the huggingface API throws when using
|
||||
# BERT scorer.
|
||||
warnings.filterwarnings("ignore")
|
||||
result = self._bert_scorer.score([cand], [ref])
|
||||
score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item()
|
||||
if score >= self.min_bert_score:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user