mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix flair+isort version bug; fix alzantot abstract constraint class bug
This commit is contained in:
4
Makefile
4
Makefile
@@ -1,11 +1,11 @@
|
||||
format: FORCE ## Run black and isort (rewriting files)
|
||||
black .
|
||||
isort --atomic --recursive tests textattack
|
||||
isort --atomic tests textattack
|
||||
|
||||
|
||||
lint: FORCE ## Run black, isort, flake8 (in check mode)
|
||||
black . --check
|
||||
isort --check-only --recursive tests textattack
|
||||
isort --check-only tests textattack
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8
|
||||
|
||||
test: FORCE ## Run tests using pytest
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
bert-score
|
||||
editdistance
|
||||
flair==0.5.1
|
||||
filelock
|
||||
language_tool_python
|
||||
lru-dict
|
||||
@@ -20,4 +21,3 @@ tokenizers==0.8.0-rc4
|
||||
tqdm
|
||||
visdom
|
||||
wandb
|
||||
flair
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ extras = {}
|
||||
# Packages required for installing docs.
|
||||
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
|
||||
# Packages required for formatting code & running tests.
|
||||
extras["test"] = ["black", "isort", "flake8", "pytest", "pytest-xdist"]
|
||||
extras["test"] = ["black", "isort==5.0.3", "flake8", "pytest", "pytest-xdist"]
|
||||
# For developers, install development tools along with all optional dependencies.
|
||||
extras["dev"] = extras["docs"] + extras["test"]
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import pdb
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
import pytest
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
import pytest
|
||||
|
||||
augment_test_params = [
|
||||
(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
import pytest
|
||||
|
||||
list_test_params = [
|
||||
(
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
def test_imports():
|
||||
import textattack
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
|
||||
del textattack, torch
|
||||
|
||||
|
||||
|
||||
@@ -122,13 +122,12 @@ class CharSwapAugmenter(Augmenter):
|
||||
""" Augments words by swapping characters out for other characters. """
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from textattack.transformations import CompositeTransformation
|
||||
from textattack.transformations import (
|
||||
CompositeTransformation,
|
||||
WordSwapNeighboringCharacterSwap,
|
||||
WordSwapRandomCharacterDeletion,
|
||||
WordSwapRandomCharacterInsertion,
|
||||
WordSwapRandomCharacterSubstitution,
|
||||
WordSwapNeighboringCharacterSwap,
|
||||
)
|
||||
|
||||
transformation = CompositeTransformation(
|
||||
|
||||
@@ -51,6 +51,9 @@ class GoogleLanguageModel(Constraint):
|
||||
|
||||
def get_probs(current_text, transformed_texts):
|
||||
word_swap_index = current_text.first_word_diff_index(transformed_texts[0])
|
||||
if word_swap_index is None:
|
||||
return []
|
||||
|
||||
prefix = current_text.words[word_swap_index - 1]
|
||||
swapped_words = np.array(
|
||||
[t.words[word_swap_index] for t in transformed_texts]
|
||||
@@ -75,6 +78,9 @@ class GoogleLanguageModel(Constraint):
|
||||
probs_of_swaps_at_index = list(
|
||||
zip(item_indices, get_probs(current_text, this_transformed_texts))
|
||||
)
|
||||
# if len(probs_of_swaps_at_index) == 0:
|
||||
# probs.extend(0)
|
||||
# continue
|
||||
# Sort by probability in descending order and take the top n for this index.
|
||||
probs_of_swaps_at_index.sort(key=lambda x: -x[1])
|
||||
if self.top_n_per_index:
|
||||
@@ -104,6 +110,11 @@ class GoogleLanguageModel(Constraint):
|
||||
|
||||
return [transformed_texts[i] for i in max_el_indices]
|
||||
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
return self._check_constraint_many(
|
||||
[transformed_text], current_text, original_text=original_text
|
||||
)
|
||||
|
||||
def __call__(self, x, x_adv):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from torch import nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
|
||||
from .adaptive_softmax import AdaptiveSoftmax
|
||||
|
||||
|
||||
@@ -48,9 +48,9 @@ class PartOfSpeech(Constraint):
|
||||
)
|
||||
|
||||
if self.tagger_type == "flair":
|
||||
word_list, pos_list = zip_flair_result(
|
||||
self._flair_pos_tagger.predict(context_key)[0]
|
||||
)
|
||||
context_key_sentence = Sentence(context_key)
|
||||
self._flair_pos_tagger.predict(context_key_sentence)
|
||||
word_list, pos_list = zip_flair_result(context_key_sentence)
|
||||
|
||||
self._pos_tag_cache[context_key] = (word_list, pos_list)
|
||||
|
||||
|
||||
@@ -8,12 +8,11 @@
|
||||
"""
|
||||
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
class InferSentModel(nn.Module):
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn as nn
|
||||
|
||||
from textattack.shared import logger, utils
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn as nn
|
||||
|
||||
import textattack
|
||||
from textattack.models.helpers import GloveEmbeddingLayer
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import textattack
|
||||
from textattack.models.helpers import GloveEmbeddingLayer
|
||||
|
||||
Reference in New Issue
Block a user