Add group ids (#89)

* set and retrieve group ids

* update add episode with group id support

* add episode and search functional

* update bulk

* mypy updates

* remove unused imports

* update unit tests

* unit tests

* add optional uuid field

* format

* mypy

* ellipsis
This commit is contained in:
Preston Rasmussen
2024-09-06 12:33:42 -04:00
committed by GitHub
parent c7fc057106
commit 42fb590606
15 changed files with 329 additions and 356 deletions

View File

@@ -69,6 +69,7 @@ async def main(use_bulk: bool = True):
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
)
return

View File

@@ -18,6 +18,7 @@ import logging
from abc import ABC, abstractmethod
from datetime import datetime
from time import time
from typing import Any
from uuid import uuid4
from neo4j import AsyncDriver
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid4().hex)
group_id: str | None = Field(description='partition of the graph')
source_node_uuid: str
target_node_uuid: str
created_at: datetime
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at}
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)
@@ -93,6 +96,7 @@ class EpisodicEdge(Edge):
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
uuid=uuid,
)
edges: list[EpisodicEdge] = []
for record in records:
edges.append(
EpisodicEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
)
edges = [get_episodic_edge_from_record(record) for record in records]
logger.info(f'Found Edge: {uuid}')
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""",
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
target_uuid=self.target_node_uuid,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
fact=self.fact,
fact_embedding=self.fact_embedding,
episodes=self.episodes,
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
uuid=uuid,
)
edges: list[EntityEdge] = []
for record in records:
edges.append(
EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
)
edges = [get_entity_edge_from_record(record) for record in records]
logger.info(f'Found Edge: {uuid}')
return edges[0]
# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
return EpisodicEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
def get_entity_edge_from_record(record: Any) -> EntityEdge:
return EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)

View File

@@ -18,7 +18,6 @@ import asyncio
import logging
from datetime import datetime
from time import time
from typing import Callable
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
@@ -120,7 +119,7 @@ class Graphiti:
Parameters
----------
None
self
Returns
-------
@@ -151,7 +150,7 @@ class Graphiti:
Parameters
----------
None
self
Returns
-------
@@ -178,6 +177,7 @@ class Graphiti:
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@@ -191,6 +191,8 @@ class Graphiti:
The reference time to retrieve episodes before.
last_n : int, optional
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
group_ids : list[str | None], optional
The group ids to return data from.
Returns
-------
@@ -202,7 +204,7 @@ class Graphiti:
The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module.
"""
return await retrieve_episodes(self.driver, reference_time, last_n)
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
async def add_episode(
self,
@@ -211,8 +213,8 @@ class Graphiti:
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
group_id: str | None = None,
uuid: str | None = None,
):
"""
Process an episode and update the graph.
@@ -232,10 +234,10 @@ class Graphiti:
The reference time for the episode.
source : EpisodeType, optional
The type of the episode. Defaults to EpisodeType.message.
success_callback : Callable | None, optional
A callback function to be called upon successful processing.
error_callback : Callable | None, optional
A callback function to be called if an error occurs during processing.
group_id : str | None
An id for the graph partition the episode is a part of.
uuid : str | None
Optional uuid of the episode.
Returns
-------
@@ -266,9 +268,12 @@ class Graphiti:
embedder = self.llm_client.get_embedder()
now = datetime.now()
previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
previous_episodes = await self.retrieve_episodes(
reference_time, last_n=3, group_ids=[group_id]
)
episode = EpisodicNode(
name=name,
group_id=group_id,
labels=[],
source=source,
content=episode_body,
@@ -276,6 +281,7 @@ class Graphiti:
created_at=now,
valid_at=reference_time,
)
episode.uuid = uuid if uuid is not None else episode.uuid
# Extract entities as nodes
@@ -299,7 +305,9 @@ class Graphiti:
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
),
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)
@@ -388,11 +396,7 @@ class Graphiti:
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
episode,
now,
)
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
logger.info(f'Built episodic edges: {episodic_edges}')
@@ -405,18 +409,10 @@ class Graphiti:
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e
raise e
async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
):
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
"""
Process multiple episodes in bulk and update the graph.
@@ -427,6 +423,8 @@ class Graphiti:
----------
bulk_episodes : list[RawEpisode]
A list of RawEpisode objects to be processed and added to the graph.
group_id : str | None
An id for the graph partition the episode is a part of.
Returns
-------
@@ -463,6 +461,7 @@ class Graphiti:
source=episode.source,
content=episode.content,
source_description=episode.source_description,
group_id=group_id,
created_at=now,
valid_at=episode.reference_time,
)
@@ -527,7 +526,13 @@ class Graphiti:
except Exception as e:
raise e
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
async def search(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=10,
):
"""
Perform a hybrid search on the knowledge graph.
@@ -540,6 +545,8 @@ class Graphiti:
The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
num_results : int, optional
The maximum number of results to return. Defaults to 10.
@@ -562,6 +569,7 @@ class Graphiti:
num_episodes=0,
num_edges=num_results,
num_nodes=0,
group_ids=group_ids,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
@@ -590,7 +598,10 @@ class Graphiti:
)
async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self,
query: str,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
@@ -602,6 +613,8 @@ class Graphiti:
----------
query : str
The text query to search for in the graph.
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int | None, optional
The maximum number of results to return per search method.
If None, a default limit will be applied.
@@ -626,5 +639,7 @@ class Graphiti:
"""
embedder = self.llm_client.get_embedder()
query_embedding = await generate_embedding(embedder, query)
relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
relevant_nodes = await hybrid_node_search(
[query], [query_embedding], self.driver, group_ids, limit
)
return relevant_nodes

View File

@@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from time import time
from typing import Any
from uuid import uuid4
from neo4j import AsyncDriver
from openai import OpenAI
from pydantic import BaseModel, Field
from graphiti_core.llm_client.config import EMBEDDING_DIM
@@ -69,6 +69,7 @@ class EpisodeType(Enum):
class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid4().hex)
name: str = Field(description='name of the node')
group_id: str | None = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now())
@@ -106,11 +107,12 @@ class EpisodicNode(Node):
result = await driver.execute_query(
"""
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
source_description=self.source_description,
content=self.content,
entity_edges=self.entity_edges,
@@ -141,29 +143,19 @@ class EpisodicNode(Node):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic {uuid: $uuid})
RETURN e.content as content,
e.created_at as created_at,
e.valid_at as valid_at,
e.uuid as uuid,
e.name as name,
e.source_description as source_description,
e.source as source
RETURN e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id
e.source_description AS source_description,
e.source AS source
""",
uuid=uuid,
)
episodes = [
EpisodicNode(
content=record['content'],
created_at=record['created_at'].to_native().timestamp(),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)
for record in records
]
episodes = [get_episodic_node_from_record(record) for record in records]
logger.info(f'Found Node: {uuid}')
@@ -174,10 +166,6 @@ class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
async def update_summary(self, driver: AsyncDriver): ...
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
start = time()
text = self.name.replace('\n', ' ')
@@ -192,10 +180,11 @@ class EntityNode(Node):
result = await driver.execute_query(
"""
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
@@ -227,25 +216,14 @@ class EntityNode(Node):
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id
n.created_at AS created_at,
n.summary AS summary
""",
uuid=uuid,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
nodes = [get_entity_node_from_record(record) for record in records]
logger.info(f'Found Node: {uuid}')
@@ -253,3 +231,26 @@ class EntityNode(Node):
# Node helpers
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
return EpisodicNode(
content=record['content'],
created_at=record['created_at'].to_native().timestamp(),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)
def get_entity_node_from_record(record: Any) -> EntityNode:
return EntityNode(
uuid=record['uuid'],
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)

View File

@@ -52,6 +52,7 @@ class SearchConfig(BaseModel):
num_edges: int = Field(default=10)
num_nodes: int = Field(default=10)
num_episodes: int = EPISODE_WINDOW_LEN
group_ids: list[str | None] | None
search_methods: list[SearchMethod]
reranker: Reranker | None
@@ -83,7 +84,9 @@ async def hybrid_search(
nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges)
text_search = await edge_fulltext_search(
driver, query, None, None, config.group_ids, 2 * config.num_edges
)
search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods:
@@ -95,7 +98,7 @@ async def hybrid_search(
)
similarity_search = await edge_similarity_search(
driver, search_vector, None, None, 2 * config.num_edges
driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
)
search_results.append(similarity_search)

View File

@@ -3,13 +3,11 @@ import logging
import re
from collections import defaultdict
from time import time
from typing import Any
from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge
from graphiti_core.helpers import parse_db_date
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
logger = logging.getLogger(__name__)
@@ -23,6 +21,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding
n.created_at AS created_at,
@@ -31,86 +30,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
uuids=episode_uuids,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN DISTINCT
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context: dict[str, Any] = {}
for record in records:
n_uuid = record['source_node_uuid']
if n_uuid in context:
context[n_uuid]['facts'].append(record['fact'])
else:
context[n_uuid] = {
'name': record['source_name'],
'summary': record['source_summary'],
'facts': [record['fact']],
}
m_uuid = record['target_node_uuid']
if m_uuid not in context:
context[m_uuid] = {
'name': record['target_name'],
'summary': record['target_summary'],
'facts': [],
}
logger.info(f'bfs search returned context: {context}')
return context
async def edge_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over embedded facts
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -129,8 +71,10 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -148,8 +92,10 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -167,8 +113,10 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -187,41 +135,32 @@ async def edge_similarity_search(
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)
edges = [get_entity_edge_from_record(record) for record in records]
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
search_vector: list[float],
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
MATCH (n WHERE n.group_id IN $group_ids)
RETURN
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
@@ -229,58 +168,44 @@ async def entity_similarity_search(
ORDER BY score DESC
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
query: str,
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
YIELD node AS n, score
MATCH (n WHERE n.group_id in $group_ids)
RETURN
node.uuid AS uuid,
node.name AS name,
node.name_embedding AS name_embedding,
node.created_at AS created_at,
node.summary AS summary
n.uuid AS uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
@@ -290,15 +215,20 @@ async def edge_fulltext_search(
query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
group_ids = group_ids if group_ids is not None else [None]
# fulltext search over facts
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -317,8 +247,10 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -336,8 +268,10 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -355,8 +289,10 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
@@ -377,27 +313,11 @@ async def edge_fulltext_search(
query=fuzzy_query,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)
edges = [get_entity_edge_from_record(record) for record in records]
return edges
@@ -406,6 +326,7 @@ async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
@@ -422,6 +343,8 @@ async def hybrid_node_search(
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
driver : AsyncDriver
The Neo4j driver instance for database operations.
group_ids : list[str] | None, optional
The list of group ids to retrieve nodes from.
limit : int | None, optional
The maximum number of results to return per search method. If None, a default limit will be applied.
@@ -448,8 +371,8 @@ async def hybrid_node_search(
results: list[list[EntityNode]] = list(
await asyncio.gather(
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
*[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
*[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
)
)
@@ -500,6 +423,7 @@ async def get_relevant_nodes(
[node.name for node in nodes],
[node.name_embedding for node in nodes if node.name_embedding is not None],
driver,
[node.group_id for node in nodes],
)
return relevant_nodes
@@ -518,13 +442,20 @@ async def get_relevant_edges(
results = await asyncio.gather(
*[
edge_similarity_search(
driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
driver,
edge.fact_embedding,
source_node_uuid,
target_node_uuid,
[edge.group_id],
limit,
)
for edge in edges
if edge.fact_embedding is not None
],
*[
edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
edge_fulltext_search(
driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit
)
for edge in edges
],
)

View File

@@ -17,6 +17,7 @@ limitations under the License.
import asyncio
import logging
import typing
from collections import defaultdict
from datetime import datetime
from math import ceil
@@ -42,7 +43,6 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
from graphiti_core.utils.utils import chunk_edges_by_nodes
logger = logging.getLogger(__name__)
@@ -62,7 +62,9 @@ async def retrieve_previous_episodes_bulk(
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather(
*[
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
retrieve_episodes(
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
)
for episode in episodes
]
)
@@ -90,7 +92,13 @@ async def extract_nodes_and_edges_bulk(
extracted_edges_bulk = await asyncio.gather(
*[
extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i])
extract_edges(
llm_client,
episode,
extracted_nodes_bulk[i],
previous_episodes_list[i],
episode.group_id,
)
for i, episode in enumerate(episodes)
]
)
@@ -343,3 +351,23 @@ async def extract_edge_dates_bulk(
edge.expired_at = datetime.now()
return edges
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
return edge_chunks

View File

@@ -37,15 +37,15 @@ def build_episodic_edges(
episode: EpisodicNode,
created_at: datetime,
) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = []
for node in entity_nodes:
edge = EpisodicEdge(
edges: List[EpisodicEdge] = [
EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=created_at,
group_id=episode.group_id,
)
edges.append(edge)
for node in entity_nodes
]
return edges
@@ -55,6 +55,7 @@ async def extract_edges(
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
group_id: str | None,
) -> list[EntityEdge]:
start = time()
@@ -88,6 +89,7 @@ async def extract_edges(
source_node_uuid=edge_data['source_node_uuid'],
target_node_uuid=edge_data['target_node_uuid'],
name=edge_data['relation_type'],
group_id=group_id,
fact=edge_data['fact'],
episodes=[episode.uuid],
created_at=datetime.now(),

View File

@@ -34,6 +34,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
@@ -86,6 +90,7 @@ async def retrieve_episodes(
driver: AsyncDriver,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@@ -96,25 +101,28 @@ async def retrieve_episodes(
less than or equal to this reference_time will be retrieved. This allows for
querying the graph's state at a specific point in time.
last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time.
group_ids (list[str], optional): The list of group ids to return data from.
Returns:
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
"""
result = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
RETURN e.content as content,
e.created_at as created_at,
e.valid_at as valid_at,
e.uuid as uuid,
e.name as name,
e.source_description as source_description,
e.source as source
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
RETURN e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.group_id AS group_id,
e.name AS name,
e.source_description AS source_description,
e.source AS source
ORDER BY e.created_at DESC
LIMIT $num_episodes
""",
reference_time=reference_time,
num_episodes=last_n,
group_ids=group_ids,
)
episodes = [
EpisodicNode(
@@ -124,6 +132,7 @@ async def retrieve_episodes(
),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],

View File

@@ -85,6 +85,7 @@ async def extract_nodes(
for node_data in extracted_node_data:
new_node = EntityNode(
name=node_data['name'],
group_id=episode.group_id,
labels=node_data['labels'],
summary=node_data['summary'],
created_at=datetime.now(),

View File

@@ -1,60 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections import defaultdict
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=episode.created_at,
)
)
return edges
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
return edge_chunks

View File

@@ -74,15 +74,15 @@ async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
edges = await graphiti.search('Freakenomics guest')
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('tania tetlow\n')
edges = await graphiti.search('tania tetlow', group_ids=['1'])
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('issues with higher ed')
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
graphiti.close()

View File

@@ -33,9 +33,9 @@ def create_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now)
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now)
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now)
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges
existing_edge1 = EntityEdge(
@@ -45,6 +45,7 @@ def create_test_data():
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
existing_edge2 = EntityEdge(
uuid='e2',
@@ -53,6 +54,7 @@ def create_test_data():
name='LIKES',
fact='Node2 likes Node3',
created_at=now,
group_id='1',
)
new_edge1 = EntityEdge(
uuid='e3',
@@ -61,6 +63,7 @@ def create_test_data():
name='WORKS_WITH',
fact='Node1 works with Node3',
created_at=now,
group_id='1',
)
new_edge2 = EntityEdge(
uuid='e4',
@@ -69,6 +72,7 @@ def create_test_data():
name='DISLIKES',
fact='Node1 dislikes Node2',
created_at=now,
group_id='1',
)
return {
@@ -135,9 +139,9 @@ def test_prepare_invalidation_context():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now)
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now)
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now)
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges
edge1 = EntityEdge(
@@ -147,6 +151,7 @@ def test_prepare_invalidation_context():
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
edge2 = EntityEdge(
uuid='e2',
@@ -155,6 +160,7 @@ def test_prepare_invalidation_context():
name='LIKES',
fact='Node2 likes Node3',
created_at=now,
group_id='1',
)
# Create NodeEdgeNodeTriplet objects
@@ -173,6 +179,7 @@ def test_prepare_invalidation_context():
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
previous_episodes = [
EpisodicNode(
@@ -182,6 +189,7 @@ def test_prepare_invalidation_context():
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode 1 for unit testing',
group_id='1',
),
EpisodicNode(
name='Previous Episode 2',
@@ -190,6 +198,7 @@ def test_prepare_invalidation_context():
valid_at=now - timedelta(days=2),
source=EpisodeType.message,
source_description='Test previous episode 2 for unit testing',
group_id='1',
),
]
@@ -235,6 +244,7 @@ def test_prepare_invalidation_context_empty_input():
valid_at=now,
source=EpisodeType.message,
source_description='Test empty episode for unit testing',
group_id='1',
)
result = prepare_invalidation_context([], [], current_episode, [])
assert isinstance(result, dict)
@@ -252,8 +262,8 @@ def test_prepare_invalidation_context_sorting():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now)
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now)
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
# Create edges with different timestamps
edge1 = EntityEdge(
@@ -263,6 +273,7 @@ def test_prepare_invalidation_context_sorting():
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
edge2 = EntityEdge(
uuid='e2',
@@ -271,6 +282,7 @@ def test_prepare_invalidation_context_sorting():
name='LIKES',
fact='Node2 likes Node1',
created_at=now + timedelta(hours=1),
group_id='1',
)
edge_with_nodes1 = (node1, edge1, node2)
@@ -287,6 +299,7 @@ def test_prepare_invalidation_context_sorting():
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
previous_episodes = [
EpisodicNode(
@@ -296,6 +309,7 @@ def test_prepare_invalidation_context_sorting():
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
),
]
@@ -321,6 +335,7 @@ class TestExtractDateStringsFromEdge(unittest.TestCase):
created_at=datetime.now(),
valid_at=valid_at,
invalid_at=invalid_at,
group_id='1',
)
def test_both_dates_present(self):

View File

@@ -76,6 +76,7 @@ def create_test_data():
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
# Create previous episodes
@@ -87,6 +88,7 @@ def create_test_data():
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
)
]
@@ -142,10 +144,12 @@ def create_complex_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now)
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now)
node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now)
node4 = EntityNode(uuid='4', name='Company XYZ', labels=['Organization'], created_at=now)
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1')
node4 = EntityNode(
uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1'
)
# Create edges
edge1 = EntityEdge(
@@ -154,6 +158,7 @@ def create_complex_test_data():
target_node_uuid='2',
name='LIKES',
fact='Alice likes Bob',
group_id='1',
created_at=now - timedelta(days=5),
)
edge2 = EntityEdge(
@@ -162,6 +167,7 @@ def create_complex_test_data():
target_node_uuid='3',
name='FRIENDS_WITH',
fact='Alice is friends with Charlie',
group_id='1',
created_at=now - timedelta(days=3),
)
edge3 = EntityEdge(
@@ -170,6 +176,7 @@ def create_complex_test_data():
target_node_uuid='4',
name='WORKS_FOR',
fact='Bob works for Company XYZ',
group_id='1',
created_at=now - timedelta(days=2),
)
@@ -199,6 +206,7 @@ async def test_invalidate_edges_complex():
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
group_id='1',
created_at=datetime.now(),
),
nodes[1],
@@ -225,6 +233,7 @@ async def test_invalidate_edges_temporal_update():
target_node_uuid='4',
name='LEFT_JOB',
fact='Bob left his job at Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
@@ -251,6 +260,7 @@ async def test_invalidate_edges_multiple_invalidations():
target_node_uuid='2',
name='ENEMIES_WITH',
fact='Alice and Bob are now enemies',
group_id='1',
created_at=datetime.now(),
),
nodes[1],
@@ -263,6 +273,7 @@ async def test_invalidate_edges_multiple_invalidations():
target_node_uuid='3',
name='ENDED_FRIENDSHIP',
fact='Alice ended her friendship with Charlie',
group_id='1',
created_at=datetime.now(),
),
nodes[2],
@@ -292,6 +303,7 @@ async def test_invalidate_edges_no_effect():
target_node_uuid='4',
name='APPLIED_TO',
fact='Charlie applied to Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
@@ -316,6 +328,7 @@ async def test_invalidate_edges_partial_update():
target_node_uuid='4',
name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],

View File

@@ -19,12 +19,12 @@ async def test_hybrid_node_search_deduplication():
) as mock_similarity_search:
# Set up mock return values
mock_fulltext_search.side_effect = [
[EntityNode(uuid='1', name='Alice', labels=['Entity'])],
[EntityNode(uuid='2', name='Bob', labels=['Entity'])],
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
[EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')],
]
mock_similarity_search.side_effect = [
[EntityNode(uuid='1', name='Alice', labels=['Entity'])],
[EntityNode(uuid='3', name='Charlie', labels=['Entity'])],
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
[EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')],
]
# Call the function with test data
@@ -70,7 +70,9 @@ async def test_hybrid_node_search_only_fulltext():
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.entity_similarity_search'
) as mock_similarity_search:
mock_fulltext_search.return_value = [EntityNode(uuid='1', name='Alice', labels=['Entity'])]
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
]
mock_similarity_search.return_value = []
queries = ['Alice']
@@ -93,18 +95,23 @@ async def test_hybrid_node_search_with_limit():
'graphiti_core.search.search_utils.entity_similarity_search'
) as mock_similarity_search:
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity']),
EntityNode(uuid='2', name='Bob', labels=['Entity']),
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
]
mock_similarity_search.return_value = [
EntityNode(uuid='3', name='Charlie', labels=['Entity']),
EntityNode(uuid='4', name='David', labels=['Entity']),
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
EntityNode(
uuid='4',
name='David',
labels=['Entity'],
group_id='1',
),
]
queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]]
limit = 1
results = await hybrid_node_search(queries, embeddings, mock_driver, limit)
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
# We expect 4 results because the limit is applied per search method
# before deduplication, and we're not actually limiting the results
@@ -113,8 +120,8 @@ async def test_hybrid_node_search_with_limit():
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with('Test', mock_driver, 2)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2)
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2)
@pytest.mark.asyncio
@@ -127,18 +134,18 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
'graphiti_core.search.search_utils.entity_similarity_search'
) as mock_similarity_search:
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity']),
EntityNode(uuid='2', name='Bob', labels=['Entity']),
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
]
mock_similarity_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity']), # Duplicate
EntityNode(uuid='3', name='Charlie', labels=['Entity']),
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), # Duplicate
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
]
queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]]
limit = 2
results = await hybrid_node_search(queries, embeddings, mock_driver, limit)
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
# We expect 3 results because:
# 1. The limit of 2 is applied to each search method
@@ -148,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with('Test', mock_driver, 4)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4)
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)