mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
BERT sentence embedding constraint
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -15,3 +15,6 @@ docs/_build/
|
||||
|
||||
# Packaging
|
||||
*.egg-info/
|
||||
|
||||
# Files from IDES
|
||||
.*.py
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
language_check
|
||||
nltk
|
||||
numpy<1.17
|
||||
sentence_transformers
|
||||
scipy
|
||||
torch
|
||||
transformers==2.0.0
|
||||
|
||||
@@ -121,7 +121,7 @@ class Attack:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _call_model(self, tokenized_text_list, batch_size=16):
|
||||
def _call_model(self, tokenized_text_list, batch_size=8):
|
||||
"""
|
||||
Returns model predictions for a list of TokenizedText objects.
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .sentence_encoder import SentenceEncoder
|
||||
|
||||
from .bert import BERT
|
||||
from .infer_sent import InferSent
|
||||
from .universal_sentence_encoder import UniversalSentenceEncoder
|
||||
@@ -0,0 +1 @@
|
||||
from .bert import BERT
|
||||
@@ -0,0 +1,18 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
|
||||
from textattack.utils import get_device
|
||||
|
||||
class BERT(SentenceEncoder):
|
||||
"""
|
||||
Constraint using similarity between sentence encodings of x and x_adv where
|
||||
the text embeddings are created using BERT, trained on NLI data, and fine-
|
||||
tuned on the STS benchmark dataset.
|
||||
"""
|
||||
def __init__(self, use_version=3, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model = SentenceTransformer('bert-base-nli-mean-tokens')
|
||||
self.model.to(get_device())
|
||||
|
||||
def encode(self, sentences):
|
||||
return self.model.encode(sentences)
|
||||
@@ -70,10 +70,12 @@ TRANSFORMATION_CLASS_NAMES = {
|
||||
}
|
||||
|
||||
CONSTRAINT_CLASS_NAMES = {
|
||||
'embedding': 'constraints.semantics.WordEmbeddingDistance',
|
||||
'goog-lm': 'constraints.semantics.language_models.GoogleLanguageModel',
|
||||
'bert': 'constraints.semantics.sentence_encoders.BERT',
|
||||
'infer-sent': 'constraints.semantics.sentence_encoders.InferSent',
|
||||
'use': 'constraints.semantics.sentence_encoders.UniversalSentenceEncoder',
|
||||
'lang-tool': 'constraints.syntax.LanguageTool',
|
||||
'goog-lm': 'constraints.semantics.language_models.GoogleLanguageModel',
|
||||
}
|
||||
|
||||
ATTACK_CLASS_NAMES = {
|
||||
|
||||
Reference in New Issue
Block a user