mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix model
This commit is contained in:
@@ -45,12 +45,12 @@ def run(args):
|
||||
if not text:
|
||||
continue
|
||||
|
||||
tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, model.tokenizer)
|
||||
tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, goal_function.model.tokenizer)
|
||||
|
||||
result = goal_function.get_results([tokenized_text])[0]
|
||||
result = goal_function.get_results([tokenized_text], goal_function.get_output(tokenized_text))[0]
|
||||
print('Attacking...')
|
||||
|
||||
result = next(attack.attack_dataset([(result.output, text, False)]))
|
||||
result = next(attack.attack_dataset([(result.output, text)]))
|
||||
print(result.__str__(color_method='stdout'))
|
||||
|
||||
else:
|
||||
|
||||
@@ -14,12 +14,16 @@ class GoalFunction:
|
||||
"""
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.num_queries = 0
|
||||
self._call_model_cache = lru.LRU(2**18)
|
||||
|
||||
def should_skip(self, tokenized_text, correct_output):
|
||||
model_outputs = self._call_model([tokenized_text])
|
||||
return self._is_goal_complete(model_outputs[0], correct_output)
|
||||
|
||||
def get_output(self, tokenized_text):
|
||||
return self._get_displayed_output(self._call_model([tokenized_text])[0])
|
||||
|
||||
def get_results(self, tokenized_text_list, correct_output):
|
||||
"""
|
||||
For each tokenized_text object in tokenized_text_list, returns a result consisting of whether or not the goal has been achieved, the output for display purposes, and a score.
|
||||
|
||||
@@ -17,7 +17,7 @@ class BERTForClassification:
|
||||
"""
|
||||
def __init__(self, model_path, num_labels=2, entailment=False):
|
||||
#model_file_path = utils.download_if_needed(model_path)
|
||||
model_file_path = '/p/qdata/jm8wx/research/text_attacks/RobustNLP/BertClassifier/outputs-counterfit/mr-uncased-2020-03-04-23:05'
|
||||
model_file_path = utils.download_if_needed(model_path)
|
||||
self.model = BertForSequenceClassification.from_pretrained(
|
||||
model_file_path, num_labels=num_labels)
|
||||
self.model.to(utils.get_device())
|
||||
|
||||
@@ -6,7 +6,7 @@ class BERTTokenizer(Tokenizer):
|
||||
any type of tokenization, be it word, wordpiece, or character-based.
|
||||
"""
|
||||
def __init__(self, model_path='bert-base-uncased', max_seq_length=256):
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
def convert_text_to_tokens(self, input_text):
|
||||
|
||||
Reference in New Issue
Block a user