mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
2
main.py
2
main.py
@@ -47,7 +47,7 @@ def main(rel_set='conceptnet',
|
||||
seed_ent_tuples=info['seed_ent_tuples'])
|
||||
knowledge_harvester.set_prompts(
|
||||
prompts=info['init_prompts'] if use_init_prompts
|
||||
else info['init_prompts'] + info['prompts'])
|
||||
else list(set(info['init_prompts'] + info['prompts'])))
|
||||
|
||||
knowledge_harvester.update_prompts()
|
||||
json.dump(knowledge_harvester.weighted_prompts, open(
|
||||
|
||||
@@ -54,10 +54,6 @@ class EntityTupleSearcher:
|
||||
if cur_ent_idx == n_ents:
|
||||
pred = [min(cur_logprobs), cur_ent_tuple]
|
||||
|
||||
# filter tuples with only very short entities
|
||||
if sum([len(ent) for ent in cur_ent_tuple]) == 3 * n_ents:
|
||||
return
|
||||
|
||||
for ent in cur_ent_tuple:
|
||||
for word in ent.split():
|
||||
if repeat_cnt.get(word, 0) + 1 > max_word_repeat:
|
||||
@@ -139,8 +135,8 @@ class EntityTupleSearcher:
|
||||
if any([word in stopwords for word in pred_ent.split()]):
|
||||
return
|
||||
|
||||
# filter entity with less than 3 characters
|
||||
if len(pred_ent.replace(' ', '')) <= 2:
|
||||
# filter entity with less than 4 characters
|
||||
if len(pred_ent.replace(' ', '')) <= 3:
|
||||
return
|
||||
|
||||
# filter entity with single-character words
|
||||
@@ -170,7 +166,7 @@ class EntityTupleSearcher:
|
||||
for raw_prompt, weight in weighted_prompts:
|
||||
prompt = raw_prompt.replace(
|
||||
f'<ENT{ent_idx}>',
|
||||
self._model.tokenizer.decode(cur_token_ids) +
|
||||
self._model.tokenizer.decode(cur_token_ids).lower() +
|
||||
self._model.tokenizer.mask_token * (
|
||||
n_masks[ent_idx] - len(cur_token_ids)))
|
||||
|
||||
|
||||
@@ -45,7 +45,25 @@ class KnowledgeHarvester:
|
||||
scores = []
|
||||
for ent_tuple in self._seed_ent_tuples:
|
||||
ent_tuple = [ent.replace('_', ' ') for ent in ent_tuple]
|
||||
scores.append(self.score(prompt=prompt, ent_tuple=ent_tuple))
|
||||
|
||||
score = self.score(prompt=prompt, ent_tuple=ent_tuple)
|
||||
|
||||
neg_scores = []
|
||||
for ent_idx in range(len(ent_tuple)):
|
||||
for ent_tuple1 in self._seed_ent_tuples:
|
||||
if ent_tuple1[ent_idx] == ent_tuple[ent_idx]:
|
||||
continue
|
||||
|
||||
ent_tuple_neg = \
|
||||
ent_tuple[:ent_idx] + \
|
||||
[ent_tuple1[ent_idx]] + \
|
||||
ent_tuple[ent_idx + 1:]
|
||||
|
||||
neg_scores.append(self.score(
|
||||
prompt=prompt, ent_tuple=ent_tuple_neg))
|
||||
|
||||
neg_score = sum(neg_scores) / len(neg_scores)
|
||||
scores.append(score - neg_score)
|
||||
|
||||
self._weighted_prompts[i][1] = \
|
||||
sum(scores) / len(scores) / self._prompt_temp
|
||||
|
||||
Reference in New Issue
Block a user