1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix goal completion check

This commit is contained in:
Jin Yong Yoo
2020-05-21 10:53:43 -04:00
parent 8934bb60aa
commit 8ba4f5bb54

View File

@@ -12,10 +12,14 @@ class NonOverlappingOutput(TextToTextGoalFunction):
"""
def _is_goal_complete(self, model_output, ground_truth_output):
return word_difference_score(model_output, ground_truth_output) == len(get_words_cached(ground_truth_output))
return self._get_score(model_output, ground_truth_output) == 1.0
def _get_score(self, model_output, ground_truth_output):
return word_difference_score(model_output, ground_truth_output)
num_words_diff = word_difference_score(model_output, ground_truth_output)
if num_words_diff == 0:
return 0.0
else:
return num_words_diff / len(get_words_cached(ground_truth_output))
@functools.lru_cache(maxsize=2**12)
def get_words_cached(s):
@@ -23,7 +27,7 @@ def get_words_cached(s):
@functools.lru_cache(maxsize=2**12)
def word_difference_score(s1, s2):
""" Returns the number of words that overlap between s1 and s2. """
""" Returns the number of words that are non-overlapping between s1 and s2. """
s1_words = get_words_cached(s1)
s2_words = get_words_cached(s2)
min_length = min(len(s1_words), len(s2_words))
@@ -31,4 +35,4 @@ def word_difference_score(s1, s2):
return 0
s1_words = s1_words[:min_length]
s2_words = s2_words[:min_length]
return (s1_words != s2_words).sum() / min_length
return (s1_words != s2_words).sum()