1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix flair+isort version bug; fix alzantot abstract constraint class bug

This commit is contained in:
Jack Morris
2020-07-06 10:57:50 -04:00
parent 99fedba104
commit 489fd92ace
15 changed files with 30 additions and 23 deletions

View File

@@ -1,11 +1,11 @@
format: FORCE ## Run black and isort (rewriting files)
black .
isort --atomic --recursive tests textattack
isort --atomic tests textattack
lint: FORCE ## Run black, isort, flake8 (in check mode)
black . --check
isort --check-only --recursive tests textattack
isort --check-only tests textattack
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8
test: FORCE ## Run tests using pytest

View File

@@ -1,5 +1,6 @@
bert-score
editdistance
flair==0.5.1
filelock
language_tool_python
lru-dict
@@ -20,4 +21,3 @@ tokenizers==0.8.0-rc4
tqdm
visdom
wandb
flair

View File

@@ -10,7 +10,7 @@ extras = {}
# Packages required for installing docs.
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
# Packages required for formatting code & running tests.
extras["test"] = ["black", "isort", "flake8", "pytest", "pytest-xdist"]
extras["test"] = ["black", "isort==5.0.3", "flake8", "pytest", "pytest-xdist"]
# For developers, install development tools along with all optional dependencies.
extras["dev"] = extras["docs"] + extras["test"]

View File

@@ -1,9 +1,8 @@
import pdb
import re
import pytest
from helpers import run_command_and_get_result
import pytest
DEBUG = False

View File

@@ -1,6 +1,5 @@
import pytest
from helpers import run_command_and_get_result
import pytest
augment_test_params = [
(

View File

@@ -1,6 +1,5 @@
import pytest
from helpers import run_command_and_get_result
import pytest
list_test_params = [
(

View File

@@ -1,7 +1,8 @@
def test_imports():
import textattack
import torch
import textattack
del textattack, torch

View File

@@ -122,13 +122,12 @@ class CharSwapAugmenter(Augmenter):
""" Augments words by swapping characters out for other characters. """
def __init__(self, **kwargs):
from textattack.transformations import CompositeTransformation
from textattack.transformations import (
CompositeTransformation,
WordSwapNeighboringCharacterSwap,
WordSwapRandomCharacterDeletion,
WordSwapRandomCharacterInsertion,
WordSwapRandomCharacterSubstitution,
WordSwapNeighboringCharacterSwap,
)
transformation = CompositeTransformation(

View File

@@ -51,6 +51,9 @@ class GoogleLanguageModel(Constraint):
def get_probs(current_text, transformed_texts):
word_swap_index = current_text.first_word_diff_index(transformed_texts[0])
if word_swap_index is None:
return []
prefix = current_text.words[word_swap_index - 1]
swapped_words = np.array(
[t.words[word_swap_index] for t in transformed_texts]
@@ -75,6 +78,9 @@ class GoogleLanguageModel(Constraint):
probs_of_swaps_at_index = list(
zip(item_indices, get_probs(current_text, this_transformed_texts))
)
# if len(probs_of_swaps_at_index) == 0:
# probs.extend(0)
# continue
# Sort by probability in descending order and take the top n for this index.
probs_of_swaps_at_index.sort(key=lambda x: -x[1])
if self.top_n_per_index:
@@ -104,6 +110,11 @@ class GoogleLanguageModel(Constraint):
return [transformed_texts[i] for i in max_el_indices]
def _check_constraint(self, transformed_text, current_text, original_text=None):
return self._check_constraint_many(
[transformed_text], current_text, original_text=original_text
)
def __call__(self, x, x_adv):
raise NotImplementedError()

View File

@@ -1,5 +1,5 @@
from torch import nn as nn
from torch.autograd import Variable
import torch.nn as nn
from .adaptive_softmax import AdaptiveSoftmax

View File

@@ -48,9 +48,9 @@ class PartOfSpeech(Constraint):
)
if self.tagger_type == "flair":
word_list, pos_list = zip_flair_result(
self._flair_pos_tagger.predict(context_key)[0]
)
context_key_sentence = Sentence(context_key)
self._flair_pos_tagger.predict(context_key_sentence)
word_list, pos_list = zip_flair_result(context_key_sentence)
self._pos_tag_cache[context_key] = (word_list, pos_list)

View File

@@ -8,12 +8,11 @@
"""
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
"""
import time
import numpy as np
import torch
import torch.nn as nn
from torch import nn as nn
class InferSentModel(nn.Module):

View File

@@ -2,7 +2,7 @@ import os
import numpy as np
import torch
import torch.nn as nn
from torch import nn as nn
from textattack.shared import logger, utils

View File

@@ -1,5 +1,5 @@
import torch
import torch.nn as nn
from torch import nn as nn
import textattack
from textattack.models.helpers import GloveEmbeddingLayer

View File

@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn as nn
from torch.nn import functional as F
import textattack
from textattack.models.helpers import GloveEmbeddingLayer