mirror of
https://github.com/hhy-huang/HiRAG.git
synced 2025-09-16 23:52:00 +03:00
1960 lines
73 KiB
Python
1960 lines
73 KiB
Python
import re
|
|
import json
|
|
import asyncio
|
|
import tiktoken
|
|
import networkx as nx
|
|
import time
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from typing import Union
|
|
from collections import Counter, defaultdict
|
|
from ._splitter import SeparatorSplitter
|
|
from ._utils import (
|
|
logger,
|
|
clean_str,
|
|
compute_mdhash_id,
|
|
decode_tokens_by_tiktoken,
|
|
encode_string_by_tiktoken,
|
|
is_float_regex,
|
|
list_of_list_to_csv,
|
|
pack_user_ass_to_openai_messages,
|
|
split_string_by_multi_markers,
|
|
truncate_list_by_token_size,
|
|
)
|
|
from .base import (
|
|
BaseGraphStorage,
|
|
BaseKVStorage,
|
|
BaseVectorStorage,
|
|
SingleCommunitySchema,
|
|
CommunitySchema,
|
|
TextChunkSchema,
|
|
QueryParam,
|
|
)
|
|
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
|
from ._cluster_utils import Hierarchical_Clustering
|
|
|
|
|
|
@contextmanager
|
|
def timer():
|
|
start_time = time.perf_counter()
|
|
try:
|
|
yield
|
|
finally:
|
|
end_time = time.perf_counter()
|
|
elapsed_time = end_time - start_time
|
|
logging.info(f"[Retrieval Time: {elapsed_time:.6f} seconds]")
|
|
|
|
|
|
def chunking_by_token_size(
|
|
tokens_list: list[list[int]],
|
|
doc_keys,
|
|
tiktoken_model,
|
|
overlap_token_size=128,
|
|
max_token_size=1024,
|
|
):
|
|
# tokenizer
|
|
results = []
|
|
for index, tokens in enumerate(tokens_list):
|
|
chunk_token = []
|
|
lengths = []
|
|
for start in range(0, len(tokens), max_token_size - overlap_token_size):
|
|
|
|
chunk_token.append(tokens[start : start + max_token_size])
|
|
lengths.append(min(max_token_size, len(tokens) - start))
|
|
|
|
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
|
|
chunk_token = tiktoken_model.decode_batch(chunk_token)
|
|
for i, chunk in enumerate(chunk_token):
|
|
|
|
results.append(
|
|
{
|
|
"tokens": lengths[i],
|
|
"content": chunk.strip(),
|
|
"chunk_order_index": i,
|
|
"full_doc_id": doc_keys[index],
|
|
}
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
def chunking_by_seperators(
|
|
tokens_list: list[list[int]],
|
|
doc_keys,
|
|
tiktoken_model,
|
|
overlap_token_size=128,
|
|
max_token_size=1024,
|
|
):
|
|
|
|
splitter = SeparatorSplitter(
|
|
separators=[
|
|
tiktoken_model.encode(s) for s in PROMPTS["default_text_separator"]
|
|
],
|
|
chunk_size=max_token_size,
|
|
chunk_overlap=overlap_token_size,
|
|
)
|
|
results = []
|
|
for index, tokens in enumerate(tokens_list):
|
|
chunk_token = splitter.split_tokens(tokens)
|
|
lengths = [len(c) for c in chunk_token]
|
|
|
|
# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
|
|
chunk_token = tiktoken_model.decode_batch(chunk_token)
|
|
for i, chunk in enumerate(chunk_token):
|
|
|
|
results.append(
|
|
{
|
|
"tokens": lengths[i],
|
|
"content": chunk.strip(),
|
|
"chunk_order_index": i,
|
|
"full_doc_id": doc_keys[index],
|
|
}
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
def get_chunks(new_docs, chunk_func=chunking_by_token_size, **chunk_func_params):
|
|
inserting_chunks = {}
|
|
|
|
new_docs_list = list(new_docs.items())
|
|
docs = [new_doc[1]["content"] for new_doc in new_docs_list]
|
|
doc_keys = [new_doc[0] for new_doc in new_docs_list]
|
|
|
|
ENCODER = tiktoken.encoding_for_model("gpt-4o")
|
|
tokens = ENCODER.encode_batch(docs, num_threads=16)
|
|
chunks = chunk_func(
|
|
tokens, doc_keys=doc_keys, tiktoken_model=ENCODER, **chunk_func_params
|
|
)
|
|
|
|
for chunk in chunks:
|
|
inserting_chunks.update(
|
|
{compute_mdhash_id(chunk["content"], prefix="chunk-"): chunk}
|
|
)
|
|
|
|
return inserting_chunks
|
|
|
|
|
|
async def _handle_entity_relation_summary(
|
|
entity_or_relation_name: str,
|
|
description: str,
|
|
global_config: dict,
|
|
) -> str:
|
|
"""Summarize the entity or relation description,is used during entity extraction and when merging nodes or edges in the knowledge graph
|
|
|
|
Args:
|
|
entity_or_relation_name: entity or relation name
|
|
description: description
|
|
global_config: global configuration
|
|
"""
|
|
use_llm_func: callable = global_config["cheap_model_func"]
|
|
llm_max_tokens = global_config["cheap_model_max_token_size"]
|
|
tiktoken_model_name = global_config["tiktoken_model_name"]
|
|
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
|
|
|
|
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
|
|
if len(tokens) < summary_max_tokens: # No need for summary
|
|
return description
|
|
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
|
use_description = decode_tokens_by_tiktoken(
|
|
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
|
)
|
|
context_base = dict(
|
|
entity_name=entity_or_relation_name,
|
|
description_list=use_description.split(GRAPH_FIELD_SEP),
|
|
)
|
|
use_prompt = prompt_template.format(**context_base)
|
|
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
|
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
|
return summary
|
|
|
|
|
|
async def _handle_single_entity_extraction(
|
|
record_attributes: list[str],
|
|
chunk_key: str,
|
|
):
|
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
|
return None
|
|
# add this record as a node in the G
|
|
entity_name = clean_str(record_attributes[1].upper())
|
|
if not entity_name.strip():
|
|
return None
|
|
entity_type = clean_str(record_attributes[2].upper())
|
|
entity_description = clean_str(record_attributes[3])
|
|
entity_source_id = chunk_key
|
|
return dict(
|
|
entity_name=entity_name,
|
|
entity_type=entity_type,
|
|
description=entity_description,
|
|
source_id=entity_source_id,
|
|
)
|
|
|
|
|
|
async def _handle_single_relationship_extraction(
|
|
record_attributes: list[str],
|
|
chunk_key: str,
|
|
):
|
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
|
return None
|
|
# add this record as edge
|
|
source = clean_str(record_attributes[1].upper())
|
|
target = clean_str(record_attributes[2].upper())
|
|
edge_description = clean_str(record_attributes[3])
|
|
edge_source_id = chunk_key
|
|
weight = (
|
|
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
|
)
|
|
return dict(
|
|
src_id=source,
|
|
tgt_id=target,
|
|
weight=weight,
|
|
description=edge_description,
|
|
source_id=edge_source_id,
|
|
)
|
|
|
|
|
|
async def _merge_nodes_then_upsert(
|
|
entity_name: str,
|
|
nodes_data: list[dict],
|
|
knwoledge_graph_inst: BaseGraphStorage,
|
|
global_config: dict,
|
|
):
|
|
already_entitiy_types = []
|
|
already_source_ids = []
|
|
already_description = []
|
|
|
|
already_node = await knwoledge_graph_inst.get_node(entity_name)
|
|
if already_node is not None: # already exist
|
|
already_entitiy_types.append(already_node["entity_type"])
|
|
already_source_ids.extend(
|
|
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
|
)
|
|
already_description.append(already_node["description"])
|
|
|
|
entity_type = sorted(
|
|
Counter(
|
|
[dp["entity_type"] for dp in nodes_data] + already_entitiy_types
|
|
).items(),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)[0][0]
|
|
description = GRAPH_FIELD_SEP.join(
|
|
sorted(set([dp["description"] for dp in nodes_data] + already_description))
|
|
)
|
|
source_id = GRAPH_FIELD_SEP.join(
|
|
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
|
)
|
|
description = await _handle_entity_relation_summary(
|
|
entity_name, description, global_config
|
|
)
|
|
node_data = dict(
|
|
entity_type=entity_type,
|
|
description=description,
|
|
source_id=source_id,
|
|
)
|
|
await knwoledge_graph_inst.upsert_node(
|
|
entity_name,
|
|
node_data=node_data,
|
|
)
|
|
node_data["entity_name"] = entity_name
|
|
return node_data
|
|
|
|
|
|
async def _merge_edges_then_upsert(
|
|
src_id: str,
|
|
tgt_id: str,
|
|
edges_data: list[dict],
|
|
knwoledge_graph_inst: BaseGraphStorage,
|
|
global_config: dict,
|
|
):
|
|
already_weights = []
|
|
already_source_ids = []
|
|
already_description = []
|
|
already_order = []
|
|
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
|
|
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
|
|
already_weights.append(already_edge["weight"])
|
|
already_source_ids.extend(
|
|
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
|
)
|
|
already_description.append(already_edge["description"])
|
|
already_order.append(already_edge.get("order", 1))
|
|
|
|
# [numberchiffre]: `Relationship.order` is only returned from DSPy's predictions
|
|
order = min([dp.get("order", 1) for dp in edges_data] + already_order)
|
|
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
|
description = GRAPH_FIELD_SEP.join(
|
|
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
|
)
|
|
source_id = GRAPH_FIELD_SEP.join(
|
|
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
|
)
|
|
for need_insert_id in [src_id, tgt_id]:
|
|
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
|
|
await knwoledge_graph_inst.upsert_node(
|
|
need_insert_id,
|
|
node_data={
|
|
"source_id": source_id,
|
|
"description": description,
|
|
"entity_type": '"UNKNOWN"',
|
|
},
|
|
)
|
|
description = await _handle_entity_relation_summary(
|
|
(src_id, tgt_id), description, global_config
|
|
)
|
|
await knwoledge_graph_inst.upsert_edge(
|
|
src_id,
|
|
tgt_id,
|
|
edge_data=dict(
|
|
weight=weight, description=description, source_id=source_id, order=order
|
|
),
|
|
)
|
|
|
|
# TODO:
|
|
# extract entities with normal and attribute entities
|
|
async def extract_hierarchical_entities(
|
|
chunks: dict[str, TextChunkSchema],
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entity_vdb: BaseVectorStorage,
|
|
global_config: dict,
|
|
)-> Union[BaseGraphStorage, None]:
|
|
"""Extract entities and relations from text chunks
|
|
|
|
Args:
|
|
chunks: text chunks
|
|
knowledge_graph_inst: knowledge graph instance
|
|
entity_vdb: entity vector database
|
|
global_config: global configuration
|
|
|
|
Returns:
|
|
Union[BaseGraphStorage, None]: knowledge graph instance
|
|
"""
|
|
use_llm_func: callable = global_config["best_model_func"]
|
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
|
|
|
ordered_chunks = list(chunks.items())
|
|
entity_extract_prompt = PROMPTS["hi_entity_extraction"] # give 3 examples in the prompt context
|
|
relation_extract_prompt = PROMPTS["hi_relation_extraction"]
|
|
|
|
context_base_entity = dict(
|
|
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
|
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
|
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
|
entity_types=",".join(PROMPTS["META_ENTITY_TYPES"])
|
|
)
|
|
continue_prompt = PROMPTS["entiti_continue_extraction"] # means low quality in the last extraction
|
|
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] # judge if there are still entities still need to be extracted
|
|
|
|
already_processed = 0
|
|
already_entities = 0
|
|
already_relations = 0
|
|
|
|
async def _process_single_content_entity(chunk_key_dp: tuple[str, TextChunkSchema]): # for each chunk, run the func
|
|
nonlocal already_processed, already_entities, already_relations
|
|
chunk_key = chunk_key_dp[0]
|
|
chunk_dp = chunk_key_dp[1]
|
|
content = chunk_dp["content"]
|
|
hint_prompt = entity_extract_prompt.format(**context_base_entity, input_text=content) # fill in the parameter
|
|
final_result = await use_llm_func(hint_prompt) # feed into LLM with the prompt
|
|
|
|
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # set as history
|
|
for now_glean_index in range(entity_extract_max_gleaning):
|
|
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
|
|
|
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) # add to history
|
|
final_result += glean_result
|
|
if now_glean_index == entity_extract_max_gleaning - 1:
|
|
break
|
|
|
|
if_loop_result: str = await use_llm_func( # judge if we still need the next iteration
|
|
if_loop_prompt, history_messages=history
|
|
)
|
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
|
if if_loop_result != "yes":
|
|
break
|
|
|
|
records = split_string_by_multi_markers( # split entities from result --> list of entities
|
|
final_result,
|
|
[context_base_entity["record_delimiter"], context_base_entity["completion_delimiter"]],
|
|
)
|
|
# resolve the entities
|
|
maybe_nodes = defaultdict(list)
|
|
maybe_edges = defaultdict(list)
|
|
for record in records:
|
|
record = re.search(r"\((.*)\)", record)
|
|
if record is None:
|
|
continue
|
|
record = record.group(1)
|
|
record_attributes = split_string_by_multi_markers( # split entity
|
|
record, [context_base_entity["tuple_delimiter"]]
|
|
)
|
|
if_entities = await _handle_single_entity_extraction( # get the name, type, desc, source_id of entity--> dict
|
|
record_attributes, chunk_key
|
|
)
|
|
if if_entities is not None:
|
|
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
|
continue
|
|
|
|
if_relation = await _handle_single_relationship_extraction(
|
|
record_attributes, chunk_key
|
|
)
|
|
if if_relation is not None:
|
|
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
|
if_relation
|
|
)
|
|
already_processed += 1 # already processed chunks
|
|
already_entities += len(maybe_nodes)
|
|
already_relations += len(maybe_edges)
|
|
now_ticks = PROMPTS["process_tickers"][ # for visualization
|
|
already_processed % len(PROMPTS["process_tickers"])
|
|
]
|
|
print(
|
|
f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
|
end="",
|
|
flush=True,
|
|
)
|
|
return dict(maybe_nodes), dict(maybe_edges)
|
|
|
|
# extract entities
|
|
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
|
|
entity_results = await asyncio.gather(
|
|
*[_process_single_content_entity(c) for c in ordered_chunks]
|
|
)
|
|
print() # clear the progress bar
|
|
|
|
# fetch all entities from results
|
|
all_entities = {}
|
|
for item in entity_results:
|
|
for k, v in item[0].items():
|
|
value = v[0]
|
|
all_entities[k] = v[0]
|
|
context_entities = {key[0]: list(x[0].keys()) for key, x in zip(ordered_chunks, entity_results)}
|
|
|
|
# fetch embeddings
|
|
entity_discriptions = [v["description"] for k, v in all_entities.items()]
|
|
entity_sequence_embeddings = []
|
|
embeddings_batch_size = 64
|
|
num_embeddings_batches = (len(entity_discriptions) + embeddings_batch_size - 1) // embeddings_batch_size
|
|
for i in range(num_embeddings_batches):
|
|
start_index = i * embeddings_batch_size
|
|
end_index = min((i + 1) * embeddings_batch_size, len(entity_discriptions))
|
|
batch = entity_discriptions[start_index:end_index]
|
|
result = await entity_vdb.embedding_func(batch)
|
|
entity_sequence_embeddings.extend(result)
|
|
entity_embeddings = entity_sequence_embeddings
|
|
for (k, v), x in zip(all_entities.items(), entity_embeddings):
|
|
value = v
|
|
value["embedding"] = x
|
|
all_entities[k] = value
|
|
|
|
already_processed = 0
|
|
async def _process_single_content_relation(chunk_key_dp: tuple[str, TextChunkSchema]): # for each chunk, run the func
|
|
nonlocal already_processed, already_entities, already_relations
|
|
chunk_key = chunk_key_dp[0]
|
|
chunk_dp = chunk_key_dp[1]
|
|
content = chunk_dp["content"]
|
|
|
|
entities = context_entities[chunk_key]
|
|
context_base_relation = dict(
|
|
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
|
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
|
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
|
entities=",".join(entities)
|
|
)
|
|
hint_prompt = relation_extract_prompt.format(**context_base_relation, input_text=content) # fill in the parameter
|
|
final_result = await use_llm_func(hint_prompt) # feed into LLM with the prompt
|
|
|
|
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # set as history
|
|
for now_glean_index in range(entity_extract_max_gleaning):
|
|
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
|
|
|
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) # add to history
|
|
final_result += glean_result
|
|
if now_glean_index == entity_extract_max_gleaning - 1:
|
|
break
|
|
|
|
if_loop_result: str = await use_llm_func( # judge if we still need the next iteration
|
|
if_loop_prompt, history_messages=history
|
|
)
|
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
|
if if_loop_result != "yes":
|
|
break
|
|
|
|
records = split_string_by_multi_markers( # split entities from result --> list of entities
|
|
final_result,
|
|
[context_base_relation["record_delimiter"], context_base_relation["completion_delimiter"]],
|
|
)
|
|
# resolve the entities
|
|
maybe_nodes = defaultdict(list)
|
|
maybe_edges = defaultdict(list)
|
|
for record in records:
|
|
record = re.search(r"\((.*)\)", record)
|
|
if record is None:
|
|
continue
|
|
record = record.group(1)
|
|
record_attributes = split_string_by_multi_markers( # split entity
|
|
record, [context_base_relation["tuple_delimiter"]]
|
|
)
|
|
if_entities = await _handle_single_entity_extraction( # get the name, type, desc, source_id of entity--> dict
|
|
record_attributes, chunk_key
|
|
)
|
|
if if_entities is not None:
|
|
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
|
continue
|
|
|
|
if_relation = await _handle_single_relationship_extraction(
|
|
record_attributes, chunk_key
|
|
)
|
|
if if_relation is not None:
|
|
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
|
if_relation
|
|
)
|
|
already_processed += 1 # already processed chunks
|
|
already_entities += len(maybe_nodes)
|
|
already_relations += len(maybe_edges)
|
|
now_ticks = PROMPTS["process_tickers"][ # for visualization
|
|
already_processed % len(PROMPTS["process_tickers"])
|
|
]
|
|
print(
|
|
f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
|
end="",
|
|
flush=True,
|
|
)
|
|
return dict(maybe_nodes), dict(maybe_edges)
|
|
|
|
# extract entities
|
|
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
|
|
relation_results = await asyncio.gather(
|
|
*[_process_single_content_relation(c) for c in ordered_chunks]
|
|
)
|
|
print()
|
|
|
|
# fetch all relations from results
|
|
all_relations = {}
|
|
for item in relation_results:
|
|
for k, v in item[1].items():
|
|
all_relations[k] = v
|
|
|
|
# TODO: hierarchical clustering
|
|
logger.info(f"[Hierarchical Clustering]")
|
|
hierarchical_cluster = Hierarchical_Clustering()
|
|
hierarchical_clustered_entities_relations = await hierarchical_cluster.perform_clustering(entity_vdb=entity_vdb, global_config=global_config, entities=all_entities)
|
|
hierarchical_clustered_entities = [[x for x in y if "entity_name" in x.keys()] for y in hierarchical_clustered_entities_relations]
|
|
hierarchical_clustered_relations = [[x for x in y if "src_id" in x.keys()] for y in hierarchical_clustered_entities_relations]
|
|
|
|
maybe_nodes = defaultdict(list) # for all chunks
|
|
maybe_edges = defaultdict(list)
|
|
# extracted entities and relations
|
|
for m_nodes, m_edges in zip(entity_results, relation_results):
|
|
for k, v in m_nodes[0].items():
|
|
maybe_nodes[k].extend(v)
|
|
for k, v in m_edges[1].items():
|
|
# it's undirected graph
|
|
maybe_edges[tuple(sorted(k))].extend(v)
|
|
# clustered entities
|
|
for cluster_layer in hierarchical_clustered_entities:
|
|
for item in cluster_layer:
|
|
maybe_nodes[item['entity_name']].extend([item])
|
|
# clustered relations
|
|
for cluster_layer in hierarchical_clustered_relations:
|
|
for item in cluster_layer:
|
|
maybe_edges[tuple(sorted((item["src_id"], item["tgt_id"])))].extend([item])
|
|
# store the nodes
|
|
all_entities_data = await asyncio.gather(
|
|
*[
|
|
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
|
for k, v in maybe_nodes.items()
|
|
]
|
|
)
|
|
# store the edges
|
|
await asyncio.gather(
|
|
*[
|
|
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
|
|
for k, v in maybe_edges.items()
|
|
]
|
|
)
|
|
if not len(all_entities_data):
|
|
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
|
return None
|
|
if entity_vdb is not None:
|
|
data_for_vdb = { # key is the md5 hash of the entity name string
|
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
"content": dp["entity_name"] + dp["description"], # entity name and description construct the content
|
|
"entity_name": dp["entity_name"],
|
|
}
|
|
for dp in all_entities_data
|
|
}
|
|
await entity_vdb.upsert(data_for_vdb)
|
|
return knowledge_graph_inst
|
|
|
|
async def extract_entities(
|
|
chunks: dict[str, TextChunkSchema],
|
|
knwoledge_graph_inst: BaseGraphStorage,
|
|
entity_vdb: BaseVectorStorage,
|
|
global_config: dict,
|
|
) -> Union[BaseGraphStorage, None]:
|
|
use_llm_func: callable = global_config["best_model_func"]
|
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
|
|
|
ordered_chunks = list(chunks.items()) # chunks
|
|
|
|
entity_extract_prompt = PROMPTS["entity_extraction"] # give 3 examples in the prompt context
|
|
context_base = dict(
|
|
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
|
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
|
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
|
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
|
|
)
|
|
continue_prompt = PROMPTS["entiti_continue_extraction"] # means low quality in the last extraction
|
|
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"] # judge if there are still entities still need to be extracted
|
|
|
|
already_processed = 0
|
|
already_entities = 0
|
|
already_relations = 0
|
|
|
|
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): # for each chunk, run the func
|
|
nonlocal already_processed, already_entities, already_relations
|
|
chunk_key = chunk_key_dp[0]
|
|
chunk_dp = chunk_key_dp[1]
|
|
content = chunk_dp["content"]
|
|
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) # fill in the parameter
|
|
final_result = await use_llm_func(hint_prompt) # feed into LLM with the prompt
|
|
|
|
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # set as history
|
|
for now_glean_index in range(entity_extract_max_gleaning):
|
|
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
|
|
|
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) # add to history
|
|
final_result += glean_result
|
|
if now_glean_index == entity_extract_max_gleaning - 1:
|
|
break
|
|
|
|
if_loop_result: str = await use_llm_func( # judge if we still need the next iteration
|
|
if_loop_prompt, history_messages=history
|
|
)
|
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
|
if if_loop_result != "yes":
|
|
break
|
|
|
|
records = split_string_by_multi_markers( # split entities from result --> list of entities
|
|
final_result,
|
|
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
|
)
|
|
|
|
maybe_nodes = defaultdict(list)
|
|
maybe_edges = defaultdict(list)
|
|
for record in records:
|
|
record = re.search(r"\((.*)\)", record)
|
|
if record is None:
|
|
continue
|
|
record = record.group(1)
|
|
record_attributes = split_string_by_multi_markers( # split entity
|
|
record, [context_base["tuple_delimiter"]]
|
|
)
|
|
if_entities = await _handle_single_entity_extraction( # get the name, type, desc, source_id of entity--> dict
|
|
record_attributes, chunk_key
|
|
)
|
|
if if_entities is not None:
|
|
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
|
continue
|
|
|
|
if_relation = await _handle_single_relationship_extraction(
|
|
record_attributes, chunk_key
|
|
)
|
|
if if_relation is not None:
|
|
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
|
if_relation
|
|
)
|
|
already_processed += 1 # already processed chunks
|
|
already_entities += len(maybe_nodes)
|
|
already_relations += len(maybe_edges)
|
|
now_ticks = PROMPTS["process_tickers"][ # for visualization
|
|
already_processed % len(PROMPTS["process_tickers"])
|
|
]
|
|
print(
|
|
f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
|
end="",
|
|
flush=True,
|
|
)
|
|
return dict(maybe_nodes), dict(maybe_edges)
|
|
|
|
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
|
|
results = await asyncio.gather(
|
|
*[_process_single_content(c) for c in ordered_chunks]
|
|
)
|
|
print() # clear the progress bar
|
|
maybe_nodes = defaultdict(list) # for all chunks
|
|
maybe_edges = defaultdict(list)
|
|
for m_nodes, m_edges in results:
|
|
for k, v in m_nodes.items():
|
|
maybe_nodes[k].extend(v)
|
|
for k, v in m_edges.items():
|
|
# it's undirected graph
|
|
maybe_edges[tuple(sorted(k))].extend(v)
|
|
all_entities_data = await asyncio.gather( # store the nodes
|
|
*[
|
|
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
|
|
for k, v in maybe_nodes.items()
|
|
]
|
|
)
|
|
await asyncio.gather( # store the edges
|
|
*[
|
|
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
|
|
for k, v in maybe_edges.items()
|
|
]
|
|
)
|
|
if not len(all_entities_data):
|
|
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
|
return None
|
|
if entity_vdb is not None:
|
|
data_for_vdb = { # key is the md5 hash of the entity name string
|
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
"content": dp["entity_name"] + dp["description"], # entity name and description construct the content
|
|
"entity_name": dp["entity_name"],
|
|
}
|
|
for dp in all_entities_data
|
|
}
|
|
await entity_vdb.upsert(data_for_vdb)
|
|
return knwoledge_graph_inst
|
|
|
|
|
|
def _pack_single_community_by_sub_communities(
|
|
community: SingleCommunitySchema,
|
|
max_token_size: int,
|
|
already_reports: dict[str, CommunitySchema],
|
|
) -> tuple[str, int]:
|
|
# TODO
|
|
all_sub_communities = [
|
|
already_reports[k] for k in community["sub_communities"] if k in already_reports
|
|
]
|
|
all_sub_communities = sorted(
|
|
all_sub_communities, key=lambda x: x["occurrence"], reverse=True
|
|
)
|
|
may_trun_all_sub_communities = truncate_list_by_token_size(
|
|
all_sub_communities,
|
|
key=lambda x: x["report_string"],
|
|
max_token_size=max_token_size,
|
|
)
|
|
sub_fields = ["id", "report", "rating", "importance"]
|
|
sub_communities_describe = list_of_list_to_csv(
|
|
[sub_fields]
|
|
+ [
|
|
[
|
|
i,
|
|
c["report_string"],
|
|
c["report_json"].get("rating", -1),
|
|
c["occurrence"],
|
|
]
|
|
for i, c in enumerate(may_trun_all_sub_communities)
|
|
]
|
|
)
|
|
already_nodes = []
|
|
already_edges = []
|
|
for c in may_trun_all_sub_communities:
|
|
already_nodes.extend(c["nodes"])
|
|
already_edges.extend([tuple(e) for e in c["edges"]])
|
|
return (
|
|
sub_communities_describe,
|
|
len(encode_string_by_tiktoken(sub_communities_describe)),
|
|
set(already_nodes),
|
|
set(already_edges),
|
|
)
|
|
|
|
|
|
async def _pack_single_community_describe(
|
|
knwoledge_graph_inst: BaseGraphStorage,
|
|
community: SingleCommunitySchema,
|
|
max_token_size: int = 12000,
|
|
already_reports: dict[str, CommunitySchema] = {},
|
|
global_config: dict = {},
|
|
) -> str:
|
|
nodes_in_order = sorted(community["nodes"])
|
|
edges_in_order = sorted(community["edges"], key=lambda x: x[0] + x[1])
|
|
|
|
nodes_data = await asyncio.gather(
|
|
*[knwoledge_graph_inst.get_node(n) for n in nodes_in_order]
|
|
)
|
|
edges_data = await asyncio.gather(
|
|
*[knwoledge_graph_inst.get_edge(src, tgt) for src, tgt in edges_in_order]
|
|
)
|
|
node_fields = ["id", "entity", "type", "description", "degree"]
|
|
edge_fields = ["id", "source", "target", "description", "rank"]
|
|
nodes_list_data = [
|
|
[
|
|
i,
|
|
node_name,
|
|
node_data.get("entity_type", "UNKNOWN"),
|
|
node_data.get("description", "UNKNOWN"),
|
|
await knwoledge_graph_inst.node_degree(node_name),
|
|
]
|
|
for i, (node_name, node_data) in enumerate(zip(nodes_in_order, nodes_data))
|
|
]
|
|
nodes_list_data = sorted(nodes_list_data, key=lambda x: x[-1], reverse=True)
|
|
nodes_may_truncate_list_data = truncate_list_by_token_size(
|
|
nodes_list_data, key=lambda x: x[3], max_token_size=max_token_size // 2
|
|
)
|
|
edges_list_data = [
|
|
[
|
|
i,
|
|
edge_name[0],
|
|
edge_name[1],
|
|
edge_data.get("description", "UNKNOWN"),
|
|
await knwoledge_graph_inst.edge_degree(*edge_name),
|
|
]
|
|
for i, (edge_name, edge_data) in enumerate(zip(edges_in_order, edges_data))
|
|
]
|
|
edges_list_data = sorted(edges_list_data, key=lambda x: x[-1], reverse=True)
|
|
edges_may_truncate_list_data = truncate_list_by_token_size(
|
|
edges_list_data, key=lambda x: x[3], max_token_size=max_token_size // 2
|
|
)
|
|
|
|
truncated = len(nodes_list_data) > len(nodes_may_truncate_list_data) or len(
|
|
edges_list_data
|
|
) > len(edges_may_truncate_list_data)
|
|
|
|
# If context is exceed the limit and have sub-communities:
|
|
report_describe = ""
|
|
need_to_use_sub_communities = (
|
|
truncated and len(community["sub_communities"]) and len(already_reports)
|
|
)
|
|
force_to_use_sub_communities = global_config["addon_params"].get(
|
|
"force_to_use_sub_communities", False
|
|
)
|
|
if need_to_use_sub_communities or force_to_use_sub_communities:
|
|
logger.debug(
|
|
f"Community {community['title']} exceeds the limit or you set force_to_use_sub_communities to True, using its sub-communities"
|
|
)
|
|
report_describe, report_size, contain_nodes, contain_edges = (
|
|
_pack_single_community_by_sub_communities(
|
|
community, max_token_size, already_reports
|
|
)
|
|
)
|
|
report_exclude_nodes_list_data = [
|
|
n for n in nodes_list_data if n[1] not in contain_nodes
|
|
]
|
|
report_include_nodes_list_data = [
|
|
n for n in nodes_list_data if n[1] in contain_nodes
|
|
]
|
|
report_exclude_edges_list_data = [
|
|
e for e in edges_list_data if (e[1], e[2]) not in contain_edges
|
|
]
|
|
report_include_edges_list_data = [
|
|
e for e in edges_list_data if (e[1], e[2]) in contain_edges
|
|
]
|
|
# if report size is bigger than max_token_size, nodes and edges are []
|
|
nodes_may_truncate_list_data = truncate_list_by_token_size(
|
|
report_exclude_nodes_list_data + report_include_nodes_list_data,
|
|
key=lambda x: x[3],
|
|
max_token_size=(max_token_size - report_size) // 2,
|
|
)
|
|
edges_may_truncate_list_data = truncate_list_by_token_size(
|
|
report_exclude_edges_list_data + report_include_edges_list_data,
|
|
key=lambda x: x[3],
|
|
max_token_size=(max_token_size - report_size) // 2,
|
|
)
|
|
nodes_describe = list_of_list_to_csv([node_fields] + nodes_may_truncate_list_data)
|
|
edges_describe = list_of_list_to_csv([edge_fields] + edges_may_truncate_list_data)
|
|
return f"""-----Reports-----
|
|
```csv
|
|
{report_describe}
|
|
```
|
|
-----Entities-----
|
|
```csv
|
|
{nodes_describe}
|
|
```
|
|
-----Relationships-----
|
|
```csv
|
|
{edges_describe}
|
|
```"""
|
|
|
|
|
|
def _community_report_json_to_str(parsed_output: dict) -> str:
|
|
"""refer official graphrag: index/graph/extractors/community_reports"""
|
|
title = parsed_output.get("title", "Report")
|
|
summary = parsed_output.get("summary", "")
|
|
findings = parsed_output.get("findings", [])
|
|
|
|
def finding_summary(finding: dict):
|
|
if isinstance(finding, str):
|
|
return finding
|
|
return finding.get("summary")
|
|
|
|
def finding_explanation(finding: dict):
|
|
if isinstance(finding, str):
|
|
return ""
|
|
return finding.get("explanation")
|
|
|
|
report_sections = "\n\n".join(
|
|
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
|
|
)
|
|
return f"# {title}\n\n{summary}\n\n{report_sections}"
|
|
|
|
|
|
async def generate_community_report(
|
|
community_report_kv: BaseKVStorage[CommunitySchema],
|
|
knwoledge_graph_inst: BaseGraphStorage,
|
|
global_config: dict,
|
|
):
|
|
llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]
|
|
use_llm_func: callable = global_config["best_model_func"]
|
|
use_string_json_convert_func: callable = global_config[
|
|
"convert_response_to_json_func"
|
|
]
|
|
|
|
community_report_prompt = PROMPTS["community_report"]
|
|
|
|
communities_schema = await knwoledge_graph_inst.community_schema()
|
|
community_keys, community_values = list(communities_schema.keys()), list(
|
|
communities_schema.values()
|
|
)
|
|
already_processed = 0
|
|
|
|
async def _form_single_community_report(
|
|
community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]
|
|
):
|
|
nonlocal already_processed
|
|
describe = await _pack_single_community_describe(
|
|
knwoledge_graph_inst,
|
|
community,
|
|
max_token_size=global_config["best_model_max_token_size"],
|
|
already_reports=already_reports,
|
|
global_config=global_config,
|
|
)
|
|
prompt = community_report_prompt.format(input_text=describe)
|
|
response = await use_llm_func(prompt, **llm_extra_kwargs)
|
|
data = use_string_json_convert_func(response)
|
|
already_processed += 1
|
|
now_ticks = PROMPTS["process_tickers"][
|
|
already_processed % len(PROMPTS["process_tickers"])
|
|
]
|
|
print(
|
|
f"{now_ticks} Processed {already_processed} communities\r",
|
|
end="",
|
|
flush=True,
|
|
)
|
|
return data
|
|
|
|
levels = sorted(set([c["level"] for c in community_values]), reverse=True)
|
|
logger.info(f"Generating by levels: {levels}")
|
|
community_datas = {}
|
|
for level in levels:
|
|
this_level_community_keys, this_level_community_values = zip(
|
|
*[
|
|
(k, v)
|
|
for k, v in zip(community_keys, community_values)
|
|
if v["level"] == level
|
|
]
|
|
)
|
|
this_level_communities_reports = await asyncio.gather(
|
|
*[
|
|
_form_single_community_report(c, community_datas)
|
|
for c in this_level_community_values
|
|
]
|
|
)
|
|
community_datas.update(
|
|
{
|
|
k: {
|
|
"report_string": _community_report_json_to_str(r),
|
|
"report_json": r,
|
|
**v,
|
|
}
|
|
for k, r, v in zip(
|
|
this_level_community_keys,
|
|
this_level_communities_reports,
|
|
this_level_community_values,
|
|
)
|
|
}
|
|
)
|
|
print() # clear the progress bar
|
|
await community_report_kv.upsert(community_datas)
|
|
|
|
|
|
async def _find_most_related_community_from_entities(
|
|
node_datas: list[dict],
|
|
query_param: QueryParam,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
):
|
|
related_communities = []
|
|
for node_d in node_datas:
|
|
if "clusters" not in node_d:
|
|
continue
|
|
related_communities.extend(json.loads(node_d["clusters"]))
|
|
related_community_dup_keys = [
|
|
str(dp["cluster"])
|
|
for dp in related_communities
|
|
if dp["level"] <= query_param.level
|
|
]
|
|
related_community_keys_counts = dict(Counter(related_community_dup_keys))
|
|
_related_community_datas = await asyncio.gather( # get community reports
|
|
*[community_reports.get_by_id(k) for k in related_community_keys_counts.keys()]
|
|
)
|
|
related_community_datas = {
|
|
k: v
|
|
for k, v in zip(related_community_keys_counts.keys(), _related_community_datas)
|
|
if v is not None
|
|
}
|
|
related_community_keys = sorted( # sort by ratings
|
|
related_community_keys_counts.keys(),
|
|
key=lambda k: (
|
|
related_community_keys_counts[k],
|
|
related_community_datas[k]["report_json"].get("rating", -1),
|
|
),
|
|
reverse=True,
|
|
)
|
|
sorted_community_datas = [ # community reports sorted by ratings
|
|
related_community_datas[k] for k in related_community_keys
|
|
]
|
|
|
|
use_community_reports = truncate_list_by_token_size( # in case community reprot is longer than token limitation
|
|
sorted_community_datas,
|
|
key=lambda x: x["report_string"],
|
|
max_token_size=query_param.max_token_for_community_report,
|
|
)
|
|
if query_param.community_single_one:
|
|
use_community_reports = use_community_reports[:1]
|
|
return use_community_reports
|
|
|
|
|
|
async def _find_most_related_text_unit_from_entities(
|
|
node_datas: list[dict],
|
|
query_param: QueryParam,
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
):
|
|
text_units = [ # the entities related to the retrieved entities
|
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
|
for dp in node_datas
|
|
]
|
|
edges = await asyncio.gather( # get relations related to the retrieved entities
|
|
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
|
) # where the source entities are the retrieved entities
|
|
all_one_hop_nodes = set() # find the one hop neighbors
|
|
for this_edges in edges:
|
|
if not this_edges:
|
|
continue
|
|
all_one_hop_nodes.update([e[1] for e in this_edges])
|
|
all_one_hop_nodes = list(all_one_hop_nodes)
|
|
all_one_hop_nodes_data = await asyncio.gather( # get node information from storage
|
|
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
|
)
|
|
all_one_hop_text_units_lookup = { # find the text chunks of the 1-hop neighbors entities
|
|
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
|
|
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
|
|
if v is not None
|
|
}
|
|
all_text_units_lookup = {}
|
|
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
|
for c_id in this_text_units:
|
|
if c_id in all_text_units_lookup:
|
|
continue
|
|
relation_counts = 0
|
|
for e in this_edges:
|
|
if (
|
|
e[1] in all_one_hop_text_units_lookup
|
|
and c_id in all_one_hop_text_units_lookup[e[1]]
|
|
):
|
|
relation_counts += 1
|
|
all_text_units_lookup[c_id] = {
|
|
"data": await text_chunks_db.get_by_id(c_id),
|
|
"order": index,
|
|
"relation_counts": relation_counts, # count of relations related to the chunk
|
|
}
|
|
if any([v is None for v in all_text_units_lookup.values()]):
|
|
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
|
all_text_units = [
|
|
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
|
]
|
|
all_text_units = sorted( # sort by relation counts
|
|
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
|
)
|
|
all_text_units = truncate_list_by_token_size(
|
|
all_text_units,
|
|
key=lambda x: x["data"]["content"],
|
|
max_token_size=query_param.max_token_for_text_unit,
|
|
)
|
|
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
|
return all_text_units
|
|
|
|
|
|
async def _find_most_related_edges_from_entities(
|
|
node_datas: list[dict],
|
|
query_param: QueryParam,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
):
|
|
all_related_edges = await asyncio.gather(
|
|
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
|
)
|
|
all_edges = set()
|
|
for this_edges in all_related_edges:
|
|
all_edges.update([tuple(sorted(e)) for e in this_edges])
|
|
all_edges = list(all_edges)
|
|
all_edges_pack = await asyncio.gather(
|
|
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
|
)
|
|
all_edges_degree = await asyncio.gather(
|
|
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
|
)
|
|
all_edges_data = [
|
|
{"src_tgt": k, "rank": d, **v}
|
|
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
|
|
if v is not None
|
|
]
|
|
all_edges_data = sorted(
|
|
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
|
)
|
|
all_edges_data = truncate_list_by_token_size(
|
|
all_edges_data,
|
|
key=lambda x: x["description"],
|
|
max_token_size=query_param.max_token_for_local_context,
|
|
)
|
|
return all_edges_data
|
|
|
|
|
|
async def _find_most_related_edges_from_paths(
|
|
path_datas: list[dict],
|
|
path: list[str],
|
|
query_param: QueryParam,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
):
|
|
# all_related_edges = await asyncio.gather(
|
|
# *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
|
# )
|
|
# all_reasoning_path = await asyncio.gather(
|
|
# *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in knowledge_graph_inst._graph.subgraph(path).edges()]
|
|
# )
|
|
all_reasoning_path = knowledge_graph_inst._graph.subgraph(path).edges()
|
|
all_edges = set()
|
|
all_edges.update([tuple(sorted(e)) for e in all_reasoning_path])
|
|
all_edges = list(all_edges)
|
|
all_edges_pack = await asyncio.gather(
|
|
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
|
)
|
|
all_edges_degree = await asyncio.gather(
|
|
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
|
)
|
|
all_edges_data = [
|
|
{"src_tgt": k, "rank": d, **v}
|
|
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
|
|
if v is not None
|
|
]
|
|
all_edges_data = sorted(
|
|
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
|
)
|
|
all_edges_data = truncate_list_by_token_size(
|
|
all_edges_data,
|
|
key=lambda x: x["description"],
|
|
max_token_size=query_param.max_token_for_bridge_knowledge,
|
|
)
|
|
return all_edges_data
|
|
|
|
|
|
# context functions
|
|
async def _build_local_query_context(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
):
|
|
results = await entities_vdb.query(query, top_k=query_param.top_k) # find the top-k(20) related entities
|
|
if not len(results):
|
|
return None
|
|
node_datas = await asyncio.gather(
|
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
|
)
|
|
if not all([n is not None for n in node_datas]):
|
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
|
node_degrees = await asyncio.gather(
|
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
|
)
|
|
node_datas = [
|
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
|
for k, n, d in zip(results, node_datas, node_degrees)
|
|
if n is not None
|
|
]
|
|
use_communities = await _find_most_related_community_from_entities(
|
|
node_datas, query_param, community_reports
|
|
)
|
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
|
)
|
|
use_relations = await _find_most_related_edges_from_entities(
|
|
node_datas, query_param, knowledge_graph_inst
|
|
)
|
|
logger.info(
|
|
f"Using {len(node_datas)} entites, {len(use_communities)} communities, {len(use_relations)} relations, {len(use_text_units)} text units"
|
|
)
|
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
|
for i, n in enumerate(node_datas):
|
|
entites_section_list.append(
|
|
[
|
|
i,
|
|
n["entity_name"],
|
|
n.get("entity_type", "UNKNOWN"),
|
|
n.get("description", "UNKNOWN"),
|
|
n["rank"],
|
|
]
|
|
)
|
|
entities_context = list_of_list_to_csv(entites_section_list)
|
|
|
|
relations_section_list = [
|
|
["id", "source", "target", "description", "weight", "rank"]
|
|
]
|
|
for i, e in enumerate(use_relations):
|
|
relations_section_list.append(
|
|
[
|
|
i,
|
|
e["src_tgt"][0],
|
|
e["src_tgt"][1],
|
|
e["description"],
|
|
e["weight"],
|
|
e["rank"],
|
|
]
|
|
)
|
|
relations_context = list_of_list_to_csv(relations_section_list)
|
|
|
|
communities_section_list = [["id", "content"]]
|
|
for i, c in enumerate(use_communities):
|
|
communities_section_list.append([i, c["report_string"]])
|
|
communities_context = list_of_list_to_csv(communities_section_list)
|
|
|
|
text_units_section_list = [["id", "content"]]
|
|
for i, t in enumerate(use_text_units):
|
|
text_units_section_list.append([i, t["content"]])
|
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
|
return f"""
|
|
-----Reports-----
|
|
```csv
|
|
{communities_context}
|
|
```
|
|
-----Entities-----
|
|
```csv
|
|
{entities_context}
|
|
```
|
|
-----Relationships-----
|
|
```csv
|
|
{relations_context}
|
|
```
|
|
-----Sources-----
|
|
```csv
|
|
{text_units_context}
|
|
```
|
|
"""
|
|
|
|
|
|
async def _build_hierarchical_query_context(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
):
|
|
results = await entities_vdb.query(query, top_k=query_param.top_k * 10) # find the top-k(20) related entities
|
|
|
|
if not len(results): # results just with entity name
|
|
return None
|
|
node_datas = await asyncio.gather( # get full information of retrieved entities
|
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
|
)
|
|
if not all([n is not None for n in node_datas]): # for robustness
|
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
|
node_degrees = await asyncio.gather(
|
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
|
)
|
|
node_datas = [ # add rank, which is the degree
|
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
|
for k, n, d in zip(results, node_datas, node_degrees)
|
|
if n is not None
|
|
]
|
|
overall_node_datas = node_datas
|
|
node_datas = node_datas[:query_param.top_k]
|
|
|
|
use_communities = await _find_most_related_community_from_entities( # related communities
|
|
node_datas, query_param, community_reports
|
|
)
|
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
|
)
|
|
# use_relations = await _find_most_related_edges_from_entities(
|
|
# node_datas, query_param, knowledge_graph_inst
|
|
# )
|
|
|
|
def find_path_with_required_nodes(graph, source, target, required_nodes):
|
|
# inital final path
|
|
final_path = []
|
|
# 起点设置为当前节点
|
|
current_node = source
|
|
|
|
# 遍历必经节点
|
|
for next_node in required_nodes:
|
|
# 找到从当前节点到下一个必经节点的最短路径
|
|
try:
|
|
sub_path = nx.shortest_path(graph, source=current_node, target=next_node)
|
|
except nx.NetworkXNoPath:
|
|
# raise ValueError(f"No path between {current_node} and {next_node}.")
|
|
final_path.extend([next_node])
|
|
current_node = next_node
|
|
continue
|
|
|
|
# 合并路径(避免重复添加当前节点)
|
|
if final_path:
|
|
final_path.extend(sub_path[1:]) # 从第二个节点开始添加,避免重复
|
|
else:
|
|
final_path.extend(sub_path)
|
|
|
|
# 更新当前节点为下一个必经节点
|
|
current_node = next_node
|
|
|
|
# 最后,从最后一个必经节点到目标节点的路径
|
|
try:
|
|
sub_path = nx.shortest_path(graph, source=current_node, target=target)
|
|
final_path.extend(sub_path[1:]) # 从第二个节点开始添加,避免重复
|
|
except nx.NetworkXNoPath:
|
|
# raise ValueError(f"No path between {current_node} and {target}.")
|
|
final_path.extend([target])
|
|
|
|
return final_path
|
|
|
|
# find some top-k entities in each communities in use_communities
|
|
key_entities = []
|
|
max_entity_num = query_param.top_m
|
|
if use_communities:
|
|
for c in use_communities:
|
|
cur_community_key_entities = []
|
|
community_entities = c['nodes']
|
|
# find the top-k entities in this community
|
|
cur_community_key_entities.extend(
|
|
[e for e in overall_node_datas if e['entity_name'] in community_entities][:max_entity_num]
|
|
)
|
|
key_entities.append(cur_community_key_entities)
|
|
else:
|
|
key_entities = [overall_node_datas[:max_entity_num]]
|
|
# unique key entities
|
|
key_entities = [[e['entity_name'] for e in k] for k in key_entities]
|
|
key_entities = list(set([k for kk in key_entities for k in kk]))
|
|
# find the shortest path between the key entities
|
|
try:
|
|
path = find_path_with_required_nodes(knowledge_graph_inst._graph, key_entities[0], key_entities[-1], key_entities[1:-1])
|
|
# path = list(set(path))
|
|
path_datas = await asyncio.gather( # get full information of retrieved entities
|
|
*[knowledge_graph_inst.get_node(r) for r in path]
|
|
)
|
|
path_degrees = await asyncio.gather(
|
|
*[knowledge_graph_inst.node_degree(r) for r in path]
|
|
)
|
|
path_datas = [ # add rank, which is the degree
|
|
{**n, "entity_name": k, "rank": d}
|
|
for k, n, d in zip(path, path_datas, path_degrees)
|
|
if n is not None
|
|
]
|
|
# use_reasoning_path = await _find_most_related_edges_from_entities(
|
|
# path_datas, query_param, knowledge_graph_inst
|
|
# )
|
|
use_reasoning_path = await _find_most_related_edges_from_paths(
|
|
path_datas, path, query_param, knowledge_graph_inst
|
|
)
|
|
except ValueError as e:
|
|
print(e)
|
|
|
|
# # fetch the relations of the reasoning paths
|
|
# reasoning_path = []
|
|
# for i in range(len(path) - 1):
|
|
# src = path[i]
|
|
# tgt = path[i + 1]
|
|
# cur_relation = (await knowledge_graph_inst.get_edge(src, tgt))['description']
|
|
# reasoning_path.append(cur_relation)
|
|
# reasoning_path = list(set(reasoning_path))
|
|
|
|
logger.info(
|
|
f"Using {len(node_datas)} entites, {len(use_communities)} communities, {len(use_reasoning_path)} reasoning path items, {len(use_text_units)} text units"
|
|
)
|
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
|
for i, n in enumerate(node_datas):
|
|
entites_section_list.append(
|
|
[
|
|
i,
|
|
n["entity_name"],
|
|
n.get("entity_type", "UNKNOWN"),
|
|
n.get("description", "UNKNOWN"),
|
|
n["rank"],
|
|
]
|
|
)
|
|
entities_context = list_of_list_to_csv(entites_section_list)
|
|
|
|
reasoning_path_section_list = [
|
|
["id", "source", "target", "description", "weight", "rank"]
|
|
]
|
|
for i, e in enumerate(use_reasoning_path):
|
|
reasoning_path_section_list.append(
|
|
[
|
|
i,
|
|
e["src_tgt"][0],
|
|
e["src_tgt"][1],
|
|
e["description"],
|
|
e["weight"],
|
|
e["rank"],
|
|
]
|
|
)
|
|
reasoning_path_context = list_of_list_to_csv(reasoning_path_section_list)
|
|
|
|
# reasoning_path_context = list_of_list_to_csv([["id", "content"]] + [[i, p] for i, p in enumerate(reasoning_path)])
|
|
|
|
communities_section_list = [["id", "content"]]
|
|
for i, c in enumerate(use_communities):
|
|
communities_section_list.append([i, c["report_string"].replace("\n", " ")])
|
|
communities_context = list_of_list_to_csv(communities_section_list)
|
|
|
|
text_units_section_list = [["id", "content"]]
|
|
for i, t in enumerate(use_text_units):
|
|
text_units_section_list.append([i, t["content"]])
|
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
|
|
|
# display reference info
|
|
entities = [n["entity_name"] for n in node_datas]
|
|
communities = [(c["level"], c["title"]) for c in use_communities]
|
|
chunks = [(t["full_doc_id"], t["chunk_order_index"]) for t in use_text_units]
|
|
|
|
references_context = (
|
|
f"Entities ({len(entities)}): {entities}\n\n"
|
|
f"Communities (level, cluster_id) ({len(communities)}): {communities}\n\n"
|
|
f"Chunks (doc_id, chunk_index) ({len(chunks)}): {chunks}\n"
|
|
)
|
|
|
|
logging.info(f"====== References ======:\n{references_context}")
|
|
return f"""
|
|
-----Backgrounds-----
|
|
```csv
|
|
{communities_context}
|
|
```
|
|
-----Reasoning Path-----
|
|
```csv
|
|
{reasoning_path_context}
|
|
```
|
|
-----Detail Entity Information-----
|
|
```csv
|
|
{entities_context}
|
|
```
|
|
-----Source Documents-----
|
|
```csv
|
|
{text_units_context}
|
|
```
|
|
"""
|
|
|
|
|
|
async def _build_hibridge_query_context(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
):
|
|
results = await entities_vdb.query(query, top_k=query_param.top_k * 10) # find the top-k(20) related entities
|
|
|
|
if not len(results): # results just with entity name
|
|
return None
|
|
node_datas = await asyncio.gather( # get full information of retrieved entities
|
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
|
)
|
|
if not all([n is not None for n in node_datas]): # for robustness
|
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
|
node_degrees = await asyncio.gather(
|
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
|
)
|
|
node_datas = [ # add rank, which is the degree
|
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
|
for k, n, d in zip(results, node_datas, node_degrees)
|
|
if n is not None
|
|
]
|
|
overall_node_datas = node_datas
|
|
node_datas = node_datas[:query_param.top_k]
|
|
|
|
use_communities = await _find_most_related_community_from_entities( # related communities
|
|
node_datas, query_param, community_reports
|
|
)
|
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
|
)
|
|
# use_relations = await _find_most_related_edges_from_entities(
|
|
# node_datas, query_param, knowledge_graph_inst
|
|
# )
|
|
|
|
def find_path_with_required_nodes(graph, source, target, required_nodes):
|
|
# inital final path
|
|
final_path = []
|
|
# 起点设置为当前节点
|
|
current_node = source
|
|
|
|
# 遍历必经节点
|
|
for next_node in required_nodes:
|
|
# 找到从当前节点到下一个必经节点的最短路径
|
|
try:
|
|
sub_path = nx.shortest_path(graph, source=current_node, target=next_node)
|
|
except nx.NetworkXNoPath:
|
|
# raise ValueError(f"No path between {current_node} and {next_node}.")
|
|
final_path.extend([next_node])
|
|
current_node = next_node
|
|
continue
|
|
|
|
# 合并路径(避免重复添加当前节点)
|
|
if final_path:
|
|
final_path.extend(sub_path[1:]) # 从第二个节点开始添加,避免重复
|
|
else:
|
|
final_path.extend(sub_path)
|
|
|
|
# 更新当前节点为下一个必经节点
|
|
current_node = next_node
|
|
|
|
# 最后,从最后一个必经节点到目标节点的路径
|
|
try:
|
|
sub_path = nx.shortest_path(graph, source=current_node, target=target)
|
|
final_path.extend(sub_path[1:]) # 从第二个节点开始添加,避免重复
|
|
except nx.NetworkXNoPath:
|
|
# raise ValueError(f"No path between {current_node} and {target}.")
|
|
final_path.extend([target])
|
|
|
|
return final_path
|
|
|
|
# find some top-k entities in each communities in use_communities
|
|
key_entities = []
|
|
max_entity_num = query_param.top_m
|
|
if use_communities:
|
|
for c in use_communities:
|
|
cur_community_key_entities = []
|
|
community_entities = c['nodes']
|
|
# find the top-k entities in this community
|
|
cur_community_key_entities.extend(
|
|
[e for e in overall_node_datas if e['entity_name'] in community_entities][:max_entity_num]
|
|
)
|
|
key_entities.append(cur_community_key_entities)
|
|
else:
|
|
key_entities = [overall_node_datas[:max_entity_num]]
|
|
# unique key entities
|
|
key_entities = [[e['entity_name'] for e in k] for k in key_entities]
|
|
key_entities = list(set([k for kk in key_entities for k in kk]))
|
|
# find the shortest path between the key entities
|
|
try:
|
|
path = find_path_with_required_nodes(knowledge_graph_inst._graph, key_entities[0], key_entities[-1], key_entities[1:-1])
|
|
# path = list(set(path))
|
|
path_datas = await asyncio.gather( # get full information of retrieved entities
|
|
*[knowledge_graph_inst.get_node(r) for r in path]
|
|
)
|
|
path_degrees = await asyncio.gather(
|
|
*[knowledge_graph_inst.node_degree(r) for r in path]
|
|
)
|
|
path_datas = [ # add rank, which is the degree
|
|
{**n, "entity_name": k, "rank": d}
|
|
for k, n, d in zip(path, path_datas, path_degrees)
|
|
if n is not None
|
|
]
|
|
use_reasoning_path = await _find_most_related_edges_from_paths(
|
|
path_datas, path, query_param, knowledge_graph_inst
|
|
)
|
|
except ValueError as e:
|
|
print(e)
|
|
|
|
logger.info(
|
|
f"Using {len(node_datas)} entites, {len(use_communities)} communities, {len(use_reasoning_path)} reasoning path items, {len(use_text_units)} text units"
|
|
)
|
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
|
for i, n in enumerate(node_datas):
|
|
entites_section_list.append(
|
|
[
|
|
i,
|
|
n["entity_name"],
|
|
n.get("entity_type", "UNKNOWN"),
|
|
n.get("description", "UNKNOWN"),
|
|
n["rank"],
|
|
]
|
|
)
|
|
entities_context = list_of_list_to_csv(entites_section_list)
|
|
|
|
reasoning_path_section_list = [
|
|
["id", "source", "target", "description", "weight", "rank"]
|
|
]
|
|
for i, e in enumerate(use_reasoning_path):
|
|
reasoning_path_section_list.append(
|
|
[
|
|
i,
|
|
e["src_tgt"][0],
|
|
e["src_tgt"][1],
|
|
e["description"],
|
|
e["weight"],
|
|
e["rank"],
|
|
]
|
|
)
|
|
reasoning_path_context = list_of_list_to_csv(reasoning_path_section_list)
|
|
|
|
# reasoning_path_context = list_of_list_to_csv([["id", "content"]] + [[i, p] for i, p in enumerate(reasoning_path)])
|
|
|
|
communities_section_list = [["id", "content"]]
|
|
for i, c in enumerate(use_communities):
|
|
communities_section_list.append([i, c["report_string"]])
|
|
communities_context = list_of_list_to_csv(communities_section_list)
|
|
|
|
text_units_section_list = [["id", "content"]]
|
|
for i, t in enumerate(use_text_units):
|
|
text_units_section_list.append([i, t["content"]])
|
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
|
return f"""
|
|
-----Reasoning Path-----
|
|
```csv
|
|
{reasoning_path_context}
|
|
```
|
|
-----Source Documents-----
|
|
```csv
|
|
{text_units_context}
|
|
```
|
|
"""
|
|
|
|
|
|
async def _build_higlobal_query_context(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
):
|
|
results = await entities_vdb.query(query, top_k=query_param.top_k * 10) # find the top-k(20) related entities
|
|
|
|
if not len(results): # results just with entity name
|
|
return None
|
|
node_datas = await asyncio.gather( # get full information of retrieved entities
|
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
|
)
|
|
if not all([n is not None for n in node_datas]): # for robustness
|
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
|
node_degrees = await asyncio.gather(
|
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
|
)
|
|
node_datas = [ # add rank, which is the degree
|
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
|
for k, n, d in zip(results, node_datas, node_degrees)
|
|
if n is not None
|
|
]
|
|
overall_node_datas = node_datas
|
|
node_datas = node_datas[:query_param.top_k]
|
|
|
|
use_communities = await _find_most_related_community_from_entities( # related communities
|
|
node_datas, query_param, community_reports
|
|
)
|
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
|
)
|
|
|
|
|
|
logger.info(
|
|
f"Using {len(use_communities)} communities, {len(use_text_units)} text units"
|
|
)
|
|
|
|
communities_section_list = [["id", "content"]]
|
|
for i, c in enumerate(use_communities):
|
|
communities_section_list.append([i, c["report_string"]])
|
|
communities_context = list_of_list_to_csv(communities_section_list)
|
|
|
|
text_units_section_list = [["id", "content"]]
|
|
for i, t in enumerate(use_text_units):
|
|
text_units_section_list.append([i, t["content"]])
|
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
|
return f"""
|
|
-----Backgrounds-----
|
|
```csv
|
|
{communities_context}
|
|
```
|
|
-----Source Documents-----
|
|
```csv
|
|
{text_units_context}
|
|
```
|
|
"""
|
|
|
|
|
|
async def _build_hilocal_query_context(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
):
|
|
results = await entities_vdb.query(query, top_k=query_param.top_k) # find the top-k(20) related entities
|
|
|
|
if not len(results): # results just with entity name
|
|
return None
|
|
node_datas = await asyncio.gather( # get full information of retrieved entities
|
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
|
)
|
|
if not all([n is not None for n in node_datas]): # for robustness
|
|
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
|
node_degrees = await asyncio.gather(
|
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
|
)
|
|
node_datas = [ # add rank, which is the degree
|
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
|
for k, n, d in zip(results, node_datas, node_degrees)
|
|
if n is not None
|
|
]
|
|
|
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
|
)
|
|
use_relations = await _find_most_related_edges_from_entities(
|
|
node_datas, query_param, knowledge_graph_inst
|
|
)
|
|
|
|
logger.info(
|
|
f"Using {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
|
)
|
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
|
for i, n in enumerate(node_datas):
|
|
entites_section_list.append(
|
|
[
|
|
i,
|
|
n["entity_name"],
|
|
n.get("entity_type", "UNKNOWN"),
|
|
n.get("description", "UNKNOWN"),
|
|
n["rank"],
|
|
]
|
|
)
|
|
entities_context = list_of_list_to_csv(entites_section_list)
|
|
|
|
relation_section_list = [
|
|
["id", "source", "target", "description", "weight", "rank"]
|
|
]
|
|
for i, e in enumerate(use_relations):
|
|
relation_section_list.append(
|
|
[
|
|
i,
|
|
e["src_tgt"][0],
|
|
e["src_tgt"][1],
|
|
e["description"],
|
|
e["weight"],
|
|
e["rank"],
|
|
]
|
|
)
|
|
relation_context = list_of_list_to_csv(relation_section_list)
|
|
|
|
text_units_section_list = [["id", "content"]]
|
|
for i, t in enumerate(use_text_units):
|
|
text_units_section_list.append([i, t["content"]])
|
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
|
return f"""
|
|
-----Entities-----
|
|
```csv
|
|
{entities_context}
|
|
```
|
|
-----Relations-----
|
|
```csv
|
|
{relation_context}
|
|
```
|
|
-----Sources-----
|
|
```csv
|
|
{text_units_context}
|
|
```
|
|
"""
|
|
|
|
|
|
# query functions
|
|
async def hierarchical_query(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
global_config: dict,
|
|
) -> str:
|
|
use_model_func = global_config["best_model_func"]
|
|
with timer():
|
|
context = await _build_hierarchical_query_context(
|
|
query,
|
|
knowledge_graph_inst,
|
|
entities_vdb,
|
|
community_reports,
|
|
text_chunks_db,
|
|
query_param,
|
|
)
|
|
if query_param.only_need_context:
|
|
return context
|
|
if context is None:
|
|
return PROMPTS["fail_response"]
|
|
sys_prompt_temp = PROMPTS["local_rag_response"]
|
|
sys_prompt = sys_prompt_temp.format(
|
|
context_data=context, response_type=query_param.response_type
|
|
)
|
|
response = await use_model_func(
|
|
query,
|
|
system_prompt=sys_prompt,
|
|
)
|
|
return response
|
|
|
|
async def hierarchical_bridge_query(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
global_config: dict,
|
|
) -> str:
|
|
use_model_func = global_config["best_model_func"]
|
|
with timer():
|
|
context = await _build_hibridge_query_context(
|
|
query,
|
|
knowledge_graph_inst,
|
|
entities_vdb,
|
|
community_reports,
|
|
text_chunks_db,
|
|
query_param,
|
|
)
|
|
if query_param.only_need_context:
|
|
return context
|
|
if context is None:
|
|
return PROMPTS["fail_response"]
|
|
sys_prompt_temp = PROMPTS["local_rag_response"]
|
|
sys_prompt = sys_prompt_temp.format(
|
|
context_data=context, response_type=query_param.response_type
|
|
)
|
|
response = await use_model_func(
|
|
query,
|
|
system_prompt=sys_prompt,
|
|
)
|
|
return response
|
|
|
|
async def hierarchical_local_query(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
global_config: dict,
|
|
) -> str:
|
|
use_model_func = global_config["best_model_func"]
|
|
with timer():
|
|
context = await _build_hilocal_query_context(
|
|
query,
|
|
knowledge_graph_inst,
|
|
entities_vdb,
|
|
community_reports,
|
|
text_chunks_db,
|
|
query_param,
|
|
)
|
|
if query_param.only_need_context:
|
|
return context
|
|
if context is None:
|
|
return PROMPTS["fail_response"]
|
|
sys_prompt_temp = PROMPTS["local_rag_response"]
|
|
sys_prompt = sys_prompt_temp.format(
|
|
context_data=context, response_type=query_param.response_type
|
|
)
|
|
response = await use_model_func(
|
|
query,
|
|
system_prompt=sys_prompt,
|
|
)
|
|
return response
|
|
|
|
async def hierarchical_global_query(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
global_config: dict,
|
|
) -> str:
|
|
use_model_func = global_config["best_model_func"]
|
|
with timer():
|
|
context = await _build_higlobal_query_context(
|
|
query,
|
|
knowledge_graph_inst,
|
|
entities_vdb,
|
|
community_reports,
|
|
text_chunks_db,
|
|
query_param,
|
|
)
|
|
if query_param.only_need_context:
|
|
return context
|
|
if context is None:
|
|
return PROMPTS["fail_response"]
|
|
sys_prompt_temp = PROMPTS["local_rag_response"]
|
|
sys_prompt = sys_prompt_temp.format(
|
|
context_data=context, response_type=query_param.response_type
|
|
)
|
|
response = await use_model_func(
|
|
query,
|
|
system_prompt=sys_prompt,
|
|
)
|
|
return response
|
|
|
|
async def hierarchical_nobridge_query(
|
|
query,
|
|
knowledge_graph_inst: BaseGraphStorage,
|
|
entities_vdb: BaseVectorStorage,
|
|
community_reports: BaseKVStorage[CommunitySchema],
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
global_config: dict,
|
|
) -> str:
|
|
"""
|
|
retrieve with only related entities
|
|
"""
|
|
use_model_func = global_config["best_model_func"]
|
|
with timer():
|
|
context = await _build_local_query_context(
|
|
query,
|
|
knowledge_graph_inst,
|
|
entities_vdb,
|
|
community_reports,
|
|
text_chunks_db,
|
|
query_param,
|
|
)
|
|
if query_param.only_need_context:
|
|
return context
|
|
if context is None:
|
|
return PROMPTS["fail_response"]
|
|
sys_prompt_temp = PROMPTS["local_rag_response"]
|
|
sys_prompt = sys_prompt_temp.format(
|
|
context_data=context, response_type=query_param.response_type
|
|
)
|
|
response = await use_model_func(
|
|
query,
|
|
system_prompt=sys_prompt,
|
|
)
|
|
return response
|
|
|
|
async def naive_query(
|
|
query,
|
|
chunks_vdb: BaseVectorStorage,
|
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
|
query_param: QueryParam,
|
|
global_config: dict,
|
|
):
|
|
use_model_func = global_config["best_model_func"]
|
|
with timer():
|
|
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
|
if not len(results):
|
|
return PROMPTS["fail_response"]
|
|
chunks_ids = [r["id"] for r in results]
|
|
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
|
|
|
maybe_trun_chunks = truncate_list_by_token_size(
|
|
chunks,
|
|
key=lambda x: x["content"],
|
|
max_token_size=query_param.naive_max_token_for_text_unit,
|
|
)
|
|
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
|
|
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
|
|
if query_param.only_need_context:
|
|
return section
|
|
sys_prompt_temp = PROMPTS["naive_rag_response"]
|
|
sys_prompt = sys_prompt_temp.format(
|
|
content_data=section, response_type=query_param.response_type
|
|
)
|
|
response = await use_model_func(
|
|
query,
|
|
system_prompt=sys_prompt,
|
|
)
|
|
return response
|