1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

move config.yaml to aws; add post-install hook; make logger nice

This commit is contained in:
Jack Morris
2020-05-09 21:43:24 -04:00
parent e3e43e31eb
commit e1451e62ac
6 changed files with 68 additions and 42 deletions

View File

@@ -1 +0,0 @@
include textattack/config.json

View File

@@ -30,22 +30,11 @@ You should be running Python 3.6+ to use this package. A CUDA-compatible GPU is
pip install textattack
```
We use the NLTK package for its list of stopwords and access to the WordNet lexical database. To download them run in Python shell:
```
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
```
We use spaCy's English model. To download it, after installing spaCy run:
```
python -m spacy download en
```
### Cache
TextAttack provides pretrained models and datasets for user convenience. By default, all this stuff is downloaded to `~/.cache`. You can change this location by editing the `CACHE_DIR` field in `config.json`.
### Configuration
TextAttack downloads files to `~/.cache/textattack/` by default. This includes
pretrained models, dataset samples, and the configuration file `config.yaml`.
Update `config.yaml` to change run-specific variables like the batch size
for model calls and the sizes of TextAttack's various caches.
## Usage

View File

@@ -5,6 +5,7 @@ lru-dict
nltk
numpy
pandas
pyyaml
scikit-learn
scipy
sentence_transformers

View File

@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setuptools.setup(
name="textattack",
version="0.0.1.7",
version="0.0.1.8",
author="QData Lab at the University of Virginia",
author_email="jm8wx@virginia.edu",
description="A library for generating text adversarial examples",

View File

@@ -1,6 +0,0 @@
{
"CACHE_DIR": "~/.cache/textattack/",
"CONSTRAINT_CACHE_SIZE": 262144,
"MODEL_BATCH_SIZE": 32,
"MODEL_CACHE_SIZE": 262144
}

View File

@@ -1,5 +1,4 @@
import filelock
import json
import logging
import os
import pathlib
@@ -8,33 +7,30 @@ import shutil
import tempfile
import torch
import tqdm
import yaml
import zipfile
dir_path = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(dir_path, os.pardir, 'config.json')
CONFIG = json.load(open(config_path, 'r'))
CONFIG['CACHE_DIR'] = os.path.expanduser(CONFIG['CACHE_DIR'])
def config(key):
return CONFIG[key]
def get_logger():
return logging.getLogger(__name__)
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def path_in_cache(file_path):
if not os.path.exists(CONFIG['CACHE_DIR']):
os.makedirs(CONFIG['CACHE_DIR'])
return os.path.join(CONFIG['CACHE_DIR'], file_path)
textattack_cache_dir = config('CACHE_DIR')
if not os.path.exists(textattack_cache_dir):
os.makedirs(textattack_cache_dir)
return os.path.join(textattack_cache_dir, file_path)
def s3_url(uri):
return 'https://textattack.s3.amazonaws.com/' + uri
def download_if_needed(folder_name):
""" Folder name will be saved as `.cache/textattack/[folder name]`. If it
doesn't exist, the zip file will be downloaded and extracted.
doesn't exist on disk, the zip file will be downloaded and extracted.
Args:
folder_name (str): path to folder or file in cache
Returns:
str: path to the downloaded folder or file on disk
"""
cache_dest_path = path_in_cache(folder_name)
os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True)
@@ -48,7 +44,7 @@ def download_if_needed(folder_name):
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
dir=CONFIG['CACHE_DIR'],
dir=config('CACHE_DIR'),
suffix='.zip', delete=False)
http_get(folder_name, downloaded_file)
# Move or unzip the file.
@@ -222,4 +218,51 @@ def has_letter(word):
""" Returns true if `word` contains at least one character in [A-Za-z]. """
for c in word:
if c.isalpha(): return True
return False
return False
LOG_STRING = f'\033[34;1mtextattack\033[0m'
logger = None
def get_logger():
global logger
if not logger:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
formatter = logging.Formatter(f'{LOG_STRING} - %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False
return logger
def _post_install():
logger = get_logger()
logger.info('First time importing textattack: downloading remaining required packages.')
logger.info('Downloading spaCy required packages.')
import spacy
spacy.cli.download('en')
logger.info('Downloading NLTK required packages.')
import nltk
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
nltk.download('universal_tagset')
nltk.download('stopwords')
def _post_install_if_needed():
""" Runs _post_install if hasn't been run since install. """
# Check for post-install file.
post_install_file_path = os.path.join(config('CACHE_DIR'), 'post_install_check')
if os.path.exists(post_install_file_path):
return
# Run post-install.
_post_install()
# Create file that indicates post-install completed.
open(post_install_file_path, 'w').close()
def config(key):
return config_dict[key]
config_dict = {'CACHE_DIR': os.path.expanduser('~/.cache/textattack')}
config_path = download_if_needed('config.yaml')
config_dict.update(yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader))
config_dict['CACHE_DIR'] = os.path.expanduser(config_dict['CACHE_DIR'])
_post_install_if_needed()