1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Reformatting-try1

This commit is contained in:
Hanyu Liu
2020-07-10 19:21:02 -04:00
parent 18563814bd
commit 974061c0aa
130 changed files with 1126 additions and 1229 deletions

View File

@@ -16,12 +16,11 @@ from textattack.shared import AttackedText, utils
class Attack:
"""
An attack generates adversarial examples on text.
This is an abstract class that contains main helper functionality for
attacks. An attack is comprised of a search method, goal function,
a transformation, and a set of one or more linguistic constraints that
"""An attack generates adversarial examples on text.
This is an abstract class that contains main helper functionality for
attacks. An attack is comprised of a search method, goal function,
a transformation, and a set of one or more linguistic constraints that
successful examples must meet.
Args:
@@ -40,7 +39,10 @@ class Attack:
search_method=None,
constraint_cache_size=2 ** 18,
):
""" Initialize an attack object. Attacks can be run multiple times. """
"""Initialize an attack object.
Attacks can be run multiple times.
"""
self.goal_function = goal_function
if not self.goal_function:
raise NameError(
@@ -84,10 +86,9 @@ class Attack:
self.search_method.filter_transformations = self.filter_transformations
def get_transformations(self, current_text, original_text=None, **kwargs):
"""
Applies ``self.transformation`` to ``text``, then filters the list of possible transformations
through the applicable constraints.
"""Applies ``self.transformation`` to ``text``, then filters the list
of possible transformations through the applicable constraints.
Args:
current_text: The current ``AttackedText`` on which to perform the transformations.
original_text: The original ``AttackedText`` from which the attack started.
@@ -95,7 +96,6 @@ class Attack:
Returns:
A filtered list of transformations where each transformation matches the constraints
"""
if not self.transformation:
raise RuntimeError(
@@ -116,9 +116,9 @@ class Attack:
def _filter_transformations_uncached(
self, transformed_texts, current_text, original_text=None
):
"""
Filters a list of potential transformaed texts based on ``self.constraints``\.
"""Filters a list of potential transformaed texts based on
``self.constraints``\.
Args:
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
current_text: The current ``AttackedText`` on which the transformation was applied.
@@ -148,10 +148,9 @@ class Attack:
def filter_transformations(
self, transformed_texts, current_text, original_text=None
):
"""
Filters a list of potential transformed texts based on ``self.constraints``\.
Checks cache first.
"""Filters a list of potential transformed texts based on
``self.constraints``\. Checks cache first.
Args:
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
current_text: The current ``AttackedText`` on which the transformation was applied.
@@ -179,15 +178,14 @@ class Attack:
return filtered_texts
def attack_one(self, initial_result):
"""
Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
"""Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
``initial_result``.
Args:
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
Returns:
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
or ``MaximizedAttackResult``.
"""
final_result = self.search_method(initial_result)
@@ -201,13 +199,12 @@ class Attack:
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
def _get_examples_from_dataset(self, dataset, indices=None):
"""
Gets examples from a dataset and tokenizes them.
"""Gets examples from a dataset and tokenizes them.
Args:
dataset: An iterable of (text_input, ground_truth_output) pairs
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
Returns:
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
"""
@@ -243,8 +240,7 @@ class Attack:
break
def attack_dataset(self, dataset, indices=None):
"""
Runs an attack on the given dataset and outputs the results to the
"""Runs an attack on the given dataset and outputs the results to the
console and the output file.
Args:
@@ -262,9 +258,8 @@ class Attack:
yield result
def __repr__(self):
"""
Prints attack parameters in a human-readable string.
"""Prints attack parameters in a human-readable string.
Inspired by the readability of printing PyTorch nn.Modules:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
"""