diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index f100926..ec14498 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -69,6 +69,7 @@ async def main(use_bulk: bool = True): episode_body=f'{message.speaker_name} ({message.role}): {message.content}', reference_time=message.actual_timestamp, source_description='Podcast Transcript', + group_id='1', ) return diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 645a3b3..1e60c94 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -18,6 +18,7 @@ import logging from abc import ABC, abstractmethod from datetime import datetime from time import time +from typing import Any from uuid import uuid4 from neo4j import AsyncDriver @@ -32,6 +33,7 @@ logger = logging.getLogger(__name__) class Edge(BaseModel, ABC): uuid: str = Field(default_factory=lambda: uuid4().hex) + group_id: str | None = Field(description='partition of the graph') source_node_uuid: str target_node_uuid: str created_at: datetime @@ -61,11 +63,12 @@ class EpisodicEdge(Edge): MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (node:Entity {uuid: $entity_uuid}) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) - SET r = {uuid: $uuid, created_at: $created_at} + SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at} RETURN r.uuid AS uuid""", episode_uuid=self.source_node_uuid, entity_uuid=self.target_node_uuid, uuid=self.uuid, + group_id=self.group_id, created_at=self.created_at, ) @@ -92,7 +95,8 @@ class EpisodicEdge(Edge): """ MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) RETURN - e.uuid As uuid, + e.uuid As uuid, + e.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, e.created_at AS created_at @@ -100,17 +104,7 @@ class EpisodicEdge(Edge): uuid=uuid, ) - edges: list[EpisodicEdge] = [] - - for record in records: - edges.append( - EpisodicEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - created_at=record['created_at'].to_native(), - ) - ) + edges = [get_episodic_edge_from_record(record) for record in records] logger.info(f'Found Edge: {uuid}') @@ -153,7 +147,7 @@ class EntityEdge(Edge): MATCH (source:Entity {uuid: $source_uuid}) MATCH (target:Entity {uuid: $target_uuid}) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) - SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding, + SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding, episodes: $episodes, created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at} RETURN r.uuid AS uuid""", @@ -161,6 +155,7 @@ class EntityEdge(Edge): target_uuid=self.target_node_uuid, uuid=self.uuid, name=self.name, + group_id=self.group_id, fact=self.fact, fact_embedding=self.fact_embedding, episodes=self.episodes, @@ -198,6 +193,7 @@ class EntityEdge(Edge): m.uuid AS target_node_uuid, e.created_at AS created_at, e.name AS name, + e.group_id AS group_id, e.fact AS fact, e.fact_embedding AS fact_embedding, e.episodes AS episodes, @@ -208,25 +204,36 @@ class EntityEdge(Edge): uuid=uuid, ) - edges: list[EntityEdge] = [] - - for record in records: - edges.append( - EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) - ) + edges = [get_entity_edge_from_record(record) for record in records] logger.info(f'Found Edge: {uuid}') return edges[0] + + +# Edge helpers +def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: + return EpisodicEdge( + uuid=record['uuid'], + group_id=record['group_id'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + created_at=record['created_at'].to_native(), + ) + + +def get_entity_edge_from_record(record: Any) -> EntityEdge: + return EntityEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + fact=record['fact'], + name=record['name'], + group_id=record['group_id'], + episodes=record['episodes'], + fact_embedding=record['fact_embedding'], + created_at=record['created_at'].to_native(), + expired_at=parse_db_date(record['expired_at']), + valid_at=parse_db_date(record['valid_at']), + invalid_at=parse_db_date(record['invalid_at']), + ) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 6684fcc..0b4e632 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -18,7 +18,6 @@ import asyncio import logging from datetime import datetime from time import time -from typing import Callable from dotenv import load_dotenv from neo4j import AsyncGraphDatabase @@ -120,7 +119,7 @@ class Graphiti: Parameters ---------- - None + self Returns ------- @@ -151,7 +150,7 @@ class Graphiti: Parameters ---------- - None + self Returns ------- @@ -178,6 +177,7 @@ class Graphiti: self, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str | None] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -191,6 +191,8 @@ class Graphiti: The reference time to retrieve episodes before. last_n : int, optional The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN. + group_ids : list[str | None], optional + The group ids to return data from. Returns ------- @@ -202,7 +204,7 @@ class Graphiti: The actual retrieval is performed by the `retrieve_episodes` function from the `graphiti_core.utils` module. """ - return await retrieve_episodes(self.driver, reference_time, last_n) + return await retrieve_episodes(self.driver, reference_time, last_n, group_ids) async def add_episode( self, @@ -211,8 +213,8 @@ class Graphiti: source_description: str, reference_time: datetime, source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + group_id: str | None = None, + uuid: str | None = None, ): """ Process an episode and update the graph. @@ -232,10 +234,10 @@ class Graphiti: The reference time for the episode. source : EpisodeType, optional The type of the episode. Defaults to EpisodeType.message. - success_callback : Callable | None, optional - A callback function to be called upon successful processing. - error_callback : Callable | None, optional - A callback function to be called if an error occurs during processing. + group_id : str | None + An id for the graph partition the episode is a part of. + uuid : str | None + Optional uuid of the episode. Returns ------- @@ -266,9 +268,12 @@ class Graphiti: embedder = self.llm_client.get_embedder() now = datetime.now() - previous_episodes = await self.retrieve_episodes(reference_time, last_n=3) + previous_episodes = await self.retrieve_episodes( + reference_time, last_n=3, group_ids=[group_id] + ) episode = EpisodicNode( name=name, + group_id=group_id, labels=[], source=source, content=episode_body, @@ -276,6 +281,7 @@ class Graphiti: created_at=now, valid_at=reference_time, ) + episode.uuid = uuid if uuid is not None else episode.uuid # Extract entities as nodes @@ -299,7 +305,9 @@ class Graphiti: (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather( resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists), - extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes), + extract_edges( + self.llm_client, episode, extracted_nodes, previous_episodes, group_id + ), ) logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}') nodes.extend(mentioned_nodes) @@ -388,11 +396,7 @@ class Graphiti: logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}') - episodic_edges: list[EpisodicEdge] = build_episodic_edges( - mentioned_nodes, - episode, - now, - ) + episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now) logger.info(f'Built episodic edges: {episodic_edges}') @@ -405,18 +409,10 @@ class Graphiti: end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') - if success_callback: - await success_callback(episode) except Exception as e: - if error_callback: - await error_callback(episode, e) - else: - raise e + raise e - async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], - ): + async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None): """ Process multiple episodes in bulk and update the graph. @@ -427,6 +423,8 @@ class Graphiti: ---------- bulk_episodes : list[RawEpisode] A list of RawEpisode objects to be processed and added to the graph. + group_id : str | None + An id for the graph partition the episode is a part of. Returns ------- @@ -463,6 +461,7 @@ class Graphiti: source=episode.source, content=episode.content, source_description=episode.source_description, + group_id=group_id, created_at=now, valid_at=episode.reference_time, ) @@ -527,7 +526,13 @@ class Graphiti: except Exception as e: raise e - async def search(self, query: str, center_node_uuid: str | None = None, num_results=10): + async def search( + self, + query: str, + center_node_uuid: str | None = None, + group_ids: list[str | None] | None = None, + num_results=10, + ): """ Perform a hybrid search on the knowledge graph. @@ -540,6 +545,8 @@ class Graphiti: The search query string. center_node_uuid: str, optional Facts will be reranked based on proximity to this node + group_ids : list[str | None] | None, optional + The graph partitions to return data from. num_results : int, optional The maximum number of results to return. Defaults to 10. @@ -562,6 +569,7 @@ class Graphiti: num_episodes=0, num_edges=num_results, num_nodes=0, + group_ids=group_ids, search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity], reranker=reranker, ) @@ -590,7 +598,10 @@ class Graphiti: ) async def get_nodes_by_query( - self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT + self, + query: str, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. @@ -602,6 +613,8 @@ class Graphiti: ---------- query : str The text query to search for in the graph. + group_ids : list[str | None] | None, optional + The graph partitions to return data from. limit : int | None, optional The maximum number of results to return per search method. If None, a default limit will be applied. @@ -626,5 +639,7 @@ class Graphiti: """ embedder = self.llm_client.get_embedder() query_embedding = await generate_embedding(embedder, query) - relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit) + relevant_nodes = await hybrid_node_search( + [query], [query_embedding], self.driver, group_ids, limit + ) return relevant_nodes diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index f30d001..907d52b 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -19,10 +19,10 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from time import time +from typing import Any from uuid import uuid4 from neo4j import AsyncDriver -from openai import OpenAI from pydantic import BaseModel, Field from graphiti_core.llm_client.config import EMBEDDING_DIM @@ -69,6 +69,7 @@ class EpisodeType(Enum): class Node(BaseModel, ABC): uuid: str = Field(default_factory=lambda: uuid4().hex) name: str = Field(description='name of the node') + group_id: str | None = Field(description='partition of the graph') labels: list[str] = Field(default_factory=list) created_at: datetime = Field(default_factory=lambda: datetime.now()) @@ -106,11 +107,12 @@ class EpisodicNode(Node): result = await driver.execute_query( """ MERGE (n:Episodic {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content, + SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid""", uuid=self.uuid, name=self.name, + group_id=self.group_id, source_description=self.source_description, content=self.content, entity_edges=self.entity_edges, @@ -141,29 +143,19 @@ class EpisodicNode(Node): records, _, _ = await driver.execute_query( """ MATCH (e:Episodic {uuid: $uuid}) - RETURN e.content as content, - e.created_at as created_at, - e.valid_at as valid_at, - e.uuid as uuid, - e.name as name, - e.source_description as source_description, - e.source as source + RETURN e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id + e.source_description AS source_description, + e.source AS source """, uuid=uuid, ) - episodes = [ - EpisodicNode( - content=record['content'], - created_at=record['created_at'].to_native().timestamp(), - valid_at=(record['valid_at'].to_native()), - uuid=record['uuid'], - source=EpisodeType.from_str(record['source']), - name=record['name'], - source_description=record['source_description'], - ) - for record in records - ] + episodes = [get_episodic_node_from_record(record) for record in records] logger.info(f'Found Node: {uuid}') @@ -174,10 +166,6 @@ class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - async def update_summary(self, driver: AsyncDriver): ... - - async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... - async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): start = time() text = self.name.replace('\n', ' ') @@ -192,10 +180,11 @@ class EntityNode(Node): result = await driver.execute_query( """ MERGE (n:Entity {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at} + SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at} RETURN n.uuid AS uuid""", uuid=self.uuid, name=self.name, + group_id=self.group_id, summary=self.summary, name_embedding=self.name_embedding, created_at=self.created_at, @@ -227,25 +216,14 @@ class EntityNode(Node): n.uuid As uuid, n.name AS name, n.name_embedding AS name_embedding, + n.group_id AS group_id n.created_at AS created_at, n.summary AS summary """, uuid=uuid, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] logger.info(f'Found Node: {uuid}') @@ -253,3 +231,26 @@ class EntityNode(Node): # Node helpers +def get_episodic_node_from_record(record: Any) -> EpisodicNode: + return EpisodicNode( + content=record['content'], + created_at=record['created_at'].to_native().timestamp(), + valid_at=(record['valid_at'].to_native()), + uuid=record['uuid'], + group_id=record['group_id'], + source=EpisodeType.from_str(record['source']), + name=record['name'], + source_description=record['source_description'], + ) + + +def get_entity_node_from_record(record: Any) -> EntityNode: + return EntityNode( + uuid=record['uuid'], + name=record['name'], + group_id=record['group_id'], + name_embedding=record['name_embedding'], + labels=['Entity'], + created_at=record['created_at'].to_native(), + summary=record['summary'], + ) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 172cd12..3e4c59f 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -52,6 +52,7 @@ class SearchConfig(BaseModel): num_edges: int = Field(default=10) num_nodes: int = Field(default=10) num_episodes: int = EPISODE_WINDOW_LEN + group_ids: list[str | None] | None search_methods: list[SearchMethod] reranker: Reranker | None @@ -83,7 +84,9 @@ async def hybrid_search( nodes.extend(await get_mentioned_nodes(driver, episodes)) if SearchMethod.bm25 in config.search_methods: - text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges) + text_search = await edge_fulltext_search( + driver, query, None, None, config.group_ids, 2 * config.num_edges + ) search_results.append(text_search) if SearchMethod.cosine_similarity in config.search_methods: @@ -95,7 +98,7 @@ async def hybrid_search( ) similarity_search = await edge_similarity_search( - driver, search_vector, None, None, 2 * config.num_edges + driver, search_vector, None, None, config.group_ids, 2 * config.num_edges ) search_results.append(similarity_search) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 3f7987b..5b63d30 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -3,13 +3,11 @@ import logging import re from collections import defaultdict from time import time -from typing import Any from neo4j import AsyncDriver, Query -from graphiti_core.edges import EntityEdge -from graphiti_core.helpers import parse_db_date -from graphiti_core.nodes import EntityNode, EpisodicNode +from graphiti_core.edges import EntityEdge, get_entity_edge_from_record +from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record logger = logging.getLogger(__name__) @@ -23,6 +21,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids RETURN DISTINCT n.uuid As uuid, + n.group_id AS group_id, n.name AS name, n.name_embedding AS name_embedding n.created_at AS created_at, @@ -31,86 +30,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) uuids=episode_uuids, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] return nodes -async def bfs(node_ids: list[str], driver: AsyncDriver): - records, _, _ = await driver.execute_query( - """ - MATCH (n WHERE n.uuid in $node_ids)-[r]->(m) - RETURN DISTINCT - n.uuid AS source_node_uuid, - n.name AS source_name, - n.summary AS source_summary, - m.uuid AS target_node_uuid, - m.name AS target_name, - m.summary AS target_summary, - r.uuid AS uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - - """, - node_ids=node_ids, - ) - - context: dict[str, Any] = {} - - for record in records: - n_uuid = record['source_node_uuid'] - if n_uuid in context: - context[n_uuid]['facts'].append(record['fact']) - else: - context[n_uuid] = { - 'name': record['source_name'], - 'summary': record['source_summary'], - 'facts': [record['fact']], - } - - m_uuid = record['target_node_uuid'] - if m_uuid not in context: - context[m_uuid] = { - 'name': record['target_name'], - 'summary': record['target_summary'], - 'facts': [], - } - logger.info(f'bfs search returned context: {context}') - return context - - async def edge_similarity_search( driver: AsyncDriver, search_vector: list[float], source_node_uuid: str | None, target_node_uuid: str | None, + group_ids: list[str | None] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: + group_ids = group_ids if group_ids is not None else [None] # vector similarity search over embedded facts query = Query(""" CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -129,8 +71,10 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -148,8 +92,10 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -167,8 +113,10 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -187,41 +135,32 @@ async def edge_similarity_search( search_vector=search_vector, source_uuid=source_node_uuid, target_uuid=target_node_uuid, + group_ids=group_ids, limit=limit, ) - edges: list[EntityEdge] = [] - - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) - - edges.append(edge) + edges = [get_entity_edge_from_record(record) for record in records] return edges async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: + group_ids = group_ids if group_ids is not None else [None] + # vector similarity search over entity names records, _, _ = await driver.execute_query( """ CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score + MATCH (n WHERE n.group_id IN $group_ids) RETURN - n.uuid As uuid, + n.uuid As uuid, + n.group_id AS group_id, n.name AS name, n.name_embedding AS name_embedding, n.created_at AS created_at, @@ -229,58 +168,44 @@ async def entity_similarity_search( ORDER BY score DESC """, search_vector=search_vector, + group_ids=group_ids, limit=limit, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] return nodes async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: + group_ids = group_ids if group_ids is not None else [None] + # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' records, _, _ = await driver.execute_query( """ - CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score + CALL db.index.fulltext.queryNodes("name_and_summary", $query) + YIELD node AS n, score + MATCH (n WHERE n.group_id in $group_ids) RETURN - node.uuid AS uuid, - node.name AS name, - node.name_embedding AS name_embedding, - node.created_at AS created_at, - node.summary AS summary + n.uuid AS uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary ORDER BY score DESC LIMIT $limit """, query=fuzzy_query, + group_ids=group_ids, limit=limit, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -290,15 +215,20 @@ async def edge_fulltext_search( query: str, source_node_uuid: str | None, target_node_uuid: str | None, + group_ids: list[str | None] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: + group_ids = group_ids if group_ids is not None else [None] + # fulltext search over facts cypher_query = Query(""" CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -317,8 +247,10 @@ async def edge_fulltext_search( CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -335,9 +267,11 @@ async def edge_fulltext_search( cypher_query = Query(""" CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -354,9 +288,11 @@ async def edge_fulltext_search( cypher_query = Query(""" CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -377,27 +313,11 @@ async def edge_fulltext_search( query=fuzzy_query, source_uuid=source_node_uuid, target_uuid=target_node_uuid, + group_ids=group_ids, limit=limit, ) - edges: list[EntityEdge] = [] - - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) - - edges.append(edge) + edges = [get_entity_edge_from_record(record) for record in records] return edges @@ -406,6 +326,7 @@ async def hybrid_node_search( queries: list[str], embeddings: list[list[float]], driver: AsyncDriver, + group_ids: list[str | None] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ @@ -422,6 +343,8 @@ async def hybrid_node_search( A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed. driver : AsyncDriver The Neo4j driver instance for database operations. + group_ids : list[str] | None, optional + The list of group ids to retrieve nodes from. limit : int | None, optional The maximum number of results to return per search method. If None, a default limit will be applied. @@ -448,8 +371,8 @@ async def hybrid_node_search( results: list[list[EntityNode]] = list( await asyncio.gather( - *[entity_fulltext_search(q, driver, 2 * limit) for q in queries], - *[entity_similarity_search(e, driver, 2 * limit) for e in embeddings], + *[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries], + *[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings], ) ) @@ -500,6 +423,7 @@ async def get_relevant_nodes( [node.name for node in nodes], [node.name_embedding for node in nodes if node.name_embedding is not None], driver, + [node.group_id for node in nodes], ) return relevant_nodes @@ -518,13 +442,20 @@ async def get_relevant_edges( results = await asyncio.gather( *[ edge_similarity_search( - driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit + driver, + edge.fact_embedding, + source_node_uuid, + target_node_uuid, + [edge.group_id], + limit, ) for edge in edges if edge.fact_embedding is not None ], *[ - edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit) + edge_fulltext_search( + driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit + ) for edge in edges ], ) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 4c8f12a..49bc2c6 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -17,6 +17,7 @@ limitations under the License. import asyncio import logging import typing +from collections import defaultdict from datetime import datetime from math import ceil @@ -42,7 +43,6 @@ from graphiti_core.utils.maintenance.node_operations import ( extract_nodes, ) from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates -from graphiti_core.utils.utils import chunk_edges_by_nodes logger = logging.getLogger(__name__) @@ -62,7 +62,9 @@ async def retrieve_previous_episodes_bulk( ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: previous_episodes_list = await asyncio.gather( *[ - retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN) + retrieve_episodes( + driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id] + ) for episode in episodes ] ) @@ -90,7 +92,13 @@ async def extract_nodes_and_edges_bulk( extracted_edges_bulk = await asyncio.gather( *[ - extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]) + extract_edges( + llm_client, + episode, + extracted_nodes_bulk[i], + previous_episodes_list[i], + episode.group_id, + ) for i, episode in enumerate(episodes) ] ) @@ -343,3 +351,23 @@ async def extract_edge_dates_bulk( edge.expired_at = datetime.now() return edges + + +def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]: + # We only want to dedupe edges that are between the same pair of nodes + # We build a map of the edges based on their source and target nodes. + edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) + for edge in edges: + # We drop loop edges + if edge.source_node_uuid == edge.target_node_uuid: + continue + + # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution + pointers = [edge.source_node_uuid, edge.target_node_uuid] + pointers.sort() + + edge_chunk_map[pointers[0] + pointers[1]].append(edge) + + edge_chunks = [chunk for chunk in edge_chunk_map.values()] + + return edge_chunks diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 0d6aa9e..4518c8d 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -37,15 +37,15 @@ def build_episodic_edges( episode: EpisodicNode, created_at: datetime, ) -> List[EpisodicEdge]: - edges: List[EpisodicEdge] = [] - - for node in entity_nodes: - edge = EpisodicEdge( + edges: List[EpisodicEdge] = [ + EpisodicEdge( source_node_uuid=episode.uuid, target_node_uuid=node.uuid, created_at=created_at, + group_id=episode.group_id, ) - edges.append(edge) + for node in entity_nodes + ] return edges @@ -55,6 +55,7 @@ async def extract_edges( episode: EpisodicNode, nodes: list[EntityNode], previous_episodes: list[EpisodicNode], + group_id: str | None, ) -> list[EntityEdge]: start = time() @@ -88,6 +89,7 @@ async def extract_edges( source_node_uuid=edge_data['source_node_uuid'], target_node_uuid=edge_data['target_node_uuid'], name=edge_data['relation_type'], + group_id=group_id, fact=edge_data['fact'], episodes=[episode.uuid], created_at=datetime.now(), diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 38620a8..a942a00 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -34,6 +34,10 @@ async def build_indices_and_constraints(driver: AsyncDriver): 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', + 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', + 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', + 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', @@ -86,6 +90,7 @@ async def retrieve_episodes( driver: AsyncDriver, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str | None] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -96,25 +101,28 @@ async def retrieve_episodes( less than or equal to this reference_time will be retrieved. This allows for querying the graph's state at a specific point in time. last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time. + group_ids (list[str], optional): The list of group ids to return data from. Returns: list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. """ result = await driver.execute_query( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - RETURN e.content as content, - e.created_at as created_at, - e.valid_at as valid_at, - e.uuid as uuid, - e.name as name, - e.source_description as source_description, - e.source as source + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids + RETURN e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.group_id AS group_id, + e.name AS name, + e.source_description AS source_description, + e.source AS source ORDER BY e.created_at DESC LIMIT $num_episodes """, reference_time=reference_time, num_episodes=last_n, + group_ids=group_ids, ) episodes = [ EpisodicNode( @@ -124,6 +132,7 @@ async def retrieve_episodes( ), valid_at=(record['valid_at'].to_native()), uuid=record['uuid'], + group_id=record['group_id'], source=EpisodeType.from_str(record['source']), name=record['name'], source_description=record['source_description'], diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 30673ee..1aa6c75 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -85,6 +85,7 @@ async def extract_nodes( for node_data in extracted_node_data: new_node = EntityNode( name=node_data['name'], + group_id=episode.group_id, labels=node_data['labels'], summary=node_data['summary'], created_at=datetime.now(), diff --git a/graphiti_core/utils/utils.py b/graphiti_core/utils/utils.py deleted file mode 100644 index 9782127..0000000 --- a/graphiti_core/utils/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import logging -from collections import defaultdict - -from graphiti_core.edges import EntityEdge, EpisodicEdge -from graphiti_core.nodes import EntityNode, EpisodicNode - -logger = logging.getLogger(__name__) - - -def build_episodic_edges( - entity_nodes: list[EntityNode], episode: EpisodicNode -) -> list[EpisodicEdge]: - edges: list[EpisodicEdge] = [] - - for node in entity_nodes: - edges.append( - EpisodicEdge( - source_node_uuid=episode.uuid, - target_node_uuid=node.uuid, - created_at=episode.created_at, - ) - ) - - return edges - - -def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]: - # We only want to dedupe edges that are between the same pair of nodes - # We build a map of the edges based on their source and target nodes. - edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) - for edge in edges: - # We drop loop edges - if edge.source_node_uuid == edge.target_node_uuid: - continue - - # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution - pointers = [edge.source_node_uuid, edge.target_node_uuid] - pointers.sort() - - edge_chunk_map[pointers[0] + pointers[1]].append(edge) - - edge_chunks = [chunk for chunk in edge_chunk_map.values()] - - return edge_chunks diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 68c54a0..2c2ebc3 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -74,15 +74,15 @@ async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - edges = await graphiti.search('Freakenomics guest') + edges = await graphiti.search('Freakenomics guest', group_ids=['1']) logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges])) - edges = await graphiti.search('tania tetlow\n') + edges = await graphiti.search('tania tetlow', group_ids=['1']) logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges])) - edges = await graphiti.search('issues with higher ed') + edges = await graphiti.search('issues with higher ed', group_ids=['1']) logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) graphiti.close() diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py index 76224bc..0e86bd8 100644 --- a/tests/utils/maintenance/test_temporal_operations.py +++ b/tests/utils/maintenance/test_temporal_operations.py @@ -33,9 +33,9 @@ def create_test_data(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') + node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1') # Create edges existing_edge1 = EntityEdge( @@ -45,6 +45,7 @@ def create_test_data(): name='KNOWS', fact='Node1 knows Node2', created_at=now, + group_id='1', ) existing_edge2 = EntityEdge( uuid='e2', @@ -53,6 +54,7 @@ def create_test_data(): name='LIKES', fact='Node2 likes Node3', created_at=now, + group_id='1', ) new_edge1 = EntityEdge( uuid='e3', @@ -61,6 +63,7 @@ def create_test_data(): name='WORKS_WITH', fact='Node1 works with Node3', created_at=now, + group_id='1', ) new_edge2 = EntityEdge( uuid='e4', @@ -69,6 +72,7 @@ def create_test_data(): name='DISLIKES', fact='Node1 dislikes Node2', created_at=now, + group_id='1', ) return { @@ -135,9 +139,9 @@ def test_prepare_invalidation_context(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') + node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1') # Create edges edge1 = EntityEdge( @@ -147,6 +151,7 @@ def test_prepare_invalidation_context(): name='KNOWS', fact='Node1 knows Node2', created_at=now, + group_id='1', ) edge2 = EntityEdge( uuid='e2', @@ -155,6 +160,7 @@ def test_prepare_invalidation_context(): name='LIKES', fact='Node2 likes Node3', created_at=now, + group_id='1', ) # Create NodeEdgeNodeTriplet objects @@ -173,6 +179,7 @@ def test_prepare_invalidation_context(): valid_at=now, source=EpisodeType.message, source_description='Test episode for unit testing', + group_id='1', ) previous_episodes = [ EpisodicNode( @@ -182,6 +189,7 @@ def test_prepare_invalidation_context(): valid_at=now - timedelta(days=1), source=EpisodeType.message, source_description='Test previous episode 1 for unit testing', + group_id='1', ), EpisodicNode( name='Previous Episode 2', @@ -190,6 +198,7 @@ def test_prepare_invalidation_context(): valid_at=now - timedelta(days=2), source=EpisodeType.message, source_description='Test previous episode 2 for unit testing', + group_id='1', ), ] @@ -235,6 +244,7 @@ def test_prepare_invalidation_context_empty_input(): valid_at=now, source=EpisodeType.message, source_description='Test empty episode for unit testing', + group_id='1', ) result = prepare_invalidation_context([], [], current_episode, []) assert isinstance(result, dict) @@ -252,8 +262,8 @@ def test_prepare_invalidation_context_sorting(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') # Create edges with different timestamps edge1 = EntityEdge( @@ -263,6 +273,7 @@ def test_prepare_invalidation_context_sorting(): name='KNOWS', fact='Node1 knows Node2', created_at=now, + group_id='1', ) edge2 = EntityEdge( uuid='e2', @@ -271,6 +282,7 @@ def test_prepare_invalidation_context_sorting(): name='LIKES', fact='Node2 likes Node1', created_at=now + timedelta(hours=1), + group_id='1', ) edge_with_nodes1 = (node1, edge1, node2) @@ -287,6 +299,7 @@ def test_prepare_invalidation_context_sorting(): valid_at=now, source=EpisodeType.message, source_description='Test episode for unit testing', + group_id='1', ) previous_episodes = [ EpisodicNode( @@ -296,6 +309,7 @@ def test_prepare_invalidation_context_sorting(): valid_at=now - timedelta(days=1), source=EpisodeType.message, source_description='Test previous episode for unit testing', + group_id='1', ), ] @@ -321,6 +335,7 @@ class TestExtractDateStringsFromEdge(unittest.TestCase): created_at=datetime.now(), valid_at=valid_at, invalid_at=invalid_at, + group_id='1', ) def test_both_dates_present(self): diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 9e6b295..b08689f 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -76,6 +76,7 @@ def create_test_data(): valid_at=now, source=EpisodeType.message, source_description='Test episode for unit testing', + group_id='1', ) # Create previous episodes @@ -87,6 +88,7 @@ def create_test_data(): valid_at=now - timedelta(days=1), source=EpisodeType.message, source_description='Test previous episode for unit testing', + group_id='1', ) ] @@ -142,10 +144,12 @@ def create_complex_test_data(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now) - node4 = EntityNode(uuid='4', name='Company XYZ', labels=['Organization'], created_at=now) + node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1') + node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1') + node4 = EntityNode( + uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1' + ) # Create edges edge1 = EntityEdge( @@ -154,6 +158,7 @@ def create_complex_test_data(): target_node_uuid='2', name='LIKES', fact='Alice likes Bob', + group_id='1', created_at=now - timedelta(days=5), ) edge2 = EntityEdge( @@ -162,6 +167,7 @@ def create_complex_test_data(): target_node_uuid='3', name='FRIENDS_WITH', fact='Alice is friends with Charlie', + group_id='1', created_at=now - timedelta(days=3), ) edge3 = EntityEdge( @@ -170,6 +176,7 @@ def create_complex_test_data(): target_node_uuid='4', name='WORKS_FOR', fact='Bob works for Company XYZ', + group_id='1', created_at=now - timedelta(days=2), ) @@ -199,6 +206,7 @@ async def test_invalidate_edges_complex(): target_node_uuid='2', name='DISLIKES', fact='Alice dislikes Bob', + group_id='1', created_at=datetime.now(), ), nodes[1], @@ -225,6 +233,7 @@ async def test_invalidate_edges_temporal_update(): target_node_uuid='4', name='LEFT_JOB', fact='Bob left his job at Company XYZ', + group_id='1', created_at=datetime.now(), ), nodes[3], @@ -251,6 +260,7 @@ async def test_invalidate_edges_multiple_invalidations(): target_node_uuid='2', name='ENEMIES_WITH', fact='Alice and Bob are now enemies', + group_id='1', created_at=datetime.now(), ), nodes[1], @@ -263,6 +273,7 @@ async def test_invalidate_edges_multiple_invalidations(): target_node_uuid='3', name='ENDED_FRIENDSHIP', fact='Alice ended her friendship with Charlie', + group_id='1', created_at=datetime.now(), ), nodes[2], @@ -292,6 +303,7 @@ async def test_invalidate_edges_no_effect(): target_node_uuid='4', name='APPLIED_TO', fact='Charlie applied to Company XYZ', + group_id='1', created_at=datetime.now(), ), nodes[3], @@ -316,6 +328,7 @@ async def test_invalidate_edges_partial_update(): target_node_uuid='4', name='CHANGED_POSITION', fact='Bob changed his position at Company XYZ', + group_id='1', created_at=datetime.now(), ), nodes[3], diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index e476097..38837f0 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -19,12 +19,12 @@ async def test_hybrid_node_search_deduplication(): ) as mock_similarity_search: # Set up mock return values mock_fulltext_search.side_effect = [ - [EntityNode(uuid='1', name='Alice', labels=['Entity'])], - [EntityNode(uuid='2', name='Bob', labels=['Entity'])], + [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')], + [EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')], ] mock_similarity_search.side_effect = [ - [EntityNode(uuid='1', name='Alice', labels=['Entity'])], - [EntityNode(uuid='3', name='Charlie', labels=['Entity'])], + [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')], + [EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')], ] # Call the function with test data @@ -70,7 +70,9 @@ async def test_hybrid_node_search_only_fulltext(): ) as mock_fulltext_search, patch( 'graphiti_core.search.search_utils.entity_similarity_search' ) as mock_similarity_search: - mock_fulltext_search.return_value = [EntityNode(uuid='1', name='Alice', labels=['Entity'])] + mock_fulltext_search.return_value = [ + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1') + ] mock_similarity_search.return_value = [] queries = ['Alice'] @@ -93,18 +95,23 @@ async def test_hybrid_node_search_with_limit(): 'graphiti_core.search.search_utils.entity_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [ - EntityNode(uuid='1', name='Alice', labels=['Entity']), - EntityNode(uuid='2', name='Bob', labels=['Entity']), + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), + EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'), ] mock_similarity_search.return_value = [ - EntityNode(uuid='3', name='Charlie', labels=['Entity']), - EntityNode(uuid='4', name='David', labels=['Entity']), + EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'), + EntityNode( + uuid='4', + name='David', + labels=['Entity'], + group_id='1', + ), ] queries = ['Test'] embeddings = [[0.1, 0.2, 0.3]] limit = 1 - results = await hybrid_node_search(queries, embeddings, mock_driver, limit) + results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) # We expect 4 results because the limit is applied per search method # before deduplication, and we're not actually limiting the results @@ -113,8 +120,8 @@ async def test_hybrid_node_search_with_limit(): assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 # Verify that the limit was passed to the search functions - mock_fulltext_search.assert_called_with('Test', mock_driver, 2) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2) + mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2) + mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2) @pytest.mark.asyncio @@ -127,18 +134,18 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): 'graphiti_core.search.search_utils.entity_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [ - EntityNode(uuid='1', name='Alice', labels=['Entity']), - EntityNode(uuid='2', name='Bob', labels=['Entity']), + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), + EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'), ] mock_similarity_search.return_value = [ - EntityNode(uuid='1', name='Alice', labels=['Entity']), # Duplicate - EntityNode(uuid='3', name='Charlie', labels=['Entity']), + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), # Duplicate + EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'), ] queries = ['Test'] embeddings = [[0.1, 0.2, 0.3]] limit = 2 - results = await hybrid_node_search(queries, embeddings, mock_driver, limit) + results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) # We expect 3 results because: # 1. The limit of 2 is applied to each search method @@ -148,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 - mock_fulltext_search.assert_called_with('Test', mock_driver, 4) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4) + mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4) + mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)