This commit is contained in:
bwt09
2022-06-06 12:13:48 -07:00
parent 1ebaebfbd2
commit 595103af34
2 changed files with 10 additions and 15 deletions

View File

@@ -196,13 +196,13 @@ class EntityTupleSearcher:
logprobs = torch.log_softmax(mask_logits_total, dim=-1)
logprobs, pred_ids = torch.sort(logprobs, descending=True)
for logprob, pred_id in zip(logprobs[:5 * n], pred_ids[:5 * n]):
sum_logprob_upd = sum(cur_logprobs) + logprob.item()
for logprob, pred_id in zip(logprobs, pred_ids):
min_logprob_upd = min(cur_logprobs + [logprob.item()])
if len(collected_ent_heap) == n and \
sum_logprob_upd / 2. < collected_ent_heap[0][0]:
min_logprob_upd < collected_ent_heap[0][0]:
break
if sum_logprob_upd / 2. < logprob_threashold:
if min_logprob_upd < logprob_threashold:
break
if not any([ch.isalpha() for ch in

View File

@@ -8,11 +8,9 @@ from data_utils.data_utils import stopwords, get_n_ents, get_sent, find_sublist
class LanguageModelWrapper:
def __init__(self, model_name):
self._model_name = model_name
self._max_batch_size = 64
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModelForMaskedLM.from_pretrained(model_name)
self._encoder = getattr(self._model, model_name.split('-')[0])
self._model.eval()
self._model.to('cuda')
@@ -42,25 +40,22 @@ class LanguageModelWrapper:
mask_positions = []
for mask_span in mask_spans:
for i in range(mask_span[0], mask_span[1]):
mask_positions.append(i)
mask_positions.extend([pos for pos in range(*mask_span)])
masked_inputs = self.tokenizer(
[sent] * len(mask_positions), return_tensors='pt').to('cuda')
label_token_ids = []
for i in range(len(mask_positions)):
label_token_ids.append(
masked_inputs['input_ids'][i][mask_positions[i]])
for pos in mask_positions[i:]:
masked_inputs['input_ids'][i][pos] = \
self.tokenizer.mask_token_id
for i, pos in enumerate(mask_positions):
label_token_ids.append(masked_inputs['input_ids'][i][pos])
masked_inputs['input_ids'][i][mask_positions[i:]] = \
self.tokenizer.mask_token_id
with torch.no_grad():
logits = self.model(**masked_inputs).logits
logprobs = torch.log_softmax(logits, dim=-1)
mask_logprobs = logprobs[
torch.arange(0, len(mask_positions)), mask_positions,
torch.arange(len(mask_positions)), mask_positions,
label_token_ids].tolist()
torch.cuda.empty_cache()