1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/goal_functions/text/minimize_bleu.py
2020-11-01 14:02:00 -05:00

68 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Goal Function for Attempts to minimize the BLEU score
-------------------------------------------------------
"""
import functools
import nltk
import textattack
from .text_to_text_goal_function import TextToTextGoalFunction
class MinimizeBleu(TextToTextGoalFunction):
"""Attempts to minimize the BLEU score between the current output
translation and the reference translation.
BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation).
`ArxivURL`_
.. _ArxivURL: https://www.aclweb.org/anthology/P02-1040.pdf
This goal function is defined in (Its Morphin Time! Combating Linguistic Discrimination with Inflectional Perturbations).
`ArxivURL2`_
.. _ArxivURL2: https://www.aclweb.org/anthology/2020.acl-main.263
"""
EPS = 1e-10
def __init__(self, *args, target_bleu=0.0, **kwargs):
self.target_bleu = target_bleu
super().__init__(*args, **kwargs)
def clear_cache(self):
if self.use_cache:
self._call_model_cache.clear()
get_bleu.cache_clear()
def _is_goal_complete(self, model_output, _):
bleu_score = 1.0 - self._get_score(model_output, _)
return bleu_score <= (self.target_bleu + MinimizeBleu.EPS)
def _get_score(self, model_output, _):
model_output_at = textattack.shared.AttackedText(model_output)
ground_truth_at = textattack.shared.AttackedText(self.ground_truth_output)
bleu_score = get_bleu(model_output_at, ground_truth_at)
return 1.0 - bleu_score
def extra_repr_keys(self):
if self.maximizable:
return ["maximizable"]
else:
return ["maximizable", "target_bleu"]
@functools.lru_cache(maxsize=2 ** 12)
def get_bleu(a, b):
ref = a.words
hyp = b.words
bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], hyp)
return bleu_score