add bulk temporal extraction and improve bulk quality and performance (#67)

* parallelize edge deduping more

* parallelize node insertion more

* improve bulk behavior performance

* dedupe nodes actually works

* add a reranker to search

* bulk dedupe episodes only across the same nodes

* add temporal extraction bulk function

* cleaned up bulk

* default to 4o

* format

* mypy

* mympy

* mypy ignore
This commit is contained in:
Preston Rasmussen
2024-08-30 10:48:28 -04:00
committed by GitHub
parent aac06d9d24
commit 35a4e5172b
8 changed files with 203 additions and 61 deletions

View File

@@ -29,6 +29,7 @@ from graphiti_core.llm_client.utils import generate_embedding
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_relevant_edges,
get_relevant_nodes,
hybrid_node_search,
@@ -41,6 +42,7 @@ from graphiti_core.utils.bulk_utils import (
RawEpisode,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_edge_dates_bulk,
extract_nodes_and_edges_bulk,
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
@@ -319,26 +321,24 @@ class Graphiti:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode.valid_at,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
edge.expired_at = now
for edge in existing_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode.valid_at,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
edge.expired_at = now
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
@@ -481,15 +481,18 @@ class Graphiti:
*[edge.generate_embedding(embedder) for edge in extracted_edges],
)
# Dedupe extracted nodes
nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes)
# Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
)
# save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes])
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
extracted_edges_timestamped, uuid_map
)
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
@@ -579,7 +582,9 @@ class Graphiti:
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)
async def get_nodes_by_query(self, query: str, limit: int | None = None) -> list[EntityNode]:
async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.

View File

@@ -53,7 +53,9 @@ def v1(context: dict[str, Any]) -> list[Message]:
1. start with the list of nodes from New Nodes
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
node in the list
3. Respond with the resulting list of nodes
3. when deduplicating nodes, synthesize their summaries into a short new summary that contains the relevant information
of the summaries of the new and existing nodes
4. Respond with the resulting list of nodes
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
@@ -64,6 +66,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
"new_nodes": [
{{
"name": "Unique identifier for the node",
"summary": "Brief summary of the node's role or significance"
}}
]
}}
@@ -92,6 +95,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
relevant information of the summaries of the new and existing nodes.
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
@@ -104,7 +109,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
"duplicates": [
{{
"name": "name of the new node",
"duplicate_of": "name of the existing node"
"duplicate_of": "name of the existing node",
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
}}
]
}}
@@ -130,6 +136,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All duplicate names should be grouped together in the same list
3. Also return a new summary that synthesizes the summary into a new short summary
Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response
@@ -140,6 +147,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
"nodes": [
{{
"names": ["myNode", "node that is a duplicate of myNode"],
"summary": "Brief summary of the node summaries that appear in the list of names."
}}
]
}}

View File

@@ -110,10 +110,11 @@ def v2(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Create edges only between the provided nodes.
2. Each edge should represent a clear relationship between two nodes.
2. Each edge should represent a clear relationship between two DISTINCT nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. Consider temporal aspects of relationships when relevant.
6. Avoid using the same node as the source and target of a relationship
Respond with a JSON object in the following format:
{{

View File

@@ -63,12 +63,12 @@ class SearchResults(BaseModel):
async def hybrid_search(
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()

View File

@@ -268,13 +268,13 @@ async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Perform a hybrid search for nodes using both text queries and embeddings.
This method combines fulltext search and vector similarity search to find
relevant nodes in the graph database.
relevant nodes in the graph database. It uses an rrf reranker.
Parameters
----------
@@ -307,27 +307,25 @@ async def hybrid_node_search(
"""
start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()
results = await asyncio.gather(
*[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
*[
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
for e in embeddings
],
results: list[list[EntityNode]] = list(
await asyncio.gather(
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
)
)
for result in results:
for node in result:
if node.uuid in relevant_node_uuids:
continue
node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for result in results for node in result
}
result_uuids = [[node.uuid for node in result] for result in results]
relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
ranked_uuids = rrf(result_uuids)
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
end = time()
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
return relevant_nodes

View File

@@ -15,11 +15,14 @@ limitations under the License.
"""
import asyncio
import logging
import typing
from collections import defaultdict
from datetime import datetime
from math import ceil
from neo4j import AsyncDriver
from numpy import dot
from numpy import dot, sqrt
from pydantic import BaseModel
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
@@ -39,8 +42,11 @@ from graphiti_core.utils.maintenance.node_operations import (
dedupe_node_list,
extract_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
CHUNK_SIZE = 15
logger = logging.getLogger(__name__)
CHUNK_SIZE = 10
class RawEpisode(BaseModel):
@@ -52,7 +58,7 @@ class RawEpisode(BaseModel):
async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode]
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather(
*[
@@ -68,7 +74,7 @@ async def retrieve_previous_episodes_bulk(
async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await asyncio.gather(
*[
@@ -105,36 +111,67 @@ async def extract_nodes_and_edges_bulk(
async def dedupe_nodes_bulk(
driver: AsyncDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
driver: AsyncDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
# Compress nodes
nodes, uuid_map = node_name_match(extracted_nodes)
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
existing_nodes_chunks: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks]
)
)
compressed_map.update(partial_uuid_map)
results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list(
await asyncio.gather(
*[
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
for i, node_chunk in enumerate(node_chunks)
]
)
)
return nodes, compressed_map
final_nodes: list[EntityNode] = []
for result in results:
final_nodes.extend(result[0])
partial_uuid_map = result[1]
compressed_map.update(partial_uuid_map)
return final_nodes, compressed_map
async def dedupe_edges_bulk(
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
) -> list[EntityEdge]:
# Compress edges
# First compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges)
existing_edges = await get_relevant_edges(compressed_edges, driver)
edge_chunks = [
compressed_edges[i: i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
]
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
relevant_edges_chunks: list[list[EntityEdge]] = list(
await asyncio.gather(
*[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks]
)
)
resolved_edge_chunks: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
for i, edge_chunk in enumerate(edge_chunks)
]
)
)
edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
return edges
@@ -152,15 +189,60 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]:
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
if len(nodes) == 0:
return nodes, uuid_map
anchor = nodes[0]
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
# Our approach involves us deduplicating chunks of nodes in parallel.
# We want n chunks of size n so that n ** 2 == len(nodes).
# We want chunk sizes to be at least 10 for optimizing LLM processing time
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
# First calculate similarity scores between nodes
similarity_scores: list[tuple[int, int, float]] = [
(i, j, dot(n.name_embedding or [], m.name_embedding or []))
for i, n in enumerate(nodes)
for j, m in enumerate(nodes[:i])
]
# We now sort by semantic similarity
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
# initialize our chunks based on chunk size
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
# Draft the most similar nodes into the same chunk
while len(similarity_scores) > 0:
i, j, _ = similarity_scores.pop()
# determine if any of the nodes have already been drafted into a chunk
n = nodes[i]
m = nodes[j]
# make sure the shortest chunks get preference
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
# both nodes already in a chunk
if n_chunk > -1 and m_chunk > -1:
continue
# n has a chunk and that chunk is not full
elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
# put m in the same chunk as n
node_chunks[n_chunk].append(m)
# m has a chunk and that chunk is not full
elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
# put n in the same chunk as m
node_chunks[m_chunk].append(n)
# neither node has a chunk or the chunk is full
else:
# add both nodes to the shortest chunk
node_chunks[-1].extend([n, m])
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
@@ -181,13 +263,21 @@ async def compress_nodes(
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
if len(edges) == 0:
return edges
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
anchor = edges[0]
edges.sort(
key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])
)
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
@@ -225,3 +315,43 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
return edges
async def extract_edge_dates_bulk(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
) -> list[EntityEdge]:
edges: list[EntityEdge] = []
# confirm that all of our edges have at least one episode
for edge in extracted_edges:
if edge.episodes is not None and len(edge.episodes) > 0:
edges.append(edge)
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
}
results = await asyncio.gather(
*[
extract_edge_dates(
llm_client,
edge,
episode_uuid_map[edge.episodes[0]][0], # type: ignore
episode_uuid_map[edge.episodes[0]][1], # type: ignore
)
for edge in edges
]
)
for i, result in enumerate(results):
valid_at = result[0]
invalid_at = result[1]
edge = edges[i]
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
return edges

View File

@@ -189,6 +189,7 @@ async def dedupe_node_list(
uuid_map: dict[str, str] = {}
for node_data in nodes_data:
node = node_map[node_data['names'][0]]
node.summary = node_data['summary']
unique_nodes.append(node)
for name in node_data['names'][1:]:

View File

@@ -147,7 +147,6 @@ def process_edge_invalidation_llm_response(
async def extract_edge_dates(
llm_client: LLMClient,
edge: EntityEdge,
reference_time: datetime,
current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode],
) -> tuple[datetime | None, datetime | None, str]:
@@ -156,7 +155,7 @@ async def extract_edge_dates(
'edge_fact': edge.fact,
'current_episode': current_episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_timestamp': reference_time.isoformat(),
'reference_timestamp': current_episode.valid_at.isoformat(),
}
llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))