mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Reformatting-try1
This commit is contained in:
@@ -16,12 +16,11 @@ from textattack.shared import AttackedText, utils
|
||||
|
||||
|
||||
class Attack:
|
||||
"""
|
||||
An attack generates adversarial examples on text.
|
||||
|
||||
This is an abstract class that contains main helper functionality for
|
||||
attacks. An attack is comprised of a search method, goal function,
|
||||
a transformation, and a set of one or more linguistic constraints that
|
||||
"""An attack generates adversarial examples on text.
|
||||
|
||||
This is an abstract class that contains main helper functionality for
|
||||
attacks. An attack is comprised of a search method, goal function,
|
||||
a transformation, and a set of one or more linguistic constraints that
|
||||
successful examples must meet.
|
||||
|
||||
Args:
|
||||
@@ -40,7 +39,10 @@ class Attack:
|
||||
search_method=None,
|
||||
constraint_cache_size=2 ** 18,
|
||||
):
|
||||
""" Initialize an attack object. Attacks can be run multiple times. """
|
||||
"""Initialize an attack object.
|
||||
|
||||
Attacks can be run multiple times.
|
||||
"""
|
||||
self.goal_function = goal_function
|
||||
if not self.goal_function:
|
||||
raise NameError(
|
||||
@@ -84,10 +86,9 @@ class Attack:
|
||||
self.search_method.filter_transformations = self.filter_transformations
|
||||
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
"""
|
||||
Applies ``self.transformation`` to ``text``, then filters the list of possible transformations
|
||||
through the applicable constraints.
|
||||
|
||||
"""Applies ``self.transformation`` to ``text``, then filters the list
|
||||
of possible transformations through the applicable constraints.
|
||||
|
||||
Args:
|
||||
current_text: The current ``AttackedText`` on which to perform the transformations.
|
||||
original_text: The original ``AttackedText`` from which the attack started.
|
||||
@@ -95,7 +96,6 @@ class Attack:
|
||||
|
||||
Returns:
|
||||
A filtered list of transformations where each transformation matches the constraints
|
||||
|
||||
"""
|
||||
if not self.transformation:
|
||||
raise RuntimeError(
|
||||
@@ -116,9 +116,9 @@ class Attack:
|
||||
def _filter_transformations_uncached(
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""
|
||||
Filters a list of potential transformaed texts based on ``self.constraints``\.
|
||||
|
||||
"""Filters a list of potential transformaed texts based on
|
||||
``self.constraints``\.
|
||||
|
||||
Args:
|
||||
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
|
||||
current_text: The current ``AttackedText`` on which the transformation was applied.
|
||||
@@ -148,10 +148,9 @@ class Attack:
|
||||
def filter_transformations(
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""
|
||||
Filters a list of potential transformed texts based on ``self.constraints``\.
|
||||
Checks cache first.
|
||||
|
||||
"""Filters a list of potential transformed texts based on
|
||||
``self.constraints``\. Checks cache first.
|
||||
|
||||
Args:
|
||||
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
|
||||
current_text: The current ``AttackedText`` on which the transformation was applied.
|
||||
@@ -179,15 +178,14 @@ class Attack:
|
||||
return filtered_texts
|
||||
|
||||
def attack_one(self, initial_result):
|
||||
"""
|
||||
Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
|
||||
"""Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
|
||||
``initial_result``.
|
||||
|
||||
Args:
|
||||
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
|
||||
|
||||
Returns:
|
||||
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
|
||||
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
|
||||
or ``MaximizedAttackResult``.
|
||||
"""
|
||||
final_result = self.search_method(initial_result)
|
||||
@@ -201,13 +199,12 @@ class Attack:
|
||||
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
|
||||
|
||||
def _get_examples_from_dataset(self, dataset, indices=None):
|
||||
"""
|
||||
Gets examples from a dataset and tokenizes them.
|
||||
"""Gets examples from a dataset and tokenizes them.
|
||||
|
||||
Args:
|
||||
dataset: An iterable of (text_input, ground_truth_output) pairs
|
||||
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
|
||||
|
||||
|
||||
Returns:
|
||||
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
|
||||
"""
|
||||
@@ -243,8 +240,7 @@ class Attack:
|
||||
break
|
||||
|
||||
def attack_dataset(self, dataset, indices=None):
|
||||
"""
|
||||
Runs an attack on the given dataset and outputs the results to the
|
||||
"""Runs an attack on the given dataset and outputs the results to the
|
||||
console and the output file.
|
||||
|
||||
Args:
|
||||
@@ -262,9 +258,8 @@ class Attack:
|
||||
yield result
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Prints attack parameters in a human-readable string.
|
||||
|
||||
"""Prints attack parameters in a human-readable string.
|
||||
|
||||
Inspired by the readability of printing PyTorch nn.Modules:
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user