mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
"""
|
||
Goal Function for Attempts to minimize the BLEU score
|
||
-------------------------------------------------------
|
||
|
||
|
||
"""
|
||
|
||
import functools
|
||
|
||
import nltk
|
||
|
||
import textattack
|
||
|
||
from .text_to_text_goal_function import TextToTextGoalFunction
|
||
|
||
|
||
class MinimizeBleu(TextToTextGoalFunction):
|
||
"""Attempts to minimize the BLEU score between the current output
|
||
translation and the reference translation.
|
||
|
||
BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation).
|
||
|
||
`ArxivURL`_
|
||
|
||
.. _ArxivURL: https://www.aclweb.org/anthology/P02-1040.pdf
|
||
|
||
This goal function is defined in (It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations).
|
||
|
||
`ArxivURL2`_
|
||
|
||
.. _ArxivURL2: https://www.aclweb.org/anthology/2020.acl-main.263
|
||
"""
|
||
|
||
EPS = 1e-10
|
||
|
||
def __init__(self, *args, target_bleu=0.0, **kwargs):
|
||
self.target_bleu = target_bleu
|
||
super().__init__(*args, **kwargs)
|
||
|
||
def clear_cache(self):
|
||
if self.use_cache:
|
||
self._call_model_cache.clear()
|
||
get_bleu.cache_clear()
|
||
|
||
def _is_goal_complete(self, model_output, _):
|
||
bleu_score = 1.0 - self._get_score(model_output, _)
|
||
return bleu_score <= (self.target_bleu + MinimizeBleu.EPS)
|
||
|
||
def _get_score(self, model_output, _):
|
||
model_output_at = textattack.shared.AttackedText(model_output)
|
||
ground_truth_at = textattack.shared.AttackedText(self.ground_truth_output)
|
||
bleu_score = get_bleu(model_output_at, ground_truth_at)
|
||
return 1.0 - bleu_score
|
||
|
||
def extra_repr_keys(self):
|
||
if self.maximizable:
|
||
return ["maximizable"]
|
||
else:
|
||
return ["maximizable", "target_bleu"]
|
||
|
||
|
||
@functools.lru_cache(maxsize=2 ** 12)
|
||
def get_bleu(a, b):
|
||
ref = a.words
|
||
hyp = b.words
|
||
bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], hyp)
|
||
return bleu_score
|