mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge in master
This commit is contained in:
@@ -1,21 +1,29 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from textattack.shared import utils
|
||||
|
||||
from .language_model_constraint import LanguageModelConstraint
|
||||
|
||||
# temporarily silence W&B to ignore log-in warning
|
||||
os.environ["WANDB_SILENT"] = "1"
|
||||
|
||||
|
||||
class GPT2(LanguageModelConstraint):
|
||||
""" A constraint based on the GPT-2 language model.
|
||||
|
||||
|
||||
from "Better Language Models and Their Implications"
|
||||
""" A constraint based on the GPT-2 language model.
|
||||
|
||||
|
||||
from "Better Language Models and Their Implications"
|
||||
(openai.com/blog/better-language-models/)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
import transformers
|
||||
|
||||
# re-enable notifications
|
||||
os.environ["WANDB_SILENT"] = "0"
|
||||
self.model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
self.model.to(utils.device)
|
||||
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
Reference in New Issue
Block a user