mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Reformatting-try2
This commit is contained in:
3
.idea/.gitignore
generated
vendored
Normal file
3
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
15
.idea/TextAttack.iml
generated
Normal file
15
.idea/TextAttack.iml
generated
Normal file
@@ -0,0 +1,15 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="pytest" />
|
||||
</component>
|
||||
</module>
|
||||
27
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
27
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,27 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="14">
|
||||
<item index="0" class="java.lang.String" itemvalue="tokenizers" />
|
||||
<item index="1" class="java.lang.String" itemvalue="transformers" />
|
||||
<item index="2" class="java.lang.String" itemvalue="tensorboardX" />
|
||||
<item index="3" class="java.lang.String" itemvalue="bert-score" />
|
||||
<item index="4" class="java.lang.String" itemvalue="flair" />
|
||||
<item index="5" class="java.lang.String" itemvalue="pandas" />
|
||||
<item index="6" class="java.lang.String" itemvalue="tqdm" />
|
||||
<item index="7" class="java.lang.String" itemvalue="six" />
|
||||
<item index="8" class="java.lang.String" itemvalue="joblib" />
|
||||
<item index="9" class="java.lang.String" itemvalue="scikit-learn" />
|
||||
<item index="10" class="java.lang.String" itemvalue="certifi" />
|
||||
<item index="11" class="java.lang.String" itemvalue="sklearn" />
|
||||
<item index="12" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="13" class="java.lang.String" itemvalue="pytz" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
4
.idea/misc.xml
generated
Normal file
4
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (textattack)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/TextAttack.iml" filepath="$PROJECT_DIR$/.idea/TextAttack.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -18,7 +18,7 @@ augment_test_params = [
|
||||
def test_command_line_augmentation(name, command, outfile, sample_output_file):
|
||||
import os
|
||||
|
||||
desired_text = open(sample_output_file).read().strip()
|
||||
# desired_text = open(sample_output_file).read().strip()
|
||||
|
||||
# Run command and validate outputs.
|
||||
result = run_command_and_get_result(command)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
name = "textattack"
|
||||
|
||||
from . import (
|
||||
attack_recipes,
|
||||
attack_results,
|
||||
@@ -15,3 +13,5 @@ from . import (
|
||||
shared,
|
||||
transformations,
|
||||
)
|
||||
|
||||
name = "textattack"
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from textattack.constraints.grammaticality.language_models import (
|
||||
Google1BillionWordsLanguageModel,
|
||||
)
|
||||
from textattack.constraints.overlap import MaxWordsPerturbed
|
||||
from textattack.constraints.pre_transformation import (
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||
|
||||
@@ -6,7 +6,7 @@ from textattack.constraints.pre_transformation import (
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
from textattack.search_methods import ParticleSwarmOptimization
|
||||
from textattack.shared.attack import Attack
|
||||
from textattack.transformations import WordSwapEmbedding, WordSwapHowNet
|
||||
from textattack.transformations import WordSwapHowNet
|
||||
|
||||
|
||||
def PSOZang2020(model):
|
||||
|
||||
@@ -3,7 +3,7 @@ from textattack.constraints.pre_transformation import (
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||
# from textattack.constraints.semantics import WordEmbeddingDistance
|
||||
from textattack.goal_functions import NonOverlappingOutput
|
||||
from textattack.search_methods import GreedyWordSwapWIR
|
||||
from textattack.shared.attack import Attack
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from textattack.constraints.grammaticality import PartOfSpeech
|
||||
# from textattack.constraints.grammaticality import PartOfSpeech
|
||||
from textattack.constraints.pre_transformation import (
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
import textattack
|
||||
from textattack.commands import TextAttackCommand
|
||||
from textattack.commands.attack.attack_args import *
|
||||
from textattack.commands.attack.attack_args import HUGGINGFACE_DATASET_BY_MODEL, TEXTATTACK_DATASET_BY_MODEL, SEARCH_METHOD_CLASS_NAMES, BLACK_BOX_TRANSFORMATION_CLASS_NAMES, WHITE_BOX_TRANSFORMATION_CLASS_NAMES, CONSTRAINT_CLASS_NAMES, GOAL_FUNCTION_CLASS_NAMES, ATTACK_RECIPE_NAMES
|
||||
from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES
|
||||
|
||||
|
||||
@@ -56,7 +57,7 @@ class ListThingsCommand(TextAttackCommand):
|
||||
try:
|
||||
list_of_things = ListThingsCommand.things()[args.feature]
|
||||
except KeyError:
|
||||
raise ValuError(f"Unknown list key {args.thing}")
|
||||
raise ValueError(f"Unknown list key {args.thing}")
|
||||
self._list(list_of_things, plain=args.plain)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
# import os
|
||||
# import sys
|
||||
|
||||
from textattack.commands.attack import AttackCommand, AttackResumeCommand
|
||||
from textattack.commands.augment import AugmentCommand
|
||||
|
||||
@@ -21,7 +21,7 @@ class Constraint(ABC):
|
||||
def call_many(self, transformed_texts, reference_text):
|
||||
"""Filters ``transformed_texts`` based on which transformations fulfill
|
||||
the constraint. First checks compatibility with latest
|
||||
``Transformation``, then calls ``_check_constraint_many``\.
|
||||
``Transformation``, then calls ``_check_constraint_many``
|
||||
|
||||
Args:
|
||||
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``'s.
|
||||
@@ -48,10 +48,10 @@ class Constraint(ABC):
|
||||
|
||||
def _check_constraint_many(self, transformed_texts, reference_text):
|
||||
"""Filters ``transformed_texts`` based on which transformations fulfill
|
||||
the constraint. Calls ``check_constraint``\.
|
||||
the constraint. Calls ``check_constraint``
|
||||
|
||||
Args:
|
||||
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``\s.
|
||||
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``
|
||||
reference_texts (AttackedText): The ``AttackedText`` to compare against.
|
||||
"""
|
||||
return [
|
||||
@@ -63,7 +63,7 @@ class Constraint(ABC):
|
||||
def __call__(self, transformed_text, reference_text):
|
||||
"""Returns True if the constraint is fulfilled, False otherwise. First
|
||||
checks compatibility with latest ``Transformation``, then calls
|
||||
``_check_constraint``\.
|
||||
``_check_constraint``
|
||||
|
||||
Args:
|
||||
transformed_text (AttackedText): The candidate transformed ``AttackedText``.
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
All rights reserved.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
# import sys
|
||||
|
||||
from google.protobuf import text_format
|
||||
# from google.protobuf import text_format
|
||||
import lru
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -62,7 +62,7 @@ class GoogleLanguageModel(Constraint):
|
||||
for word_swap_index, item_list in word_swap_index_map.items():
|
||||
# zip(*some_list) is the inverse operator of zip!
|
||||
item_indices, this_transformed_texts = zip(*item_list)
|
||||
t1 = time.time()
|
||||
# t1 = time.time()
|
||||
probs_of_swaps_at_index = list(
|
||||
zip(item_indices, get_probs(reference_text, this_transformed_texts))
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class GoogleLanguageModel(Constraint):
|
||||
: self.top_n_per_index
|
||||
]
|
||||
probs.extend(probs_of_swaps_at_index)
|
||||
t2 = time.time()
|
||||
# t2 = time.time()
|
||||
|
||||
# Probs is a list of (index, prob) where index is the corresponding
|
||||
# position in transformed_texts.
|
||||
|
||||
@@ -25,7 +25,7 @@ class QueryHandler:
|
||||
"""
|
||||
try:
|
||||
return self.try_query(sentences, swapped_words, batch_size=batch_size)
|
||||
except:
|
||||
except Exception:
|
||||
probs = []
|
||||
for s, w in zip(sentences, swapped_words):
|
||||
probs.append(self.try_query([s], [w], batch_size=1)[0])
|
||||
|
||||
@@ -4,7 +4,7 @@ import lru
|
||||
import nltk
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared import AttackedText
|
||||
# from textattack.shared import AttackedText
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
|
||||
|
||||
@@ -3,5 +3,3 @@ from .repeat_modification import RepeatModification
|
||||
from .input_column_modification import InputColumnModification
|
||||
from .max_word_index_modification import MaxWordIndexModification
|
||||
from .min_word_length import MinWordLength
|
||||
from .repeat_modification import RepeatModification
|
||||
from .stopword_modification import StopwordModification
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from textattack.constraints import PreTransformationConstraint
|
||||
from textattack.shared.utils import default_class_repr
|
||||
# from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class MaxWordIndexModification(PreTransformationConstraint):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from textattack.constraints import PreTransformationConstraint
|
||||
from textattack.shared.utils import default_class_repr
|
||||
# from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class RepeatModification(PreTransformationConstraint):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import nltk
|
||||
|
||||
from textattack.constraints import PreTransformationConstraint
|
||||
from textattack.shared.utils import default_class_repr
|
||||
# from textattack.shared.utils import default_class_repr
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class PreTransformationConstraint(ABC):
|
||||
def __call__(self, current_text, transformation):
|
||||
"""Returns the word indices in ``current_text`` which are able to be
|
||||
modified. First checks compatibility with ``transformation`` then calls
|
||||
``_get_modifiable_indices``\.
|
||||
``_get_modifiable_indices``
|
||||
|
||||
Args:
|
||||
current_text: The ``AttackedText`` input to consider.
|
||||
|
||||
@@ -9,7 +9,7 @@ text.
|
||||
"""
|
||||
|
||||
import bert_score
|
||||
import nltk
|
||||
# import nltk
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared import utils
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
# import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
# from textattack.constraints import Constraint
|
||||
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
|
||||
from textattack.shared import utils
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
# import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -67,7 +67,7 @@ class SentenceEncoder(Constraint):
|
||||
|
||||
Args:
|
||||
starting_text: The ``AttackedText``to use as a starting point.
|
||||
transformed_text: A transformed ``AttackedText``\.
|
||||
transformed_text: A transformed ``AttackedText``
|
||||
|
||||
Returns:
|
||||
The similarity between the starting and transformed text using the metric.
|
||||
@@ -106,7 +106,7 @@ class SentenceEncoder(Constraint):
|
||||
|
||||
Args:
|
||||
starting_text: The ``AttackedText``to use as a starting point.
|
||||
transformed_texts: A list of transformed ``AttackedText``\s.
|
||||
transformed_texts: A list of transformed ``AttackedText``
|
||||
|
||||
Returns:
|
||||
A list with the similarity between the ``starting_text`` and each of
|
||||
|
||||
@@ -2,8 +2,7 @@ import functools
|
||||
|
||||
import torch
|
||||
|
||||
from textattack.shared import AttackedText, WordEmbedding, utils
|
||||
|
||||
from textattack.shared import WordEmbedding, utils
|
||||
from .sentence_encoder import SentenceEncoder
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
# import os
|
||||
|
||||
import tensorflow as tf
|
||||
# import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared import AttackedText, utils
|
||||
from textattack.shared import utils
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import pickle
|
||||
import random
|
||||
from abc import ABC
|
||||
|
||||
from textattack.shared import utils
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import nlp
|
||||
|
||||
import textattack
|
||||
from textattack.datasets import TextAttackDataset
|
||||
from textattack.shared import AttackedText
|
||||
# from textattack.shared import AttackedText
|
||||
|
||||
|
||||
def _cb(s):
|
||||
@@ -60,7 +60,7 @@ class HuggingFaceNlpDataset(TextAttackDataset):
|
||||
provided in the ``nlp`` version of the dataset.
|
||||
- output_scale_factor (float): Factor to divide ground-truth outputs by.
|
||||
Generally, TextAttack goal functions require model outputs
|
||||
between 0 and 1. Some datasets test the model's \*correlation\*
|
||||
between 0 and 1. Some datasets test the model's correlation
|
||||
with ground-truth output, instead of its accuracy, so these
|
||||
outputs may be scaled arbitrarily.
|
||||
- shuffle (bool): Whether to shuffle the dataset on load.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
# import math
|
||||
|
||||
import lru
|
||||
import numpy as np
|
||||
# import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.goal_function_results.goal_function_result import (
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
# import torch
|
||||
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
# import os
|
||||
# import sys
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from textattack.attack_results import FailedAttackResult
|
||||
# from textattack.attack_results import FailedAttackResult
|
||||
from textattack.shared import logger
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import copy
|
||||
# import copy
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import copy
|
||||
# import copy
|
||||
import socket
|
||||
|
||||
from visdom import Visdom
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
# import torch
|
||||
import transformers
|
||||
|
||||
from textattack.shared import AttackedText
|
||||
# from textattack.shared import AttackedText
|
||||
|
||||
|
||||
class AutoTokenizer:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
# import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
# import numpy as np
|
||||
import tokenizers as hf_tokenizers
|
||||
|
||||
import textattack
|
||||
# import textattack
|
||||
|
||||
|
||||
class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections import deque
|
||||
import os
|
||||
# import os
|
||||
|
||||
import lru
|
||||
import numpy as np
|
||||
@@ -117,10 +117,10 @@ class Attack:
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""Filters a list of potential transformaed texts based on
|
||||
``self.constraints``\.
|
||||
``self.constraints``
|
||||
|
||||
Args:
|
||||
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
|
||||
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
||||
current_text: The current ``AttackedText`` on which the transformation was applied.
|
||||
original_text: The original ``AttackedText`` from which the attack started.
|
||||
"""
|
||||
@@ -149,10 +149,10 @@ class Attack:
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""Filters a list of potential transformed texts based on
|
||||
``self.constraints``\. Checks cache first.
|
||||
``self.constraints`` Checks cache first.
|
||||
|
||||
Args:
|
||||
transformed_texts: A list of candidate transformed ``AttackedText``\s to filter.
|
||||
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
||||
current_text: The current ``AttackedText`` on which the transformation was applied.
|
||||
original_text: The original ``AttackedText`` from which the attack started.
|
||||
"""
|
||||
|
||||
@@ -308,7 +308,7 @@ class AttackedText:
|
||||
pass
|
||||
new_attack_attrs["modified_indices"] = shifted_modified_indices
|
||||
# Track insertions and deletions wrt original text.
|
||||
original_modification_idx = i
|
||||
# original_modification_idx = i
|
||||
new_idx_map = new_attack_attrs["original_index_map"].copy()
|
||||
if num_words_diff == -1:
|
||||
new_idx_map[new_idx_map == i] = -1
|
||||
@@ -400,11 +400,11 @@ class AttackedText:
|
||||
# color the key.
|
||||
else:
|
||||
if key_color_method:
|
||||
ck = lambda k: textattack.shared.utils.color_text(
|
||||
k, key_color, key_color_method
|
||||
)
|
||||
def ck(k):
|
||||
return textattack.shared.utils.color_text(k, key_color, key_color_method)
|
||||
else:
|
||||
ck = lambda k: k
|
||||
def ck(k):
|
||||
return k
|
||||
return "\n".join(
|
||||
f"{ck(key.capitalize())}: {value}"
|
||||
for key, value in self._text_input.items()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import logging
|
||||
# import logging
|
||||
import logging.config
|
||||
import os
|
||||
import pathlib
|
||||
@@ -8,7 +8,7 @@ import zipfile
|
||||
|
||||
import filelock
|
||||
import requests
|
||||
import torch
|
||||
# import torch
|
||||
import tqdm
|
||||
|
||||
# Hide an error message from `tokenizers` if this process is forked.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import importlib
|
||||
# import importlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
from textattack.shared import utils
|
||||
# from textattack.shared import utils
|
||||
|
||||
|
||||
def batch_tokenize(tokenizer, attacked_text_list):
|
||||
@@ -112,7 +112,7 @@ def get_list_dim(ids):
|
||||
|
||||
def pad_lists(lists, pad_token=0):
|
||||
"""Pads lists with trailing zeros to make them all the same length."""
|
||||
max_list_len = max(len(l) for l in lists)
|
||||
max_list_len = max(len(list) for list in lists)
|
||||
for i in range(len(lists)):
|
||||
lists[i] += [pad_token] * (max_list_len - len(lists[i]))
|
||||
return lists
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import collections
|
||||
# import collections
|
||||
import re
|
||||
|
||||
import textattack
|
||||
from textattack.goal_functions import *
|
||||
from textattack.goal_functions import TargetedClassification, UntargetedClassification, InputReduction, NonOverlappingOutput
|
||||
|
||||
from . import logger
|
||||
|
||||
@@ -43,7 +43,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
|
||||
except KeyError:
|
||||
raise ValueError(f"No entry found for goal function {goal_function_class}.")
|
||||
# Get options for this goal function.
|
||||
model_module = model_class.__module__
|
||||
# model_module = model_class.__module__
|
||||
model_module_path = ".".join((model_class.__module__, model_class.__name__))
|
||||
# Ensure the model matches one of these options.
|
||||
for glob in matching_model_globs:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import numpy as np
|
||||
# import numpy as np
|
||||
|
||||
from textattack.shared import utils
|
||||
from textattack.transformations.transformation import Transformation
|
||||
@@ -9,7 +9,7 @@ class CompositeTransformation(Transformation):
|
||||
returning a set of all optoins.
|
||||
|
||||
Args:
|
||||
transformations: The list of ``Transformation``\s to apply.
|
||||
transformations: The list of ``Transformation`` to apply.
|
||||
"""
|
||||
|
||||
def __init__(self, transformations):
|
||||
|
||||
@@ -20,7 +20,7 @@ class Transformation(ABC):
|
||||
|
||||
Args:
|
||||
current_text: The ``AttackedText`` to transform.
|
||||
pre_transformation_constraints: The ``PreTransformationConstraint``\s to apply before
|
||||
pre_transformation_constraints: The ``PreTransformationConstraint`` to apply before
|
||||
beginning the transformation.
|
||||
indices_to_modify: Which word indices should be modified as dictated by the
|
||||
``SearchMethod``.
|
||||
|
||||
@@ -9,7 +9,7 @@ class WordDeletion(Transformation):
|
||||
"""
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
words = current_text.words
|
||||
# words = current_text.words
|
||||
transformed_texts = []
|
||||
if len(current_text.words) > 1:
|
||||
for i in indices_to_modify:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
# import nltk
|
||||
# from nltk.corpus import stopwords
|
||||
|
||||
from .transformation import Transformation
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class WordSwapEmbedding(WordSwap):
|
||||
self.max_candidates = max_candidates
|
||||
self.embedding_type = embedding_type
|
||||
if embedding_type == "paragramcf":
|
||||
word_embeddings_folder = "paragramcf"
|
||||
# word_embeddings_folder = "paragramcf"
|
||||
word_embeddings_file = "paragram.npy"
|
||||
word_list_file = "wordlist.pickle"
|
||||
nn_matrix_file = "nn.npy"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from textattack.shared import utils
|
||||
# from textattack.shared import utils
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
# import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
@@ -87,7 +87,7 @@ class WordSwapMaskedLM(WordSwap):
|
||||
|
||||
mask_token_probs = preds[0, masked_index]
|
||||
topk = torch.topk(mask_token_probs, self.max_candidates)
|
||||
top_logits = topk[0].tolist()
|
||||
# top_logits = topk[0].tolist()
|
||||
top_ids = topk[1].tolist()
|
||||
|
||||
replacement_words = []
|
||||
@@ -181,7 +181,7 @@ class WordSwapMaskedLM(WordSwap):
|
||||
raise ValueError(f"Unrecognized value {self.method} for `self.method`.")
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
extra_args = {}
|
||||
# extra_args = {}
|
||||
if self.method == "bert-attack":
|
||||
current_inputs = self._encode_text(current_text.text)
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from textattack.shared import utils
|
||||
# from textattack.shared import utils
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import copy
|
||||
# import copy
|
||||
import random
|
||||
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from textattack.shared import utils
|
||||
# from textattack.shared import utils
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from textattack.shared import utils
|
||||
# from textattack.shared import utils
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from textattack.shared import utils
|
||||
# from textattack.shared import utils
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ class WordSwapWordNet(WordSwap):
|
||||
replaced by a homoglyph."""
|
||||
synonyms = set()
|
||||
for syn in wordnet.synsets(word):
|
||||
for l in syn.lemmas():
|
||||
syn_word = l.name()
|
||||
for lemma in syn.lemmas():
|
||||
syn_word = lemma.name()
|
||||
if (
|
||||
(syn_word != word)
|
||||
and ("_" not in syn_word)
|
||||
|
||||
Reference in New Issue
Block a user