1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Add bert-score

This commit is contained in:
Jin Yong Yoo
2020-06-22 14:43:33 -04:00
parent af21d42a47
commit 76d1e5e620
4 changed files with 47 additions and 0 deletions

View File

@@ -21,3 +21,4 @@ tqdm
visdom
wandb
flair
bert-score

View File

@@ -222,6 +222,7 @@ CONSTRAINT_CLASS_NAMES = {
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
"bert-score": "textattack.constraints.semantics.BERTScore",
#
# Grammaticality constraints
#

View File

@@ -1,3 +1,4 @@
from . import sentence_encoders
from .word_embedding_distance import WordEmbeddingDistance
from .bert_score import BERTScore

View File

@@ -0,0 +1,44 @@
import bert_score
import nltk
from textattack.constraints import Constraint
from textattack.shared import utils
class BERTScore(Constraint):
"""
A constraint on BERTScore difference. BERTScore is introduced in this paper
"BERTScore: Evaluating Text Generation with BERT" (Zhang et al, 2019) https://arxiv.org/abs/1904.09675
Args:
min_bert_score (float): minimum threshold value for BERTScore
model (str): name of model to use for scoring
score_type (str): Pick one of three choices: (1) "precision", (2) "recall", (3) "f1"
"""
def __init__(self, min_bert_score, model="bert-base-uncased", score_type="f1"):
if not isinstance(min_bert_score, float):
raise TypeError("max_bleu_score must be a float")
if min_bert_score < 0.0 or min_bert_score > 1.0:
raise ValueError("max_bert_score must be a value between 0.0 and 1.0")
self.min_bert_score = min_bert_score
self.model = model
self.score_type = score_type
# Turn off idf-weighting scheme b/c reference sentence set is small
self._bert_scorer = bert_score.BERTScorer(
model_type=model, idf=False, device=utils.device
)
self._score_type2idx = {"precision": 0, "recall": 1, "f1": 2}
def _check_constraint(self, transformed_text, current_text, original_text=None):
cand = transformed_text.text
ref = original_text.text if original_text else current_text.text
result = self._bert_scorer.score([cand], [ref])
score = result[self._score_type2idx[self.score_type]].item()
if score >= self.min_bert_score:
return True
else:
return False
def extra_repr_keys(self):
return ["min_bert_score"]