diff --git a/.github/workflows/check-formatting.yml b/.github/workflows/check-formatting.yml index f62d5caa..c7e3b10c 100644 --- a/.github/workflows/check-formatting.yml +++ b/.github/workflows/check-formatting.yml @@ -28,7 +28,7 @@ jobs: python -m pip install --upgrade pip setuptools wheel pip install black flake8 isort # Testing packages python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 - pip install -e . + pip install -e . ["dev"] - name: Check code format with black and isort run: | make lint diff --git a/.github/workflows/run-pytest.yml b/.github/workflows/run-pytest.yml index e185c68a..6c1b010a 100644 --- a/.github/workflows/run-pytest.yml +++ b/.github/workflows/run-pytest.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 @@ -29,7 +29,7 @@ jobs: pip install pytest pytest-xdist # Testing packages pip uninstall textattack --yes # Remove TA if it's already installed python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 - pip install . + pip install -e . ["dev"] pip freeze - name: Test with pytest run: | diff --git a/requirements.txt b/requirements.txt index 975e530f..b1ec020b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,19 +10,10 @@ nltk numpy<1.19.0 #TF 2.0 requires this pandas>=1.0.1 scipy==1.4.1 -sentence_transformers>0.2.6 -stanza torch==1.6 transformers==3.3.0 -tensorflow>=2 -tensorflow_hub -tensorflow_text>=2 -tensorboardX terminaltables tokenizers==0.8.1-rc2 tqdm -visdom -wandb word2number num2words -more-itertools diff --git a/setup.py b/setup.py index 1993a6cc..07a5db3f 100644 --- a/setup.py +++ b/setup.py @@ -13,13 +13,30 @@ extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-th extras["test"] = [ "black==20.8b1", "docformatter", - "isort==5.4.2", + "isort==5.6.4", "flake8", "pytest", "pytest-xdist", ] + +extras["tensorflow"] = [ + "tensorflow>=2", + "tensorflow_hub", + "tensorflow_text>=2", + "tensorboardX", +] + +extras["optional"] = [ + "sentence_transformers>0.2.6", + "stanza", + "visdom", + "wandb", +] + # For developers, install development tools along with all optional dependencies. -extras["dev"] = extras["docs"] + extras["test"] +extras["dev"] = ( + extras["docs"] + extras["test"] + extras["tensorflow"] + extras["optional"] +) setuptools.setup( name="textattack", diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index fa6e7647..a0a92d8b 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -408,7 +408,7 @@ def train_model(args): # Use Weights & Biases, if enabled. if args.enable_wandb: global wandb - import wandb + wandb = textattack.shared.utils.LazyLoader("wandb", globals(), "wandb") wandb.init(sync_tensorboard=True) diff --git a/textattack/loggers/visdom_logger.py b/textattack/loggers/visdom_logger.py index e6c1bc8b..80ab1b90 100644 --- a/textattack/loggers/visdom_logger.py +++ b/textattack/loggers/visdom_logger.py @@ -6,12 +6,12 @@ Attack Logs to Visdom import socket -from visdom import Visdom - -from textattack.shared.utils import html_table_from_rows +from textattack.shared.utils import LazyLoader, html_table_from_rows from .logger import Logger +visdom = LazyLoader("visdom", globals(), "visdom") + def port_is_open(port_num, hostname="127.0.0.1"): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -28,7 +28,7 @@ class VisdomLogger(Logger): def __init__(self, env="main", port=8097, hostname="localhost"): if not port_is_open(port, hostname=hostname): raise socket.error(f"Visdom not running on {hostname}:{port}") - self.vis = Visdom(port=port, server=hostname, env=env) + self.vis = visdom.Visdom(port=port, server=hostname, env=env) self.env = env self.port = port self.hostname = hostname @@ -41,7 +41,7 @@ class VisdomLogger(Logger): def __setstate__(self, state): self.__dict__ = state - self.vis = Visdom(port=self.port, server=self.hostname, env=self.env) + self.vis = visdom.Visdom(port=self.port, server=self.hostname, env=self.env) def log_attack_result(self, result): text_a, text_b = result.diff_color(color_method="html") diff --git a/textattack/loggers/weights_and_biases_logger.py b/textattack/loggers/weights_and_biases_logger.py index fa2cbe6b..523bb26e 100644 --- a/textattack/loggers/weights_and_biases_logger.py +++ b/textattack/loggers/weights_and_biases_logger.py @@ -4,7 +4,7 @@ Attack Logs to WandB """ -from textattack.shared.utils import html_table_from_rows +from textattack.shared.utils import LazyLoader, html_table_from_rows from .logger import Logger @@ -14,14 +14,14 @@ class WeightsAndBiasesLogger(Logger): def __init__(self, filename="", stdout=False): global wandb - import wandb + wandb = LazyLoader("wandb", globals(), "wandb") wandb.init(project="textattack", resume=True) self._result_table_rows = [] def __setstate__(self, state): global wandb - import wandb + wandb = LazyLoader("wandb", globals(), "wandb") self.__dict__ = state wandb.init(project="textattack", resume=True) diff --git a/textattack/shared/utils/importing.py b/textattack/shared/utils/importing.py index 21f8f453..dd570b61 100644 --- a/textattack/shared/utils/importing.py +++ b/textattack/shared/utils/importing.py @@ -23,7 +23,9 @@ class LazyLoader(types.ModuleType): module = importlib.import_module(self.__name__) except ModuleNotFoundError as e: raise ModuleNotFoundError( - f"Lazy module loader cannot find module named `{self.__name__}`. Please run `pip install {self.__name__}`." + f"Lazy module loader cannot find module named `{self.__name__}`. " + f"This might be because TextAttack does not automatically install some optional dependencies. " + f"Please run `pip install {self.__name__}` to install the package." ) from e self._parent_module_globals[self._local_name] = module diff --git a/textattack/shared/utils/install.py b/textattack/shared/utils/install.py index 83df8b05..2e0da3b7 100644 --- a/textattack/shared/utils/install.py +++ b/textattack/shared/utils/install.py @@ -116,10 +116,6 @@ def _post_install(): nltk.download("wordnet") nltk.download("punkt") - import stanza - - stanza.download("en") - def set_cache_dir(cache_dir): """Sets all relevant cache directories to ``TA_CACHE_DIR``.""" diff --git a/textattack/shared/utils/strings.py b/textattack/shared/utils/strings.py index 9c175a05..49243940 100644 --- a/textattack/shared/utils/strings.py +++ b/textattack/shared/utils/strings.py @@ -1,3 +1,6 @@ +import textattack + + def has_letter(word): """Returns true if `word` contains at least one character in [A-Za-z].""" # TODO implement w regex @@ -199,13 +202,14 @@ def zip_flair_result(pred, tag_type="pos-fast"): return word_list, pos_list +stanza = textattack.shared.utils.LazyLoader("stanza", globals(), "stanza") + + def zip_stanza_result(pred, tagset="universal"): """Takes the first sentence from a document from `stanza` and returns two lists, one of words and the other of their corresponding parts-of- speech.""" - from stanza.models.common.doc import Document - - if not isinstance(pred, Document): + if not isinstance(pred, stanza.models.common.doc.Document): raise TypeError("Result from Stanza POS tagger must be a `Document` object.") word_list = []