mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #184 from QData/patch-1
Lint Python code for syntax errors and undefined names
This commit is contained in:
5
.github/workflows/check-formatting.yml
vendored
5
.github/workflows/check-formatting.yml
vendored
@@ -26,10 +26,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
pip install black isort # Testing packages
|
||||
pip install black flake8 isort # Testing packages
|
||||
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
|
||||
pip install -e .
|
||||
- name: Check code format with black and isort
|
||||
run: |
|
||||
black . --check
|
||||
isort --check-only --recursive tests textattack
|
||||
make lint
|
||||
|
||||
3
Makefile
3
Makefile
@@ -3,9 +3,10 @@ format: FORCE ## Run black and isort (rewriting files)
|
||||
isort --atomic --recursive tests textattack
|
||||
|
||||
|
||||
lint: FORCE ## Run black (in check mode)
|
||||
lint: FORCE ## Run black, isort, flake8 (in check mode)
|
||||
black . --check
|
||||
isort --check-only --recursive tests textattack
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8
|
||||
|
||||
test: FORCE ## Run tests using pytest
|
||||
python -m pytest --dist=loadfile -n auto
|
||||
|
||||
14
setup.cfg
14
setup.cfg
@@ -1,9 +1,3 @@
|
||||
[flake8]
|
||||
ignore = E203, E266, E501, W503
|
||||
max-line-length = 120
|
||||
per-file-ignores = __init__.py:F401
|
||||
mypy_config = mypy.ini
|
||||
|
||||
[isort]
|
||||
line_length = 88
|
||||
skip = __init__.py
|
||||
@@ -14,3 +8,11 @@ multi_line_output = 3
|
||||
include_trailing_comma = True
|
||||
use_parentheses = True
|
||||
force_grid_wrap = 0
|
||||
|
||||
[flake8]
|
||||
exclude = .git,__pycache__,wandb,build,dist
|
||||
ignore = E203, E266, E501, W503, D203
|
||||
max-complexity = 10
|
||||
max-line-length = 120
|
||||
mypy_config = mypy.ini
|
||||
per-file-ignores = __init__.py:F401
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ extras = {}
|
||||
# Packages required for installing docs.
|
||||
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
|
||||
# Packages required for formatting code & running tests.
|
||||
extras["test"] = ["black", "isort", "pytest", "pytest-xdist"]
|
||||
extras["test"] = ["black", "isort", "flake8", "pytest", "pytest-xdist"]
|
||||
# For developers, install development tools along with all optional dependencies.
|
||||
extras["dev"] = extras["docs"] + extras["test"]
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ class TestAttackedText:
|
||||
)
|
||||
for old_idx, new_idx in enumerate(new_text.attack_attrs["original_index_map"]):
|
||||
assert (attacked_text.words[old_idx] == new_text.words[new_idx]) or (
|
||||
new_i == -1
|
||||
new_idx == -1
|
||||
)
|
||||
new_text = (
|
||||
new_text.delete_word_at_index(0)
|
||||
|
||||
@@ -17,7 +17,7 @@ class AttackResult:
|
||||
if original_result is None:
|
||||
raise ValueError("Attack original result cannot be None")
|
||||
elif not isinstance(original_result, GoalFunctionResult):
|
||||
raise TypeError(f"Invalid original goal function result: {original_text}")
|
||||
raise TypeError(f"Invalid original goal function result: {original_result}")
|
||||
if perturbed_result is None:
|
||||
raise ValueError("Attack perturbed result cannot be None")
|
||||
elif not isinstance(perturbed_result, GoalFunctionResult):
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import textattack
|
||||
from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class Constraint:
|
||||
class Constraint(ABC):
|
||||
"""
|
||||
An abstract class that represents constraints on adversial text examples.
|
||||
Constraints evaluate whether transformations from a ``AttackedText`` to another
|
||||
@@ -68,9 +71,9 @@ class Constraint:
|
||||
current_text: The current ``AttackedText``.
|
||||
original_text: The original ``AttackedText`` from which the attack began.
|
||||
"""
|
||||
if not isinstance(transformed_text, AttackedText):
|
||||
if not isinstance(transformed_text, textattack.shared.AttackedText):
|
||||
raise TypeError("transformed_text must be of type AttackedText")
|
||||
if not isinstance(current_text, AttackedText):
|
||||
if not isinstance(current_text, textattack.shared.AttackedText):
|
||||
raise TypeError("current_text must be of type AttackedText")
|
||||
|
||||
try:
|
||||
@@ -86,6 +89,7 @@ class Constraint:
|
||||
transformed_text, current_text, original_text=original_text
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
"""
|
||||
Returns True if the constraint is fulfilled, False otherwise. Must be overridden by
|
||||
|
||||
@@ -56,7 +56,7 @@ class GoogleLanguageModel(Constraint):
|
||||
[t.words[word_swap_index] for t in transformed_texts]
|
||||
)
|
||||
if self.print_step:
|
||||
print(prefix, swapped_words, suffix)
|
||||
print(prefix, swapped_words)
|
||||
probs = self.lm.get_words_probs(prefix, swapped_words)
|
||||
return probs
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from textattack.constraints import Constraint
|
||||
|
||||
|
||||
class LanguageModelConstraint(ABC, Constraint):
|
||||
class LanguageModelConstraint(Constraint, ABC):
|
||||
"""
|
||||
Determines if two sentences have a swapped word that has a similar
|
||||
probability according to a language model.
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
|
||||
|
||||
class PreTransformationConstraint(Constraint):
|
||||
class PreTransformationConstraint(Constraint, ABC):
|
||||
"""
|
||||
An abstract class that represents constraints which are applied before
|
||||
the transformation. These restrict which words are allowed to be modified
|
||||
@@ -22,6 +24,7 @@ class PreTransformationConstraint(Constraint):
|
||||
return set(range(len(current_text.words)))
|
||||
return self._get_modifiable_indices(current_text)
|
||||
|
||||
@abstractmethod
|
||||
def _get_modifiable_indices(current_text):
|
||||
"""
|
||||
Returns the word indices in ``current_text`` which are able to be modified.
|
||||
@@ -31,3 +34,8 @@ class PreTransformationConstraint(Constraint):
|
||||
current_text: The ``AttackedText`` input to consider.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _check_constraint(self):
|
||||
raise RuntimeError(
|
||||
"PreTransformationConstraints do not support `_check_constraint()`."
|
||||
)
|
||||
|
||||
@@ -72,7 +72,9 @@ class SentenceEncoder(Constraint):
|
||||
The similarity between the starting and transformed text using the metric.
|
||||
"""
|
||||
try:
|
||||
modified_index = next(iter(x_adv.attack_attrs["newly_modified_indices"]))
|
||||
modified_index = next(
|
||||
iter(transformed_text.attack_attrs["newly_modified_indices"])
|
||||
)
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"Cannot apply sentence encoder constraint without `newly_modified_indices`"
|
||||
@@ -111,7 +113,7 @@ class SentenceEncoder(Constraint):
|
||||
``transformed_texts``. If ``transformed_texts`` is empty,
|
||||
an empty tensor is returned
|
||||
"""
|
||||
# Return an empty tensor if x_adv_list is empty.
|
||||
# Return an empty tensor if transformed_texts is empty.
|
||||
# This prevents us from calling .repeat(x, 0), which throws an
|
||||
# error on machines with multiple GPUs (pytorch 1.2).
|
||||
if len(transformed_texts) == 0:
|
||||
@@ -207,7 +209,7 @@ class SentenceEncoder(Constraint):
|
||||
"Must provide original text when compare_with_original is true."
|
||||
)
|
||||
else:
|
||||
scores = self._sim_score(current_text, transformed_texts)
|
||||
scores = self._sim_score(current_text, transformed_text)
|
||||
transformed_text.attack_attrs["similarity_score"] = score
|
||||
return score >= self.threshold
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class WordEmbeddingDistance(Constraint):
|
||||
mse_dist_file = "mse_dist.p"
|
||||
cos_sim_file = "cos_sim.p"
|
||||
else:
|
||||
raise ValueError(f"Could not find word embedding {word_embedding}")
|
||||
raise ValueError(f"Could not find word embedding {embedding_type}")
|
||||
|
||||
# Download embeddings if they're not cached.
|
||||
word_embeddings_path = utils.download_if_needed(WordEmbeddingDistance.PATH)
|
||||
|
||||
@@ -18,7 +18,7 @@ class InputReduction(ClassificationGoalFunction):
|
||||
def _is_goal_complete(self, model_output, attacked_text):
|
||||
return (
|
||||
self.ground_truth_output == model_output.argmax()
|
||||
and attacked_text.num_words <= target_num_words
|
||||
and attacked_text.num_words <= self.target_num_words
|
||||
)
|
||||
|
||||
def _should_skip(self, model_output, attacked_text):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
|
||||
import lru
|
||||
@@ -11,7 +12,7 @@ from textattack.shared import utils, validators
|
||||
from textattack.shared.utils import batch_model_predict, default_class_repr
|
||||
|
||||
|
||||
class GoalFunction:
|
||||
class GoalFunction(ABC):
|
||||
"""
|
||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||
|
||||
@@ -124,24 +125,28 @@ class GoalFunction:
|
||||
return GoalFunctionResultStatus.SUCCEEDED
|
||||
return GoalFunctionResultStatus.SEARCHING
|
||||
|
||||
@abstractmethod
|
||||
def _is_goal_complete(self, model_output, attacked_text):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _should_skip(self, model_output, attacked_text):
|
||||
return self._is_goal_complete(model_output, attacked_text)
|
||||
|
||||
@abstractmethod
|
||||
def _get_score(self, model_output, attacked_text):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_displayed_output(self, raw_output):
|
||||
return raw_output
|
||||
|
||||
@abstractmethod
|
||||
def _goal_function_result_type(self):
|
||||
"""
|
||||
Returns the class of this goal function's results.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def _process_model_outputs(self, inputs, outputs):
|
||||
"""
|
||||
Processes and validates a list of model outputs.
|
||||
|
||||
@@ -62,7 +62,11 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
|
||||
normalizers = []
|
||||
|
||||
if unicode_normalizer:
|
||||
normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
|
||||
normalizers += [
|
||||
hf_tokenizers.normalizers.unicode_normalizer_from_str(
|
||||
unicode_normalizer
|
||||
)
|
||||
]
|
||||
|
||||
if lowercase:
|
||||
normalizers += [hf_tokenizers.normalizers.Lowercase()]
|
||||
|
||||
@@ -105,7 +105,8 @@ class GeneticAlgorithm(SearchMethod):
|
||||
A population member containing the crossover.
|
||||
"""
|
||||
x1_text = pop_member1.attacked_text
|
||||
x2_words = pop_member2.attacked_text.words
|
||||
x2_text = pop_member2.attacked_text
|
||||
x2_words = x2_text.words
|
||||
|
||||
num_tries = 0
|
||||
passed_constraints = False
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class SearchMethod:
|
||||
class SearchMethod(ABC):
|
||||
"""
|
||||
This is an abstract class that contains main helper functionality for
|
||||
search methods. A search method is a strategy for applying transformations
|
||||
@@ -26,6 +28,7 @@ class SearchMethod:
|
||||
)
|
||||
return self._perform_search(initial_result)
|
||||
|
||||
@abstractmethod
|
||||
def _perform_search(self, initial_result):
|
||||
"""
|
||||
Perturbs `attacked_text` from ``initial_result`` until goal is reached or search is
|
||||
|
||||
@@ -24,7 +24,7 @@ def html_style_from_dict(style_dict):
|
||||
def html_table_from_rows(rows, title=None, header=None, style_dict=None):
|
||||
# Stylize the container div.
|
||||
if style_dict:
|
||||
table_html = "<div {}>".format(style_from_dict(style_dict))
|
||||
table_html = "<div {}>".format(html_style_from_dict(style_dict))
|
||||
else:
|
||||
table_html = "<div>"
|
||||
# Print the title string.
|
||||
|
||||
@@ -24,7 +24,7 @@ class WordEmbedding:
|
||||
mse_dist_file = "mse_dist.p"
|
||||
cos_sim_file = "cos_sim.p"
|
||||
else:
|
||||
raise ValueError(f"Could not find word embedding {word_embedding}")
|
||||
raise ValueError(f"Could not find word embedding {embedding_type}")
|
||||
|
||||
# Download embeddings if they're not cached.
|
||||
word_embeddings_root_path = textattack.shared.utils.download_if_needed(
|
||||
|
||||
@@ -21,6 +21,14 @@ class CompositeTransformation(Transformation):
|
||||
raise ValueError("transformations cannot be empty")
|
||||
self.transformations = transformations
|
||||
|
||||
def _get_transformations(self, *_):
|
||||
""" Placeholder method that would throw an error if a user tried to
|
||||
treat the CompositeTransformation as a 'normal' transformation.
|
||||
"""
|
||||
raise RuntimeError(
|
||||
"CompositeTransformation does not support _get_transformations()."
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
new_attacked_texts = set()
|
||||
for transformation in self.transformations:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class Transformation:
|
||||
class Transformation(ABC):
|
||||
"""
|
||||
An abstract class for transforming a sequence of text to produce
|
||||
a potential adversarial example.
|
||||
@@ -44,6 +46,7 @@ class Transformation:
|
||||
text.attack_attrs["last_transformation"] = self
|
||||
return transformed_texts
|
||||
|
||||
@abstractmethod
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
"""
|
||||
Returns a list of all possible transformations for ``current_text``, only modifying
|
||||
|
||||
Reference in New Issue
Block a user