mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
Add group ids (#89)
* set and retrieve group ids * update add episode with group id support * add episode and search functional * update bulk * mypy updates * remove unused imports * update unit tests * unit tests * add optional uuid field * format * mypy * ellipsis
This commit is contained in:
committed by
GitHub
parent
c7fc057106
commit
42fb590606
@@ -69,6 +69,7 @@ async def main(use_bulk: bool = True):
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id='1',
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
||||
group_id: str | None = Field(description='partition of the graph')
|
||||
source_node_uuid: str
|
||||
target_node_uuid: str
|
||||
created_at: datetime
|
||||
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, created_at: $created_at}
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid""",
|
||||
episode_uuid=self.source_node_uuid,
|
||||
entity_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
@@ -93,6 +96,7 @@ class EpisodicEdge(Edge):
|
||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
edges: list[EpisodicEdge] = []
|
||||
|
||||
for record in records:
|
||||
edges.append(
|
||||
EpisodicEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
)
|
||||
)
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Edge: {uuid}')
|
||||
|
||||
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
|
||||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
|
||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
|
||||
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
||||
valid_at: $valid_at, invalid_at: $invalid_at}
|
||||
RETURN r.uuid AS uuid""",
|
||||
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
|
||||
target_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
fact=self.fact,
|
||||
fact_embedding=self.fact_embedding,
|
||||
episodes=self.episodes,
|
||||
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edges.append(
|
||||
EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
)
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Edge: {uuid}')
|
||||
|
||||
return edges[0]
|
||||
|
||||
|
||||
# Edge helpers
|
||||
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
||||
return EpisodicEdge(
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
)
|
||||
|
||||
|
||||
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||
return EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import Callable
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from neo4j import AsyncGraphDatabase
|
||||
@@ -120,7 +119,7 @@ class Graphiti:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
self
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -151,7 +150,7 @@ class Graphiti:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
self
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -178,6 +177,7 @@ class Graphiti:
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str | None] | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
@@ -191,6 +191,8 @@ class Graphiti:
|
||||
The reference time to retrieve episodes before.
|
||||
last_n : int, optional
|
||||
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
|
||||
group_ids : list[str | None], optional
|
||||
The group ids to return data from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -202,7 +204,7 @@ class Graphiti:
|
||||
The actual retrieval is performed by the `retrieve_episodes` function
|
||||
from the `graphiti_core.utils` module.
|
||||
"""
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n)
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
@@ -211,8 +213,8 @@ class Graphiti:
|
||||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
group_id: str | None = None,
|
||||
uuid: str | None = None,
|
||||
):
|
||||
"""
|
||||
Process an episode and update the graph.
|
||||
@@ -232,10 +234,10 @@ class Graphiti:
|
||||
The reference time for the episode.
|
||||
source : EpisodeType, optional
|
||||
The type of the episode. Defaults to EpisodeType.message.
|
||||
success_callback : Callable | None, optional
|
||||
A callback function to be called upon successful processing.
|
||||
error_callback : Callable | None, optional
|
||||
A callback function to be called if an error occurs during processing.
|
||||
group_id : str | None
|
||||
An id for the graph partition the episode is a part of.
|
||||
uuid : str | None
|
||||
Optional uuid of the episode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -266,9 +268,12 @@ class Graphiti:
|
||||
embedder = self.llm_client.get_embedder()
|
||||
now = datetime.now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
|
||||
previous_episodes = await self.retrieve_episodes(
|
||||
reference_time, last_n=3, group_ids=[group_id]
|
||||
)
|
||||
episode = EpisodicNode(
|
||||
name=name,
|
||||
group_id=group_id,
|
||||
labels=[],
|
||||
source=source,
|
||||
content=episode_body,
|
||||
@@ -276,6 +281,7 @@ class Graphiti:
|
||||
created_at=now,
|
||||
valid_at=reference_time,
|
||||
)
|
||||
episode.uuid = uuid if uuid is not None else episode.uuid
|
||||
|
||||
# Extract entities as nodes
|
||||
|
||||
@@ -299,7 +305,9 @@ class Graphiti:
|
||||
|
||||
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
||||
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
||||
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
|
||||
extract_edges(
|
||||
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
||||
),
|
||||
)
|
||||
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
||||
nodes.extend(mentioned_nodes)
|
||||
@@ -388,11 +396,7 @@ class Graphiti:
|
||||
|
||||
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
||||
|
||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
||||
mentioned_nodes,
|
||||
episode,
|
||||
now,
|
||||
)
|
||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
||||
|
||||
logger.info(f'Built episodic edges: {episodic_edges}')
|
||||
|
||||
@@ -405,18 +409,10 @@ class Graphiti:
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||
|
||||
if success_callback:
|
||||
await success_callback(episode)
|
||||
except Exception as e:
|
||||
if error_callback:
|
||||
await error_callback(episode, e)
|
||||
else:
|
||||
raise e
|
||||
raise e
|
||||
|
||||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
):
|
||||
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
|
||||
"""
|
||||
Process multiple episodes in bulk and update the graph.
|
||||
|
||||
@@ -427,6 +423,8 @@ class Graphiti:
|
||||
----------
|
||||
bulk_episodes : list[RawEpisode]
|
||||
A list of RawEpisode objects to be processed and added to the graph.
|
||||
group_id : str | None
|
||||
An id for the graph partition the episode is a part of.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -463,6 +461,7 @@ class Graphiti:
|
||||
source=episode.source,
|
||||
content=episode.content,
|
||||
source_description=episode.source_description,
|
||||
group_id=group_id,
|
||||
created_at=now,
|
||||
valid_at=episode.reference_time,
|
||||
)
|
||||
@@ -527,7 +526,13 @@ class Graphiti:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
num_results=10,
|
||||
):
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
|
||||
@@ -540,6 +545,8 @@ class Graphiti:
|
||||
The search query string.
|
||||
center_node_uuid: str, optional
|
||||
Facts will be reranked based on proximity to this node
|
||||
group_ids : list[str | None] | None, optional
|
||||
The graph partitions to return data from.
|
||||
num_results : int, optional
|
||||
The maximum number of results to return. Defaults to 10.
|
||||
|
||||
@@ -562,6 +569,7 @@ class Graphiti:
|
||||
num_episodes=0,
|
||||
num_edges=num_results,
|
||||
num_nodes=0,
|
||||
group_ids=group_ids,
|
||||
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
||||
reranker=reranker,
|
||||
)
|
||||
@@ -590,7 +598,10 @@ class Graphiti:
|
||||
)
|
||||
|
||||
async def get_nodes_by_query(
|
||||
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
||||
self,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Retrieve nodes from the graph database based on a text query.
|
||||
@@ -602,6 +613,8 @@ class Graphiti:
|
||||
----------
|
||||
query : str
|
||||
The text query to search for in the graph.
|
||||
group_ids : list[str | None] | None, optional
|
||||
The graph partitions to return data from.
|
||||
limit : int | None, optional
|
||||
The maximum number of results to return per search method.
|
||||
If None, a default limit will be applied.
|
||||
@@ -626,5 +639,7 @@ class Graphiti:
|
||||
"""
|
||||
embedder = self.llm_client.get_embedder()
|
||||
query_embedding = await generate_embedding(embedder, query)
|
||||
relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
|
||||
relevant_nodes = await hybrid_node_search(
|
||||
[query], [query_embedding], self.driver, group_ids, limit
|
||||
)
|
||||
return relevant_nodes
|
||||
|
||||
@@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from time import time
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
@@ -69,6 +69,7 @@ class EpisodeType(Enum):
|
||||
class Node(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
||||
name: str = Field(description='name of the node')
|
||||
group_id: str | None = Field(description='partition of the graph')
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
||||
|
||||
@@ -106,11 +107,12 @@ class EpisodicNode(Node):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid""",
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
source_description=self.source_description,
|
||||
content=self.content,
|
||||
entity_edges=self.entity_edges,
|
||||
@@ -141,29 +143,19 @@ class EpisodicNode(Node):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic {uuid: $uuid})
|
||||
RETURN e.content as content,
|
||||
e.created_at as created_at,
|
||||
e.valid_at as valid_at,
|
||||
e.uuid as uuid,
|
||||
e.name as name,
|
||||
e.source_description as source_description,
|
||||
e.source as source
|
||||
RETURN e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=record['created_at'].to_native().timestamp(),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
uuid=record['uuid'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
)
|
||||
for record in records
|
||||
]
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Node: {uuid}')
|
||||
|
||||
@@ -174,10 +166,6 @@ class EntityNode(Node):
|
||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||
|
||||
async def update_summary(self, driver: AsyncDriver): ...
|
||||
|
||||
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
||||
|
||||
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
@@ -192,10 +180,11 @@ class EntityNode(Node):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
|
||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
RETURN n.uuid AS uuid""",
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
@@ -227,25 +216,14 @@ class EntityNode(Node):
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
name_embedding=record['name_embedding'],
|
||||
labels=['Entity'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Node: {uuid}')
|
||||
|
||||
@@ -253,3 +231,26 @@ class EntityNode(Node):
|
||||
|
||||
|
||||
# Node helpers
|
||||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||
return EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=record['created_at'].to_native().timestamp(),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
)
|
||||
|
||||
|
||||
def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||
return EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
name_embedding=record['name_embedding'],
|
||||
labels=['Entity'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
|
||||
@@ -52,6 +52,7 @@ class SearchConfig(BaseModel):
|
||||
num_edges: int = Field(default=10)
|
||||
num_nodes: int = Field(default=10)
|
||||
num_episodes: int = EPISODE_WINDOW_LEN
|
||||
group_ids: list[str | None] | None
|
||||
search_methods: list[SearchMethod]
|
||||
reranker: Reranker | None
|
||||
|
||||
@@ -83,7 +84,9 @@ async def hybrid_search(
|
||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||
|
||||
if SearchMethod.bm25 in config.search_methods:
|
||||
text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges)
|
||||
text_search = await edge_fulltext_search(
|
||||
driver, query, None, None, config.group_ids, 2 * config.num_edges
|
||||
)
|
||||
search_results.append(text_search)
|
||||
|
||||
if SearchMethod.cosine_similarity in config.search_methods:
|
||||
@@ -95,7 +98,7 @@ async def hybrid_search(
|
||||
)
|
||||
|
||||
similarity_search = await edge_similarity_search(
|
||||
driver, search_vector, None, None, 2 * config.num_edges
|
||||
driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
|
||||
@@ -3,13 +3,11 @@ import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
from neo4j import AsyncDriver, Query
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,6 +21,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding
|
||||
n.created_at AS created_at,
|
||||
@@ -31,86 +30,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
||||
uuids=episode_uuids,
|
||||
)
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
name_embedding=record['name_embedding'],
|
||||
labels=['Entity'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_node_uuid,
|
||||
n.name AS source_name,
|
||||
n.summary AS source_summary,
|
||||
m.uuid AS target_node_uuid,
|
||||
m.name AS target_name,
|
||||
m.summary AS target_summary,
|
||||
r.uuid AS uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
|
||||
""",
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
for record in records:
|
||||
n_uuid = record['source_node_uuid']
|
||||
if n_uuid in context:
|
||||
context[n_uuid]['facts'].append(record['fact'])
|
||||
else:
|
||||
context[n_uuid] = {
|
||||
'name': record['source_name'],
|
||||
'summary': record['source_summary'],
|
||||
'facts': [record['fact']],
|
||||
}
|
||||
|
||||
m_uuid = record['target_node_uuid']
|
||||
if m_uuid not in context:
|
||||
context[m_uuid] = {
|
||||
'name': record['target_name'],
|
||||
'summary': record['target_summary'],
|
||||
'facts': [],
|
||||
}
|
||||
logger.info(f'bfs search returned context: {context}')
|
||||
return context
|
||||
|
||||
|
||||
async def edge_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -129,8 +71,10 @@ async def edge_similarity_search(
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -148,8 +92,10 @@ async def edge_similarity_search(
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -167,8 +113,10 @@ async def edge_similarity_search(
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -187,41 +135,32 @@ async def edge_similarity_search(
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def entity_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
search_vector: list[float],
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
MATCH (n WHERE n.group_id IN $group_ids)
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
@@ -229,58 +168,44 @@ async def entity_similarity_search(
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
name_embedding=record['name_embedding'],
|
||||
labels=['Entity'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def entity_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
query: str,
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
MATCH (n WHERE n.group_id in $group_ids)
|
||||
RETURN
|
||||
node.uuid AS uuid,
|
||||
node.name AS name,
|
||||
node.name_embedding AS name_embedding,
|
||||
node.created_at AS created_at,
|
||||
node.summary AS summary
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
name_embedding=record['name_embedding'],
|
||||
labels=['Entity'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
@@ -290,15 +215,20 @@ async def edge_fulltext_search(
|
||||
query: str,
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# fulltext search over facts
|
||||
cypher_query = Query("""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -317,8 +247,10 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -336,8 +268,10 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -355,8 +289,10 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
@@ -377,27 +313,11 @@ async def edge_fulltext_search(
|
||||
query=fuzzy_query,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
@@ -406,6 +326,7 @@ async def hybrid_node_search(
|
||||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str | None] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
@@ -422,6 +343,8 @@ async def hybrid_node_search(
|
||||
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
||||
driver : AsyncDriver
|
||||
The Neo4j driver instance for database operations.
|
||||
group_ids : list[str] | None, optional
|
||||
The list of group ids to retrieve nodes from.
|
||||
limit : int | None, optional
|
||||
The maximum number of results to return per search method. If None, a default limit will be applied.
|
||||
|
||||
@@ -448,8 +371,8 @@ async def hybrid_node_search(
|
||||
|
||||
results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
|
||||
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
|
||||
*[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
|
||||
*[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -500,6 +423,7 @@ async def get_relevant_nodes(
|
||||
[node.name for node in nodes],
|
||||
[node.name_embedding for node in nodes if node.name_embedding is not None],
|
||||
driver,
|
||||
[node.group_id for node in nodes],
|
||||
)
|
||||
return relevant_nodes
|
||||
|
||||
@@ -518,13 +442,20 @@ async def get_relevant_edges(
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
edge_similarity_search(
|
||||
driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
|
||||
driver,
|
||||
edge.fact_embedding,
|
||||
source_node_uuid,
|
||||
target_node_uuid,
|
||||
[edge.group_id],
|
||||
limit,
|
||||
)
|
||||
for edge in edges
|
||||
if edge.fact_embedding is not None
|
||||
],
|
||||
*[
|
||||
edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
|
||||
edge_fulltext_search(
|
||||
driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit
|
||||
)
|
||||
for edge in edges
|
||||
],
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
@@ -42,7 +43,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_nodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
||||
from graphiti_core.utils.utils import chunk_edges_by_nodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,7 +62,9 @@ async def retrieve_previous_episodes_bulk(
|
||||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||
previous_episodes_list = await asyncio.gather(
|
||||
*[
|
||||
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
|
||||
retrieve_episodes(
|
||||
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
||||
)
|
||||
for episode in episodes
|
||||
]
|
||||
)
|
||||
@@ -90,7 +92,13 @@ async def extract_nodes_and_edges_bulk(
|
||||
|
||||
extracted_edges_bulk = await asyncio.gather(
|
||||
*[
|
||||
extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i])
|
||||
extract_edges(
|
||||
llm_client,
|
||||
episode,
|
||||
extracted_nodes_bulk[i],
|
||||
previous_episodes_list[i],
|
||||
episode.group_id,
|
||||
)
|
||||
for i, episode in enumerate(episodes)
|
||||
]
|
||||
)
|
||||
@@ -343,3 +351,23 @@ async def extract_edge_dates_bulk(
|
||||
edge.expired_at = datetime.now()
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
||||
# We only want to dedupe edges that are between the same pair of nodes
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
# We drop loop edges
|
||||
if edge.source_node_uuid == edge.target_node_uuid:
|
||||
continue
|
||||
|
||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||
pointers.sort()
|
||||
|
||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||
|
||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||
|
||||
return edge_chunks
|
||||
|
||||
@@ -37,15 +37,15 @@ def build_episodic_edges(
|
||||
episode: EpisodicNode,
|
||||
created_at: datetime,
|
||||
) -> List[EpisodicEdge]:
|
||||
edges: List[EpisodicEdge] = []
|
||||
|
||||
for node in entity_nodes:
|
||||
edge = EpisodicEdge(
|
||||
edges: List[EpisodicEdge] = [
|
||||
EpisodicEdge(
|
||||
source_node_uuid=episode.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
created_at=created_at,
|
||||
group_id=episode.group_id,
|
||||
)
|
||||
edges.append(edge)
|
||||
for node in entity_nodes
|
||||
]
|
||||
|
||||
return edges
|
||||
|
||||
@@ -55,6 +55,7 @@ async def extract_edges(
|
||||
episode: EpisodicNode,
|
||||
nodes: list[EntityNode],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
group_id: str | None,
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
@@ -88,6 +89,7 @@ async def extract_edges(
|
||||
source_node_uuid=edge_data['source_node_uuid'],
|
||||
target_node_uuid=edge_data['target_node_uuid'],
|
||||
name=edge_data['relation_type'],
|
||||
group_id=group_id,
|
||||
fact=edge_data['fact'],
|
||||
episodes=[episode.uuid],
|
||||
created_at=datetime.now(),
|
||||
|
||||
@@ -34,6 +34,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
||||
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
||||
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
||||
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
||||
@@ -86,6 +90,7 @@ async def retrieve_episodes(
|
||||
driver: AsyncDriver,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str | None] | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
@@ -96,25 +101,28 @@ async def retrieve_episodes(
|
||||
less than or equal to this reference_time will be retrieved. This allows for
|
||||
querying the graph's state at a specific point in time.
|
||||
last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time.
|
||||
group_ids (list[str], optional): The list of group ids to return data from.
|
||||
|
||||
Returns:
|
||||
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
||||
"""
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
RETURN e.content as content,
|
||||
e.created_at as created_at,
|
||||
e.valid_at as valid_at,
|
||||
e.uuid as uuid,
|
||||
e.name as name,
|
||||
e.source_description as source_description,
|
||||
e.source as source
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
|
||||
RETURN e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.group_id AS group_id,
|
||||
e.name AS name,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
ORDER BY e.created_at DESC
|
||||
LIMIT $num_episodes
|
||||
""",
|
||||
reference_time=reference_time,
|
||||
num_episodes=last_n,
|
||||
group_ids=group_ids,
|
||||
)
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
@@ -124,6 +132,7 @@ async def retrieve_episodes(
|
||||
),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
|
||||
@@ -85,6 +85,7 @@ async def extract_nodes(
|
||||
for node_data in extracted_node_data:
|
||||
new_node = EntityNode(
|
||||
name=node_data['name'],
|
||||
group_id=episode.group_id,
|
||||
labels=node_data['labels'],
|
||||
summary=node_data['summary'],
|
||||
created_at=datetime.now(),
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_episodic_edges(
|
||||
entity_nodes: list[EntityNode], episode: EpisodicNode
|
||||
) -> list[EpisodicEdge]:
|
||||
edges: list[EpisodicEdge] = []
|
||||
|
||||
for node in entity_nodes:
|
||||
edges.append(
|
||||
EpisodicEdge(
|
||||
source_node_uuid=episode.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
created_at=episode.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
||||
# We only want to dedupe edges that are between the same pair of nodes
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
# We drop loop edges
|
||||
if edge.source_node_uuid == edge.target_node_uuid:
|
||||
continue
|
||||
|
||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||
pointers.sort()
|
||||
|
||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||
|
||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||
|
||||
return edge_chunks
|
||||
@@ -74,15 +74,15 @@ async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
|
||||
edges = await graphiti.search('Freakenomics guest')
|
||||
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
|
||||
|
||||
logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
edges = await graphiti.search('tania tetlow\n')
|
||||
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
||||
|
||||
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
edges = await graphiti.search('issues with higher ed')
|
||||
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
|
||||
|
||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
||||
graphiti.close()
|
||||
|
||||
@@ -33,9 +33,9 @@ def create_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now)
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now)
|
||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now)
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
||||
|
||||
# Create edges
|
||||
existing_edge1 = EntityEdge(
|
||||
@@ -45,6 +45,7 @@ def create_test_data():
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
existing_edge2 = EntityEdge(
|
||||
uuid='e2',
|
||||
@@ -53,6 +54,7 @@ def create_test_data():
|
||||
name='LIKES',
|
||||
fact='Node2 likes Node3',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
new_edge1 = EntityEdge(
|
||||
uuid='e3',
|
||||
@@ -61,6 +63,7 @@ def create_test_data():
|
||||
name='WORKS_WITH',
|
||||
fact='Node1 works with Node3',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
new_edge2 = EntityEdge(
|
||||
uuid='e4',
|
||||
@@ -69,6 +72,7 @@ def create_test_data():
|
||||
name='DISLIKES',
|
||||
fact='Node1 dislikes Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -135,9 +139,9 @@ def test_prepare_invalidation_context():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now)
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now)
|
||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now)
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
||||
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
@@ -147,6 +151,7 @@ def test_prepare_invalidation_context():
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid='e2',
|
||||
@@ -155,6 +160,7 @@ def test_prepare_invalidation_context():
|
||||
name='LIKES',
|
||||
fact='Node2 likes Node3',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
# Create NodeEdgeNodeTriplet objects
|
||||
@@ -173,6 +179,7 @@ def test_prepare_invalidation_context():
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
@@ -182,6 +189,7 @@ def test_prepare_invalidation_context():
|
||||
valid_at=now - timedelta(days=1),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode 1 for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
EpisodicNode(
|
||||
name='Previous Episode 2',
|
||||
@@ -190,6 +198,7 @@ def test_prepare_invalidation_context():
|
||||
valid_at=now - timedelta(days=2),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode 2 for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
]
|
||||
|
||||
@@ -235,6 +244,7 @@ def test_prepare_invalidation_context_empty_input():
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test empty episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
result = prepare_invalidation_context([], [], current_episode, [])
|
||||
assert isinstance(result, dict)
|
||||
@@ -252,8 +262,8 @@ def test_prepare_invalidation_context_sorting():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now)
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now)
|
||||
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
||||
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
||||
|
||||
# Create edges with different timestamps
|
||||
edge1 = EntityEdge(
|
||||
@@ -263,6 +273,7 @@ def test_prepare_invalidation_context_sorting():
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=now,
|
||||
group_id='1',
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid='e2',
|
||||
@@ -271,6 +282,7 @@ def test_prepare_invalidation_context_sorting():
|
||||
name='LIKES',
|
||||
fact='Node2 likes Node1',
|
||||
created_at=now + timedelta(hours=1),
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
edge_with_nodes1 = (node1, edge1, node2)
|
||||
@@ -287,6 +299,7 @@ def test_prepare_invalidation_context_sorting():
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
@@ -296,6 +309,7 @@ def test_prepare_invalidation_context_sorting():
|
||||
valid_at=now - timedelta(days=1),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
group_id='1',
|
||||
),
|
||||
]
|
||||
|
||||
@@ -321,6 +335,7 @@ class TestExtractDateStringsFromEdge(unittest.TestCase):
|
||||
created_at=datetime.now(),
|
||||
valid_at=valid_at,
|
||||
invalid_at=invalid_at,
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
def test_both_dates_present(self):
|
||||
|
||||
@@ -76,6 +76,7 @@ def create_test_data():
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
# Create previous episodes
|
||||
@@ -87,6 +88,7 @@ def create_test_data():
|
||||
valid_at=now - timedelta(days=1),
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
group_id='1',
|
||||
)
|
||||
]
|
||||
|
||||
@@ -142,10 +144,12 @@ def create_complex_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now)
|
||||
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now)
|
||||
node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now)
|
||||
node4 = EntityNode(uuid='4', name='Company XYZ', labels=['Organization'], created_at=now)
|
||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
|
||||
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1')
|
||||
node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1')
|
||||
node4 = EntityNode(
|
||||
uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1'
|
||||
)
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
@@ -154,6 +158,7 @@ def create_complex_test_data():
|
||||
target_node_uuid='2',
|
||||
name='LIKES',
|
||||
fact='Alice likes Bob',
|
||||
group_id='1',
|
||||
created_at=now - timedelta(days=5),
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
@@ -162,6 +167,7 @@ def create_complex_test_data():
|
||||
target_node_uuid='3',
|
||||
name='FRIENDS_WITH',
|
||||
fact='Alice is friends with Charlie',
|
||||
group_id='1',
|
||||
created_at=now - timedelta(days=3),
|
||||
)
|
||||
edge3 = EntityEdge(
|
||||
@@ -170,6 +176,7 @@ def create_complex_test_data():
|
||||
target_node_uuid='4',
|
||||
name='WORKS_FOR',
|
||||
fact='Bob works for Company XYZ',
|
||||
group_id='1',
|
||||
created_at=now - timedelta(days=2),
|
||||
)
|
||||
|
||||
@@ -199,6 +206,7 @@ async def test_invalidate_edges_complex():
|
||||
target_node_uuid='2',
|
||||
name='DISLIKES',
|
||||
fact='Alice dislikes Bob',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[1],
|
||||
@@ -225,6 +233,7 @@ async def test_invalidate_edges_temporal_update():
|
||||
target_node_uuid='4',
|
||||
name='LEFT_JOB',
|
||||
fact='Bob left his job at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
@@ -251,6 +260,7 @@ async def test_invalidate_edges_multiple_invalidations():
|
||||
target_node_uuid='2',
|
||||
name='ENEMIES_WITH',
|
||||
fact='Alice and Bob are now enemies',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[1],
|
||||
@@ -263,6 +273,7 @@ async def test_invalidate_edges_multiple_invalidations():
|
||||
target_node_uuid='3',
|
||||
name='ENDED_FRIENDSHIP',
|
||||
fact='Alice ended her friendship with Charlie',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[2],
|
||||
@@ -292,6 +303,7 @@ async def test_invalidate_edges_no_effect():
|
||||
target_node_uuid='4',
|
||||
name='APPLIED_TO',
|
||||
fact='Charlie applied to Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
@@ -316,6 +328,7 @@ async def test_invalidate_edges_partial_update():
|
||||
target_node_uuid='4',
|
||||
name='CHANGED_POSITION',
|
||||
fact='Bob changed his position at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
|
||||
@@ -19,12 +19,12 @@ async def test_hybrid_node_search_deduplication():
|
||||
) as mock_similarity_search:
|
||||
# Set up mock return values
|
||||
mock_fulltext_search.side_effect = [
|
||||
[EntityNode(uuid='1', name='Alice', labels=['Entity'])],
|
||||
[EntityNode(uuid='2', name='Bob', labels=['Entity'])],
|
||||
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
|
||||
[EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')],
|
||||
]
|
||||
mock_similarity_search.side_effect = [
|
||||
[EntityNode(uuid='1', name='Alice', labels=['Entity'])],
|
||||
[EntityNode(uuid='3', name='Charlie', labels=['Entity'])],
|
||||
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
|
||||
[EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')],
|
||||
]
|
||||
|
||||
# Call the function with test data
|
||||
@@ -70,7 +70,9 @@ async def test_hybrid_node_search_only_fulltext():
|
||||
) as mock_fulltext_search, patch(
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
mock_fulltext_search.return_value = [EntityNode(uuid='1', name='Alice', labels=['Entity'])]
|
||||
mock_fulltext_search.return_value = [
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
|
||||
]
|
||||
mock_similarity_search.return_value = []
|
||||
|
||||
queries = ['Alice']
|
||||
@@ -93,18 +95,23 @@ async def test_hybrid_node_search_with_limit():
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
mock_fulltext_search.return_value = [
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity']),
|
||||
EntityNode(uuid='2', name='Bob', labels=['Entity']),
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
||||
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
|
||||
]
|
||||
mock_similarity_search.return_value = [
|
||||
EntityNode(uuid='3', name='Charlie', labels=['Entity']),
|
||||
EntityNode(uuid='4', name='David', labels=['Entity']),
|
||||
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
|
||||
EntityNode(
|
||||
uuid='4',
|
||||
name='David',
|
||||
labels=['Entity'],
|
||||
group_id='1',
|
||||
),
|
||||
]
|
||||
|
||||
queries = ['Test']
|
||||
embeddings = [[0.1, 0.2, 0.3]]
|
||||
limit = 1
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, limit)
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
|
||||
|
||||
# We expect 4 results because the limit is applied per search method
|
||||
# before deduplication, and we're not actually limiting the results
|
||||
@@ -113,8 +120,8 @@ async def test_hybrid_node_search_with_limit():
|
||||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
# Verify that the limit was passed to the search functions
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, 2)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2)
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -127,18 +134,18 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
||||
'graphiti_core.search.search_utils.entity_similarity_search'
|
||||
) as mock_similarity_search:
|
||||
mock_fulltext_search.return_value = [
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity']),
|
||||
EntityNode(uuid='2', name='Bob', labels=['Entity']),
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
|
||||
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
|
||||
]
|
||||
mock_similarity_search.return_value = [
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity']), # Duplicate
|
||||
EntityNode(uuid='3', name='Charlie', labels=['Entity']),
|
||||
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), # Duplicate
|
||||
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
|
||||
]
|
||||
|
||||
queries = ['Test']
|
||||
embeddings = [[0.1, 0.2, 0.3]]
|
||||
limit = 2
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, limit)
|
||||
results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
|
||||
|
||||
# We expect 3 results because:
|
||||
# 1. The limit of 2 is applied to each search method
|
||||
@@ -148,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
||||
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
||||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, 4)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4)
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)
|
||||
|
||||
Reference in New Issue
Block a user