Add Missing Node and edge CRUD (#51)

* add CRUD operations and fix search limit bugs

* format

* update tests

* å

* update tests to double limit call

* add default field

* format

* import correct field
This commit is contained in:
Preston Rasmussen
2024-08-27 16:18:01 -04:00
committed by GitHub
parent 3f3fb60a55
commit 06d8d9359f
7 changed files with 251 additions and 39 deletions

View File

@@ -23,6 +23,7 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.helpers import parse_db_date
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.nodes import Node
@@ -38,6 +39,9 @@ class Edge(BaseModel, ABC):
@abstractmethod
async def save(self, driver: AsyncDriver): ...
@abstractmethod
async def delete(self, driver: AsyncDriver): ...
def __hash__(self):
return hash(self.uuid)
@@ -46,6 +50,9 @@ class Edge(BaseModel, ABC):
return self.uuid == other.uuid
return False
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
@@ -66,9 +73,48 @@ class EpisodicEdge(Edge):
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
DELETE e
""",
uuid=self.uuid,
)
# TODO: Neo4j doesn't support variables for edge types and labels.
# Right now we have all edge nodes as type RELATES_TO
logger.info(f'Deleted Edge: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid As uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuid=uuid,
)
edges: list[EpisodicEdge] = []
for record in records:
edges.append(
EpisodicEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
)
logger.info(f'Found Edge: {uuid}')
return edges[0]
class EntityEdge(Edge):
@@ -97,7 +143,7 @@ class EntityEdge(Edge):
self.fact_embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.info(f'embedded {text} in {end-start} ms')
logger.info(f'embedded {text} in {end - start} ms')
return embedding
@@ -127,3 +173,60 @@ class EntityEdge(Edge):
logger.info(f'Saved edge to neo4j: {self.uuid}')
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
DELETE e
""",
uuid=self.uuid,
)
logger.info(f'Deleted Edge: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
""",
uuid=uuid,
)
edges: list[EntityEdge] = []
for record in records:
edges.append(
EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
)
logger.info(f'Found Edge: {uuid}')
return edges[0]

7
graphiti_core/helpers.py Normal file
View File

@@ -0,0 +1,7 @@
from datetime import datetime
from neo4j import time as neo4j_time
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None

View File

@@ -75,6 +75,9 @@ class Node(BaseModel, ABC):
@abstractmethod
async def save(self, driver: AsyncDriver): ...
@abstractmethod
async def delete(self, driver: AsyncDriver): ...
def __hash__(self):
return hash(self.uuid)
@@ -83,6 +86,9 @@ class Node(BaseModel, ABC):
return self.uuid == other.uuid
return False
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
class EpisodicNode(Node):
source: EpisodeType = Field(description='source type')
@@ -111,13 +117,58 @@ class EpisodicNode(Node):
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source.value,
_database='neo4j',
)
logger.info(f'Saved Node to neo4j: {self.uuid}')
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Episodic {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.info(f'Deleted Node: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic {uuid: $uuid})
RETURN e.content as content,
e.created_at as created_at,
e.valid_at as valid_at,
e.uuid as uuid,
e.name as name,
e.source_description as source_description,
e.source as source
""",
uuid=uuid,
)
episodes = [
EpisodicNode(
content=record['content'],
created_at=record['created_at'].to_native().timestamp(),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)
for record in records
]
logger.info(f'Found Node: {uuid}')
return episodes[0]
class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
@@ -153,3 +204,47 @@ class EntityNode(Node):
logger.info(f'Saved Node to neo4j: {self.uuid}')
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.info(f'Deleted Node: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
""",
uuid=uuid,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
logger.info(f'Found Node: {uuid}')
return nodes[0]

View File

@@ -20,7 +20,7 @@ from enum import Enum
from time import time
from neo4j import AsyncDriver
from pydantic import BaseModel
from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client.config import EMBEDDING_DIM
@@ -49,8 +49,8 @@ class Reranker(Enum):
class SearchConfig(BaseModel):
num_edges: int = 10
num_nodes: int = 10
num_edges: int = Field(default=10)
num_nodes: int = Field(default=10)
num_episodes: int = EPISODE_WINDOW_LEN
search_methods: list[SearchMethod]
reranker: Reranker | None
@@ -63,12 +63,12 @@ class SearchResults(BaseModel):
async def hybrid_search(
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()
@@ -79,11 +79,11 @@ async def hybrid_search(
search_results = []
if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp))
episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver)
text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges)
search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods:
@@ -94,7 +94,9 @@ async def hybrid_search(
.embedding[:EMBEDDING_DIM]
)
similarity_search = await edge_similarity_search(search_vector, driver)
similarity_search = await edge_similarity_search(
search_vector, driver, 2 * config.num_edges
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:

View File

@@ -3,13 +3,12 @@ import logging
import re
import typing
from collections import defaultdict
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from neo4j import time as neo4j_time
from graphiti_core.edges import EntityEdge
from graphiti_core.helpers import parse_db_date
from graphiti_core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__)
@@ -17,10 +16,6 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
@@ -106,7 +101,7 @@ async def edge_similarity_search(
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN
@@ -121,7 +116,7 @@ async def edge_similarity_search(
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
ORDER BY score DESC
""",
search_vector=search_vector,
limit=limit,
@@ -316,8 +311,11 @@ async def hybrid_node_search(
relevant_node_uuids = set()
results = await asyncio.gather(
*[entity_fulltext_search(q, driver, limit or RELEVANT_SCHEMA_LIMIT) for q in queries],
*[entity_similarity_search(e, driver, limit or RELEVANT_SCHEMA_LIMIT) for e in embeddings],
*[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
*[
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
for e in embeddings
],
)
for result in results:

View File

@@ -22,8 +22,6 @@ from datetime import datetime
import pytest
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from openai import OpenAI
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
@@ -74,7 +72,7 @@ def format_context(facts):
@pytest.mark.asyncio
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
edges = await graphiti.search('Freakenomics guest')
@@ -92,11 +90,9 @@ async def test_graphiti_init():
@pytest.mark.asyncio
async def test_graph_integration():
driver = AsyncGraphDatabase.driver(
NEO4J_URI,
auth=(NEO4j_USER, NEO4j_PASSWORD),
)
embedder = OpenAI().embeddings
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
embedder = client.llm_client.get_embedder()
driver = client.driver
now = datetime.now()
episode = EpisodicNode(
@@ -139,10 +135,21 @@ async def test_graph_integration():
invalid_at=now,
)
entity_edge.generate_embedding(embedder)
await entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
# test save
await asyncio.gather(*[node.save(driver) for node in nodes])
await asyncio.gather(*[edge.save(driver) for edge in edges])
# test get
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
# test delete
await asyncio.gather(*[node.delete(driver) for node in nodes])
await asyncio.gather(*[edge.delete(driver) for edge in edges])

View File

@@ -113,8 +113,8 @@ async def test_hybrid_node_search_with_limit():
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with('Test', mock_driver, 1)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 1)
mock_fulltext_search.assert_called_with('Test', mock_driver, 2)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2)
@pytest.mark.asyncio
@@ -148,5 +148,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with('Test', mock_driver, 2)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2)
mock_fulltext_search.assert_called_with('Test', mock_driver, 4)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4)