search update (#81)

* search update

* update string literals
This commit is contained in:
Preston Rasmussen
2024-09-04 10:05:45 -04:00
committed by GitHub
parent 2b6adb5279
commit e56a599a72
7 changed files with 212 additions and 88 deletions

View File

@@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
await client.add_episode_bulk(episodes)
asyncio.run(main(False))
asyncio.run(main(True))

View File

@@ -180,9 +180,9 @@ class Graphiti:
await build_indices_and_constraints(self.driver)
async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@@ -210,14 +210,14 @@ class Graphiti:
return await retrieve_episodes(self.driver, reference_time, last_n)
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""
Process an episode and update the graph.
@@ -321,11 +321,11 @@ class Graphiti:
await asyncio.gather(
*[
get_relevant_edges(
[edge],
self.driver,
RELEVANT_SCHEMA_LIMIT,
[edge],
edge.source_node_uuid,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges
]
@@ -422,8 +422,8 @@ class Graphiti:
raise e
async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
self,
bulk_episodes: list[RawEpisode],
):
"""
Process multiple episodes in bulk and update the graph.
@@ -587,18 +587,18 @@ class Graphiti:
return edges
async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)
async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.

View File

@@ -83,7 +83,7 @@ async def hybrid_search(
nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges)
text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges)
search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods:
@@ -95,7 +95,7 @@ async def hybrid_search(
)
similarity_search = await edge_similarity_search(
driver, search_vector, 2 * config.num_edges
driver, search_vector, None, None, 2 * config.num_edges
)
search_results.append(similarity_search)

View File

@@ -1,11 +1,11 @@
import asyncio
import logging
import re
import typing
from collections import defaultdict
from time import time
from typing import Any
from neo4j import AsyncDriver
from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge
from graphiti_core.helpers import parse_db_date
@@ -66,12 +66,12 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context: dict[str, typing.Any] = {}
context: dict[str, Any] = {}
for record in records:
n_uuid = record['source_node_uuid']
@@ -96,15 +96,14 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
@@ -121,7 +120,68 @@ async def edge_similarity_search(
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""",
""")
if source_node_uuid is None and target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif source_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
@@ -151,7 +211,7 @@ async def edge_similarity_search(
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
@@ -161,6 +221,7 @@ async def entity_similarity_search(
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embeddings AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
@@ -175,6 +236,7 @@ async def entity_similarity_search(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
@@ -185,7 +247,7 @@ async def entity_similarity_search(
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
@@ -193,8 +255,9 @@ async def entity_fulltext_search(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
node.uuid As uuid,
node.uuid AS uuid,
node.name AS name,
node.name_embeddings AS name_embedding,
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
@@ -210,6 +273,7 @@ async def entity_fulltext_search(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
@@ -220,21 +284,18 @@ async def entity_fulltext_search(
async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
limit=RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
driver: AsyncDriver,
query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
@@ -247,7 +308,70 @@ async def edge_fulltext_search(
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
""")
if source_node_uuid is None and target_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
elif source_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
elif target_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
cypher_query,
query=fuzzy_query,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
@@ -277,16 +401,16 @@ async def edge_fulltext_search(
async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Perform a hybrid search for nodes using both text queries and embeddings.
This method combines fulltext search and vector similarity search to find
relevant nodes in the graph database. It uses an rrf reranker.
relevant nodes in the graph database. It uses a rrf reranker.
Parameters
----------
@@ -342,8 +466,8 @@ async def hybrid_node_search(
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
"""
Retrieve relevant nodes based on the provided list of EntityNodes.
@@ -379,11 +503,11 @@ async def get_relevant_nodes(
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
driver: AsyncDriver,
edges: list[EntityEdge],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
@@ -392,13 +516,13 @@ async def get_relevant_edges(
results = await asyncio.gather(
*[
edge_similarity_search(
driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid
driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
)
for edge in edges
if edge.fact_embedding is not None
],
*[
edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid)
edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
for edge in edges
],
)
@@ -433,14 +557,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
scores: dict[str, float] = {}
for uuid in sorted_uuids:
# Find shortest path to center node
# Find the shortest path to center node
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
@@ -455,8 +579,8 @@ async def node_distance_reranker(
for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']

View File

@@ -158,7 +158,7 @@ async def dedupe_edges_bulk(
relevant_edges_chunks: list[list[EntityEdge]] = list(
await asyncio.gather(
*[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks]
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
)
)

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
async def extract_message_nodes(
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
@@ -49,8 +49,8 @@ async def extract_message_nodes(
async def extract_json_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
llm_client: LLMClient,
episode: EpisodicNode,
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
@@ -67,9 +67,9 @@ async def extract_json_nodes(
async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
start = time()
extracted_node_data: list[dict[str, Any]] = []
@@ -96,9 +96,9 @@ async def extract_nodes(
async def dedupe_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()
@@ -146,9 +146,9 @@ async def dedupe_extracted_nodes(
async def resolve_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes_lists: list[list[EntityNode]],
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes_lists: list[list[EntityNode]],
) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
@@ -169,7 +169,7 @@ async def resolve_extracted_nodes(
async def resolve_extracted_node(
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
) -> tuple[EntityNode, dict[str, str]]:
start = time()
@@ -214,8 +214,8 @@ async def resolve_extracted_node(
async def dedupe_node_list(
llm_client: LLMClient,
nodes: list[EntityNode],
llm_client: LLMClient,
nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
start = time()

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.2.0"
version = "0.2.1"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",