1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/constraints/semantics/bert_score.py
2020-07-06 12:51:16 -04:00

67 lines
2.6 KiB
Python

"""
BERT-Score is introduced in this paper "BERTScore: Evaluating Text Generation with BERT" (Zhang et al, 2019) https://arxiv.org/abs/1904.09675
BERT-Score measures token similarity between two text using contextual embedding.
To decide which two tokens to compare, it greedily chooses the most similar token from one text and matches it to a token in the second text.
"""
import bert_score
import nltk
from textattack.constraints import Constraint
from textattack.shared import utils
class BERTScore(Constraint):
"""
A constraint on BERT-Score difference.
Args:
min_bert_score (float): minimum threshold value for BERT-Score
model (str): name of model to use for scoring
score_type (str): Pick one of following three choices (1) "precision", (2) "recall", (3) "f1"
- "precision": match words from candidate text to reference text
- "recall": match words from reference text to candidate text
- "f1": harmonic mean of precision and recall (recommended)
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`.
Otherwise, compare it against the previous `x_adv`.
"""
SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2}
def __init__(
self,
min_bert_score,
model="bert-base-uncased",
score_type="f1",
compare_against_original=True,
):
super().__init__(compare_against_original)
if not isinstance(min_bert_score, float):
raise TypeError("max_bleu_score must be a float")
if min_bert_score < 0.0 or min_bert_score > 1.0:
raise ValueError("max_bert_score must be a value between 0.0 and 1.0")
self.min_bert_score = min_bert_score
self.model = model
self.score_type = score_type
# Turn off idf-weighting scheme b/c reference sentence set is small
self._bert_scorer = bert_score.BERTScorer(
model_type=model, idf=False, device=utils.device
)
def _check_constraint(self, transformed_text, reference_text):
"""
Return `True` if BERT Score between `transformed_text` and `reference_text`
is lower than minimum BERT Score.
"""
cand = transformed_text.text
ref = reference_text.text
result = self._bert_scorer.score([cand], [ref])
score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item()
if score >= self.min_bert_score:
return True
else:
return False
def extra_repr_keys(self):
return ["min_bert_score", "model", "score_type"] + super().extra_repr_keys()