1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

BERT sentence embedding constraint

This commit is contained in:
Jack Morris
2019-11-27 17:14:53 -05:00
parent 6f3b704be2
commit 47b34c8bc6
7 changed files with 28 additions and 2 deletions

3
.gitignore vendored
View File

@@ -15,3 +15,6 @@ docs/_build/
# Packaging
*.egg-info/
# Files from IDES
.*.py

View File

@@ -1,6 +1,7 @@
language_check
nltk
numpy<1.17
sentence_transformers
scipy
torch
transformers==2.0.0

View File

@@ -121,7 +121,7 @@ class Attack:
"""
raise NotImplementedError()
def _call_model(self, tokenized_text_list, batch_size=16):
def _call_model(self, tokenized_text_list, batch_size=8):
"""
Returns model predictions for a list of TokenizedText objects.

View File

@@ -1,4 +1,5 @@
from .sentence_encoder import SentenceEncoder
from .bert import BERT
from .infer_sent import InferSent
from .universal_sentence_encoder import UniversalSentenceEncoder

View File

@@ -0,0 +1 @@
from .bert import BERT

View File

@@ -0,0 +1,18 @@
from sentence_transformers import SentenceTransformer
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
from textattack.utils import get_device
class BERT(SentenceEncoder):
"""
Constraint using similarity between sentence encodings of x and x_adv where
the text embeddings are created using BERT, trained on NLI data, and fine-
tuned on the STS benchmark dataset.
"""
def __init__(self, use_version=3, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = SentenceTransformer('bert-base-nli-mean-tokens')
self.model.to(get_device())
def encode(self, sentences):
return self.model.encode(sentences)

View File

@@ -70,10 +70,12 @@ TRANSFORMATION_CLASS_NAMES = {
}
CONSTRAINT_CLASS_NAMES = {
'embedding': 'constraints.semantics.WordEmbeddingDistance',
'goog-lm': 'constraints.semantics.language_models.GoogleLanguageModel',
'bert': 'constraints.semantics.sentence_encoders.BERT',
'infer-sent': 'constraints.semantics.sentence_encoders.InferSent',
'use': 'constraints.semantics.sentence_encoders.UniversalSentenceEncoder',
'lang-tool': 'constraints.syntax.LanguageTool',
'goog-lm': 'constraints.semantics.language_models.GoogleLanguageModel',
}
ATTACK_CLASS_NAMES = {