1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Reformatting-try2

This commit is contained in:
Hanyu Liu
2020-07-10 21:49:20 -04:00
parent 974061c0aa
commit 4f8b227ef9
59 changed files with 156 additions and 93 deletions

3
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,3 @@
# Default ignored files
/shelf/
/workspace.xml

15
.idea/TextAttack.iml generated Normal file
View File

@@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="pytest" />
</component>
</module>

View File

@@ -0,0 +1,27 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="14">
<item index="0" class="java.lang.String" itemvalue="tokenizers" />
<item index="1" class="java.lang.String" itemvalue="transformers" />
<item index="2" class="java.lang.String" itemvalue="tensorboardX" />
<item index="3" class="java.lang.String" itemvalue="bert-score" />
<item index="4" class="java.lang.String" itemvalue="flair" />
<item index="5" class="java.lang.String" itemvalue="pandas" />
<item index="6" class="java.lang.String" itemvalue="tqdm" />
<item index="7" class="java.lang.String" itemvalue="six" />
<item index="8" class="java.lang.String" itemvalue="joblib" />
<item index="9" class="java.lang.String" itemvalue="scikit-learn" />
<item index="10" class="java.lang.String" itemvalue="certifi" />
<item index="11" class="java.lang.String" itemvalue="sklearn" />
<item index="12" class="java.lang.String" itemvalue="numpy" />
<item index="13" class="java.lang.String" itemvalue="pytz" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

4
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (textattack)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/TextAttack.iml" filepath="$PROJECT_DIR$/.idea/TextAttack.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

View File

@@ -18,7 +18,7 @@ augment_test_params = [
def test_command_line_augmentation(name, command, outfile, sample_output_file):
import os
desired_text = open(sample_output_file).read().strip()
# desired_text = open(sample_output_file).read().strip()
# Run command and validate outputs.
result = run_command_and_get_result(command)

View File

@@ -1,5 +1,3 @@
name = "textattack"
from . import (
attack_recipes,
attack_results,
@@ -15,3 +13,5 @@ from . import (
shared,
transformations,
)
name = "textattack"

View File

@@ -1,9 +1,5 @@
from textattack.constraints.grammaticality.language_models import (
Google1BillionWordsLanguageModel,
)
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance

View File

@@ -6,7 +6,7 @@ from textattack.constraints.pre_transformation import (
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import ParticleSwarmOptimization
from textattack.shared.attack import Attack
from textattack.transformations import WordSwapEmbedding, WordSwapHowNet
from textattack.transformations import WordSwapHowNet
def PSOZang2020(model):

View File

@@ -3,7 +3,7 @@ from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
# from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.goal_functions import NonOverlappingOutput
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared.attack import Attack

View File

@@ -1,4 +1,4 @@
from textattack.constraints.grammaticality import PartOfSpeech
# from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,

View File

@@ -1,7 +1,8 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import textattack
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args import *
from textattack.commands.attack.attack_args import HUGGINGFACE_DATASET_BY_MODEL, TEXTATTACK_DATASET_BY_MODEL, SEARCH_METHOD_CLASS_NAMES, BLACK_BOX_TRANSFORMATION_CLASS_NAMES, WHITE_BOX_TRANSFORMATION_CLASS_NAMES, CONSTRAINT_CLASS_NAMES, GOAL_FUNCTION_CLASS_NAMES, ATTACK_RECIPE_NAMES
from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES
@@ -56,7 +57,7 @@ class ListThingsCommand(TextAttackCommand):
try:
list_of_things = ListThingsCommand.things()[args.feature]
except KeyError:
raise ValuError(f"Unknown list key {args.thing}")
raise ValueError(f"Unknown list key {args.thing}")
self._list(list_of_things, plain=args.plain)
@staticmethod

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python
import argparse
import os
import sys
# import os
# import sys
from textattack.commands.attack import AttackCommand, AttackResumeCommand
from textattack.commands.augment import AugmentCommand

View File

@@ -21,7 +21,7 @@ class Constraint(ABC):
def call_many(self, transformed_texts, reference_text):
"""Filters ``transformed_texts`` based on which transformations fulfill
the constraint. First checks compatibility with latest
``Transformation``, then calls ``_check_constraint_many``\.
``Transformation``, then calls ``_check_constraint_many``
Args:
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``'s.
@@ -48,10 +48,10 @@ class Constraint(ABC):
def _check_constraint_many(self, transformed_texts, reference_text):
"""Filters ``transformed_texts`` based on which transformations fulfill
the constraint. Calls ``check_constraint``\.
the constraint. Calls ``check_constraint``
Args:
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``\s.
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``
reference_texts (AttackedText): The ``AttackedText`` to compare against.
"""
return [
@@ -63,7 +63,7 @@ class Constraint(ABC):
def __call__(self, transformed_text, reference_text):
"""Returns True if the constraint is fulfilled, False otherwise. First
checks compatibility with latest ``Transformation``, then calls
``_check_constraint``\.
``_check_constraint``
Args:
transformed_text (AttackedText): The candidate transformed ``AttackedText``.

View File

@@ -3,9 +3,9 @@
All rights reserved.
"""
import os
import sys
# import sys
from google.protobuf import text_format
# from google.protobuf import text_format
import lru
import numpy as np
import tensorflow as tf

View File

@@ -62,7 +62,7 @@ class GoogleLanguageModel(Constraint):
for word_swap_index, item_list in word_swap_index_map.items():
# zip(*some_list) is the inverse operator of zip!
item_indices, this_transformed_texts = zip(*item_list)
t1 = time.time()
# t1 = time.time()
probs_of_swaps_at_index = list(
zip(item_indices, get_probs(reference_text, this_transformed_texts))
)
@@ -73,7 +73,7 @@ class GoogleLanguageModel(Constraint):
: self.top_n_per_index
]
probs.extend(probs_of_swaps_at_index)
t2 = time.time()
# t2 = time.time()
# Probs is a list of (index, prob) where index is the corresponding
# position in transformed_texts.

View File

@@ -25,7 +25,7 @@ class QueryHandler:
"""
try:
return self.try_query(sentences, swapped_words, batch_size=batch_size)
except:
except Exception:
probs = []
for s, w in zip(sentences, swapped_words):
probs.append(self.try_query([s], [w], batch_size=1)[0])

View File

@@ -4,7 +4,7 @@ import lru
import nltk
from textattack.constraints import Constraint
from textattack.shared import AttackedText
# from textattack.shared import AttackedText
from textattack.shared.validators import transformation_consists_of_word_swaps

View File

@@ -3,5 +3,3 @@ from .repeat_modification import RepeatModification
from .input_column_modification import InputColumnModification
from .max_word_index_modification import MaxWordIndexModification
from .min_word_length import MinWordLength
from .repeat_modification import RepeatModification
from .stopword_modification import StopwordModification

View File

@@ -1,5 +1,5 @@
from textattack.constraints import PreTransformationConstraint
from textattack.shared.utils import default_class_repr
# from textattack.shared.utils import default_class_repr
class MaxWordIndexModification(PreTransformationConstraint):

View File

@@ -1,5 +1,5 @@
from textattack.constraints import PreTransformationConstraint
from textattack.shared.utils import default_class_repr
# from textattack.shared.utils import default_class_repr
class RepeatModification(PreTransformationConstraint):

View File

@@ -1,7 +1,7 @@
import nltk
from textattack.constraints import PreTransformationConstraint
from textattack.shared.utils import default_class_repr
# from textattack.shared.utils import default_class_repr
from textattack.shared.validators import transformation_consists_of_word_swaps

View File

@@ -15,7 +15,7 @@ class PreTransformationConstraint(ABC):
def __call__(self, current_text, transformation):
"""Returns the word indices in ``current_text`` which are able to be
modified. First checks compatibility with ``transformation`` then calls
``_get_modifiable_indices``\.
``_get_modifiable_indices``
Args:
current_text: The ``AttackedText`` input to consider.

View File

@@ -9,7 +9,7 @@ text.
"""
import bert_score
import nltk
# import nltk
from textattack.constraints import Constraint
from textattack.shared import utils

View File

@@ -1,9 +1,9 @@
import os
import numpy as np
# import numpy as np
import torch
from textattack.constraints import Constraint
# from textattack.constraints import Constraint
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
from textattack.shared import utils

View File

@@ -1,5 +1,5 @@
import math
import os
# import os
import numpy as np
import torch
@@ -67,7 +67,7 @@ class SentenceEncoder(Constraint):
Args:
starting_text: The ``AttackedText``to use as a starting point.
transformed_text: A transformed ``AttackedText``\.
transformed_text: A transformed ``AttackedText``
Returns:
The similarity between the starting and transformed text using the metric.
@@ -106,7 +106,7 @@ class SentenceEncoder(Constraint):
Args:
starting_text: The ``AttackedText``to use as a starting point.
transformed_texts: A list of transformed ``AttackedText``\s.
transformed_texts: A list of transformed ``AttackedText``
Returns:
A list with the similarity between the ``starting_text`` and each of

View File

@@ -2,8 +2,7 @@ import functools
import torch
from textattack.shared import AttackedText, WordEmbedding, utils
from textattack.shared import WordEmbedding, utils
from .sentence_encoder import SentenceEncoder

View File

@@ -1,6 +1,6 @@
import os
# import os
import tensorflow as tf
# import tensorflow as tf
import tensorflow_hub as hub
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder

View File

@@ -5,7 +5,7 @@ import numpy as np
import torch
from textattack.constraints import Constraint
from textattack.shared import AttackedText, utils
from textattack.shared import utils
from textattack.shared.validators import transformation_consists_of_word_swaps

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
import pickle
import random
from abc import ABC
from textattack.shared import utils

View File

@@ -5,7 +5,7 @@ import nlp
import textattack
from textattack.datasets import TextAttackDataset
from textattack.shared import AttackedText
# from textattack.shared import AttackedText
def _cb(s):
@@ -60,7 +60,7 @@ class HuggingFaceNlpDataset(TextAttackDataset):
provided in the ``nlp`` version of the dataset.
- output_scale_factor (float): Factor to divide ground-truth outputs by.
Generally, TextAttack goal functions require model outputs
between 0 and 1. Some datasets test the model's \*correlation\*
between 0 and 1. Some datasets test the model's correlation
with ground-truth output, instead of its accuracy, so these
outputs may be scaled arbitrarily.
- shuffle (bool): Whether to shuffle the dataset on load.

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
import math
# import math
import lru
import numpy as np
# import numpy as np
import torch
from textattack.goal_function_results.goal_function_result import (

View File

@@ -1,5 +1,5 @@
import numpy as np
import torch
# import torch
from textattack.attack_results import FailedAttackResult, SkippedAttackResult

View File

@@ -1,10 +1,10 @@
import csv
import os
import sys
# import os
# import sys
import pandas as pd
from textattack.attack_results import FailedAttackResult
# from textattack.attack_results import FailedAttackResult
from textattack.shared import logger
from .logger import Logger

View File

@@ -1,4 +1,4 @@
import copy
# import copy
import os
import sys

View File

@@ -1,4 +1,4 @@
import copy
# import copy
import socket
from visdom import Visdom

View File

@@ -1,7 +1,7 @@
import torch
# import torch
import transformers
from textattack.shared import AttackedText
# from textattack.shared import AttackedText
class AutoTokenizer:

View File

@@ -1,11 +1,11 @@
import json
import os
# import os
import tempfile
import numpy as np
# import numpy as np
import tokenizers as hf_tokenizers
import textattack
# import textattack
class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):

View File

@@ -1,5 +1,5 @@
from collections import deque
import os
# import os
import lru
import numpy as np
@@ -117,10 +117,10 @@ class Attack:
self, transformed_texts, current_text, original_text=None
):
"""Filters a list of potential transformaed texts based on
``self.constraints``\.
``self.constraints``
Args:
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
current_text: The current ``AttackedText`` on which the transformation was applied.
original_text: The original ``AttackedText`` from which the attack started.
"""
@@ -149,10 +149,10 @@ class Attack:
self, transformed_texts, current_text, original_text=None
):
"""Filters a list of potential transformed texts based on
``self.constraints``\. Checks cache first.
``self.constraints`` Checks cache first.
Args:
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
current_text: The current ``AttackedText`` on which the transformation was applied.
original_text: The original ``AttackedText`` from which the attack started.
"""

View File

@@ -308,7 +308,7 @@ class AttackedText:
pass
new_attack_attrs["modified_indices"] = shifted_modified_indices
# Track insertions and deletions wrt original text.
original_modification_idx = i
# original_modification_idx = i
new_idx_map = new_attack_attrs["original_index_map"].copy()
if num_words_diff == -1:
new_idx_map[new_idx_map == i] = -1
@@ -400,11 +400,11 @@ class AttackedText:
# color the key.
else:
if key_color_method:
ck = lambda k: textattack.shared.utils.color_text(
k, key_color, key_color_method
)
def ck(k):
return textattack.shared.utils.color_text(k, key_color, key_color_method)
else:
ck = lambda k: k
def ck(k):
return k
return "\n".join(
f"{ck(key.capitalize())}: {value}"
for key, value in self._text_input.items()

View File

@@ -1,4 +1,4 @@
import logging
# import logging
import logging.config
import os
import pathlib
@@ -8,7 +8,7 @@ import zipfile
import filelock
import requests
import torch
# import torch
import tqdm
# Hide an error message from `tokenizers` if this process is forked.

View File

@@ -1,4 +1,4 @@
import importlib
# import importlib
import json
import os
import random

View File

@@ -1,7 +1,7 @@
import torch
import textattack
from textattack.shared import utils
# from textattack.shared import utils
def batch_tokenize(tokenizer, attacked_text_list):
@@ -112,7 +112,7 @@ def get_list_dim(ids):
def pad_lists(lists, pad_token=0):
"""Pads lists with trailing zeros to make them all the same length."""
max_list_len = max(len(l) for l in lists)
max_list_len = max(len(list) for list in lists)
for i in range(len(lists)):
lists[i] += [pad_token] * (max_list_len - len(lists[i]))
return lists

View File

@@ -1,8 +1,8 @@
import collections
# import collections
import re
import textattack
from textattack.goal_functions import *
from textattack.goal_functions import TargetedClassification, UntargetedClassification, InputReduction, NonOverlappingOutput
from . import logger
@@ -43,7 +43,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
except KeyError:
raise ValueError(f"No entry found for goal function {goal_function_class}.")
# Get options for this goal function.
model_module = model_class.__module__
# model_module = model_class.__module__
model_module_path = ".".join((model_class.__module__, model_class.__name__))
# Ensure the model matches one of these options.
for glob in matching_model_globs:

View File

@@ -1,4 +1,4 @@
import numpy as np
# import numpy as np
from textattack.shared import utils
from textattack.transformations.transformation import Transformation
@@ -9,7 +9,7 @@ class CompositeTransformation(Transformation):
returning a set of all optoins.
Args:
transformations: The list of ``Transformation``\s to apply.
transformations: The list of ``Transformation`` to apply.
"""
def __init__(self, transformations):

View File

@@ -20,7 +20,7 @@ class Transformation(ABC):
Args:
current_text: The ``AttackedText`` to transform.
pre_transformation_constraints: The ``PreTransformationConstraint``\s to apply before
pre_transformation_constraints: The ``PreTransformationConstraint`` to apply before
beginning the transformation.
indices_to_modify: Which word indices should be modified as dictated by the
``SearchMethod``.

View File

@@ -9,7 +9,7 @@ class WordDeletion(Transformation):
"""
def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
# words = current_text.words
transformed_texts = []
if len(current_text.words) > 1:
for i in indices_to_modify:

View File

@@ -1,8 +1,8 @@
import random
import string
import nltk
from nltk.corpus import stopwords
# import nltk
# from nltk.corpus import stopwords
from .transformation import Transformation

View File

@@ -17,7 +17,7 @@ class WordSwapEmbedding(WordSwap):
self.max_candidates = max_candidates
self.embedding_type = embedding_type
if embedding_type == "paragramcf":
word_embeddings_folder = "paragramcf"
# word_embeddings_folder = "paragramcf"
word_embeddings_file = "paragram.npy"
word_list_file = "wordlist.pickle"
nn_matrix_file = "nn.npy"

View File

@@ -1,6 +1,6 @@
import numpy as np
from textattack.shared import utils
# from textattack.shared import utils
from textattack.transformations.word_swap import WordSwap

View File

@@ -1,6 +1,6 @@
import itertools
import numpy as np
# import numpy as np
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
@@ -87,7 +87,7 @@ class WordSwapMaskedLM(WordSwap):
mask_token_probs = preds[0, masked_index]
topk = torch.topk(mask_token_probs, self.max_candidates)
top_logits = topk[0].tolist()
# top_logits = topk[0].tolist()
top_ids = topk[1].tolist()
replacement_words = []
@@ -181,7 +181,7 @@ class WordSwapMaskedLM(WordSwap):
raise ValueError(f"Unrecognized value {self.method} for `self.method`.")
def _get_transformations(self, current_text, indices_to_modify):
extra_args = {}
# extra_args = {}
if self.method == "bert-attack":
current_inputs = self._encode_text(current_text.text)
with torch.no_grad():

View File

@@ -1,6 +1,6 @@
import numpy as np
from textattack.shared import utils
# from textattack.shared import utils
from textattack.transformations.word_swap import WordSwap

View File

@@ -1,4 +1,4 @@
import copy
# import copy
import random
from textattack.transformations.word_swap import WordSwap

View File

@@ -1,6 +1,6 @@
import numpy as np
from textattack.shared import utils
# from textattack.shared import utils
from textattack.transformations.word_swap import WordSwap

View File

@@ -1,6 +1,6 @@
import numpy as np
from textattack.shared import utils
# from textattack.shared import utils
from textattack.transformations.word_swap import WordSwap

View File

@@ -1,6 +1,6 @@
import numpy as np
from textattack.shared import utils
# from textattack.shared import utils
from textattack.transformations.word_swap import WordSwap

View File

@@ -13,8 +13,8 @@ class WordSwapWordNet(WordSwap):
replaced by a homoglyph."""
synonyms = set()
for syn in wordnet.synsets(word):
for l in syn.lemmas():
syn_word = l.name()
for lemma in syn.lemmas():
syn_word = lemma.name()
if (
(syn_word != word)
and ("_" not in syn_word)