1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

free unncessary memory in TokenizedText

This commit is contained in:
Jin Yong Yoo
2020-05-24 09:24:37 -04:00
parent 9152ba8d82
commit a793bd8aa3
5 changed files with 28 additions and 17 deletions

View File

@@ -28,8 +28,8 @@ class AttackResult:
# We don't want the TokenizedText `ids` sticking around clogging up
# space on our devices. Delete them here, if they're still present,
# because we won't need them anymore anyway.
self.original_result.tokenized_text.delete_tensors()
self.perturbed_result.tokenized_text.delete_tensors()
self.original_result.tokenized_text.free_memory()
self.perturbed_result.tokenized_text.free_memory()
def original_text(self):
""" Returns the text portion of `self.original_result`. Helper method.

View File

@@ -141,7 +141,6 @@ class GeneticAlgorithm(SearchMethod):
for idx, result in enumerate(pop_results):
pop[idx].result = pop_results[idx]
pop = sorted(pop, key=lambda x: -x.result.score)
#print('\t\t', i, ' -- ', float(pop[0].result.score))
pop_scores = torch.Tensor([r.score for r in pop_results])
logits = ((-pop_scores) / self.temp).exp()

View File

@@ -115,12 +115,12 @@ class Checkpoint:
logger.info('Saving checkpoint under "{}" at {} after {} attacks.'.format(path, self.datetime, self.results_count))
print('=' * 125 + '\n')
with open(path, 'wb') as f:
pickle.dump(self, f)
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
@classmethod
def load(self, path):
with open(path, 'rb') as f:
checkpoint = pickle.load(f)
checkpoint = pickle.load(f, protocol=pickle.HIGHEST_PROTOCOL)
assert isinstance(checkpoint, Checkpoint)
return checkpoint

View File

@@ -37,6 +37,8 @@ def attack_from_queue(args, in_queue, out_queue):
results_gen = attack.attack_dataset([(output, text)], num_examples=1)
result = next(results_gen)
out_queue.put(result)
del output
del text
except Exception as e:
out_queue.put(e)
exit()
@@ -103,6 +105,7 @@ def run(args):
pbar = tqdm.tqdm(total=num_examples, smoothing=0)
while num_results < num_examples:
result = out_queue.get(block=True)
if isinstance(result, Exception):
raise result
attack_log_manager.log_result(result)
@@ -119,9 +122,12 @@ def run(args):
in_queue.put((label, text))
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
attack_log_manager.flush()
checkpoint = textattack.shared.Checkpoint(args, attack_log_manager)
checkpoint.save()
attack_log_manager.flush()
else:
if num_results > 0 and num_results % 50 == 0:
attack_log_manager.flush()
pbar.close()
print()

View File

@@ -1,5 +1,4 @@
import torch
from copy import deepcopy
from .utils import get_device, words_from_text
class TokenizedText:
@@ -21,13 +20,16 @@ class TokenizedText:
def __init__(self, text, tokenizer, attack_attrs=dict()):
text = text.strip()
self.tokenizer = tokenizer
ids = tokenizer.encode(text)
if not isinstance(ids, tuple):
# Some tokenizers may tokenize text to a single vector.
# In this case, wrap the vector in a tuple to mirror the
# format of other tokenizers.
ids = (ids,)
self.ids = ids
if tokenizer:
ids = tokenizer.encode(text)
if not isinstance(ids, tuple):
# Some tokenizers may tokenize text to a single vector.
# In this case, wrap the vector in a tuple to mirror the
# format of other tokenizers.
ids = (ids,)
self.ids = ids
else:
self.ids = None
self.words = words_from_text(text, words_to_ignore=[TokenizedText.SPLIT_TOKEN])
self.text = text
self.attack_attrs = attack_attrs
@@ -39,11 +41,15 @@ class TokenizedText:
def __hash__(self):
return hash(self.text)
def delete_tensors(self):
""" Delete tensors to clear up GPU space. Only should be called
once the TokenizedText is only needed to display.
def free_memory(self):
""" Delete items that take up memory.
Delete tensors to clear up GPU space.
Only should be called once the TokenizedText is only needed to display.
"""
self.ids = None
self.tokenizer = None
if 'last_transformation' in self.attack_attrs:
del self.attack_attrs['last_transformation']
for key in self.attack_attrs:
if isinstance(self.attack_attrs[key], torch.Tensor):
del self.attack_attrs[key]