1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

support training for entailment and fix test, update README

This commit is contained in:
Jack Morris
2020-06-20 23:46:52 -04:00
parent 6b8edfa4f9
commit 05e86a8f03
5 changed files with 33 additions and 17 deletions

View File

@@ -174,6 +174,23 @@ of a string or a list of strings. Here's an example of how to use the `Embedding
['What I notable create, I do not understand.', 'What I significant create, I do not understand.', 'What I cannot engender, I do not understand.', 'What I cannot creating, I do not understand.', 'What I cannot creations, I do not understand.', 'What I cannot create, I do not comprehend.', 'What I cannot create, I do not fathom.', 'What I cannot create, I do not understanding.', 'What I cannot create, I do not understands.', 'What I cannot create, I do not understood.', 'What I cannot create, I do not realise.']
```
### Training Models
Our model training code is available via `textattack train` to help you train LSTMs,
CNNs, and `transformers` models using TextAttack out-of-the-box. Datasets are
automatically loaded using the `nlp` package.
#### Training Examples
*Train our default LSTM for 50 epochs on the Yelp Polarity dataset:*
```bash
textattack train --model lstm --dataset yelp_polarity --batch-size 64 --epochs 5
```
*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs**:
```bash
textattack train --model bert-base-uncased --dataset glue:cola --batch-size 32 --epochs 5
```
## Design
### TokenizedText

View File

@@ -65,7 +65,7 @@ def train_model(args):
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_label_id_list)})"
)
model = model_from_args(args)
model = model_from_args(args, num_labels)
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
train_text_ids = encode_batch(model.tokenizer, train_text)
@@ -184,7 +184,7 @@ def train_model(args):
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Checkpoint saved to %s.", output_dir)
logger.info(f"Checkpoint saved to {output_dir}.")
model.train()
best_eval_acc = 0
@@ -210,9 +210,7 @@ def train_model(args):
}
logits = textattack.shared.utils.model_predict(model, input_ids)
loss_fct = torch.nn.CrossEntropyLoss()
loss = torch.nn.CrossEntropyLoss()(
logits.view(-1, num_labels), labels.view(-1)
)
loss = torch.nn.CrossEntropyLoss()(logits, labels)
if global_step % args.tb_writer_step == 0:
tb_writer.add_scalar("loss", loss, global_step)
tb_writer.add_scalar("lr", loss, global_step)
@@ -222,9 +220,8 @@ def train_model(args):
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
# Save model checkpoint to file.
if global_step % args.checkpoint_steps == 0:
if global_step > 0 and global_step % args.checkpoint_steps == 0:
save_model_checkpoint()
model.zero_grad()

View File

@@ -31,7 +31,7 @@ def dataset_from_args(args):
except KeyError:
raise KeyError(f"Error: no `train` split found in `{args.dataset}` dataset")
train_text, train_labels = prepare_dataset_for_training(train_dataset)
if args.dataset_split:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split=args.dataset_split
@@ -57,23 +57,22 @@ def dataset_from_args(args):
f"Could not find `dev` or `test` split in dataset {args.dataset}."
)
eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
return train_text, train_labels, eval_text, eval_labels
def model_from_args(args):
def model_from_args(args, num_labels):
if args.model == "lstm":
textattack.shared.logger.info("Loading textattack model: LSTMForClassification")
model = textattack.models.helpers.LSTMForClassification(
max_seq_length=args.max_length
max_seq_length=args.max_length, num_labels=num_labels
)
elif args.model == "cnn":
textattack.shared.logger.info(
"Loading textattack model: WordCNNForClassification"
)
model = textattack.models.helpers.WordCNNForClassification(
max_seq_length=args.max_length
max_seq_length=args.max_length, num_labels=num_labels
)
else:
import transformers
@@ -81,8 +80,11 @@ def model_from_args(args):
textattack.shared.logger.info(
f"Loading transformers AutoModelForSequenceClassification: {args.model}"
)
config = transformers.AutoConfig.from_pretrained(
args.model, num_labels=num_labels, finetuning_task=args.dataset
)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
args.model
args.model, config=config,
)
tokenizer = textattack.models.tokenizers.AutoTokenizer(
args.model, use_fast=False, max_length=args.max_length

View File

@@ -19,7 +19,7 @@ class LSTMForClassification(nn.Module):
hidden_size=150,
depth=1,
dropout=0.3,
nclasses=2,
num_classeses=2,
max_seq_length=128,
model_path=None,
):
@@ -40,7 +40,7 @@ class LSTMForClassification(nn.Module):
bidirectional=True,
)
d_out = hidden_size
self.out = nn.Linear(d_out, nclasses)
self.out = nn.Linear(d_out, num_classeses)
self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
)

View File

@@ -19,7 +19,7 @@ class WordCNNForClassification(nn.Module):
self,
hidden_size=150,
dropout=0.3,
nclasses=2,
num_classeses=2,
max_seq_length=128,
model_path=None,
):
@@ -31,7 +31,7 @@ class WordCNNForClassification(nn.Module):
self.emb_layer.n_d, widths=[3, 4, 5], filters=hidden_size
)
d_out = 3 * hidden_size
self.out = nn.Linear(d_out, nclasses)
self.out = nn.Linear(d_out, num_classeses)
self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
)