1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/attack_recipes/bert_attack_li_2020.py
2020-06-30 17:47:12 -04:00

77 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared.attack import Attack
from textattack.transformations import WordSwapMaskedLM
def BERTAttackLi2020(model):
"""
Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020).
BERT-ATTACK: Adversarial Attack Against BERT Using BERT
https://arxiv.org/abs/2004.09984
This is "attack mode" 1 from the paper, BAE-R, word replacement.
"""
# [from correspondence with the author]
# Candidate size K is set to 48 for all data-sets.
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)
#
# Don't modify the same word twice or stopwords.
#
constraints = [RepeatModification(), StopwordModification()]
# "We only take ε percent of the most important words since we tend to keep
# perturbations minimum."
#
# [from correspondence with the author]
# "Word percentage allowed to change is set to 0.4 for most data-sets, this
# parameter is trivial since most attacks only need a few changes. This
# epsilon is only used to avoid too much queries on those very hard samples."
constraints.append(MaxWordsPerturbed(max_percent=0.4))
# "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence
# Encoder (Cer et al., 2018) to measure the semantic consistency between the
# adversarial sample and the original sequence. To balance between semantic
# preservation and attack success rate, we set up a threshold of semantic
# similarity score to filter the less similar examples."
#
# [from correspondence with author]
# "Over the full texts, after generating all the adversarial samples, we filter
# out low USE score samples. Thus the success rate is lower but the USE score
# can be higher. (actually USE score is not a golden metric, so we simply
# measure the USE score over the final texts for a comparison with TextFooler).
# For datasets like IMDB, we set a higher threshold between 0.4-0.7; for
# datasets like MNLI, we set threshold between 0-0.2."
#
# Since the threshold in the real world can't be determined from the training
# data, the TextAttack implementation uses a fixed threshold - determined to
# be 0.2 to be most fair.
use_constraint = UniversalSentenceEncoder(
threshold=0.2, metric="cosine", compare_with_original=True, window_size=None,
)
constraints.append(use_constraint)
#
# Goal is untargeted classification.
#
goal_function = UntargetedClassification(model)
#
# "We first select the words in the sequence which have a high significance
# influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote
# the input sentence, and oy(S) denote the logit output by the target model
# for correct label y, the importance score Iwi is defined as
# Iwi = oy(S) oy(S\wi), where S\wi = [w0, ··· , wi1, [MASK], wi+1, ···]
# is the sentence after replacing wi with [MASK]. Then we rank all the words
# according to the ranking score Iwi in descending order to create word list
# L."
search_method = GreedyWordSwapWIR(wir_method="unk")
return Attack(goal_function, constraints, transformation, search_method)