mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix goal completion check
This commit is contained in:
@@ -12,10 +12,14 @@ class NonOverlappingOutput(TextToTextGoalFunction):
|
||||
"""
|
||||
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
return word_difference_score(model_output, ground_truth_output) == len(get_words_cached(ground_truth_output))
|
||||
return self._get_score(model_output, ground_truth_output) == 1.0
|
||||
|
||||
def _get_score(self, model_output, ground_truth_output):
|
||||
return word_difference_score(model_output, ground_truth_output)
|
||||
num_words_diff = word_difference_score(model_output, ground_truth_output)
|
||||
if num_words_diff == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return num_words_diff / len(get_words_cached(ground_truth_output))
|
||||
|
||||
@functools.lru_cache(maxsize=2**12)
|
||||
def get_words_cached(s):
|
||||
@@ -23,7 +27,7 @@ def get_words_cached(s):
|
||||
|
||||
@functools.lru_cache(maxsize=2**12)
|
||||
def word_difference_score(s1, s2):
|
||||
""" Returns the number of words that overlap between s1 and s2. """
|
||||
""" Returns the number of words that are non-overlapping between s1 and s2. """
|
||||
s1_words = get_words_cached(s1)
|
||||
s2_words = get_words_cached(s2)
|
||||
min_length = min(len(s1_words), len(s2_words))
|
||||
@@ -31,4 +35,4 @@ def word_difference_score(s1, s2):
|
||||
return 0
|
||||
s1_words = s1_words[:min_length]
|
||||
s2_words = s2_words[:min_length]
|
||||
return (s1_words != s2_words).sum() / min_length
|
||||
return (s1_words != s2_words).sum()
|
||||
|
||||
Reference in New Issue
Block a user