mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add args
This commit is contained in:
@@ -1,25 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
|
||||
|
||||
class LanguageModelConstraint(ABC, Constraint):
|
||||
class LanguageModelConstraint(Constraint):
|
||||
"""
|
||||
Determines if two sentences have a swapped word that has a similar
|
||||
probability according to a language model.
|
||||
|
||||
Args:
|
||||
max_log_prob_diff (float): the maximum decrease in log-probability
|
||||
in swapped words from x to x_adv
|
||||
compare_against_original (bool): whether to compare against the original
|
||||
text or the most recent
|
||||
in swapped words from `x` to `x_adv`
|
||||
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`.
|
||||
Otherwise, compare it against the previous `x_adv`.
|
||||
"""
|
||||
|
||||
def __init__(self, max_log_prob_diff=None, compare_against_original=True):
|
||||
if max_log_prob_diff is None:
|
||||
raise ValueError("Must set max_log_prob_diff")
|
||||
self.max_log_prob_diff = max_log_prob_diff
|
||||
self.compare_against_original = compare_against_original
|
||||
super().__init__(compare_against_original)
|
||||
|
||||
@abstractmethod
|
||||
def get_log_probs_at_index(self, text_list, word_index):
|
||||
|
||||
Reference in New Issue
Block a user