mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #509 from wenh06/master
Fix incorrect `__eq__` method of `AttackedText` in `textattack/shared/attacked_text.py`
This commit is contained in:
@@ -46,7 +46,8 @@ class CSVLogger(Logger):
|
||||
self._flushed = True
|
||||
|
||||
def close(self):
|
||||
self.fout.close()
|
||||
# self.fout.close()
|
||||
super().close()
|
||||
|
||||
def __del__(self):
|
||||
if not self._flushed:
|
||||
|
||||
@@ -84,7 +84,7 @@ class WordCNNForClassification(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, name_or_path):
|
||||
"""Load trained LSTM model by name or from path.
|
||||
"""Load trained Word CNN model by name or from path.
|
||||
|
||||
Args:
|
||||
name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`.
|
||||
|
||||
@@ -80,6 +80,8 @@ class AttackedText:
|
||||
"""
|
||||
if not (self.text == other.text):
|
||||
return False
|
||||
if len(self.attack_attrs) != len(other.attack_attrs):
|
||||
return False
|
||||
for key in self.attack_attrs:
|
||||
if key not in other.attack_attrs:
|
||||
return False
|
||||
@@ -193,7 +195,10 @@ class AttackedText:
|
||||
# Find all words until `i` in string.
|
||||
look_after_index = 0
|
||||
for word in pre_words:
|
||||
look_after_index = lower_text.find(word.lower(), look_after_index)
|
||||
look_after_index = lower_text.find(word.lower(), look_after_index) + len(
|
||||
word
|
||||
)
|
||||
look_after_index -= len(self.words[i])
|
||||
return look_after_index
|
||||
|
||||
def text_until_word_index(self, i):
|
||||
@@ -217,7 +222,7 @@ class AttackedText:
|
||||
w2 = other_attacked_text.words
|
||||
for i in range(min(len(w1), len(w2))):
|
||||
if w1[i] != w2[i]:
|
||||
return w1
|
||||
return w1[i]
|
||||
return None
|
||||
|
||||
def first_word_diff_index(self, other_attacked_text):
|
||||
|
||||
Reference in New Issue
Block a user