mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
import spacy
|
|
|
|
from textattack.models.tokenizers import Tokenizer
|
|
|
|
|
|
class SpacyTokenizer(Tokenizer):
|
|
""" A basic implementation of the spaCy English tokenizer.
|
|
|
|
Params:
|
|
word2id (dict<string, int>): A dictionary that matches words to IDs
|
|
oov_id (int): An out-of-variable ID
|
|
"""
|
|
|
|
def __init__(self, word2id, oov_id, pad_id, max_length=128):
|
|
self.tokenizer = spacy.load("en").tokenizer
|
|
self.word2id = word2id
|
|
self.id2word = {v: k for k, v in word2id.items()}
|
|
self.oov_id = oov_id
|
|
self.pad_id = pad_id
|
|
self.max_length = max_length
|
|
|
|
def convert_text_to_tokens(self, text):
|
|
if isinstance(text, tuple):
|
|
if len(text) > 1:
|
|
raise TypeError('Cannot train LSTM/CNN models with multi-sequence inputs.')
|
|
text = text[0]
|
|
if not isinstance(text, str):
|
|
raise TypeError(
|
|
f"SpacyTokenizer can only tokenize `str`, got type {type(text)}"
|
|
)
|
|
spacy_tokens = [t.text for t in self.tokenizer(text)]
|
|
return spacy_tokens[: self.max_length]
|
|
|
|
def convert_tokens_to_ids(self, tokens):
|
|
ids = []
|
|
for raw_token in tokens:
|
|
token = raw_token.lower()
|
|
if token in self.word2id:
|
|
ids.append(self.word2id[token])
|
|
else:
|
|
ids.append(self.oov_id)
|
|
pad_ids_to_add = [self.pad_id] * (self.max_length - len(ids))
|
|
ids += pad_ids_to_add
|
|
return ids
|
|
|
|
def convert_id_to_word(self, _id):
|
|
"""
|
|
Takes an integer input and returns the corresponding word from the
|
|
vocabulary.
|
|
|
|
Raises: KeyError on OOV.
|
|
"""
|
|
return self.id2word[_id]
|