This commit is contained in:
bwt09
2022-06-06 14:54:08 -07:00
parent 7a935ca451
commit 3bf52451d4
3 changed files with 23 additions and 9 deletions

View File

@@ -47,7 +47,7 @@ def main(rel_set='conceptnet',
seed_ent_tuples=info['seed_ent_tuples'])
knowledge_harvester.set_prompts(
prompts=info['init_prompts'] if use_init_prompts
else info['init_prompts'] + info['prompts'])
else list(set(info['init_prompts'] + info['prompts'])))
knowledge_harvester.update_prompts()
json.dump(knowledge_harvester.weighted_prompts, open(

View File

@@ -54,10 +54,6 @@ class EntityTupleSearcher:
if cur_ent_idx == n_ents:
pred = [min(cur_logprobs), cur_ent_tuple]
# filter tuples with only very short entities
if sum([len(ent) for ent in cur_ent_tuple]) == 3 * n_ents:
return
for ent in cur_ent_tuple:
for word in ent.split():
if repeat_cnt.get(word, 0) + 1 > max_word_repeat:
@@ -139,8 +135,8 @@ class EntityTupleSearcher:
if any([word in stopwords for word in pred_ent.split()]):
return
# filter entity with less than 3 characters
if len(pred_ent.replace(' ', '')) <= 2:
# filter entity with less than 4 characters
if len(pred_ent.replace(' ', '')) <= 3:
return
# filter entity with single-character words
@@ -170,7 +166,7 @@ class EntityTupleSearcher:
for raw_prompt, weight in weighted_prompts:
prompt = raw_prompt.replace(
f'<ENT{ent_idx}>',
self._model.tokenizer.decode(cur_token_ids) +
self._model.tokenizer.decode(cur_token_ids).lower() +
self._model.tokenizer.mask_token * (
n_masks[ent_idx] - len(cur_token_ids)))

View File

@@ -45,7 +45,25 @@ class KnowledgeHarvester:
scores = []
for ent_tuple in self._seed_ent_tuples:
ent_tuple = [ent.replace('_', ' ') for ent in ent_tuple]
scores.append(self.score(prompt=prompt, ent_tuple=ent_tuple))
score = self.score(prompt=prompt, ent_tuple=ent_tuple)
neg_scores = []
for ent_idx in range(len(ent_tuple)):
for ent_tuple1 in self._seed_ent_tuples:
if ent_tuple1[ent_idx] == ent_tuple[ent_idx]:
continue
ent_tuple_neg = \
ent_tuple[:ent_idx] + \
[ent_tuple1[ent_idx]] + \
ent_tuple[ent_idx + 1:]
neg_scores.append(self.score(
prompt=prompt, ent_tuple=ent_tuple_neg))
neg_score = sum(neg_scores) / len(neg_scores)
scores.append(score - neg_score)
self._weighted_prompts[i][1] = \
sum(scores) / len(scores) / self._prompt_temp