mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
free unncessary memory in TokenizedText
This commit is contained in:
@@ -28,8 +28,8 @@ class AttackResult:
|
||||
# We don't want the TokenizedText `ids` sticking around clogging up
|
||||
# space on our devices. Delete them here, if they're still present,
|
||||
# because we won't need them anymore anyway.
|
||||
self.original_result.tokenized_text.delete_tensors()
|
||||
self.perturbed_result.tokenized_text.delete_tensors()
|
||||
self.original_result.tokenized_text.free_memory()
|
||||
self.perturbed_result.tokenized_text.free_memory()
|
||||
|
||||
def original_text(self):
|
||||
""" Returns the text portion of `self.original_result`. Helper method.
|
||||
|
||||
@@ -141,7 +141,6 @@ class GeneticAlgorithm(SearchMethod):
|
||||
for idx, result in enumerate(pop_results):
|
||||
pop[idx].result = pop_results[idx]
|
||||
pop = sorted(pop, key=lambda x: -x.result.score)
|
||||
#print('\t\t', i, ' -- ', float(pop[0].result.score))
|
||||
|
||||
pop_scores = torch.Tensor([r.score for r in pop_results])
|
||||
logits = ((-pop_scores) / self.temp).exp()
|
||||
|
||||
@@ -115,12 +115,12 @@ class Checkpoint:
|
||||
logger.info('Saving checkpoint under "{}" at {} after {} attacks.'.format(path, self.datetime, self.results_count))
|
||||
print('=' * 125 + '\n')
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self, f)
|
||||
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
@classmethod
|
||||
def load(self, path):
|
||||
with open(path, 'rb') as f:
|
||||
checkpoint = pickle.load(f)
|
||||
checkpoint = pickle.load(f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert isinstance(checkpoint, Checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -37,6 +37,8 @@ def attack_from_queue(args, in_queue, out_queue):
|
||||
results_gen = attack.attack_dataset([(output, text)], num_examples=1)
|
||||
result = next(results_gen)
|
||||
out_queue.put(result)
|
||||
del output
|
||||
del text
|
||||
except Exception as e:
|
||||
out_queue.put(e)
|
||||
exit()
|
||||
@@ -103,6 +105,7 @@ def run(args):
|
||||
pbar = tqdm.tqdm(total=num_examples, smoothing=0)
|
||||
while num_results < num_examples:
|
||||
result = out_queue.get(block=True)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
attack_log_manager.log_result(result)
|
||||
@@ -119,9 +122,12 @@ def run(args):
|
||||
in_queue.put((label, text))
|
||||
|
||||
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
|
||||
attack_log_manager.flush()
|
||||
checkpoint = textattack.shared.Checkpoint(args, attack_log_manager)
|
||||
checkpoint.save()
|
||||
attack_log_manager.flush()
|
||||
else:
|
||||
if num_results > 0 and num_results % 50 == 0:
|
||||
attack_log_manager.flush()
|
||||
|
||||
pbar.close()
|
||||
print()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from .utils import get_device, words_from_text
|
||||
|
||||
class TokenizedText:
|
||||
@@ -21,13 +20,16 @@ class TokenizedText:
|
||||
def __init__(self, text, tokenizer, attack_attrs=dict()):
|
||||
text = text.strip()
|
||||
self.tokenizer = tokenizer
|
||||
ids = tokenizer.encode(text)
|
||||
if not isinstance(ids, tuple):
|
||||
# Some tokenizers may tokenize text to a single vector.
|
||||
# In this case, wrap the vector in a tuple to mirror the
|
||||
# format of other tokenizers.
|
||||
ids = (ids,)
|
||||
self.ids = ids
|
||||
if tokenizer:
|
||||
ids = tokenizer.encode(text)
|
||||
if not isinstance(ids, tuple):
|
||||
# Some tokenizers may tokenize text to a single vector.
|
||||
# In this case, wrap the vector in a tuple to mirror the
|
||||
# format of other tokenizers.
|
||||
ids = (ids,)
|
||||
self.ids = ids
|
||||
else:
|
||||
self.ids = None
|
||||
self.words = words_from_text(text, words_to_ignore=[TokenizedText.SPLIT_TOKEN])
|
||||
self.text = text
|
||||
self.attack_attrs = attack_attrs
|
||||
@@ -39,11 +41,15 @@ class TokenizedText:
|
||||
def __hash__(self):
|
||||
return hash(self.text)
|
||||
|
||||
def delete_tensors(self):
|
||||
""" Delete tensors to clear up GPU space. Only should be called
|
||||
once the TokenizedText is only needed to display.
|
||||
def free_memory(self):
|
||||
""" Delete items that take up memory.
|
||||
Delete tensors to clear up GPU space.
|
||||
Only should be called once the TokenizedText is only needed to display.
|
||||
"""
|
||||
self.ids = None
|
||||
self.tokenizer = None
|
||||
if 'last_transformation' in self.attack_attrs:
|
||||
del self.attack_attrs['last_transformation']
|
||||
for key in self.attack_attrs:
|
||||
if isinstance(self.attack_attrs[key], torch.Tensor):
|
||||
del self.attack_attrs[key]
|
||||
|
||||
Reference in New Issue
Block a user