mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
82 lines
1.8 KiB
Python
82 lines
1.8 KiB
Python
import re
|
|
from nltk.corpus import stopwords
|
|
|
|
|
|
stopwords = stopwords.words('english')
|
|
stopwords.extend([
|
|
'everything', 'everybody', 'everyone',
|
|
'anything', 'anybody', 'anyone',
|
|
'something', 'somebody', 'someone',
|
|
'nothing', 'nobody',
|
|
'one', 'neither', 'either', 'many',
|
|
'us', 'first', 'second', 'next',
|
|
'following', 'last', 'new', 'main', 'also'])
|
|
|
|
|
|
def is_valid_prompt(prompt):
|
|
for i in range(1, len(prompt)):
|
|
if prompt[i:].startswith('<ENT') and prompt[i - 1] not in [' ', '\"']:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def get_n_ents(prompt):
|
|
n = 0
|
|
while f'<ENT{n}>' in prompt:
|
|
n += 1
|
|
return n
|
|
|
|
|
|
def get_sent(prompt, ent_tuple):
|
|
sent = prompt
|
|
for idx, ent in enumerate(ent_tuple):
|
|
sent = sent.replace(f'<ENT{idx}>', ent)
|
|
|
|
return sent
|
|
|
|
|
|
def get_mask_place(ent_idx, n_masks, prompt):
|
|
mask_idx = 0
|
|
for t in re.findall(r'<ENT[0-9]+>', prompt):
|
|
t_idx = int(t[len('<ENT'):-1])
|
|
if t_idx != ent_idx:
|
|
mask_idx += n_masks[t_idx]
|
|
else:
|
|
break
|
|
|
|
return mask_idx
|
|
|
|
|
|
def get_n_masks(t, n_ents, max_ent_subwords):
|
|
n_masks = []
|
|
for i in range(n_ents):
|
|
n_masks.append(t % max_ent_subwords + 1)
|
|
t //= max_ent_subwords
|
|
|
|
return n_masks
|
|
|
|
|
|
def get_masked_prompt(prompt, n_masks, mask_token):
|
|
input_text = prompt
|
|
for ent_idx, n_mask in enumerate(n_masks):
|
|
input_text = input_text.replace(f'<ENT{ent_idx}>', mask_token * n_mask)
|
|
|
|
return input_text
|
|
|
|
|
|
def fix_prompt_style(prompt):
|
|
prompt = prompt.strip(' .')
|
|
if prompt[0].isalpha():
|
|
prompt = prompt[0].upper() + prompt[1:]
|
|
|
|
return prompt + ' .'
|
|
|
|
|
|
def find_sublist(a, b):
|
|
for l in range(len(a)):
|
|
if a[l:l+len(b)] == b:
|
|
return l
|
|
|
|
return None
|