mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
Add Missing Node and edge CRUD (#51)
* add CRUD operations and fix search limit bugs * format * update tests * å * update tests to double limit call * add default field * format * import correct field
This commit is contained in:
committed by
GitHub
parent
3f3fb60a55
commit
06d8d9359f
@@ -23,6 +23,7 @@ from uuid import uuid4
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
from graphiti_core.nodes import Node
|
||||
|
||||
@@ -38,6 +39,9 @@ class Edge(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, driver: AsyncDriver): ...
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uuid)
|
||||
|
||||
@@ -46,6 +50,9 @@ class Edge(BaseModel, ABC):
|
||||
return self.uuid == other.uuid
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
||||
|
||||
|
||||
class EpisodicEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
@@ -66,9 +73,48 @@ class EpisodicEdge(Edge):
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
# TODO: Neo4j doesn't support variables for edge types and labels.
|
||||
# Right now we have all edge nodes as type RELATES_TO
|
||||
logger.info(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
edges: list[EpisodicEdge] = []
|
||||
|
||||
for record in records:
|
||||
edges.append(
|
||||
EpisodicEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f'Found Edge: {uuid}')
|
||||
|
||||
return edges[0]
|
||||
|
||||
|
||||
class EntityEdge(Edge):
|
||||
@@ -97,7 +143,7 @@ class EntityEdge(Edge):
|
||||
self.fact_embedding = embedding[:EMBEDDING_DIM]
|
||||
|
||||
end = time()
|
||||
logger.info(f'embedded {text} in {end-start} ms')
|
||||
logger.info(f'embedded {text} in {end - start} ms')
|
||||
|
||||
return embedding
|
||||
|
||||
@@ -127,3 +173,60 @@ class EntityEdge(Edge):
|
||||
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edges.append(
|
||||
EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f'Found Edge: {uuid}')
|
||||
|
||||
return edges[0]
|
||||
|
||||
7
graphiti_core/helpers.py
Normal file
7
graphiti_core/helpers.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from neo4j import time as neo4j_time
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
@@ -75,6 +75,9 @@ class Node(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, driver: AsyncDriver): ...
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uuid)
|
||||
|
||||
@@ -83,6 +86,9 @@ class Node(BaseModel, ABC):
|
||||
return self.uuid == other.uuid
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
||||
|
||||
|
||||
class EpisodicNode(Node):
|
||||
source: EpisodeType = Field(description='source type')
|
||||
@@ -111,13 +117,58 @@ class EpisodicNode(Node):
|
||||
created_at=self.created_at,
|
||||
valid_at=self.valid_at,
|
||||
source=self.source.value,
|
||||
_database='neo4j',
|
||||
)
|
||||
|
||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Node: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic {uuid: $uuid})
|
||||
RETURN e.content as content,
|
||||
e.created_at as created_at,
|
||||
e.valid_at as valid_at,
|
||||
e.uuid as uuid,
|
||||
e.name as name,
|
||||
e.source_description as source_description,
|
||||
e.source as source
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=record['created_at'].to_native().timestamp(),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
uuid=record['uuid'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
)
|
||||
for record in records
|
||||
]
|
||||
|
||||
logger.info(f'Found Node: {uuid}')
|
||||
|
||||
return episodes[0]
|
||||
|
||||
|
||||
class EntityNode(Node):
|
||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
@@ -153,3 +204,47 @@ class EntityNode(Node):
|
||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Node: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=['Entity'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f'Found Node: {uuid}')
|
||||
|
||||
return nodes[0]
|
||||
|
||||
@@ -20,7 +20,7 @@ from enum import Enum
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
@@ -49,8 +49,8 @@ class Reranker(Enum):
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
num_edges: int = 10
|
||||
num_nodes: int = 10
|
||||
num_edges: int = Field(default=10)
|
||||
num_nodes: int = Field(default=10)
|
||||
num_episodes: int = EPISODE_WINDOW_LEN
|
||||
search_methods: list[SearchMethod]
|
||||
reranker: Reranker | None
|
||||
@@ -63,12 +63,12 @@ class SearchResults(BaseModel):
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
@@ -79,11 +79,11 @@ async def hybrid_search(
|
||||
search_results = []
|
||||
|
||||
if config.num_episodes > 0:
|
||||
episodes.extend(await retrieve_episodes(driver, timestamp))
|
||||
episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
|
||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||
|
||||
if SearchMethod.bm25 in config.search_methods:
|
||||
text_search = await edge_fulltext_search(query, driver)
|
||||
text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges)
|
||||
search_results.append(text_search)
|
||||
|
||||
if SearchMethod.cosine_similarity in config.search_methods:
|
||||
@@ -94,7 +94,9 @@ async def hybrid_search(
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
||||
similarity_search = await edge_similarity_search(search_vector, driver)
|
||||
similarity_search = await edge_similarity_search(
|
||||
search_vector, driver, 2 * config.num_edges
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
|
||||
@@ -3,13 +3,12 @@ import logging
|
||||
import re
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from neo4j import time as neo4j_time
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,10 +16,6 @@ logger = logging.getLogger(__name__)
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
|
||||
|
||||
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
@@ -106,7 +101,7 @@ async def edge_similarity_search(
|
||||
# vector similarity search over embedded facts
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n)-[r:RELATES_TO]->(m)
|
||||
RETURN
|
||||
@@ -121,7 +116,7 @@ async def edge_similarity_search(
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
@@ -316,8 +311,11 @@ async def hybrid_node_search(
|
||||
relevant_node_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[entity_fulltext_search(q, driver, limit or RELEVANT_SCHEMA_LIMIT) for q in queries],
|
||||
*[entity_similarity_search(e, driver, limit or RELEVANT_SCHEMA_LIMIT) for e in embeddings],
|
||||
*[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
|
||||
*[
|
||||
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
|
||||
for e in embeddings
|
||||
],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
|
||||
@@ -22,8 +22,6 @@ from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from openai import OpenAI
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
@@ -74,7 +72,7 @@ def format_context(facts):
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
|
||||
edges = await graphiti.search('Freakenomics guest')
|
||||
|
||||
@@ -92,11 +90,9 @@ async def test_graphiti_init():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_integration():
|
||||
driver = AsyncGraphDatabase.driver(
|
||||
NEO4J_URI,
|
||||
auth=(NEO4j_USER, NEO4j_PASSWORD),
|
||||
)
|
||||
embedder = OpenAI().embeddings
|
||||
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
embedder = client.llm_client.get_embedder()
|
||||
driver = client.driver
|
||||
|
||||
now = datetime.now()
|
||||
episode = EpisodicNode(
|
||||
@@ -139,10 +135,21 @@ async def test_graph_integration():
|
||||
invalid_at=now,
|
||||
)
|
||||
|
||||
entity_edge.generate_embedding(embedder)
|
||||
await entity_edge.generate_embedding(embedder)
|
||||
|
||||
nodes = [episode, alice_node, bob_node]
|
||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||
|
||||
# test save
|
||||
await asyncio.gather(*[node.save(driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.save(driver) for edge in edges])
|
||||
|
||||
# test get
|
||||
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
||||
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
||||
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
||||
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
||||
|
||||
# test delete
|
||||
await asyncio.gather(*[node.delete(driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.delete(driver) for edge in edges])
|
||||
|
||||
@@ -113,8 +113,8 @@ async def test_hybrid_node_search_with_limit():
|
||||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
# Verify that the limit was passed to the search functions
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, 1)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 1)
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, 2)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -148,5 +148,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
||||
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
|
||||
assert mock_fulltext_search.call_count == 1
|
||||
assert mock_similarity_search.call_count == 1
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, 2)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2)
|
||||
mock_fulltext_search.assert_called_with('Test', mock_driver, 4)
|
||||
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4)
|
||||
|
||||
Reference in New Issue
Block a user