mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix crossover constraint checking
This commit is contained in:
@@ -77,6 +77,9 @@ class Attack:
|
||||
self.search_method.get_transformations = self.get_transformations
|
||||
self.search_method.get_goal_results = self.goal_function.get_results
|
||||
|
||||
if isinstance(self.search_method, textattack.search_methods.GeneticAlgorithm):
|
||||
self.search_method.filter_transformations = self.filter_transformations
|
||||
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
"""
|
||||
Applies ``self.transformation`` to ``text``, then filters the list of possible transformations
|
||||
@@ -103,7 +106,7 @@ class Attack:
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
return self._filter_transformations(
|
||||
return self.filter_transformations(
|
||||
transformed_texts, current_text, original_text
|
||||
)
|
||||
|
||||
@@ -133,7 +136,7 @@ class Attack:
|
||||
self.constraints_cache[(current_text, filtered_text)] = True
|
||||
return filtered_texts
|
||||
|
||||
def _filter_transformations(
|
||||
def filter_transformations(
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user