1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add cache clearing

This commit is contained in:
Jin Yong Yoo
2020-07-22 04:42:44 -04:00
parent 28952235dc
commit cdd9061cb4
7 changed files with 33 additions and 0 deletions

View File

@@ -86,6 +86,14 @@ class Attack:
)
self.search_method.filter_transformations = self.filter_transformations
def clear_cache(self, recursive=True):
self.constraints_cache.clear()
if recursive:
self.goal_function.clear_cache()
for constraint in self.constraints:
if hasattr(constraint, "clear_cache"):
constraint.clear_cache()
def get_transformations(self, current_text, original_text=None, **kwargs):
"""Applies ``self.transformation`` to ``text``, then filters the list
of possible transformations through the applicable constraints.
@@ -191,6 +199,7 @@ class Attack:
or ``MaximizedAttackResult``.
"""
final_result = self.search_method(initial_result)
self.clear_cache()
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult(initial_result, final_result,)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING: