first commit

This commit is contained in:
haoyuhuang
2025-03-14 11:13:03 +08:00
commit 9336b47ae7
98 changed files with 96220 additions and 0 deletions

1
hirag/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .hirag import HiRAG, QueryParam

353
hirag/_cluster_utils.py Normal file
View File

@@ -0,0 +1,353 @@
import logging
import random
import re
import numpy as np
import tiktoken
import umap
import copy
import asyncio
from abc import ABC, abstractmethod
from typing import List, Optional
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
from collections import Counter, defaultdict
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage
)
from ._utils import split_string_by_multi_markers, clean_str, is_float_regex
from .prompt import GRAPH_FIELD_SEP, PROMPTS
# Initialize logging
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
# Set a random seed for reproducibility
RANDOM_SEED = 224
random.seed(RANDOM_SEED)
def global_cluster_embeddings(
embeddings: np.ndarray,
dim: int,
n_neighbors: int = 15,
metric: str = "cosine",
) -> np.ndarray:
if n_neighbors is None:
n_neighbors = int((len(embeddings) - 1) ** 0.5)
reduced_embeddings = umap.UMAP(
n_neighbors=n_neighbors, n_components=dim, metric=metric
).fit_transform(embeddings)
return reduced_embeddings
def local_cluster_embeddings(
embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
) -> np.ndarray:
reduced_embeddings = umap.UMAP(
n_neighbors=num_neighbors, n_components=dim, metric=metric
).fit_transform(embeddings)
return reduced_embeddings
def fit_gaussian_mixture(n_components, embeddings, random_state):
gm = GaussianMixture(
n_components=n_components,
random_state=random_state,
n_init=5,
init_params='k-means++'
)
gm.fit(embeddings)
return gm.bic(embeddings)
def get_optimal_clusters(embeddings, max_clusters=50, random_state=0, rel_tol=1e-3):
max_clusters = min(len(embeddings), max_clusters)
n_clusters = np.arange(1, max_clusters)
bics = []
prev_bic = float('inf')
for n in tqdm(n_clusters):
bic = fit_gaussian_mixture(n, embeddings, random_state)
# print(bic)
bics.append(bic)
# early stop
if (abs(prev_bic - bic) / abs(prev_bic)) < rel_tol:
break
prev_bic = bic
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
n_clusters = get_optimal_clusters(embeddings)
gm = GaussianMixture(
n_components=n_clusters,
random_state=random_state,
n_init=5,
init_params='k-means++')
gm.fit(embeddings)
probs = gm.predict_proba(embeddings) # [num, cluster_num]
labels = [np.where(prob > threshold)[0] for prob in probs]
return labels, n_clusters
def perform_clustering(
embeddings: np.ndarray, dim: int, threshold: float, verbose: bool = False
) -> List[np.ndarray]:
reduced_embeddings_global = global_cluster_embeddings(embeddings, min(dim, len(embeddings) -2))
global_clusters, n_global_clusters = GMM_cluster( # (num, 2)
reduced_embeddings_global, threshold
)
if verbose:
logging.info(f"Global Clusters: {n_global_clusters}")
all_clusters = [[] for _ in range(len(embeddings))]
embedding_to_index = {tuple(embedding): idx for idx, embedding in enumerate(embeddings)}
for i in tqdm(range(n_global_clusters)):
global_cluster_embeddings_ = embeddings[
np.array([i in gc for gc in global_clusters])
]
if verbose:
logging.info(
f"Nodes in Global Cluster {i}: {len(global_cluster_embeddings_)}"
)
if len(global_cluster_embeddings_) == 0:
continue
# embedding indices
indices = [
embedding_to_index[tuple(embedding)]
for embedding in global_cluster_embeddings_
]
# update
for idx in indices:
all_clusters[idx].append(i)
all_clusters = [np.array(cluster) for cluster in all_clusters]
if verbose:
logging.info(f"Total Clusters: {len(n_global_clusters)}")
return all_clusters
async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name,
entity_type=entity_type,
description=entity_description,
source_id=entity_source_id,
)
async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_source_id = chunk_key
weight = (
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
)
return dict(
src_id=source,
tgt_id=target,
weight=weight,
description=edge_description,
source_id=edge_source_id,
)
class ClusteringAlgorithm(ABC):
@abstractmethod
def perform_clustering(self, embeddings: np.ndarray, **kwargs) -> List[List[int]]:
pass
class Hierarchical_Clustering(ClusteringAlgorithm):
async def perform_clustering(
self,
entity_vdb: BaseVectorStorage,
global_config: dict,
entities: dict,
layers: int = 50,
max_length_in_cluster: int = 60000,
tokenizer=tiktoken.get_encoding("cl100k_base"),
reduction_dimension: int = 2,
cluster_threshold: float = 0.1,
verbose: bool = False,
threshold: float = 0.98, # 0.99
thredshold_change_rate: float = 0.05
) -> List[dict]:
use_llm_func: callable = global_config["best_model_func"]
# Get the embeddings from the nodes
nodes = list(entities.values())
embeddings = np.array([x["embedding"] for x in nodes])
hierarchical_clusters = [nodes]
pre_cluster_sparsity = 0.01
for layer in range(layers):
logging.info(f"############ Layer[{layer}] Clustering ############")
# Perform the clustering
clusters = perform_clustering(
embeddings, dim=reduction_dimension, threshold=cluster_threshold
)
# Initialize an empty list to store the clusters of nodes
node_clusters = []
# Iterate over each unique label in the clusters
unique_clusters = np.unique(np.concatenate(clusters))
logging.info(f"[Clustered Label Num: {len(unique_clusters)} / Last Layer Total Entity Num: {len(nodes)}]")
# calculate the number of nodes belong to each cluster
cluster_sizes = Counter(np.concatenate(clusters))
# calculate cluster sparsity
cluster_sparsity = 1 - sum([x * (x - 1) for x in cluster_sizes.values()])/(len(nodes) * (len(nodes) - 1))
cluster_sparsity_change_rate = (abs(cluster_sparsity - pre_cluster_sparsity) / pre_cluster_sparsity)
pre_cluster_sparsity = cluster_sparsity
logging.info(f"[Cluster Sparsity: {round(cluster_sparsity, 4) * 100}%]")
# stop if there will be no improvements on clustering
if cluster_sparsity >= threshold:
logging.info(f"[Stop Clustering at Layer{layer} with Cluster Sparsity {cluster_sparsity}]")
break
if cluster_sparsity_change_rate <= thredshold_change_rate:
logging.info(f"[Stop Clustering at Layer{layer} with Cluster Sparsity Change Rate {round(cluster_sparsity_change_rate, 4) * 100}%]")
break
# summarize
for label in unique_clusters:
# Get the indices of the nodes that belong to this cluster
indices = [i for i, cluster in enumerate(clusters) if label in cluster]
# Add the corresponding nodes to the node_clusters list
cluster_nodes = [nodes[i] for i in indices]
# Base case: if the cluster only has one node, do not attempt to recluster it
logging.info(f"[Label{str(int(label))} Size: {len(cluster_nodes)}]")
if len(cluster_nodes) == 1:
node_clusters += cluster_nodes
continue
# Calculate the total length of the text in the nodes
total_length = sum(
[len(tokenizer.encode(node["description"])) + len(tokenizer.encode(node["entity_name"])) for node in cluster_nodes]
)
base_discount = 0.8
discount_times = 0
# If the total length exceeds the maximum allowed length, reduce the node size
while total_length > max_length_in_cluster:
logging.info(
f"Reducing cluster size with {base_discount * 100 * (base_discount**discount_times):.2f}% of entities"
)
# for node in cluster_nodes:
# description = node["description"]
# node['description'] = description[:int(len(description) * base_discount)]
# Randomly select 80% of the nodes
num_to_select = max(1, int(len(cluster_nodes) * base_discount)) # Ensure at least one node is selected
cluster_nodes = random.sample(cluster_nodes, num_to_select)
# Recalculate the total length
total_length = sum(
[len(tokenizer.encode(node["description"])) + len(tokenizer.encode(node["entity_name"])) for node in cluster_nodes]
)
discount_times += 1
# summarize and generate new entities
entity_description_list = [f"({x['entity_name']}, {x['description']})" for x in cluster_nodes]
context_base_summarize = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
meta_attribute_list=PROMPTS["META_ENTITY_TYPES"],
entity_description_list=",".join(entity_description_list)
)
summarize_prompt = PROMPTS["summary_clusters"]
hint_prompt = summarize_prompt.format(**context_base_summarize)
summarize_result = await use_llm_func(hint_prompt)
chunk_key = ""
# resolve results
records = split_string_by_multi_markers( # split entities from result --> list of entities
summarize_result,
[context_base_summarize["record_delimiter"], context_base_summarize["completion_delimiter"]],
)
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers( # split entity
record, [context_base_summarize["tuple_delimiter"]]
)
if_entities = await _handle_single_entity_extraction( # get the name, type, desc, source_id of entity--> dict
record_attributes, chunk_key
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
# fetch all entities from results
entity_results = (dict(maybe_nodes), dict(maybe_edges))
all_entities_relations = {}
for item in entity_results:
for k, v in item.items():
value = v[0]
all_entities_relations[k] = v[0]
# fetch embeddings
entity_discriptions = [v["description"] for k, v in all_entities_relations.items()]
entity_sequence_embeddings = []
embeddings_batch_size = 64
num_embeddings_batches = (len(entity_discriptions) + embeddings_batch_size - 1) // embeddings_batch_size
for i in range(num_embeddings_batches):
start_index = i * embeddings_batch_size
end_index = min((i + 1) * embeddings_batch_size, len(entity_discriptions))
batch = entity_discriptions[start_index:end_index]
result = await entity_vdb.embedding_func(batch)
entity_sequence_embeddings.extend(result)
entity_embeddings = entity_sequence_embeddings
for (k, v), x in zip(all_entities_relations.items(), entity_embeddings):
value = v
value["embedding"] = x
all_entities_relations[k] = value
# append the attribute entities of current clustered set to results
all_entities_relations = [v for k, v in all_entities_relations.items()]
node_clusters += all_entities_relations
hierarchical_clusters.append(node_clusters)
# update nodes to be clustered in the next layer
nodes = copy.deepcopy([x for x in node_clusters if "entity_name" in x.keys()])
# filter the duplicate entities
seen = set()
unique_nodes = []
for item in nodes:
entity_name = item['entity_name']
if entity_name not in seen:
seen.add(entity_name)
unique_nodes.append(item)
nodes = unique_nodes
embeddings = np.array([x["embedding"] for x in unique_nodes])
# stop if the number of deduplicated cluster is too small
if len(embeddings) <= 2:
logging.info(f"[Stop Clustering at Layer{layer} with entity num {len(embeddings)}]")
break
return hierarchical_clusters

189
hirag/_llm.py Normal file
View File

@@ -0,0 +1,189 @@
import numpy as np
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import os
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
from .base import BaseKVStorage
global_openai_async_client = None
global_azure_openai_async_client = None
def get_openai_async_client_instance():
global global_openai_async_client
if global_openai_async_client is None:
global_openai_async_client = AsyncOpenAI()
return global_openai_async_client
def get_azure_openai_async_client_instance():
global global_azure_openai_async_client
if global_azure_openai_async_client is None:
global_azure_openai_async_client = AsyncAzureOpenAI()
return global_azure_openai_async_client
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = get_openai_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
await hashing_kv.index_done_callback()
return response.choices[0].message.content
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_35_turbo_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-3.5-turbo",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_embedding(texts: list[str]) -> np.ndarray:
openai_async_client = get_openai_async_client_instance()
response = await openai_async_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def azure_openai_complete_if_cache(
deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
azure_openai_client = get_azure_openai_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(deployment_name, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await azure_openai_client.chat.completions.create(
model=deployment_name, messages=messages, **kwargs
)
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response.choices[0].message.content,
"model": deployment_name,
}
}
)
await hashing_kv.index_done_callback()
return response.choices[0].message.content
async def azure_gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def azure_gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
azure_openai_client = get_azure_openai_async_client_instance()
response = await azure_openai_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

1928
hirag/_op.py Normal file

File diff suppressed because it is too large Load Diff

94
hirag/_splitter.py Normal file
View File

@@ -0,0 +1,94 @@
from typing import List, Optional, Union, Literal
class SeparatorSplitter:
def __init__(
self,
separators: Optional[List[List[int]]] = None,
keep_separator: Union[bool, Literal["start", "end"]] = "end",
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: callable = len,
):
self._separators = separators or []
self._keep_separator = keep_separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
def split_tokens(self, tokens: List[int]) -> List[List[int]]:
splits = self._split_tokens_with_separators(tokens)
return self._merge_splits(splits)
def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
splits = []
current_split = []
i = 0
while i < len(tokens):
separator_found = False
for separator in self._separators:
if tokens[i:i+len(separator)] == separator:
if self._keep_separator in [True, "end"]:
current_split.extend(separator)
if current_split:
splits.append(current_split)
current_split = []
if self._keep_separator == "start":
current_split.extend(separator)
i += len(separator)
separator_found = True
break
if not separator_found:
current_split.append(tokens[i])
i += 1
if current_split:
splits.append(current_split)
return [s for s in splits if s]
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
if not splits:
return []
merged_splits = []
current_chunk = []
for split in splits:
if not current_chunk:
current_chunk = split
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
current_chunk.extend(split)
else:
merged_splits.append(current_chunk)
current_chunk = split
if current_chunk:
merged_splits.append(current_chunk)
if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
return self._split_chunk(merged_splits[0])
if self._chunk_overlap > 0:
return self._enforce_overlap(merged_splits)
return merged_splits
def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
result = []
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
new_chunk = chunk[i:i + self._chunk_size]
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
result.append(new_chunk)
return result
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
result = []
for i, chunk in enumerate(chunks):
if i == 0:
result.append(chunk)
else:
overlap = chunks[i-1][-self._chunk_overlap:]
new_chunk = overlap + chunk
if self._length_function(new_chunk) > self._chunk_size:
new_chunk = new_chunk[:self._chunk_size]
result.append(new_chunk)
return result

View File

@@ -0,0 +1,5 @@
from .gdb_networkx import NetworkXStorage
from .gdb_neo4j import Neo4jStorage
from .vdb_hnswlib import HNSWVectorStorage
from .vdb_nanovectordb import NanoVectorDBStorage
from .kv_json import JsonKVStorage

330
hirag/_storage/gdb_neo4j.py Normal file
View File

@@ -0,0 +1,330 @@
import json
import asyncio
from collections import defaultdict
from neo4j import AsyncGraphDatabase
from dataclasses import dataclass
from typing import Union
from ..base import BaseGraphStorage, SingleCommunitySchema
from .._utils import logger
from ..prompt import GRAPH_FIELD_SEP
neo4j_lock = asyncio.Lock()
def make_path_idable(path):
return path.replace(".", "_").replace("/", "__").replace("-", "_")
@dataclass
class Neo4jStorage(BaseGraphStorage):
def __post_init__(self):
self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
self.namespace = (
f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
)
logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
if self.neo4j_url is None or self.neo4j_auth is None:
raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
self.async_driver = AsyncGraphDatabase.driver(
self.neo4j_url, auth=self.neo4j_auth
)
# async def create_database(self):
# async with self.async_driver.session() as session:
# try:
# constraints = await session.run("SHOW CONSTRAINTS")
# # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
# # so have to check if the constrain exists
# constrain_exists = False
# async for record in constraints:
# if (
# self.namespace in record["labelsOrTypes"]
# and "id" in record["properties"]
# and record["type"] == "UNIQUENESS"
# ):
# constrain_exists = True
# break
# if not constrain_exists:
# await session.run(
# f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
# )
# logger.info(f"Add constraint for namespace: {self.namespace}")
# except Exception as e:
# logger.error(f"Error accessing or setting up the database: {str(e)}")
# raise
async def _init_workspace(self):
await self.async_driver.verify_authentication()
await self.async_driver.verify_connectivity()
# TODOLater: create database if not exists always cause an error when async
# await self.create_database()
async def index_start_callback(self):
logger.info("Init Neo4j workspace")
await self._init_workspace()
async def has_node(self, node_id: str) -> bool:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (n:{self.namespace}) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
node_id=node_id,
)
record = await result.single()
return record["exists"] if record else False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) "
"WHERE s.id = $source_id AND t.id = $target_id "
"RETURN COUNT(r) > 0 AS exists",
source_id=source_node_id,
target_id=target_node_id,
)
record = await result.single()
return record["exists"] if record else False
async def node_degree(self, node_id: str) -> int:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (n:{self.namespace}) WHERE n.id = $node_id "
f"RETURN COUNT {{(n)-[]-(:{self.namespace})}} AS degree",
node_id=node_id,
)
record = await result.single()
return record["degree"] if record else 0
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (s:{self.namespace}), (t:{self.namespace}) "
"WHERE s.id = $src_id AND t.id = $tgt_id "
f"RETURN COUNT {{(s)-[]-(:{self.namespace})}} + COUNT {{(t)-[]-(:{self.namespace})}} AS degree",
src_id=src_id,
tgt_id=tgt_id,
)
record = await result.single()
return record["degree"] if record else 0
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (n:{self.namespace}) WHERE n.id = $node_id RETURN properties(n) AS node_data",
node_id=node_id,
)
record = await result.single()
raw_node_data = record["node_data"] if record else None
if raw_node_data is None:
return None
raw_node_data["clusters"] = json.dumps(
[
{
"level": index,
"cluster": cluster_id,
}
for index, cluster_id in enumerate(
raw_node_data.get("communityIds", [])
)
]
)
return raw_node_data
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) "
"WHERE s.id = $source_id AND t.id = $target_id "
"RETURN properties(r) AS edge_data",
source_id=source_node_id,
target_id=target_node_id,
)
record = await result.single()
return record["edge_data"] if record else None
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) WHERE s.id = $source_id "
"RETURN s.id AS source, t.id AS target",
source_id=source_node_id,
)
edges = []
async for record in result:
edges.append((record["source"], record["target"]))
return edges
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
async with self.async_driver.session() as session:
await session.run(
f"MERGE (n:{self.namespace}:{node_type} {{id: $node_id}}) "
"SET n += $node_data",
node_id=node_id,
node_data=node_data,
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
edge_data.setdefault("weight", 0.0)
async with self.async_driver.session() as session:
await session.run(
f"MATCH (s:{self.namespace}), (t:{self.namespace}) "
"WHERE s.id = $source_id AND t.id = $target_id "
"MERGE (s)-[r:RELATED]->(t) " # Added relationship type 'RELATED'
"SET r += $edge_data",
source_id=source_node_id,
target_id=target_node_id,
edge_data=edge_data,
)
async def clustering(self, algorithm: str):
if algorithm != "leiden":
raise ValueError(
f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
)
random_seed = self.global_config["graph_cluster_seed"]
max_level = self.global_config["max_graph_cluster_size"]
async with self.async_driver.session() as session:
try:
# Project the graph with undirected relationships
await session.run(
f"""
CALL gds.graph.project(
'graph_{self.namespace}',
['{self.namespace}'],
{{
RELATED: {{
orientation: 'UNDIRECTED',
properties: ['weight']
}}
}}
)
"""
)
# Run Leiden algorithm
result = await session.run(
f"""
CALL gds.leiden.write(
'graph_{self.namespace}',
{{
writeProperty: 'communityIds',
includeIntermediateCommunities: True,
relationshipWeightProperty: "weight",
maxLevels: {max_level},
tolerance: 0.0001,
gamma: 1.0,
theta: 0.01,
randomSeed: {random_seed}
}}
)
YIELD communityCount, modularities;
"""
)
result = await result.single()
community_count: int = result["communityCount"]
modularities = result["modularities"]
logger.info(
f"Performed graph clustering with {community_count} communities and modularities {modularities}"
)
finally:
# Drop the projected graph
await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
results = defaultdict(
lambda: dict(
level=None,
title=None,
edges=set(),
nodes=set(),
chunk_ids=set(),
occurrence=0.0,
sub_communities=[],
)
)
async with self.async_driver.session() as session:
# Fetch community data
result = await session.run(
f"""
MATCH (n:{self.namespace})
WITH n, n.communityIds AS communityIds, [(n)-[]-(m:{self.namespace}) | m.id] AS connected_nodes
RETURN n.id AS node_id, n.source_id AS source_id,
communityIds AS cluster_key,
connected_nodes
"""
)
# records = await result.fetch()
max_num_ids = 0
async for record in result:
for index, c_id in enumerate(record["cluster_key"]):
node_id = str(record["node_id"])
source_id = record["source_id"]
level = index
cluster_key = str(c_id)
connected_nodes = record["connected_nodes"]
results[cluster_key]["level"] = level
results[cluster_key]["title"] = f"Cluster {cluster_key}"
results[cluster_key]["nodes"].add(node_id)
results[cluster_key]["edges"].update(
[
tuple(sorted([node_id, str(connected)]))
for connected in connected_nodes
if connected != node_id
]
)
chunk_ids = source_id.split(GRAPH_FIELD_SEP)
results[cluster_key]["chunk_ids"].update(chunk_ids)
max_num_ids = max(
max_num_ids, len(results[cluster_key]["chunk_ids"])
)
# Process results
for k, v in results.items():
v["edges"] = [list(e) for e in v["edges"]]
v["nodes"] = list(v["nodes"])
v["chunk_ids"] = list(v["chunk_ids"])
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
# Compute sub-communities (this is a simplified approach)
for cluster in results.values():
cluster["sub_communities"] = [
sub_key
for sub_key, sub_cluster in results.items()
if sub_cluster["level"] > cluster["level"]
and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
]
return dict(results)
async def index_done_callback(self):
await self.async_driver.close()
async def _debug_delete_all_node_edges(self):
async with self.async_driver.session() as session:
try:
# Delete all relationships in the namespace
await session.run(f"MATCH (n:{self.namespace})-[r]-() DELETE r")
# Delete all nodes in the namespace
await session.run(f"MATCH (n:{self.namespace}) DELETE n")
logger.info(
f"All nodes and edges in namespace '{self.namespace}' have been deleted."
)
except Exception as e:
logger.error(f"Error deleting nodes and edges: {str(e)}")
raise

View File

@@ -0,0 +1,238 @@
import html
import json
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Union, cast
import networkx as nx
import numpy as np
from .._utils import logger
from ..base import (
BaseGraphStorage,
SingleCommunitySchema,
)
from ..prompt import GRAPH_FIELD_SEP
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._clustering_algorithms = {
"leiden": self._leiden_clustering,
}
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def clustering(self, algorithm: str):
if algorithm not in self._clustering_algorithms:
raise ValueError(f"Clustering algorithm {algorithm} not supported")
await self._clustering_algorithms[algorithm]()
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
results = defaultdict(
lambda: dict(
level=None,
title=None,
edges=set(),
nodes=set(),
chunk_ids=set(),
occurrence=0.0,
sub_communities=[],
)
)
max_num_ids = 0
levels = defaultdict(set)
for node_id, node_data in self._graph.nodes(data=True):
if "clusters" not in node_data:
continue
clusters = json.loads(node_data["clusters"])
this_node_edges = self._graph.edges(node_id)
for cluster in clusters:
level = cluster["level"]
cluster_key = str(cluster["cluster"])
levels[level].add(cluster_key)
results[cluster_key]["level"] = level
results[cluster_key]["title"] = f"Cluster {cluster_key}"
results[cluster_key]["nodes"].add(node_id)
results[cluster_key]["edges"].update(
[tuple(sorted(e)) for e in this_node_edges]
)
results[cluster_key]["chunk_ids"].update(
node_data["source_id"].split(GRAPH_FIELD_SEP)
)
max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
ordered_levels = sorted(levels.keys())
for i, curr_level in enumerate(ordered_levels[:-1]):
next_level = ordered_levels[i + 1]
this_level_comms = levels[curr_level]
next_level_comms = levels[next_level]
# compute the sub-communities by nodes intersection
for comm in this_level_comms:
results[comm]["sub_communities"] = [
c
for c in next_level_comms
if results[c]["nodes"].issubset(results[comm]["nodes"])
]
for k, v in results.items():
v["edges"] = list(v["edges"])
v["edges"] = [list(e) for e in v["edges"]]
v["nodes"] = list(v["nodes"])
v["chunk_ids"] = list(v["chunk_ids"])
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
return dict(results)
def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
for node_id, clusters in cluster_data.items():
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
async def _leiden_clustering(self):
from graspologic.partition import hierarchical_leiden
graph = NetworkXStorage.stable_largest_connected_component(self._graph)
community_mapping = hierarchical_leiden(
graph,
max_cluster_size=self.global_config["max_graph_cluster_size"],
random_seed=self.global_config["graph_cluster_seed"],
)
node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
__levels = defaultdict(set)
for partition in community_mapping:
level_key = partition.level
cluster_id = partition.cluster
node_communities[partition.node].append(
{"level": level_key, "cluster": cluster_id}
)
__levels[level_key].add(cluster_id)
node_communities = dict(node_communities)
__levels = {k: len(v) for k, v in __levels.items()}
logger.info(f"Each level has communities: {dict(__levels)}")
self._cluster_data_to_subgraphs(node_communities)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def _node2vec_embed(self):
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids

46
hirag/_storage/kv_json.py Normal file
View File

@@ -0,0 +1,46 @@
import os
from dataclasses import dataclass
from .._utils import load_json, logger, write_json
from ..base import (
BaseKVStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
self._data.update(data)
async def drop(self):
self._data = {}

View File

@@ -0,0 +1,141 @@
import asyncio
import os
from dataclasses import dataclass, field
from typing import Any
import pickle
import hnswlib
import numpy as np
import xxhash
from .._utils import logger
from ..base import BaseVectorStorage
@dataclass
class HNSWVectorStorage(BaseVectorStorage):
ef_construction: int = 100
M: int = 16
max_elements: int = 1000000
ef_search: int = 50
num_threads: int = -1
_index: Any = field(init=False)
_metadata: dict[str, dict] = field(default_factory=dict)
_current_elements: int = 0
def __post_init__(self):
self._index_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
)
self._metadata_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
)
self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
self.M = hnsw_params.get("M", self.M)
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
self._index = hnswlib.Index(
space="cosine", dim=self.embedding_func.embedding_dim
)
if os.path.exists(self._index_file_name) and os.path.exists(
self._metadata_file_name
):
self._index.load_index(
self._index_file_name, max_elements=self.max_elements
)
with open(self._metadata_file_name, "rb") as f:
self._metadata, self._current_elements = pickle.load(f)
logger.info(
f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
)
else:
self._index.init_index(
max_elements=self.max_elements,
ef_construction=self.ef_construction,
M=self.M,
)
self._index.set_ef(self.ef_search)
self._metadata = {}
self._current_elements = 0
logger.info(f"Created new index for {self.namespace}")
async def upsert(self, data: dict[str, dict]) -> np.ndarray:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not data:
logger.warning("You insert an empty data to vector DB")
return []
if self._current_elements + len(data) > self.max_elements:
raise ValueError(
f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
)
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batch_size = min(self._embedding_batch_num, len(contents))
embeddings = np.concatenate(
await asyncio.gather(
*[
self.embedding_func(contents[i : i + batch_size])
for i in range(0, len(contents), batch_size)
]
)
)
ids = np.fromiter(
(xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
dtype=np.uint32,
count=len(list_data),
)
self._metadata.update(
{
id_int: {
k: v for k, v in d.items() if k in self.meta_fields or k == "id"
}
for id_int, d in zip(ids, list_data)
}
)
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
self._current_elements = self._index.get_current_count()
return ids
async def query(self, query: str, top_k: int = 5) -> list[dict]:
if self._current_elements == 0:
return []
top_k = min(top_k, self._current_elements)
if top_k > self.ef_search:
logger.warning(
f"Setting ef_search to {top_k} because top_k is larger than ef_search"
)
self._index.set_ef(top_k)
embedding = await self.embedding_func([query])
labels, distances = self._index.knn_query(
data=embedding[0], k=top_k, num_threads=self.num_threads
)
return [
{
**self._metadata.get(label, {}),
"distance": distance,
"similarity": 1 - distance,
}
for label, distance in zip(labels[0], distances[0])
]
async def index_done_callback(self):
self._index.save_index(self._index_file_name)
with open(self._metadata_file_name, "wb") as f:
pickle.dump((self._metadata, self._current_elements), f)

View File

@@ -0,0 +1,68 @@
import asyncio
import os
from dataclasses import dataclass
import numpy as np
from nano_vectordb import NanoVectorDB
from .._utils import logger
from ..base import BaseVectorStorage
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
self.cosine_better_than_threshold = self.global_config.get(
"query_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
return results
async def index_done_callback(self):
self._client.save()

261
hirag/_utils.py Normal file
View File

@@ -0,0 +1,261 @@
import asyncio
import html
import json
import logging
import os
import re
import numbers
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union
import numpy as np
import tiktoken
logger = logging.getLogger("HiRAG")
ENCODER = None
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
# If there is already an event loop, use it.
loop = asyncio.get_event_loop()
except RuntimeError:
# If in a sub-thread, create a new event loop.
logger.info("Creating a new event loop in a sub-thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def extract_first_complete_json(s: str):
"""Extract the first complete JSON object from the string using a stack to track braces."""
stack = []
first_json_start = None
for i, char in enumerate(s):
if char == '{':
stack.append(i)
if first_json_start is None:
first_json_start = i
elif char == '}':
if stack:
start = stack.pop()
if not stack:
first_json_str = s[first_json_start:i+1]
try:
# Attempt to parse the JSON string
return json.loads(first_json_str.replace("\n", ""))
except json.JSONDecodeError as e:
logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...")
return None
finally:
first_json_start = None
logger.warning("No complete JSON object found in the input string.")
return None
def parse_value(value: str):
"""Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'"""
value = value.strip()
if value == "null":
return None
elif value == "true":
return True
elif value == "false":
return False
else:
# Try to convert to int or float
try:
if '.' in value: # If there's a dot, it might be a float
return float(value)
else:
return int(value)
except ValueError:
# If conversion fails, return the value as-is (likely a string)
return value.strip('"') # Remove surrounding quotes if they exist
def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False):
"""Extract key values from a non-standard or malformed JSON string, handling nested objects."""
extracted_values = {}
# Enhanced pattern to match both quoted and unquoted values, as well as nested objects
regex_pattern = r'(?P<key>"?\w+"?)\s*:\s*(?P<value>{[^}]*}|".*?"|[^,}]+)'
for match in re.finditer(regex_pattern, json_string, re.DOTALL):
key = match.group('key').strip('"') # Strip quotes from key
value = match.group('value').strip()
# If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it
if value.startswith('{') and value.endswith('}'):
extracted_values[key] = extract_values_from_json(value)
else:
# Parse the value into the appropriate type (int, float, bool, etc.)
extracted_values[key] = parse_value(value)
if not extracted_values:
logger.warning("No values could be extracted from the string.")
return extracted_values
def convert_response_to_json(response: str) -> dict:
"""Convert response string to JSON, with error handling and fallback to non-standard JSON extraction."""
prediction_json = extract_first_complete_json(response)
if prediction_json is None:
logger.info("Attempting to extract values from a non-standard JSON string...")
prediction_json = extract_values_from_json(response, allow_no_quotes=True)
if not prediction_json:
logger.error("Unable to extract meaningful data from the response.")
else:
logger.info("JSON data successfully extracted.")
return prediction_json
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
def write_json(json_obj, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
# it's dirty to type, so it's a good way to have fun
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
def enclose_string_with_quotes(content: Any) -> str:
"""Enclose a string with quotes"""
if isinstance(content, numbers.Number):
return str(content)
content = str(content)
content = content.strip().strip("'").strip('"')
return f'"{content}"'
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[
",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
for data_d in data
]
)
# -----------------------------------------------------------------------------------
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
# Utils types -----------------------------------------------------------------------
@dataclass
class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
# Decorators ------------------------------------------------------------------------
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
return new_func
return final_decro

152
hirag/base.py Normal file
View File

@@ -0,0 +1,152 @@
from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar
from ._utils import EmbeddingFunc
import numpy as np
@dataclass
class QueryParam:
mode: Literal["hi_global", "hi_local", "hi_bridge", "hi_nobridge", "naive", "hi"] = "hi"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
level: int = 2
top_k: int = 20 # retrieve top-k entities
top_m: int = 10 # retrieve top-m entities in each retrieved community
# naive search
naive_max_token_for_text_unit = 10000
# hi search
max_token_for_text_unit: int = 20000
max_token_for_local_context: int = 20000
max_token_for_bridge_knowledge: int = 12500
max_token_for_community_report: int = 12500
community_single_one: bool = False
TextChunkSchema = TypedDict(
"TextChunkSchema",
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
)
SingleCommunitySchema = TypedDict(
"SingleCommunitySchema",
{
"level": int,
"title": str,
"edges": list[list[str, str]],
"nodes": list[str],
"chunk_ids": list[str],
"occurrence": float,
"sub_communities": list[str],
},
)
class CommunitySchema(SingleCommunitySchema):
report_string: str
report_json: dict
T = TypeVar("T")
@dataclass
class StorageNameSpace:
namespace: str
global_config: dict
async def index_start_callback(self):
"""commit the storage operations after indexing"""
pass
async def index_done_callback(self):
"""commit the storage operations after indexing"""
pass
async def query_done_callback(self):
"""commit the storage operations after querying"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
raise NotImplementedError
async def drop(self):
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
raise NotImplementedError
async def clustering(self, algorithm: str):
raise NotImplementedError
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
"""Return the community representation with report and nodes"""
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in HiRAG.")

409
hirag/hirag.py Normal file
View File

@@ -0,0 +1,409 @@
import asyncio
import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Callable, Dict, List, Optional, Type, Union, cast
import tiktoken
from ._llm import (
gpt_4o_complete,
gpt_4o_mini_complete,
gpt_35_turbo_complete,
openai_embedding,
azure_gpt_4o_complete,
azure_openai_embedding,
azure_gpt_4o_mini_complete,
)
from ._op import (
chunking_by_token_size,
extract_entities,
extract_hierarchical_entities,
generate_community_report,
get_chunks,
hierarchical_query,
hierarchical_bridge_query,
hierarchical_local_query,
hierarchical_global_query,
hierarchical_nobridge_query,
naive_query,
)
from ._storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
)
from ._utils import (
EmbeddingFunc,
compute_mdhash_id,
limit_async_func_call,
convert_response_to_json,
always_get_an_event_loop,
logger,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
StorageNameSpace,
QueryParam,
)
@dataclass
class HiRAG:
working_dir: str = field(
default_factory=lambda: f"./hirag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
# graph mode
enable_local: bool = True
enable_naive_rag: bool = False
enable_hierachical_mode: bool = True
# text chunking
chunk_func: Callable[
[
list[list[int]],
List[str],
tiktoken.Encoding,
Optional[int],
Optional[int],
],
List[Dict[str, Union[str, int]]],
] = chunking_by_token_size
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
tiktoken_model_name: str = "gpt-4o"
# entity extraction
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# graph clustering
graph_cluster_algorithm: str = "leiden"
max_graph_cluster_size: int = 10
graph_cluster_seed: int = 0xDEADBEEF
# node embedding
node_embedding_algorithm: str = "node2vec"
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
# community reports
special_community_report_llm_kwargs: dict = field(
default_factory=lambda: {"response_format": {"type": "json_object"}}
)
# text embedding
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 8
query_better_than_threshold: float = 0.2
# LLM
using_azure_openai: bool = False
# best_model_func: callable = gpt_35_turbo_complete
best_model_func: callable = gpt_4o_mini_complete
best_model_max_token_size: int = 32768
best_model_max_async: int = 8
cheap_model_func: callable = gpt_35_turbo_complete
cheap_model_max_token_size: int = 32768
cheap_model_max_async: int = 8
# entity extraction
entity_extraction_func: callable = extract_entities
hierarchical_entity_extraction_func: callable = extract_hierarchical_entities
# storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
enable_llm_cache: bool = True
# extension
always_create_working_dir: bool = True
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"HiRAG init with param:\n\n {_print_config}\n")
if self.using_azure_openai:
# If there's no OpenAI API key, use Azure OpenAI
if self.best_model_func == gpt_4o_complete:
self.best_model_func = azure_gpt_4o_complete
if self.cheap_model_func == gpt_4o_mini_complete:
self.cheap_model_func = azure_gpt_4o_mini_complete
if self.embedding_func == openai_embedding:
self.embedding_func = azure_openai_embedding
logger.info(
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
)
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self)
)
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self)
)
self.llm_response_cache = (
self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self)
)
if self.enable_llm_cache
else None
)
self.community_reports = self.key_string_value_json_storage_cls(
namespace="community_reports", global_config=asdict(self)
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self)
)
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
if self.enable_local
else None
)
self.chunks_vdb = (
self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
if self.enable_naive_rag
else None
)
self.best_model_func = limit_async_func_call(self.best_model_max_async)(
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
)
self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
)
def insert(self, string_or_strings):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "naive" and not self.enable_naive_rag:
raise ValueError("enable_naive_rag is False, cannot query in naive mode")
if param.mode == "hi" and not self.enable_hierachical_mode:
raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical mode")
if param.mode == "hi_nobridge" and not self.enable_hierachical_mode:
raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_nobridge mode")
if param.mode == "hi_bridge" and not self.enable_hierachical_mode:
raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_bridge mode")
if param.mode == "hi_local" and not self.enable_hierachical_mode:
raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_local mode")
if param.mode == "hi_global" and not self.enable_hierachical_mode:
raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_global mode")
if param.mode == "hi": # retrieve with hierarchical knowledge
response = await hierarchical_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hi_bridge": # retrieve with only bridge knowledge
response = await hierarchical_bridge_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hi_local": # retrieve with only local knowledge
response = await hierarchical_local_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hi_global": # retrieve with only global knowledge
response = await hierarchical_global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "hi_nobridge": # retrieve with no bridge knowledge
response = await hierarchical_nobridge_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "naive": # retrieve with only text units
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def ainsert(self, string_or_strings):
await self._insert_start()
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# ---------- new docs
new_docs = { # dict: {hash: ori_content}
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) # filter the docs that has already in the storage.
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning(f"All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
# ---------- chunking
inserting_chunks = get_chunks(
new_docs=new_docs,
chunk_func=self.chunk_func,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
if self.enable_naive_rag:
logger.info("Insert chunks for naive RAG")
await self.chunks_vdb.upsert(inserting_chunks)
# TODO: no incremental update for communities now, so just drop all
await self.community_reports.drop() # empty the data
# ---------- extract/summary entity and upsert to graph
if not self.enable_hierachical_mode:
logger.info("[Entity Extraction]...")
maybe_new_kg = await self.entity_extraction_func(
inserting_chunks,
knwoledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
global_config=asdict(self),
)
else:
logger.info("[Hierachical Entity Extraction]...")
maybe_new_kg = await self.hierarchical_entity_extraction_func(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# ---------- update clusterings of graph
logger.info("[Community Report]...")
await self.chunk_entity_relation_graph.clustering(
self.graph_cluster_algorithm # use leiden
)
await generate_community_report(
self.community_reports, self.chunk_entity_relation_graph, asdict(self)
)
# ---------- commit upsertings and indexing
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
await self._insert_done()
async def _insert_start(self):
tasks = []
for storage_inst in [
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_start_callback())
await asyncio.gather(*tasks)
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.community_reports,
self.entities_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)

734
hirag/prompt.py Normal file
View File

@@ -0,0 +1,734 @@
"""
Reference:
- Prompts are from [graphrag](https://github.com/microsoft/graphrag)
"""
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {}
PROMPTS[
"claim_extraction"
] = """-Target activity-
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.
-Goal-
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.
-Steps-
1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
For each claim, extract the following information:
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.
Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
4. When finished, output {completion_delimiter}
-Examples-
Example 1:
Entity specification: organization
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{completion_delimiter}
Example 2:
Entity specification: Company A, Person C
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{record_delimiter}
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
{completion_delimiter}
-Real Data-
Use the following input for your answer.
Entity specification: {entity_specs}
Claim description: {claim_description}
Text: {input_text}
Output: """
PROMPTS[
"community_report"
] = """You are an AI assistant that helps a human analyst to perform general information discovery.
Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
# Goal
Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims.
# Report Structure
The report should include the following sections:
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
Return output as a well-formed JSON-formatted string with the following format:
{{
"title": <report_title>,
"summary": <executive_summary>,
"rating": <impact_severity_rating>,
"rating_explanation": <rating_explanation>,
"findings": [
{{
"summary":<insight_1_summary>,
"explanation": <insight_1_explanation>
}},
{{
"summary":<insight_2_summary>,
"explanation": <insight_2_explanation>
}}
...
]
}}
# Grounding Rules
Do not include information where the supporting evidence for it is not provided.
# Example Input
-----------
Text:
```
Entities:
```csv
id,entity,type,description
5,VERDANT OASIS PLAZA,geo,Verdant Oasis Plaza is the location of the Unity March
6,HARMONY ASSEMBLY,organization,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza
```
Relationships:
```csv
id,source,target,description
37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March
38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza
39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza
40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza
41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march
43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March
```
```
Output:
{{
"title": "Verdant Oasis Plaza and Unity March",
"summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.",
"rating": 5.0,
"rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.",
"findings": [
{{
"summary": "Verdant Oasis Plaza as the central location",
"explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes."
}},
{{
"summary": "Harmony Assembly's role in the community",
"explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community."
}},
{{
"summary": "Unity March as a significant event",
"explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community."
}},
{{
"summary": "Role of Tribune Spotlight",
"explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved."
}}
]
}}
# Real Data
Use the following text for your answer. Do not make anything up in your answer.
Text:
```
{input_text}
```
The report should include the following sections:
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
Return output as a well-formed JSON-formatted string with the following format:
{{
"title": <report_title>,
"summary": <executive_summary>,
"rating": <impact_severity_rating>,
"rating_explanation": <rating_explanation>,
"findings": [
{{
"summary":<insight_1_summary>,
"explanation": <insight_1_explanation>
}},
{{
"summary":<insight_2_summary>,
"explanation": <insight_2_explanation>
}}
...
]
}}
# Grounding Rules
Do not include information where the supporting evidence for it is not provided.
Output:
"""
PROMPTS[
"entity_extraction"
] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [{entity_types}]
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
4. When finished, output {completion_delimiter}
######################
-Examples-
######################
Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
################
Output:
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter}
#############################
Example 2:
Entity_types: [person, technology, mission, organization, location]
Text:
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
#############
Output:
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter}
#############################
Example 3:
Entity_types: [person, role, technology, organization, event, location, concept]
Text:
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
#############
Output:
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
PROMPTS[
"hi_entity_extraction"
] = """
Given a text document that is potentially relevant to a list of entity types, identify all entities of those types.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [{entity_types}], normal_entity means that doesn't belong to any other types.
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
2. Return output in English as a single list of all the entities identified in step 1. Use **{record_delimiter}** as the list delimiter.
3. When finished, output {completion_delimiter}
######################
-Examples-
######################
Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
################
Output:
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
#############################
Example 2:
Entity_types: [person, technology, mission, organization, location]
Text:
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
#############
Output:
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
#############################
Example 3:
Entity_types: [person, role, technology, organization, event, location, concept]
Text:
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
#############
Output:
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
"""
PROMPTS[
"hi_relation_extraction"
] = """
Given a text document that is potentially relevant to a list of entities, identify all relationships among the given identified entities.
-Steps-
1. From the entities given by user, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
2. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
3. When finished, output {completion_delimiter}
######################
-Examples-
######################
Example 1:
Entities: ["Alex", "Taylor", "Jordan", "Cruz", "The Device"]
Text:
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
################
Output:
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter}
#############################
Example 2:
Entities: ["Washington", "Operation: Dulce", "The team"]
Text:
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
#############
Output:
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter}
#############################
Example 3:
Entity_types: ["Sam Rivera", "Alex", "Control", "Intelligence", "First Contact", "Humanity's Response"]
Text:
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
#############
Output:
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter}
#############################
-Real Data-
######################
Entities: {entities}
Text: {input_text}
######################
Output:
"""
PROMPTS[
"summarize_entity_descriptions"
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.
#######
-Data-
Entities: {entity_name}
Description List: {description_list}
#######
Output:
"""
PROMPTS[
"entiti_continue_extraction"
] = """MANY entities were missed in the last extraction. Add them below using the same format:
"""
PROMPTS[
"entiti_if_loop_extraction"
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
"""
PROMPTS[
"summary_clusters"
] = """You are tasked with analyzing a set of entity descriptions and a given list of meta attributes. Your goal is to summarize at least one attribute entity for the entity set in the given entity descriptions. And the summarized attribute entity must match the type of at least one meta attribute in the given meta attribute list (e.g., if a meta attribute is "company", the attribute entity could be "Amazon" or "Meta", which is a kind of meta attribute "company"). And it shoud be directly relevant to the entities described in the entity description set. The relationship between the entity set and the generated attribute entity should be clear and logical.
-Steps-
1. Identify at least one attribute entity for the given entity description list. For each attribute entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [{meta_attribute_list}], normal_entity means that doesn't belong to any other types.
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
2. From each given entity, identify all pairs of (source_entity, target_entity) that are *clearly related* to the attribute entities identified in step 1. And there should be no relations between the attribute entities.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as given in entity list
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
4. When finished, output {completion_delimiter}
######################
-Example-
######################
Input:
Meta attribute list: ["company", "location"]
Entity description list: [("Instagram", "Instagram is a software developed by Meta, which captures and shares the world's moments. Follow friends and family to see what they're up to, and discover accounts from all over the world that are sharing things you love."), ("Facebook", "Facebook is a social networking platform launched in 2004 that allows users to connect, share updates, and engage with communities. Owned by Meta, it is one of the largest social media platforms globally, offering tools for communication, business, and advertising."), ("WhatsApp", "WhatsApp Messenger: A messaging app of Meta for simple, reliable, and secure communication. Connect with friends and family, send messages, make voice and video calls, share media, and stay in touch with loved ones, no matter where they are")]
#######
Output:
("entity"{tuple_delimiter}"Meta"{tuple_delimiter}"company"{tuple_delimiter}"Meta, formerly known as Facebook, Inc., is an American multinational technology conglomerate. It is known for its various online social media services."){record_delimiter}
("relationship"{tuple_delimiter}"Instagram"{tuple_delimiter}"Meta"{tuple_delimiter}"Instagram is a software developed by Meta."{tuple_delimiter}8.5){record_delimiter}
("relationship"{tuple_delimiter}"Facebook"{tuple_delimiter}"Meta"{tuple_delimiter}"Facebook is owned by Meta."{tuple_delimiter}9.0){record_delimiter}
("relationship"{tuple_delimiter}"WhatsApp"{tuple_delimiter}"Meta"{tuple_delimiter}"WhatsApp Messenger is a messaging app of Meta."{tuple_delimiter}8.0){record_delimiter}
#############################
-Real Data-
######################
Input:
Meta attribute list: {meta_attribute_list}
Entity description list: {entity_description_list}
#######
Output:
"""
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
PROMPTS["META_ENTITY_TYPES"] = ["organization", "person", "location", "event", "product", "technology", "industry", "mathematics", "social sciences"]
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS[
"local_rag_response"
] = """---Role---
You are a helpful assistant responding to questions about data in the tables provided.
---Goal---
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
If you don't know the answer, just say so. Do not make anything up.
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
---Data tables---
{context_data}
---Goal---
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
If you don't know the answer, just say so. Do not make anything up.
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
"""
PROMPTS[
"global_map_rag_points"
] = """---Role---
You are a helpful assistant responding to questions about data in the tables provided.
---Goal---
Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables.
You should use the data provided in the data tables below as the primary context for generating the response.
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
Each key point in the response should have the following element:
- Description: A comprehensive description of the point.
- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.
The response should be JSON formatted as follows:
{{
"points": [
{{"description": "Description of point 1...", "score": score_value}},
{{"description": "Description of point 2...", "score": score_value}}
]
}}
The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
Do not include information where the supporting evidence for it is not provided.
---Data tables---
{context_data}
---Goal---
Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables.
You should use the data provided in the data tables below as the primary context for generating the response.
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
Each key point in the response should have the following element:
- Description: A comprehensive description of the point.
- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.
The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
Do not include information where the supporting evidence for it is not provided.
The response should be JSON formatted as follows:
{{
"points": [
{{"description": "Description of point 1", "score": score_value}},
{{"description": "Description of point 2", "score": score_value}}
]
}}
"""
PROMPTS[
"global_reduce_rag_response"
] = """---Role---
You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts.
---Goal---
Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.
Note that the analysts' reports provided below are ranked in the **descending order of importance**.
If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.
The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
---Analyst Reports---
{report_data}
---Goal---
Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.
Note that the analysts' reports provided below are ranked in the **descending order of importance**.
If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.
The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.
The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
"""
PROMPTS[
"naive_rag_response"
] = """You're a helpful assistant
Below are the knowledge you know:
{content_data}
---
If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up.
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
If you don't know the answer, just say so. Do not make anything up.
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
"""
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
PROMPTS["process_tickers"] = ["", "", "", "", "", "", "", "", "", ""]
PROMPTS["default_text_separator"] = [
# Paragraph separators
"\n\n",
"\r\n\r\n",
# Line breaks
"\n",
"\r\n",
# Sentence ending punctuation
"", # Chinese period
"", # Full-width dot
".", # English period
"", # Chinese exclamation mark
"!", # English exclamation mark
"", # Chinese question mark
"?", # English question mark
# Whitespace characters
" ", # Space
"\t", # Tab
"\u3000", # Full-width space
# Special characters
"\u200b", # Zero-width space (used in some Asian languages)
]