mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Entailment needed [CLS] token. Now all tests pass!
This commit is contained in:
@@ -91,6 +91,7 @@ class GreedyWordSwapWIR(Attack):
|
||||
original_result,
|
||||
best_result
|
||||
)
|
||||
else:
|
||||
tokenized_text = results[0].tokenized_text
|
||||
|
||||
return FailedAttackResult(original_result, results[0])
|
||||
|
||||
@@ -7,7 +7,6 @@ class UntargetedClassification(ClassificationGoalFunction):
|
||||
"""
|
||||
|
||||
def _is_goal_complete(self, model_output, correct_output):
|
||||
import pdb; pdb.set_trace()
|
||||
return model_output.argmax() != correct_output
|
||||
|
||||
def _get_score(self, model_output, correct_output):
|
||||
|
||||
@@ -43,7 +43,6 @@ class GoalFunction:
|
||||
display purposes, and a score.
|
||||
"""
|
||||
model_outputs = self._call_model(tokenized_text_list)
|
||||
import pdb; pdb.set_trace()
|
||||
results = []
|
||||
for tokenized_text, raw_output in zip(tokenized_text_list, model_outputs):
|
||||
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
|
||||
|
||||
@@ -44,7 +44,7 @@ class BERTEntailmentTokenizer(BERTTokenizer):
|
||||
# Ensure they will fit in self.max_seq_length.
|
||||
self._truncate_seq_pair(tokens_a, tokens_b)
|
||||
# Concatenate and return.
|
||||
return tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
|
||||
return ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
|
||||
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
|
||||
Reference in New Issue
Block a user