mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
[CODE] Add new attack recipe A2T
This commit is contained in:
@@ -128,6 +128,15 @@ To run an attack recipe: `textattack attack --recipe [recipe_name]`
|
||||
<tbody>
|
||||
<tr><td style="text-align: center;" colspan="6"><strong><br>Attacks on classification tasks, like sentiment classification and entailment:<br></strong></td></tr>
|
||||
|
||||
<tr>
|
||||
<td><code>a2t</code>
|
||||
<span class="citation" data-cites="yoo2021a2t"></span></td>
|
||||
<td><sub>Untargeted {Classification, Entailment}</sub></td>
|
||||
<td><sub>Percentage of words perturbed, Word embedding distance, DistilBERT sentence encoding cosine similarity, part-of-speech consistency</sub></td>
|
||||
<td><sub>Counter-fitted word embedding swap (or) BERT Masked Token Prediction</sub></td>
|
||||
<td><sub>Greedy-WIR (gradient)</sub></td>
|
||||
<td ><sub>from (["Towards Improving Adversarial Training of NLP Models" (Yoo et al., 2021)](https://arxiv.org/abs/2109.00544))</sub></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><code>alzantot</code> <span class="citation" data-cites="Alzantot2018GeneratingNL Jia2019CertifiedRT"></span></td>
|
||||
<td><sub>Untargeted {Classification, Entailment}</sub></td>
|
||||
|
||||
@@ -35,6 +35,7 @@ ATTACK_RECIPE_NAMES = {
|
||||
"pso": "textattack.attack_recipes.PSOZang2020",
|
||||
"checklist": "textattack.attack_recipes.CheckList2020",
|
||||
"clare": "textattack.attack_recipes.CLARE2020",
|
||||
"a2t": "textattack.attack_recipes.A2TYoo2021",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ TextAttack supports the following attack recipes (each recipe's documentation co
|
||||
|
||||
from .attack_recipe import AttackRecipe
|
||||
|
||||
from .a2t_yoo_2021 import A2TYoo2021
|
||||
from .bae_garg_2019 import BAEGarg2019
|
||||
from .bert_attack_li_2020 import BERTAttackLi2020
|
||||
from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
|
||||
|
||||
68
textattack/attack_recipes/a2t_yoo_2021.py
Normal file
68
textattack/attack_recipes/a2t_yoo_2021.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from textattack import Attack
|
||||
from textattack.constraints.grammaticality import PartOfSpeech
|
||||
from textattack.constraints.pre_transformation import (
|
||||
InputColumnModification,
|
||||
MaxModificationRate,
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||
from textattack.constraints.semantics.sentence_encoders import BERT
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
from textattack.search_methods import GreedyWordSwapWIR
|
||||
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
|
||||
|
||||
from .attack_recipe import AttackRecipe
|
||||
|
||||
|
||||
class A2TYoo2021(AttackRecipe):
|
||||
"""Towards Improving Adversarial Training of NLP Models.
|
||||
|
||||
(Yoo et al., 2021)
|
||||
|
||||
https://arxiv.org/abs/2109.00544
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def build(model_wrapper, mlm=False):
|
||||
"""Build attack recipe.
|
||||
|
||||
Args:
|
||||
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
|
||||
Model wrapper containing both the model and the tokenizer.
|
||||
mlm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack.
|
||||
|
||||
Returns:
|
||||
:class:`~textattack.Attack`: A2T attack.
|
||||
"""
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
input_column_modification = InputColumnModification(
|
||||
["premise", "hypothesis"], {"premise"}
|
||||
)
|
||||
constraints.append(input_column_modification)
|
||||
constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
|
||||
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
|
||||
sent_encoder = BERT(
|
||||
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
|
||||
)
|
||||
constraints.append(sent_encoder)
|
||||
|
||||
if mlm:
|
||||
transformation = transformation = WordSwapMaskedLM(
|
||||
method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
|
||||
)
|
||||
else:
|
||||
transformation = WordSwapEmbedding(max_candidates=20)
|
||||
constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
|
||||
|
||||
#
|
||||
# Goal is untargeted classification
|
||||
#
|
||||
goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
|
||||
#
|
||||
# Greedily swap words with "Word Importance Ranking".
|
||||
#
|
||||
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
Reference in New Issue
Block a user