1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/shared/utils.py

268 lines
8.7 KiB
Python

import filelock
import logging
import os
import pathlib
import requests
import shutil
import tempfile
import torch
import tqdm
import yaml
import zipfile
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def path_in_cache(file_path):
textattack_cache_dir = config('CACHE_DIR')
if not os.path.exists(textattack_cache_dir):
os.makedirs(textattack_cache_dir)
return os.path.join(textattack_cache_dir, file_path)
def s3_url(uri):
return 'https://textattack.s3.amazonaws.com/' + uri
def download_if_needed(folder_name):
""" Folder name will be saved as `.cache/textattack/[folder name]`. If it
doesn't exist on disk, the zip file will be downloaded and extracted.
Args:
folder_name (str): path to folder or file in cache
Returns:
str: path to the downloaded folder or file on disk
"""
cache_dest_path = path_in_cache(folder_name)
os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True)
# Use a lock to prevent concurrent downloads.
cache_dest_lock_path = cache_dest_path + '.lock'
cache_file_lock = filelock.FileLock(cache_dest_lock_path)
cache_file_lock.acquire()
# Check if already downloaded.
if os.path.exists(cache_dest_path):
cache_file_lock.release()
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
dir=config('CACHE_DIR'),
suffix='.zip', delete=False)
http_get(folder_name, downloaded_file)
# Move or unzip the file.
downloaded_file.close()
if zipfile.is_zipfile(downloaded_file.name):
unzip_file(downloaded_file.name, cache_dest_path)
else:
print('Copying', downloaded_file.name, 'to', cache_dest_path + '.')
shutil.copyfile(downloaded_file.name, cache_dest_path)
cache_file_lock.release()
# Remove the temporary file.
os.remove(downloaded_file.name)
print(f'Successfully saved {folder_name} to cache.')
return cache_dest_path
def unzip_file(path_to_zip_file, unzipped_folder_path):
""" Unzips a .zip file to folder path. """
print('Unzipping file', path_to_zip_file, 'to', unzipped_folder_path + '.')
enclosing_unzipped_path = pathlib.Path(unzipped_folder_path).parent
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
zip_ref.extractall(enclosing_unzipped_path)
def http_get(folder_name, out_file, proxies=None):
""" Get contents of a URL and save to a file.
https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
"""
folder_s3_url = s3_url(folder_name)
print(f'Downloading {folder_s3_url}.')
req = requests.get(folder_s3_url, stream=True, proxies=proxies)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
if req.status_code == 403: # Not found on AWS
raise Exception(f'Could not find {folder_name} on server.')
progress = tqdm.tqdm(unit="B", unit_scale=True, total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
out_file.write(chunk)
progress.close()
def add_indent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def default_class_repr(self):
extra_params = []
for key in self.extra_repr_keys():
extra_params.append(' ('+key+')'+': {'+key+'}')
if len(extra_params):
extra_str = '\n' + '\n'.join(extra_params) + '\n'
extra_str = f'({extra_str})'
else:
extra_str = ''
extra_str = extra_str.format(**self.__dict__)
return f'{self.__class__.__name__}{extra_str}'
LABEL_COLORS = [
'red', 'green',
'blue', 'purple',
'yellow', 'orange',
'pink', 'cyan',
'gray', 'brown'
]
def color_from_label(label_num):
""" Colors for labels (arbitrary). """
label_num %= len(LABEL_COLORS)
return LABEL_COLORS[label_num]
def color_text(text, color=None, method=None):
if method is None:
return text
if method == 'html':
return f'<font color = {color}>{text}</font>'
elif method == 'stdout':
if color == 'green':
color = ANSI_ESCAPE_CODES.OKGREEN
elif color == 'red':
color = ANSI_ESCAPE_CODES.FAIL
elif color == 'blue':
color = ANSI_ESCAPE_CODES.OKBLUE
elif color == 'gray':
color = ANSI_ESCAPE_CODES.GRAY
else:
color = ANSI_ESCAPE_CODES.BOLD
return color + text + ANSI_ESCAPE_CODES.STOP
elif method == 'file':
return '[[' + text + ']]'
def words_from_text(s, words_to_ignore=[]):
""" Lowercases a string, removes all non-alphanumeric characters,
and splits into words. """
words = []
word = ''
for c in ' '.join(s.split()):
if c.isalpha():
word += c
elif word:
if word not in words_to_ignore: words.append(word)
word = ''
if len(word) and (word not in words_to_ignore):
words.append(word)
return words
class ANSI_ESCAPE_CODES:
""" Escape codes for printing color to the terminal. """
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
GRAY = '\033[37m'
FAIL = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
""" This color stops the current color sequence. """
STOP = '\033[0m'
def html_style_from_dict(style_dict):
""" Turns
{ 'color': 'red', 'height': '100px'}
into
style: "color: red; height: 100px"
"""
style_str = ''
for key in style_dict:
style_str += key + ': ' + style_dict[key] + ';'
return 'style="{}"'.format(style_str)
def html_table_from_rows(rows, title=None, header=None, style_dict=None):
# Stylize the container div.
if style_dict:
table_html = '<div {}>'.format(style_from_dict(style_dict))
else:
table_html = '<div>'
# Print the title string.
if title:
table_html += '<h1>{}</h1>'.format(title)
# Construct each row as HTML.
table_html = '<table class="table">'
if header:
table_html += '<tr>'
for element in header:
table_html += '<th>'
table_html += str(element)
table_html += '</th>'
table_html += '</tr>'
for row in rows:
table_html += '<tr>'
for element in row:
table_html += '<td>'
table_html += str(element)
table_html += '</td>'
table_html += '</tr>'
# Close the table and print to screen.
table_html += '</table></div>'
return table_html
def has_letter(word):
""" Returns true if `word` contains at least one character in [A-Za-z]. """
for c in word:
if c.isalpha(): return True
return False
LOG_STRING = f'\033[34;1mtextattack\033[0m'
logger = None
def get_logger():
global logger
if not logger:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
formatter = logging.Formatter(f'{LOG_STRING} - %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False
return logger
def _post_install():
logger = get_logger()
logger.info('First time importing textattack: downloading remaining required packages.')
logger.info('Downloading spaCy required packages.')
import spacy
spacy.cli.download('en')
logger.info('Downloading NLTK required packages.')
import nltk
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
nltk.download('universal_tagset')
nltk.download('stopwords')
def _post_install_if_needed():
""" Runs _post_install if hasn't been run since install. """
# Check for post-install file.
post_install_file_path = os.path.join(config('CACHE_DIR'), 'post_install_check')
if os.path.exists(post_install_file_path):
return
# Run post-install.
_post_install()
# Create file that indicates post-install completed.
open(post_install_file_path, 'w').close()
def config(key):
return config_dict[key]
config_dict = {'CACHE_DIR': os.path.expanduser('~/.cache/textattack')}
config_path = download_if_needed('config.yaml')
config_dict.update(yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader))
config_dict['CACHE_DIR'] = os.path.expanduser(config_dict['CACHE_DIR'])
_post_install_if_needed()