1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Entailment needed [CLS] token. Now all tests pass!

This commit is contained in:
Jack Morris
2020-04-29 17:59:48 -04:00
parent f47734dff7
commit 57b11e74a0
4 changed files with 3 additions and 4 deletions

View File

@@ -91,6 +91,7 @@ class GreedyWordSwapWIR(Attack):
original_result,
best_result
)
else:
tokenized_text = results[0].tokenized_text
return FailedAttackResult(original_result, results[0])

View File

@@ -7,7 +7,6 @@ class UntargetedClassification(ClassificationGoalFunction):
"""
def _is_goal_complete(self, model_output, correct_output):
import pdb; pdb.set_trace()
return model_output.argmax() != correct_output
def _get_score(self, model_output, correct_output):

View File

@@ -43,7 +43,6 @@ class GoalFunction:
display purposes, and a score.
"""
model_outputs = self._call_model(tokenized_text_list)
import pdb; pdb.set_trace()
results = []
for tokenized_text, raw_output in zip(tokenized_text_list, model_outputs):
succeeded = self._is_goal_complete(raw_output, ground_truth_output)

View File

@@ -44,7 +44,7 @@ class BERTEntailmentTokenizer(BERTTokenizer):
# Ensure they will fit in self.max_seq_length.
self._truncate_seq_pair(tokens_a, tokens_b)
# Concatenate and return.
return tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
return ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
def convert_tokens_to_ids(self, tokens):