1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

pyarrow<1.0 (temporary fix for nlp package)

This commit is contained in:
Jack Morris
2020-07-29 10:26:49 -04:00
parent 7acb17aa27
commit a04a6a847a
4 changed files with 8 additions and 4 deletions

View File

@@ -9,6 +9,7 @@ nlp
nltk
numpy
pandas>=1.0.1
pyarrow<1.0
scikit-learn
scipy==1.4.1
sentence_transformers>0.2.6

View File

@@ -19,7 +19,6 @@ def test_command_line_augmentation(name, command, outfile, sample_output_file):
import os
# desired_text = open(sample_output_file).read().strip()
# Run command and validate outputs.
result = run_command_and_get_result(command)

View File

@@ -14,7 +14,7 @@ def _cb(s):
def get_nlp_dataset_columns(dataset):
schema = set(dataset.schema.names)
schema = set(dataset.column_names)
if {"premise", "hypothesis", "label"} <= schema:
input_columns = ("premise", "hypothesis")
output_column = "label"

View File

@@ -117,7 +117,9 @@ class WordSwapMaskedLM(WordSwap):
)
current_inputs = self._encode_text(masked_text.text)
current_ids = current_inputs["input_ids"].tolist()[0]
word_tokens = self._lm_tokenizer.encode(current_text.words[index], add_special_tokens=False)
word_tokens = self._lm_tokenizer.encode(
current_text.words[index], add_special_tokens=False
)
try:
# Need try-except b/c mask-token located past max_length might be truncated by tokenizer
@@ -126,7 +128,9 @@ class WordSwapMaskedLM(WordSwap):
return []
# List of indices of tokens that are part of the target word
target_ids_pos = list(range(masked_index, min(masked_index + len(word_tokens), self.max_length)))
target_ids_pos = list(
range(masked_index, min(masked_index + len(word_tokens), self.max_length))
)
if not len(target_ids_pos):
return []