1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/models/tokenizers/spacy_tokenizer.py
2020-06-23 20:22:03 -04:00

54 lines
1.7 KiB
Python

import spacy
from textattack.models.tokenizers import Tokenizer
class SpacyTokenizer(Tokenizer):
""" A basic implementation of the spaCy English tokenizer.
Params:
word2id (dict<string, int>): A dictionary that matches words to IDs
oov_id (int): An out-of-variable ID
"""
def __init__(self, word2id, oov_id, pad_id, max_length=128):
self.tokenizer = spacy.load("en").tokenizer
self.word2id = word2id
self.id2word = {v: k for k, v in word2id.items()}
self.oov_id = oov_id
self.pad_id = pad_id
self.max_length = max_length
def convert_text_to_tokens(self, text):
if isinstance(text, tuple):
if len(text) > 1:
raise TypeError('Cannot train LSTM/CNN models with multi-sequence inputs.')
text = text[0]
if not isinstance(text, str):
raise TypeError(
f"SpacyTokenizer can only tokenize `str`, got type {type(text)}"
)
spacy_tokens = [t.text for t in self.tokenizer(text)]
return spacy_tokens[: self.max_length]
def convert_tokens_to_ids(self, tokens):
ids = []
for raw_token in tokens:
token = raw_token.lower()
if token in self.word2id:
ids.append(self.word2id[token])
else:
ids.append(self.oov_id)
pad_ids_to_add = [self.pad_id] * (self.max_length - len(ids))
ids += pad_ids_to_add
return ids
def convert_id_to_word(self, _id):
"""
Takes an integer input and returns the corresponding word from the
vocabulary.
Raises: KeyError on OOV.
"""
return self.id2word[_id]