mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
move config.yaml to aws; add post-install hook; make logger nice
This commit is contained in:
@@ -1 +0,0 @@
|
||||
include textattack/config.json
|
||||
21
README.md
21
README.md
@@ -30,22 +30,11 @@ You should be running Python 3.6+ to use this package. A CUDA-compatible GPU is
|
||||
pip install textattack
|
||||
```
|
||||
|
||||
We use the NLTK package for its list of stopwords and access to the WordNet lexical database. To download them run in Python shell:
|
||||
|
||||
```
|
||||
import nltk
|
||||
nltk.download('stopwords')
|
||||
nltk.download('wordnet')
|
||||
```
|
||||
|
||||
We use spaCy's English model. To download it, after installing spaCy run:
|
||||
|
||||
```
|
||||
python -m spacy download en
|
||||
```
|
||||
|
||||
### Cache
|
||||
TextAttack provides pretrained models and datasets for user convenience. By default, all this stuff is downloaded to `~/.cache`. You can change this location by editing the `CACHE_DIR` field in `config.json`.
|
||||
### Configuration
|
||||
TextAttack downloads files to `~/.cache/textattack/` by default. This includes
|
||||
pretrained models, dataset samples, and the configuration file `config.yaml`.
|
||||
Update `config.yaml` to change run-specific variables like the batch size
|
||||
for model calls and the sizes of TextAttack's various caches.
|
||||
|
||||
## Usage
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ lru-dict
|
||||
nltk
|
||||
numpy
|
||||
pandas
|
||||
pyyaml
|
||||
scikit-learn
|
||||
scipy
|
||||
sentence_transformers
|
||||
|
||||
2
setup.py
2
setup.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
|
||||
|
||||
setuptools.setup(
|
||||
name="textattack",
|
||||
version="0.0.1.7",
|
||||
version="0.0.1.8",
|
||||
author="QData Lab at the University of Virginia",
|
||||
author_email="jm8wx@virginia.edu",
|
||||
description="A library for generating text adversarial examples",
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"CACHE_DIR": "~/.cache/textattack/",
|
||||
"CONSTRAINT_CACHE_SIZE": 262144,
|
||||
"MODEL_BATCH_SIZE": 32,
|
||||
"MODEL_CACHE_SIZE": 262144
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
import filelock
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
@@ -8,33 +7,30 @@ import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
import tqdm
|
||||
import yaml
|
||||
import zipfile
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path = os.path.join(dir_path, os.pardir, 'config.json')
|
||||
CONFIG = json.load(open(config_path, 'r'))
|
||||
CONFIG['CACHE_DIR'] = os.path.expanduser(CONFIG['CACHE_DIR'])
|
||||
|
||||
def config(key):
|
||||
return CONFIG[key]
|
||||
|
||||
def get_logger():
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def get_device():
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def path_in_cache(file_path):
|
||||
if not os.path.exists(CONFIG['CACHE_DIR']):
|
||||
os.makedirs(CONFIG['CACHE_DIR'])
|
||||
return os.path.join(CONFIG['CACHE_DIR'], file_path)
|
||||
textattack_cache_dir = config('CACHE_DIR')
|
||||
if not os.path.exists(textattack_cache_dir):
|
||||
os.makedirs(textattack_cache_dir)
|
||||
return os.path.join(textattack_cache_dir, file_path)
|
||||
|
||||
def s3_url(uri):
|
||||
return 'https://textattack.s3.amazonaws.com/' + uri
|
||||
|
||||
def download_if_needed(folder_name):
|
||||
""" Folder name will be saved as `.cache/textattack/[folder name]`. If it
|
||||
doesn't exist, the zip file will be downloaded and extracted.
|
||||
doesn't exist on disk, the zip file will be downloaded and extracted.
|
||||
|
||||
Args:
|
||||
folder_name (str): path to folder or file in cache
|
||||
|
||||
Returns:
|
||||
str: path to the downloaded folder or file on disk
|
||||
"""
|
||||
cache_dest_path = path_in_cache(folder_name)
|
||||
os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True)
|
||||
@@ -48,7 +44,7 @@ def download_if_needed(folder_name):
|
||||
return cache_dest_path
|
||||
# If the file isn't found yet, download the zip file to the cache.
|
||||
downloaded_file = tempfile.NamedTemporaryFile(
|
||||
dir=CONFIG['CACHE_DIR'],
|
||||
dir=config('CACHE_DIR'),
|
||||
suffix='.zip', delete=False)
|
||||
http_get(folder_name, downloaded_file)
|
||||
# Move or unzip the file.
|
||||
@@ -222,4 +218,51 @@ def has_letter(word):
|
||||
""" Returns true if `word` contains at least one character in [A-Za-z]. """
|
||||
for c in word:
|
||||
if c.isalpha(): return True
|
||||
return False
|
||||
return False
|
||||
|
||||
LOG_STRING = f'\033[34;1mtextattack\033[0m'
|
||||
logger = None
|
||||
def get_logger():
|
||||
global logger
|
||||
if not logger:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
formatter = logging.Formatter(f'{LOG_STRING} - %(message)s')
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
def _post_install():
|
||||
logger = get_logger()
|
||||
logger.info('First time importing textattack: downloading remaining required packages.')
|
||||
logger.info('Downloading spaCy required packages.')
|
||||
import spacy
|
||||
spacy.cli.download('en')
|
||||
logger.info('Downloading NLTK required packages.')
|
||||
import nltk
|
||||
nltk.download('wordnet')
|
||||
nltk.download('averaged_perceptron_tagger')
|
||||
nltk.download('universal_tagset')
|
||||
nltk.download('stopwords')
|
||||
|
||||
def _post_install_if_needed():
|
||||
""" Runs _post_install if hasn't been run since install. """
|
||||
# Check for post-install file.
|
||||
post_install_file_path = os.path.join(config('CACHE_DIR'), 'post_install_check')
|
||||
if os.path.exists(post_install_file_path):
|
||||
return
|
||||
# Run post-install.
|
||||
_post_install()
|
||||
# Create file that indicates post-install completed.
|
||||
open(post_install_file_path, 'w').close()
|
||||
|
||||
def config(key):
|
||||
return config_dict[key]
|
||||
|
||||
config_dict = {'CACHE_DIR': os.path.expanduser('~/.cache/textattack')}
|
||||
config_path = download_if_needed('config.yaml')
|
||||
config_dict.update(yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader))
|
||||
config_dict['CACHE_DIR'] = os.path.expanduser(config_dict['CACHE_DIR'])
|
||||
_post_install_if_needed()
|
||||
Reference in New Issue
Block a user