mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
@@ -196,13 +196,13 @@ class EntityTupleSearcher:
|
||||
logprobs = torch.log_softmax(mask_logits_total, dim=-1)
|
||||
logprobs, pred_ids = torch.sort(logprobs, descending=True)
|
||||
|
||||
for logprob, pred_id in zip(logprobs[:5 * n], pred_ids[:5 * n]):
|
||||
sum_logprob_upd = sum(cur_logprobs) + logprob.item()
|
||||
for logprob, pred_id in zip(logprobs, pred_ids):
|
||||
min_logprob_upd = min(cur_logprobs + [logprob.item()])
|
||||
if len(collected_ent_heap) == n and \
|
||||
sum_logprob_upd / 2. < collected_ent_heap[0][0]:
|
||||
min_logprob_upd < collected_ent_heap[0][0]:
|
||||
break
|
||||
|
||||
if sum_logprob_upd / 2. < logprob_threashold:
|
||||
if min_logprob_upd < logprob_threashold:
|
||||
break
|
||||
|
||||
if not any([ch.isalpha() for ch in
|
||||
|
||||
@@ -8,11 +8,9 @@ from data_utils.data_utils import stopwords, get_n_ents, get_sent, find_sublist
|
||||
class LanguageModelWrapper:
|
||||
def __init__(self, model_name):
|
||||
self._model_name = model_name
|
||||
self._max_batch_size = 64
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self._model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||
self._encoder = getattr(self._model, model_name.split('-')[0])
|
||||
|
||||
self._model.eval()
|
||||
self._model.to('cuda')
|
||||
@@ -42,25 +40,22 @@ class LanguageModelWrapper:
|
||||
|
||||
mask_positions = []
|
||||
for mask_span in mask_spans:
|
||||
for i in range(mask_span[0], mask_span[1]):
|
||||
mask_positions.append(i)
|
||||
mask_positions.extend([pos for pos in range(*mask_span)])
|
||||
|
||||
masked_inputs = self.tokenizer(
|
||||
[sent] * len(mask_positions), return_tensors='pt').to('cuda')
|
||||
label_token_ids = []
|
||||
for i in range(len(mask_positions)):
|
||||
label_token_ids.append(
|
||||
masked_inputs['input_ids'][i][mask_positions[i]])
|
||||
for pos in mask_positions[i:]:
|
||||
masked_inputs['input_ids'][i][pos] = \
|
||||
self.tokenizer.mask_token_id
|
||||
for i, pos in enumerate(mask_positions):
|
||||
label_token_ids.append(masked_inputs['input_ids'][i][pos])
|
||||
masked_inputs['input_ids'][i][mask_positions[i:]] = \
|
||||
self.tokenizer.mask_token_id
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**masked_inputs).logits
|
||||
logprobs = torch.log_softmax(logits, dim=-1)
|
||||
|
||||
mask_logprobs = logprobs[
|
||||
torch.arange(0, len(mask_positions)), mask_positions,
|
||||
torch.arange(len(mask_positions)), mask_positions,
|
||||
label_token_ids].tolist()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user