mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add cache clearing
This commit is contained in:
@@ -86,6 +86,14 @@ class Attack:
|
||||
)
|
||||
self.search_method.filter_transformations = self.filter_transformations
|
||||
|
||||
def clear_cache(self, recursive=True):
|
||||
self.constraints_cache.clear()
|
||||
if recursive:
|
||||
self.goal_function.clear_cache()
|
||||
for constraint in self.constraints:
|
||||
if hasattr(constraint, "clear_cache"):
|
||||
constraint.clear_cache()
|
||||
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
"""Applies ``self.transformation`` to ``text``, then filters the list
|
||||
of possible transformations through the applicable constraints.
|
||||
@@ -191,6 +199,7 @@ class Attack:
|
||||
or ``MaximizedAttackResult``.
|
||||
"""
|
||||
final_result = self.search_method(initial_result)
|
||||
self.clear_cache()
|
||||
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
return SuccessfulAttackResult(initial_result, final_result,)
|
||||
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
|
||||
|
||||
Reference in New Issue
Block a user