import importlib import json import os import random import numpy as np import torch import textattack device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def html_style_from_dict(style_dict): """ Turns { 'color': 'red', 'height': '100px'} into style: "color: red; height: 100px" """ style_str = "" for key in style_dict: style_str += key + ": " + style_dict[key] + ";" return 'style="{}"'.format(style_str) def html_table_from_rows(rows, title=None, header=None, style_dict=None): # Stylize the container div. if style_dict: table_html = "
".format(html_style_from_dict(style_dict)) else: table_html = "
" # Print the title string. if title: table_html += "

{}

".format(title) # Construct each row as HTML. table_html = '' if header: table_html += "" for element in header: table_html += "" table_html += "" for row in rows: table_html += "" for element in row: table_html += "" table_html += "" # Close the table and print to screen. table_html += "
" table_html += str(element) table_html += "
" table_html += str(element) table_html += "
" return table_html def load_textattack_model_from_path(model_name, model_path): """ Loads a pre-trained TextAttack model from its name and path. """ def get_num_labels(): model_cache_path = textattack.shared.utils.download_if_needed(model_path) train_args_path = os.path.join(model_cache_path, "train_args.json") if not os.path.exists(train_args_path): textattack.shared.logger.warn( f"train_args.json not found in model path {model_path}. Defaulting to 2 labels." ) return 2 else: args = json.loads(open(train_args_path).read()) return args["num_labels"] colored_model_name = textattack.shared.utils.color_text( model_name, color="blue", method="ansi" ) if model_name.startswith("lstm"): num_labels = get_num_labels() textattack.shared.logger.info( f"Loading pre-trained TextAttack LSTM: {colored_model_name}" ) model = textattack.models.helpers.LSTMForClassification( model_path=model_path, num_labels=num_labels ) elif model_name.startswith("cnn"): num_labels = get_num_labels() textattack.shared.logger.info( f"Loading pre-trained TextAttack CNN: {colored_model_name}" ) model = textattack.models.helpers.WordCNNForClassification( model_path=model_path, num_labels=num_labels ) elif model_name.startswith("bert"): model_path, num_labels = model_path textattack.shared.logger.info( f"Loading pre-trained TextAttack BERT model: {colored_model_name}" ) model = textattack.models.helpers.BERTForClassification( model_path=model_path, num_labels=num_labels ) elif model_name.startswith("t5"): model = textattack.models.helpers.T5ForTextToText(model_path) else: raise ValueError(f"Unknown textattack model {model_path}") return model def set_seed(random_seed): random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed)