1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #184 from QData/patch-1

Lint Python code for syntax errors and undefined names
This commit is contained in:
Jack Morris
2020-07-03 20:00:23 -04:00
committed by GitHub
21 changed files with 71 additions and 31 deletions

View File

@@ -26,10 +26,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install black isort # Testing packages
pip install black flake8 isort # Testing packages
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e .
- name: Check code format with black and isort
run: |
black . --check
isort --check-only --recursive tests textattack
make lint

View File

@@ -3,9 +3,10 @@ format: FORCE ## Run black and isort (rewriting files)
isort --atomic --recursive tests textattack
lint: FORCE ## Run black (in check mode)
lint: FORCE ## Run black, isort, flake8 (in check mode)
black . --check
isort --check-only --recursive tests textattack
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8
test: FORCE ## Run tests using pytest
python -m pytest --dist=loadfile -n auto

View File

@@ -1,9 +1,3 @@
[flake8]
ignore = E203, E266, E501, W503
max-line-length = 120
per-file-ignores = __init__.py:F401
mypy_config = mypy.ini
[isort]
line_length = 88
skip = __init__.py
@@ -14,3 +8,11 @@ multi_line_output = 3
include_trailing_comma = True
use_parentheses = True
force_grid_wrap = 0
[flake8]
exclude = .git,__pycache__,wandb,build,dist
ignore = E203, E266, E501, W503, D203
max-complexity = 10
max-line-length = 120
mypy_config = mypy.ini
per-file-ignores = __init__.py:F401

View File

@@ -10,7 +10,7 @@ extras = {}
# Packages required for installing docs.
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
# Packages required for formatting code & running tests.
extras["test"] = ["black", "isort", "pytest", "pytest-xdist"]
extras["test"] = ["black", "isort", "flake8", "pytest", "pytest-xdist"]
# For developers, install development tools along with all optional dependencies.
extras["dev"] = extras["docs"] + extras["test"]

View File

@@ -165,7 +165,7 @@ class TestAttackedText:
)
for old_idx, new_idx in enumerate(new_text.attack_attrs["original_index_map"]):
assert (attacked_text.words[old_idx] == new_text.words[new_idx]) or (
new_i == -1
new_idx == -1
)
new_text = (
new_text.delete_word_at_index(0)

View File

@@ -17,7 +17,7 @@ class AttackResult:
if original_result is None:
raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult):
raise TypeError(f"Invalid original goal function result: {original_text}")
raise TypeError(f"Invalid original goal function result: {original_result}")
if perturbed_result is None:
raise ValueError("Attack perturbed result cannot be None")
elif not isinstance(perturbed_result, GoalFunctionResult):

View File

@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
import textattack
from textattack.shared.utils import default_class_repr
class Constraint:
class Constraint(ABC):
"""
An abstract class that represents constraints on adversial text examples.
Constraints evaluate whether transformations from a ``AttackedText`` to another
@@ -68,9 +71,9 @@ class Constraint:
current_text: The current ``AttackedText``.
original_text: The original ``AttackedText`` from which the attack began.
"""
if not isinstance(transformed_text, AttackedText):
if not isinstance(transformed_text, textattack.shared.AttackedText):
raise TypeError("transformed_text must be of type AttackedText")
if not isinstance(current_text, AttackedText):
if not isinstance(current_text, textattack.shared.AttackedText):
raise TypeError("current_text must be of type AttackedText")
try:
@@ -86,6 +89,7 @@ class Constraint:
transformed_text, current_text, original_text=original_text
)
@abstractmethod
def _check_constraint(self, transformed_text, current_text, original_text=None):
"""
Returns True if the constraint is fulfilled, False otherwise. Must be overridden by

View File

@@ -56,7 +56,7 @@ class GoogleLanguageModel(Constraint):
[t.words[word_swap_index] for t in transformed_texts]
)
if self.print_step:
print(prefix, swapped_words, suffix)
print(prefix, swapped_words)
probs = self.lm.get_words_probs(prefix, swapped_words)
return probs

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from textattack.constraints import Constraint
class LanguageModelConstraint(ABC, Constraint):
class LanguageModelConstraint(Constraint, ABC):
"""
Determines if two sentences have a swapped word that has a similar
probability according to a language model.

View File

@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from textattack.constraints import Constraint
class PreTransformationConstraint(Constraint):
class PreTransformationConstraint(Constraint, ABC):
"""
An abstract class that represents constraints which are applied before
the transformation. These restrict which words are allowed to be modified
@@ -22,6 +24,7 @@ class PreTransformationConstraint(Constraint):
return set(range(len(current_text.words)))
return self._get_modifiable_indices(current_text)
@abstractmethod
def _get_modifiable_indices(current_text):
"""
Returns the word indices in ``current_text`` which are able to be modified.
@@ -31,3 +34,8 @@ class PreTransformationConstraint(Constraint):
current_text: The ``AttackedText`` input to consider.
"""
raise NotImplementedError()
def _check_constraint(self):
raise RuntimeError(
"PreTransformationConstraints do not support `_check_constraint()`."
)

View File

@@ -72,7 +72,9 @@ class SentenceEncoder(Constraint):
The similarity between the starting and transformed text using the metric.
"""
try:
modified_index = next(iter(x_adv.attack_attrs["newly_modified_indices"]))
modified_index = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
except KeyError:
raise KeyError(
"Cannot apply sentence encoder constraint without `newly_modified_indices`"
@@ -111,7 +113,7 @@ class SentenceEncoder(Constraint):
``transformed_texts``. If ``transformed_texts`` is empty,
an empty tensor is returned
"""
# Return an empty tensor if x_adv_list is empty.
# Return an empty tensor if transformed_texts is empty.
# This prevents us from calling .repeat(x, 0), which throws an
# error on machines with multiple GPUs (pytorch 1.2).
if len(transformed_texts) == 0:
@@ -207,7 +209,7 @@ class SentenceEncoder(Constraint):
"Must provide original text when compare_with_original is true."
)
else:
scores = self._sim_score(current_text, transformed_texts)
scores = self._sim_score(current_text, transformed_text)
transformed_text.attack_attrs["similarity_score"] = score
return score >= self.threshold

View File

@@ -47,7 +47,7 @@ class WordEmbeddingDistance(Constraint):
mse_dist_file = "mse_dist.p"
cos_sim_file = "cos_sim.p"
else:
raise ValueError(f"Could not find word embedding {word_embedding}")
raise ValueError(f"Could not find word embedding {embedding_type}")
# Download embeddings if they're not cached.
word_embeddings_path = utils.download_if_needed(WordEmbeddingDistance.PATH)

View File

@@ -18,7 +18,7 @@ class InputReduction(ClassificationGoalFunction):
def _is_goal_complete(self, model_output, attacked_text):
return (
self.ground_truth_output == model_output.argmax()
and attacked_text.num_words <= target_num_words
and attacked_text.num_words <= self.target_num_words
)
def _should_skip(self, model_output, attacked_text):

View File

@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
import math
import lru
@@ -11,7 +12,7 @@ from textattack.shared import utils, validators
from textattack.shared.utils import batch_model_predict, default_class_repr
class GoalFunction:
class GoalFunction(ABC):
"""
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
@@ -124,24 +125,28 @@ class GoalFunction:
return GoalFunctionResultStatus.SUCCEEDED
return GoalFunctionResultStatus.SEARCHING
@abstractmethod
def _is_goal_complete(self, model_output, attacked_text):
raise NotImplementedError()
def _should_skip(self, model_output, attacked_text):
return self._is_goal_complete(model_output, attacked_text)
@abstractmethod
def _get_score(self, model_output, attacked_text):
raise NotImplementedError()
def _get_displayed_output(self, raw_output):
return raw_output
@abstractmethod
def _goal_function_result_type(self):
"""
Returns the class of this goal function's results.
"""
raise NotImplementedError()
@abstractmethod
def _process_model_outputs(self, inputs, outputs):
"""
Processes and validates a list of model outputs.

View File

@@ -62,7 +62,11 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
normalizers = []
if unicode_normalizer:
normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
normalizers += [
hf_tokenizers.normalizers.unicode_normalizer_from_str(
unicode_normalizer
)
]
if lowercase:
normalizers += [hf_tokenizers.normalizers.Lowercase()]

View File

@@ -105,7 +105,8 @@ class GeneticAlgorithm(SearchMethod):
A population member containing the crossover.
"""
x1_text = pop_member1.attacked_text
x2_words = pop_member2.attacked_text.words
x2_text = pop_member2.attacked_text
x2_words = x2_text.words
num_tries = 0
passed_constraints = False

View File

@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from textattack.shared.utils import default_class_repr
class SearchMethod:
class SearchMethod(ABC):
"""
This is an abstract class that contains main helper functionality for
search methods. A search method is a strategy for applying transformations
@@ -26,6 +28,7 @@ class SearchMethod:
)
return self._perform_search(initial_result)
@abstractmethod
def _perform_search(self, initial_result):
"""
Perturbs `attacked_text` from ``initial_result`` until goal is reached or search is

View File

@@ -24,7 +24,7 @@ def html_style_from_dict(style_dict):
def html_table_from_rows(rows, title=None, header=None, style_dict=None):
# Stylize the container div.
if style_dict:
table_html = "<div {}>".format(style_from_dict(style_dict))
table_html = "<div {}>".format(html_style_from_dict(style_dict))
else:
table_html = "<div>"
# Print the title string.

View File

@@ -24,7 +24,7 @@ class WordEmbedding:
mse_dist_file = "mse_dist.p"
cos_sim_file = "cos_sim.p"
else:
raise ValueError(f"Could not find word embedding {word_embedding}")
raise ValueError(f"Could not find word embedding {embedding_type}")
# Download embeddings if they're not cached.
word_embeddings_root_path = textattack.shared.utils.download_if_needed(

View File

@@ -21,6 +21,14 @@ class CompositeTransformation(Transformation):
raise ValueError("transformations cannot be empty")
self.transformations = transformations
def _get_transformations(self, *_):
""" Placeholder method that would throw an error if a user tried to
treat the CompositeTransformation as a 'normal' transformation.
"""
raise RuntimeError(
"CompositeTransformation does not support _get_transformations()."
)
def __call__(self, *args, **kwargs):
new_attacked_texts = set()
for transformation in self.transformations:

View File

@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from textattack.shared.utils import default_class_repr
class Transformation:
class Transformation(ABC):
"""
An abstract class for transforming a sequence of text to produce
a potential adversarial example.
@@ -44,6 +46,7 @@ class Transformation:
text.attack_attrs["last_transformation"] = self
return transformed_texts
@abstractmethod
def _get_transformations(self, current_text, indices_to_modify):
"""
Returns a list of all possible transformations for ``current_text``, only modifying