mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
pyarrow<1.0 (temporary fix for nlp package)
This commit is contained in:
@@ -9,6 +9,7 @@ nlp
|
||||
nltk
|
||||
numpy
|
||||
pandas>=1.0.1
|
||||
pyarrow<1.0
|
||||
scikit-learn
|
||||
scipy==1.4.1
|
||||
sentence_transformers>0.2.6
|
||||
|
||||
@@ -19,7 +19,6 @@ def test_command_line_augmentation(name, command, outfile, sample_output_file):
|
||||
import os
|
||||
|
||||
# desired_text = open(sample_output_file).read().strip()
|
||||
|
||||
# Run command and validate outputs.
|
||||
result = run_command_and_get_result(command)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ def _cb(s):
|
||||
|
||||
|
||||
def get_nlp_dataset_columns(dataset):
|
||||
schema = set(dataset.schema.names)
|
||||
schema = set(dataset.column_names)
|
||||
if {"premise", "hypothesis", "label"} <= schema:
|
||||
input_columns = ("premise", "hypothesis")
|
||||
output_column = "label"
|
||||
|
||||
@@ -117,7 +117,9 @@ class WordSwapMaskedLM(WordSwap):
|
||||
)
|
||||
current_inputs = self._encode_text(masked_text.text)
|
||||
current_ids = current_inputs["input_ids"].tolist()[0]
|
||||
word_tokens = self._lm_tokenizer.encode(current_text.words[index], add_special_tokens=False)
|
||||
word_tokens = self._lm_tokenizer.encode(
|
||||
current_text.words[index], add_special_tokens=False
|
||||
)
|
||||
|
||||
try:
|
||||
# Need try-except b/c mask-token located past max_length might be truncated by tokenizer
|
||||
@@ -126,7 +128,9 @@ class WordSwapMaskedLM(WordSwap):
|
||||
return []
|
||||
|
||||
# List of indices of tokens that are part of the target word
|
||||
target_ids_pos = list(range(masked_index, min(masked_index + len(word_tokens), self.max_length)))
|
||||
target_ids_pos = list(
|
||||
range(masked_index, min(masked_index + len(word_tokens), self.max_length))
|
||||
)
|
||||
|
||||
if not len(target_ids_pos):
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user