mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
support training for entailment and fix test, update README
This commit is contained in:
17
README.md
17
README.md
@@ -174,6 +174,23 @@ of a string or a list of strings. Here's an example of how to use the `Embedding
|
||||
['What I notable create, I do not understand.', 'What I significant create, I do not understand.', 'What I cannot engender, I do not understand.', 'What I cannot creating, I do not understand.', 'What I cannot creations, I do not understand.', 'What I cannot create, I do not comprehend.', 'What I cannot create, I do not fathom.', 'What I cannot create, I do not understanding.', 'What I cannot create, I do not understands.', 'What I cannot create, I do not understood.', 'What I cannot create, I do not realise.']
|
||||
```
|
||||
|
||||
### Training Models
|
||||
|
||||
Our model training code is available via `textattack train` to help you train LSTMs,
|
||||
CNNs, and `transformers` models using TextAttack out-of-the-box. Datasets are
|
||||
automatically loaded using the `nlp` package.
|
||||
|
||||
#### Training Examples
|
||||
*Train our default LSTM for 50 epochs on the Yelp Polarity dataset:*
|
||||
```bash
|
||||
textattack train --model lstm --dataset yelp_polarity --batch-size 64 --epochs 5
|
||||
```
|
||||
|
||||
*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs**:
|
||||
```bash
|
||||
textattack train --model bert-base-uncased --dataset glue:cola --batch-size 32 --epochs 5
|
||||
```
|
||||
|
||||
## Design
|
||||
|
||||
### TokenizedText
|
||||
|
||||
@@ -65,7 +65,7 @@ def train_model(args):
|
||||
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_label_id_list)})"
|
||||
)
|
||||
|
||||
model = model_from_args(args)
|
||||
model = model_from_args(args, num_labels)
|
||||
|
||||
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
|
||||
train_text_ids = encode_batch(model.tokenizer, train_text)
|
||||
@@ -184,7 +184,7 @@ def train_model(args):
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Checkpoint saved to %s.", output_dir)
|
||||
logger.info(f"Checkpoint saved to {output_dir}.")
|
||||
|
||||
model.train()
|
||||
best_eval_acc = 0
|
||||
@@ -210,9 +210,7 @@ def train_model(args):
|
||||
}
|
||||
logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
loss = torch.nn.CrossEntropyLoss()(
|
||||
logits.view(-1, num_labels), labels.view(-1)
|
||||
)
|
||||
loss = torch.nn.CrossEntropyLoss()(logits, labels)
|
||||
if global_step % args.tb_writer_step == 0:
|
||||
tb_writer.add_scalar("loss", loss, global_step)
|
||||
tb_writer.add_scalar("lr", loss, global_step)
|
||||
@@ -222,9 +220,8 @@ def train_model(args):
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
# Save model checkpoint to file.
|
||||
if global_step % args.checkpoint_steps == 0:
|
||||
if global_step > 0 and global_step % args.checkpoint_steps == 0:
|
||||
save_model_checkpoint()
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
@@ -31,7 +31,7 @@ def dataset_from_args(args):
|
||||
except KeyError:
|
||||
raise KeyError(f"Error: no `train` split found in `{args.dataset}` dataset")
|
||||
train_text, train_labels = prepare_dataset_for_training(train_dataset)
|
||||
|
||||
|
||||
if args.dataset_split:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split=args.dataset_split
|
||||
@@ -57,23 +57,22 @@ def dataset_from_args(args):
|
||||
f"Could not find `dev` or `test` split in dataset {args.dataset}."
|
||||
)
|
||||
eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
|
||||
|
||||
|
||||
return train_text, train_labels, eval_text, eval_labels
|
||||
|
||||
|
||||
def model_from_args(args):
|
||||
def model_from_args(args, num_labels):
|
||||
if args.model == "lstm":
|
||||
textattack.shared.logger.info("Loading textattack model: LSTMForClassification")
|
||||
model = textattack.models.helpers.LSTMForClassification(
|
||||
max_seq_length=args.max_length
|
||||
max_seq_length=args.max_length, num_labels=num_labels
|
||||
)
|
||||
elif args.model == "cnn":
|
||||
textattack.shared.logger.info(
|
||||
"Loading textattack model: WordCNNForClassification"
|
||||
)
|
||||
model = textattack.models.helpers.WordCNNForClassification(
|
||||
max_seq_length=args.max_length
|
||||
max_seq_length=args.max_length, num_labels=num_labels
|
||||
)
|
||||
else:
|
||||
import transformers
|
||||
@@ -81,8 +80,11 @@ def model_from_args(args):
|
||||
textattack.shared.logger.info(
|
||||
f"Loading transformers AutoModelForSequenceClassification: {args.model}"
|
||||
)
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
args.model, num_labels=num_labels, finetuning_task=args.dataset
|
||||
)
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model
|
||||
args.model, config=config,
|
||||
)
|
||||
tokenizer = textattack.models.tokenizers.AutoTokenizer(
|
||||
args.model, use_fast=False, max_length=args.max_length
|
||||
|
||||
@@ -19,7 +19,7 @@ class LSTMForClassification(nn.Module):
|
||||
hidden_size=150,
|
||||
depth=1,
|
||||
dropout=0.3,
|
||||
nclasses=2,
|
||||
num_classeses=2,
|
||||
max_seq_length=128,
|
||||
model_path=None,
|
||||
):
|
||||
@@ -40,7 +40,7 @@ class LSTMForClassification(nn.Module):
|
||||
bidirectional=True,
|
||||
)
|
||||
d_out = hidden_size
|
||||
self.out = nn.Linear(d_out, nclasses)
|
||||
self.out = nn.Linear(d_out, num_classeses)
|
||||
self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
|
||||
self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class WordCNNForClassification(nn.Module):
|
||||
self,
|
||||
hidden_size=150,
|
||||
dropout=0.3,
|
||||
nclasses=2,
|
||||
num_classeses=2,
|
||||
max_seq_length=128,
|
||||
model_path=None,
|
||||
):
|
||||
@@ -31,7 +31,7 @@ class WordCNNForClassification(nn.Module):
|
||||
self.emb_layer.n_d, widths=[3, 4, 5], filters=hidden_size
|
||||
)
|
||||
d_out = 3 * hidden_size
|
||||
self.out = nn.Linear(d_out, nclasses)
|
||||
self.out = nn.Linear(d_out, num_classeses)
|
||||
self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
|
||||
self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user