1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #509 from wenh06/master

Fix incorrect `__eq__` method of `AttackedText` in `textattack/shared/attacked_text.py`
This commit is contained in:
Jack Morris
2021-09-22 19:55:45 -04:00
committed by GitHub
3 changed files with 10 additions and 4 deletions

View File

@@ -46,7 +46,8 @@ class CSVLogger(Logger):
self._flushed = True
def close(self):
self.fout.close()
# self.fout.close()
super().close()
def __del__(self):
if not self._flushed:

View File

@@ -84,7 +84,7 @@ class WordCNNForClassification(nn.Module):
@classmethod
def from_pretrained(cls, name_or_path):
"""Load trained LSTM model by name or from path.
"""Load trained Word CNN model by name or from path.
Args:
name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`.

View File

@@ -80,6 +80,8 @@ class AttackedText:
"""
if not (self.text == other.text):
return False
if len(self.attack_attrs) != len(other.attack_attrs):
return False
for key in self.attack_attrs:
if key not in other.attack_attrs:
return False
@@ -193,7 +195,10 @@ class AttackedText:
# Find all words until `i` in string.
look_after_index = 0
for word in pre_words:
look_after_index = lower_text.find(word.lower(), look_after_index)
look_after_index = lower_text.find(word.lower(), look_after_index) + len(
word
)
look_after_index -= len(self.words[i])
return look_after_index
def text_until_word_index(self, i):
@@ -217,7 +222,7 @@ class AttackedText:
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
return w1
return w1[i]
return None
def first_word_diff_index(self, other_attacked_text):