1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

make some dependencies optional

This commit is contained in:
Jin Yong Yoo
2020-10-30 17:31:00 -04:00
parent e6786b6bd7
commit 7a89aba559
10 changed files with 41 additions and 31 deletions

View File

@@ -28,7 +28,7 @@ jobs:
python -m pip install --upgrade pip setuptools wheel
pip install black flake8 isort # Testing packages
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e .
pip install -e . ["dev"]
- name: Check code format with black and isort
run: |
make lint

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
@@ -29,7 +29,7 @@ jobs:
pip install pytest pytest-xdist # Testing packages
pip uninstall textattack --yes # Remove TA if it's already installed
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install .
pip install -e . ["dev"]
pip freeze
- name: Test with pytest
run: |

View File

@@ -10,19 +10,10 @@ nltk
numpy<1.19.0 #TF 2.0 requires this
pandas>=1.0.1
scipy==1.4.1
sentence_transformers>0.2.6
stanza
torch==1.6
transformers==3.3.0
tensorflow>=2
tensorflow_hub
tensorflow_text>=2
tensorboardX
terminaltables
tokenizers==0.8.1-rc2
tqdm
visdom
wandb
word2number
num2words
more-itertools

View File

@@ -13,13 +13,30 @@ extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-th
extras["test"] = [
"black==20.8b1",
"docformatter",
"isort==5.4.2",
"isort==5.6.4",
"flake8",
"pytest",
"pytest-xdist",
]
extras["tensorflow"] = [
"tensorflow>=2",
"tensorflow_hub",
"tensorflow_text>=2",
"tensorboardX",
]
extras["optional"] = [
"sentence_transformers>0.2.6",
"stanza",
"visdom",
"wandb",
]
# For developers, install development tools along with all optional dependencies.
extras["dev"] = extras["docs"] + extras["test"]
extras["dev"] = (
extras["docs"] + extras["test"] + extras["tensorflow"] + extras["optional"]
)
setuptools.setup(
name="textattack",

View File

@@ -408,7 +408,7 @@ def train_model(args):
# Use Weights & Biases, if enabled.
if args.enable_wandb:
global wandb
import wandb
wandb = textattack.shared.utils.LazyLoader("wandb", globals(), "wandb")
wandb.init(sync_tensorboard=True)

View File

@@ -6,12 +6,12 @@ Attack Logs to Visdom
import socket
from visdom import Visdom
from textattack.shared.utils import html_table_from_rows
from textattack.shared.utils import LazyLoader, html_table_from_rows
from .logger import Logger
visdom = LazyLoader("visdom", globals(), "visdom")
def port_is_open(port_num, hostname="127.0.0.1"):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -28,7 +28,7 @@ class VisdomLogger(Logger):
def __init__(self, env="main", port=8097, hostname="localhost"):
if not port_is_open(port, hostname=hostname):
raise socket.error(f"Visdom not running on {hostname}:{port}")
self.vis = Visdom(port=port, server=hostname, env=env)
self.vis = visdom.Visdom(port=port, server=hostname, env=env)
self.env = env
self.port = port
self.hostname = hostname
@@ -41,7 +41,7 @@ class VisdomLogger(Logger):
def __setstate__(self, state):
self.__dict__ = state
self.vis = Visdom(port=self.port, server=self.hostname, env=self.env)
self.vis = visdom.Visdom(port=self.port, server=self.hostname, env=self.env)
def log_attack_result(self, result):
text_a, text_b = result.diff_color(color_method="html")

View File

@@ -4,7 +4,7 @@ Attack Logs to WandB
"""
from textattack.shared.utils import html_table_from_rows
from textattack.shared.utils import LazyLoader, html_table_from_rows
from .logger import Logger
@@ -14,14 +14,14 @@ class WeightsAndBiasesLogger(Logger):
def __init__(self, filename="", stdout=False):
global wandb
import wandb
wandb = LazyLoader("wandb", globals(), "wandb")
wandb.init(project="textattack", resume=True)
self._result_table_rows = []
def __setstate__(self, state):
global wandb
import wandb
wandb = LazyLoader("wandb", globals(), "wandb")
self.__dict__ = state
wandb.init(project="textattack", resume=True)

View File

@@ -23,7 +23,9 @@ class LazyLoader(types.ModuleType):
module = importlib.import_module(self.__name__)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Lazy module loader cannot find module named `{self.__name__}`. Please run `pip install {self.__name__}`."
f"Lazy module loader cannot find module named `{self.__name__}`. "
f"This might be because TextAttack does not automatically install some optional dependencies. "
f"Please run `pip install {self.__name__}` to install the package."
) from e
self._parent_module_globals[self._local_name] = module

View File

@@ -116,10 +116,6 @@ def _post_install():
nltk.download("wordnet")
nltk.download("punkt")
import stanza
stanza.download("en")
def set_cache_dir(cache_dir):
"""Sets all relevant cache directories to ``TA_CACHE_DIR``."""

View File

@@ -1,3 +1,6 @@
import textattack
def has_letter(word):
"""Returns true if `word` contains at least one character in [A-Za-z]."""
# TODO implement w regex
@@ -199,13 +202,14 @@ def zip_flair_result(pred, tag_type="pos-fast"):
return word_list, pos_list
stanza = textattack.shared.utils.LazyLoader("stanza", globals(), "stanza")
def zip_stanza_result(pred, tagset="universal"):
"""Takes the first sentence from a document from `stanza` and returns two
lists, one of words and the other of their corresponding parts-of-
speech."""
from stanza.models.common.doc import Document
if not isinstance(pred, Document):
if not isinstance(pred, stanza.models.common.doc.Document):
raise TypeError("Result from Stanza POS tagger must be a `Document` object.")
word_list = []