mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Add bert-score
This commit is contained in:
@@ -21,3 +21,4 @@ tqdm
|
||||
visdom
|
||||
wandb
|
||||
flair
|
||||
bert-score
|
||||
|
||||
@@ -222,6 +222,7 @@ CONSTRAINT_CLASS_NAMES = {
|
||||
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
|
||||
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
|
||||
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
|
||||
"bert-score": "textattack.constraints.semantics.BERTScore",
|
||||
#
|
||||
# Grammaticality constraints
|
||||
#
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from . import sentence_encoders
|
||||
|
||||
from .word_embedding_distance import WordEmbeddingDistance
|
||||
from .bert_score import BERTScore
|
||||
|
||||
44
textattack/constraints/semantics/bert_score.py
Normal file
44
textattack/constraints/semantics/bert_score.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import bert_score
|
||||
import nltk
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared import utils
|
||||
|
||||
|
||||
class BERTScore(Constraint):
|
||||
"""
|
||||
A constraint on BERTScore difference. BERTScore is introduced in this paper
|
||||
"BERTScore: Evaluating Text Generation with BERT" (Zhang et al, 2019) https://arxiv.org/abs/1904.09675
|
||||
Args:
|
||||
min_bert_score (float): minimum threshold value for BERTScore
|
||||
model (str): name of model to use for scoring
|
||||
score_type (str): Pick one of three choices: (1) "precision", (2) "recall", (3) "f1"
|
||||
"""
|
||||
|
||||
def __init__(self, min_bert_score, model="bert-base-uncased", score_type="f1"):
|
||||
if not isinstance(min_bert_score, float):
|
||||
raise TypeError("max_bleu_score must be a float")
|
||||
if min_bert_score < 0.0 or min_bert_score > 1.0:
|
||||
raise ValueError("max_bert_score must be a value between 0.0 and 1.0")
|
||||
|
||||
self.min_bert_score = min_bert_score
|
||||
self.model = model
|
||||
self.score_type = score_type
|
||||
# Turn off idf-weighting scheme b/c reference sentence set is small
|
||||
self._bert_scorer = bert_score.BERTScorer(
|
||||
model_type=model, idf=False, device=utils.device
|
||||
)
|
||||
self._score_type2idx = {"precision": 0, "recall": 1, "f1": 2}
|
||||
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
cand = transformed_text.text
|
||||
ref = original_text.text if original_text else current_text.text
|
||||
result = self._bert_scorer.score([cand], [ref])
|
||||
score = result[self._score_type2idx[self.score_type]].item()
|
||||
if score >= self.min_bert_score:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["min_bert_score"]
|
||||
Reference in New Issue
Block a user