mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
WIP: mcts
WIP: mcts
This commit is contained in:
@@ -8,7 +8,7 @@ sentence_transformers
|
||||
spacy
|
||||
torch
|
||||
transformers>=2.0.0
|
||||
tensorflow-gpu==2
|
||||
tensorflow-gpu==2.0.0
|
||||
tensorflow_hub
|
||||
tqdm
|
||||
visdom
|
||||
|
||||
@@ -95,6 +95,7 @@ ATTACK_CLASS_NAMES = {
|
||||
'greedy-word': 'textattack.attack_methods.GreedyWordSwap',
|
||||
'ga-word': 'textattack.attack_methods.GeneticAlgorithm',
|
||||
'greedy-word-wir': 'textattack.attack_methods.GreedyWordSwapWIR',
|
||||
'mcts-old': 'textattack.attack_methods.MCTSOLD'
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .greedy_word_swap import GreedyWordSwap
|
||||
from .greedy_word_swap_wir import GreedyWordSwapWIR
|
||||
from .genetic_algorithm import GeneticAlgorithm
|
||||
from .mcts_old import MCTSOLD
|
||||
@@ -13,7 +13,6 @@ class GreedyWordSwap(Attack):
|
||||
"""
|
||||
def __init__(self, model, transformation, constraints=[], max_depth=32):
|
||||
super().__init__(model, transformation, constraints=constraints)
|
||||
self.transformation = transformations[0]
|
||||
self.max_depth = max_depth
|
||||
|
||||
def attack_one(self, original_label, tokenized_text):
|
||||
|
||||
177
textattack/attack_methods/mcts.py
Normal file
177
textattack/attack_methods/mcts.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from textattack.attacks import AttackResult, FailedAttackResult
|
||||
from textattack.attacks.blackbox import BlackBoxAttack
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.autograd import Variable
|
||||
|
||||
# Reward functions for MCTS
|
||||
def raw_prob_reward(prob, orig_label, new_label="None"):
|
||||
# New_label is only set for targetted attacks
|
||||
prob_exp = torch.exp(output)
|
||||
v1 = prob_exp.data[original_label].clone()
|
||||
prob_exp.data[original_label] = 0
|
||||
v2 = prob_exp.max().data
|
||||
return (v2 - v1).item()
|
||||
|
||||
def entropy_reward(orig_label, prob):
|
||||
return F.nll_loss(prob, orig_label).data.cpu()[0].item()
|
||||
|
||||
class Node:
|
||||
""" Helper node class used in implementation of MCTS"""
|
||||
def __init__(self,feature_set):
|
||||
self.feature_set = feature_set #Represents State
|
||||
self.f_size = self.feature_set.sum() #Represents number of positive words for transofrmation
|
||||
self.childrens = {}
|
||||
self.T_f = .0
|
||||
self.av = .0
|
||||
self.allowed_features = feature_set.nonzero()
|
||||
self.lrave_count = np.array(feature_set.shape)
|
||||
self.lrave_reward = np.array(feature_set.shape).astype(float)
|
||||
self.lrave_variance = np.array(feature_set.shape).astype(float)
|
||||
self.lrave_score = np.array(feature_set.shape).astype(float)
|
||||
|
||||
class Tree:
|
||||
""" Helper tree class used in implementation of MCTS; dictionary of nodes"""
|
||||
def __init__(self,nsize):
|
||||
self.tree = {}
|
||||
root = np.array([False] * nsize,dtype=bool)
|
||||
self.tree[root.tobytes()] = Node(root) #???
|
||||
|
||||
def find(self,feature_set):
|
||||
""" Returns node containing features_set if present in tree"""
|
||||
if feature_set.tobytes() in self.tree:
|
||||
return self.tree[feature_set.tobytes()]
|
||||
else:
|
||||
return None
|
||||
|
||||
def save(self,feature_set,node):
|
||||
""" Adds node with feature_set to tree"""
|
||||
self.tree[feature_set.tobytes()] = node
|
||||
|
||||
class MCTS():
|
||||
"""
|
||||
Uses Monte Carlo Tree Search (MCTS) to attempt to find the most important words in an input.
|
||||
Args:
|
||||
model: A PyTorch or TensorFlow model to attack.
|
||||
transformation: The type of transformation to use. Should be a subclass of WordSwap.
|
||||
constraints: A list of constraints to add to the attack
|
||||
reward_type (str): Defines what type of function to use for MCTS.
|
||||
- raw_prob: Uses "max_{c'} P(c'|s) - P(c_l|s)"
|
||||
- entropy: Uses negative log-likelihood function
|
||||
max_iter (int) : Maximum iterations for MCTS. Default is 4000
|
||||
max_words_changed (int) : Maximum number of words we change during MCTS. Effectively represents depth of search tree.
|
||||
"""
|
||||
def __init__(self, model, transformation, constraints=[],
|
||||
reward_type="raw_prob", max_iter=4000, max_words_changed=10
|
||||
):
|
||||
super().__init__(model, transformation, constraints=constraints)
|
||||
self.reward_type = reward_type
|
||||
self.max_iter = max_iter
|
||||
self.max_words_changed = max_words_changed
|
||||
self.alltimebest = -1e9
|
||||
self.bestfeature = []
|
||||
|
||||
if reward_type == "raw_prob":
|
||||
self.reward_func = raw_prob_reward
|
||||
elif reward_type == "entropy":
|
||||
self.reward_func = entropy_reward
|
||||
|
||||
def get_reward(self, current_state, input_text):
|
||||
for i in range(len(current_state)):
|
||||
if current_state[i]:
|
||||
transformed_candidates = self.get_transformations(
|
||||
test_input,
|
||||
indices_to_replace=[k]
|
||||
)
|
||||
|
||||
if len(transformed_candidates) > 0:
|
||||
rand = np.random.randint(len(transformed_candidates))
|
||||
transformed_text = transformed_candidates[rand]
|
||||
|
||||
#Evaluate current features against model
|
||||
output = self._call_model([transformed_text])[0]
|
||||
|
||||
def UCB(self):
|
||||
pass
|
||||
|
||||
def expansion(self):
|
||||
pass
|
||||
|
||||
def simulation(self):
|
||||
pass
|
||||
|
||||
def back_up(self):
|
||||
pass
|
||||
|
||||
def selection(self, current_state, depth):
|
||||
if depth>=self.params.max_depth:
|
||||
reward_value = self.reward_func(current_state)
|
||||
self.update_gRAVE(current_state, reward_V)
|
||||
else:
|
||||
next_node = self.update_Tree_And_Get_Address(current_state)
|
||||
|
||||
if (next_node.T_f != 0):
|
||||
fi = self.UCB(next_node)
|
||||
if (fi == -1): # it means that no feature has been selected and that we are going to perform random exploration
|
||||
depth = current_state.sum()
|
||||
reward_V = self.iterate_random(self.tree, current_state)
|
||||
self.update_gRAVE(current_state, reward_V)
|
||||
else: #add the feature to the feature set
|
||||
current_state[fi] = True
|
||||
reward_V = self.iterate(current_state, depth+1)
|
||||
else:
|
||||
depth_now = current_state.sum()
|
||||
reward_V = self.iterate_random(self.tree, current_state)
|
||||
self.update_gRAVE(current_state, reward_V)
|
||||
fi = -1 # indicate that random exploration has been performed and thus no feature selected
|
||||
self.update_Node(next_node, fi, reward_V)
|
||||
return reward_V
|
||||
|
||||
|
||||
def run_mcts(self, orig_label, tokenized_input):
|
||||
input_size = len(tokenized_input)
|
||||
tree = Tree(input_size)
|
||||
for i in range(self.max_iter):
|
||||
if i % 100 == 0:
|
||||
print(f"Running MCTS iteration {i}")
|
||||
current_state = np.array([False] * input_size, dtype = bool)
|
||||
|
||||
self.selection(current_state, 0)
|
||||
self.expansion()
|
||||
self.simulation()
|
||||
self.back_up()
|
||||
|
||||
def _attack_one(self, original_label, tokenized_text):
|
||||
|
||||
self.runmcts(original_label, tokenized_test.words)
|
||||
|
||||
new_tokenized_text = tokenized_text
|
||||
|
||||
#Transform each index selected by MCTS using given transformation
|
||||
indices = []
|
||||
for k in range(len(self.bestfeature)):
|
||||
if self.bestfeature[k]:
|
||||
transformed_text_candidates = self.get_transformations(
|
||||
self.transformation,
|
||||
new_tokenized_text,
|
||||
indices_to_replace=[k])
|
||||
if len(transformed_text_candidates) > 0:
|
||||
rand = np.random.randint(len(transformed_text_candidates))
|
||||
new_tokenized_text = transformed_text_candidates[rand]
|
||||
|
||||
new_output = self._call_model([new_tokenized_text])[0]
|
||||
new_text_label = self._call_model([new_tokenized_text])[0].argmax().item()
|
||||
|
||||
if original_label == new_text_label:
|
||||
return FailedAttackResult(tokenized_text, original_label)
|
||||
else:
|
||||
return AttackResult(
|
||||
tokenized_text,
|
||||
new_tokenized_text,
|
||||
original_label,
|
||||
new_text_label
|
||||
)
|
||||
258
textattack/attack_methods/mcts_old.py
Normal file
258
textattack/attack_methods/mcts_old.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from .attack import Attack
|
||||
from textattack.attack_results import AttackResult, FailedAttackResult
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.autograd import Variable
|
||||
|
||||
class dotdict(dict):
|
||||
""" dot.notation access to dictionary attributes"""
|
||||
__getattr__ = dict.get
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
class Node:
|
||||
""" Helper node class used in implementation of MCTS"""
|
||||
def __init__(self,feature_set):
|
||||
self.feature_set = feature_set #Represents State
|
||||
self.f_size = self.feature_set.sum() #Represents number of positive words for transofrmation
|
||||
self.childrens = {}
|
||||
self.T_f = .0
|
||||
self.av = .0
|
||||
self.allowed_features = feature_set.nonzero()
|
||||
self.lrave_count = np.array(feature_set.shape)
|
||||
self.lrave_reward = np.array(feature_set.shape).astype(float)
|
||||
self.lrave_variance = np.array(feature_set.shape).astype(float)
|
||||
self.lrave_score = np.array(feature_set.shape).astype(float)
|
||||
|
||||
class Tree:
|
||||
""" Helper tree class used in implementation of MCTS; dictionary of nodes"""
|
||||
def __init__(self,nsize):
|
||||
self.tree = {}
|
||||
root = np.array([False] * nsize,dtype=bool)
|
||||
self.tree[root.tobytes()] = Node(root) #???
|
||||
|
||||
def find(self,feature_set):
|
||||
""" Returns node containing features_set if present in tree"""
|
||||
if feature_set.tobytes() in self.tree:
|
||||
return self.tree[feature_set.tobytes()]
|
||||
else:
|
||||
return None
|
||||
|
||||
def save(self,feature_set,node):
|
||||
""" Adds node with feature_set to tree"""
|
||||
self.tree[feature_set.tobytes()] = node
|
||||
|
||||
class MCTSOLD(Attack):
|
||||
""" Uses Monte Carlo Tree Search (MCTS) to attempt to find the most important words in an input,
|
||||
params: targeted (bool, unimplemented!), valuefunction (str), power (int), nplayout (int)
|
||||
"""
|
||||
def __init__(self, model, transformations=[], constraints=[], rewardvalue='combined', numiter=4000, max_words_changed=10):
|
||||
super().__init__(model, transformations, constraints)
|
||||
self.params = dotdict({})
|
||||
self.params.get_reward = None
|
||||
self.params.b = 0.26 #b < 1
|
||||
self.params.cl = 0.5
|
||||
self.params.ce = 0.5
|
||||
self.params.max_depth = 0
|
||||
self.params.nrand = 50
|
||||
self.gRAVE = None
|
||||
self.tree = None
|
||||
self.alltimebest = 0
|
||||
self.bestfeature = []
|
||||
self.rewardvalue = rewardvalue
|
||||
self.numiter = numiter
|
||||
self.max_words_changed = max_words_changed
|
||||
|
||||
def update_Node(self, node, fi, reward_V):
|
||||
""" Updates the node's scores based on how its feature set performed against the model"""
|
||||
node.av = (node.av * node.T_f + reward_V)/(node.T_f+1)
|
||||
node.T_f += 1
|
||||
lrave_score_pre = node.lrave_score[fi]
|
||||
node.lrave_score[fi] = (node.lrave_score[fi] * node.lrave_count[fi] + reward_V) / (node.lrave_count[fi] + 1)
|
||||
node.lrave_variance[fi] = math.sqrt( ((reward_V - node.lrave_score[fi]) * (reward_V - lrave_score_pre) + node.lrave_count[fi] * node.lrave_variance[fi] * node.lrave_variance[fi])/(node.lrave_count[fi]+1))
|
||||
node.lrave_count[fi] = node.lrave_count[fi] + 1
|
||||
|
||||
def update_gRAVE(self, F, reward_V):
|
||||
""" Update gRAVE score for each feature of feature subset F, by adding the reward_V the the score
|
||||
gRAVE[0] -> Count
|
||||
gRAVE[1] -> Score
|
||||
"""
|
||||
for fi in range(self.gRAVE[0].shape[0]):
|
||||
if F[fi]:
|
||||
self.gRAVE[0][fi] = (self.gRAVE[0][fi]*self.gRAVE[1][fi] + reward_V)/(self.gRAVE[1][fi] + 1)
|
||||
self.gRAVE[1][fi] += 1
|
||||
|
||||
|
||||
def UCB(self, node):
|
||||
""" Upper confidence bound calculation to help balance exploration/exploitation tradeoff"""
|
||||
d = len(node.allowed_features) # feature subset size
|
||||
f = node.feature_set.shape[0] # number of features
|
||||
nrand = self.params.nrand
|
||||
|
||||
if (node.T_f < nrand): # we perform random exploration fort the first 50 visits
|
||||
return -1 # -1 indicate that we don't to want choose any feature
|
||||
|
||||
if (pow((node.T_f+1), self.params.b)-pow(node.T_f, self.params.b)>1):
|
||||
if not node.allowed_features:
|
||||
return -1
|
||||
else:
|
||||
for ft in range(f):
|
||||
if not ft in allowed_features:
|
||||
beta = self.params.cl/(self.params.cl + node.lrave_count[ft])
|
||||
rst = (1-beta) * self.lrave_score[ft] + beta * self.gRAVE[0][ft]
|
||||
if rst>nowbest:
|
||||
ft_now = ft
|
||||
nowbest = rst
|
||||
node.allowed_features.append(ft)
|
||||
else:
|
||||
return -1
|
||||
|
||||
UCB_max_score = 0
|
||||
UCB_max_feature = 0
|
||||
for next_node in node.allowed_features: # computing UCB for each feature
|
||||
UCB_Score = node.mu_f[next_node] + math.sqrt( params.ce*log(node.T_F)/node.t_f[next_node] * min(0.25 , pow(node.sg_f[next_node],2) + math.sqrt(2*math.log(node.T_F)/node.t_f[fi]) ))
|
||||
if UCB_Score>UCB_max_feature:
|
||||
UCB_max_feature = UCB_Score
|
||||
UCB_max_feature = next_node
|
||||
return UCB_max_feature
|
||||
|
||||
|
||||
def update_Tree_And_Get_Address(self, current_features):
|
||||
""" Updates MCTS tree to contain the set of current features if necessary, returning the newly created or found node"""
|
||||
if not self.tree.find(current_features):
|
||||
node = Node(current_features)
|
||||
self.tree.save(current_features, node)
|
||||
else:
|
||||
node = self.tree.find(current_features)
|
||||
return node
|
||||
|
||||
def iterate(self, current_features, depth):
|
||||
""" Perform one iteration of MCTS search"""
|
||||
# current_features = boolean array where 1 means we select it for transformation
|
||||
# depth = current depth
|
||||
if depth>=self.params.max_depth:
|
||||
reward_V = self.params.get_reward(current_features)
|
||||
self.update_gRAVE(current_features, reward_V)
|
||||
else:
|
||||
next_node = self.update_Tree_And_Get_Address(current_features)
|
||||
|
||||
if (next_node.T_f != 0):
|
||||
fi = self.UCB(next_node)
|
||||
if (fi == -1): # it means that no feature has been selected and that we are going to perform random exploration
|
||||
depth = current_features.sum()
|
||||
reward_V = self.iterate_random(self.tree, current_features)
|
||||
self.update_gRAVE(current_features, reward_V)
|
||||
else: #add the feature to the feature set
|
||||
current_features[fi] = True
|
||||
reward_V = self.iterate(current_features, depth+1)
|
||||
else:
|
||||
depth_now = current_features.sum()
|
||||
reward_V = self.iterate_random(self.tree, current_features)
|
||||
self.update_gRAVE(current_features, reward_V)
|
||||
fi = -1 # indicate that random exploration has been performed and thus no feature selected
|
||||
self.update_Node(next_node, fi, reward_V)
|
||||
return reward_V
|
||||
|
||||
def iterate_random(self, tree, current_features):
|
||||
""" basically roll out """
|
||||
""" Choose a random feature that is not already in the feature subset, and put its value to one (and not the stopping feature)"""
|
||||
f_num = current_features.shape[0]
|
||||
f_size = current_features.sum()
|
||||
while (f_size < self.params.max_depth):
|
||||
if (f_num<=f_size):
|
||||
break
|
||||
t = 0
|
||||
it = int(np.random.rand() * (f_num-f_size))
|
||||
for i in range(f_num):
|
||||
if not current_features[i] and t==it:
|
||||
it = i
|
||||
break
|
||||
elif not current_features[i]:
|
||||
t = t + 1
|
||||
current_features[it] = True
|
||||
f_size += 1
|
||||
|
||||
return self.params.get_reward(current_features)
|
||||
|
||||
def runmcts(self, rewardfunc, maxdepth, lcount, nsize):
|
||||
""" Runs lcount iterations of MCTS"""
|
||||
# nsize = number of words
|
||||
self.params.get_reward = rewardfunc # reward function
|
||||
self.params.max_depth = maxdepth # max depth of tree
|
||||
self.gRAVE = (np.array([.0]* nsize),np.array([.0]*nsize)) # tuple of [0] * number of words
|
||||
self.tree = Tree(nsize)
|
||||
for i in range(lcount):
|
||||
current_features = np.array([False] * nsize, dtype = bool)
|
||||
self.iterate(current_features, 0)
|
||||
|
||||
def attack_one(self, original_label, tokenized_text):
|
||||
#The reward function used to evaluate how good current_features is at perturbing input
|
||||
def policyvaluefunc(current_features):
|
||||
to_replace = []
|
||||
test_input = orig_input
|
||||
#Transform each input with current features using given transformation
|
||||
for k in range(len(current_features)):
|
||||
if current_features[k]:
|
||||
transformed_text_candidates = self.get_transformations(
|
||||
test_input,
|
||||
indices_to_replace=[k])
|
||||
if len(transformed_text_candidates) > 0:
|
||||
rand = np.random.randint(len(transformed_text_candidates))
|
||||
test_input = transformed_text_candidates[rand]
|
||||
|
||||
#Evaluate current features against model
|
||||
output = self._call_model([test_input])[0]
|
||||
|
||||
if 'combined' in self.rewardvalue:
|
||||
prob_exp = torch.exp(output)
|
||||
v1 = (prob_exp).data[original_label].clone()
|
||||
(prob_exp).data[original_label] = 0
|
||||
v2 = prob_exp.max().data#[0]
|
||||
value = v2 - v1
|
||||
elif 'entropy' in self.rewardvalue:
|
||||
value = F.nll_loss(output, original_label).data.cpu()[0]
|
||||
else:
|
||||
value = 1-(torch.exp(output)).data[0, original_label]*2
|
||||
|
||||
value = value.item()
|
||||
if self.alltimebest < value:
|
||||
self.alltimebest = value
|
||||
self.bestfeature = np.copy(current_features)
|
||||
return value
|
||||
|
||||
self.alltimebest = -1e9
|
||||
self.bestfeature = []
|
||||
|
||||
orig_input = tokenized_text
|
||||
|
||||
self.runmcts(policyvaluefunc, self.max_words_changed, self.numiter, len(tokenized_text.words))
|
||||
|
||||
new_tokenized_text = tokenized_text
|
||||
|
||||
#Transform each index selected by MCTS using given transformation
|
||||
indices = []
|
||||
for k in range(len(self.bestfeature)):
|
||||
if self.bestfeature[k]:
|
||||
transformed_text_candidates = self.get_transformations(
|
||||
new_tokenized_text,
|
||||
indices_to_replace=[k])
|
||||
if len(transformed_text_candidates) > 0:
|
||||
rand = np.random.randint(len(transformed_text_candidates))
|
||||
new_tokenized_text = transformed_text_candidates[rand]
|
||||
|
||||
new_output = self._call_model([new_tokenized_text])[0]
|
||||
new_text_label = self._call_model([new_tokenized_text])[0].argmax().item()
|
||||
|
||||
if original_label == new_text_label:
|
||||
return FailedAttackResult(tokenized_text, original_label)
|
||||
else:
|
||||
return AttackResult(
|
||||
tokenized_text,
|
||||
new_tokenized_text,
|
||||
original_label,
|
||||
new_text_label
|
||||
)
|
||||
0
textattack/cpackages/mcts.cpp
Normal file
0
textattack/cpackages/mcts.cpp
Normal file
52
textattack/cpackages/mcts.hpp
Normal file
52
textattack/cpackages/mcts.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include <vector>
|
||||
|
||||
|
||||
/**
|
||||
* Represents state in our MCTS tree search
|
||||
* Uses boolean vector to represent whether the word at index is selected for attack transformation
|
||||
* 0 - Not selected
|
||||
* 1 - Selected for attack
|
||||
*/
|
||||
class State {
|
||||
std::vector<bool> state;
|
||||
|
||||
public:
|
||||
State();
|
||||
State(int n): state(std::vector<bool>(n)) {}
|
||||
|
||||
std::vector<bool> getState() { return state; };
|
||||
void setState(bool b, size_t i) { state[i] = b; };
|
||||
}
|
||||
|
||||
class Node {
|
||||
State current_state;
|
||||
|
||||
public:
|
||||
Node()
|
||||
Node(bool [] cs) curr_selections(cs);
|
||||
};
|
||||
|
||||
class SearchTree {
|
||||
Node root;
|
||||
public:
|
||||
Tree() root(Node())
|
||||
Tree()
|
||||
};
|
||||
|
||||
|
||||
class MCTS {
|
||||
std::string reward_type;
|
||||
int max_iter;
|
||||
int max_words_changd;
|
||||
double all_time_best = -1e9
|
||||
std::vector<size_t> best_feature;
|
||||
|
||||
public:
|
||||
MCTS()
|
||||
|
||||
|
||||
|
||||
|
||||
};
|
||||
|
||||
//TODO extend MCTS-RAVE
|
||||
0
textattack/cpackages/mcts.pyx
Normal file
0
textattack/cpackages/mcts.pyx
Normal file
Reference in New Issue
Block a user