mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
search updates (#14)
* search updates * test updates * add opinionated search * update
This commit is contained in:
committed by
GitHub
parent
8141a783b1
commit
63b9790026
123
core/graphiti.py
123
core/graphiti.py
@@ -1,15 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, LiteralString
|
from typing import Callable
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from time import time
|
from time import time
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from core.llm_client.config import EMBEDDING_DIM
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
from core.nodes import EntityNode, EpisodicNode, Node
|
from core.edges import EntityEdge, EpisodicEdge
|
||||||
from core.edges import EntityEdge, Edge, EpisodicEdge
|
from core.search.search import SearchConfig, hybrid_search
|
||||||
from core.utils import (
|
from core.utils import (
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
@@ -19,22 +19,21 @@ from core.utils.bulk_utils import (
|
|||||||
BulkEpisode,
|
BulkEpisode,
|
||||||
extract_nodes_and_edges_bulk,
|
extract_nodes_and_edges_bulk,
|
||||||
retrieve_previous_episodes_bulk,
|
retrieve_previous_episodes_bulk,
|
||||||
compress_nodes,
|
|
||||||
dedupe_nodes_bulk,
|
dedupe_nodes_bulk,
|
||||||
resolve_edge_pointers,
|
resolve_edge_pointers,
|
||||||
dedupe_edges_bulk,
|
dedupe_edges_bulk,
|
||||||
)
|
)
|
||||||
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
||||||
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
from core.utils.maintenance.graph_data_operations import (
|
||||||
|
EPISODE_WINDOW_LEN,
|
||||||
|
build_indices_and_constraints,
|
||||||
|
)
|
||||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||||
from core.utils.maintenance.temporal_operations import (
|
from core.utils.maintenance.temporal_operations import (
|
||||||
invalidate_edges,
|
invalidate_edges,
|
||||||
prepare_edges_for_invalidation,
|
prepare_edges_for_invalidation,
|
||||||
)
|
)
|
||||||
from core.utils.search.search_utils import (
|
from core.search.search_utils import (
|
||||||
edge_similarity_search,
|
|
||||||
entity_fulltext_search,
|
|
||||||
bfs,
|
|
||||||
get_relevant_nodes,
|
get_relevant_nodes,
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
)
|
)
|
||||||
@@ -64,10 +63,13 @@ class Graphiti:
|
|||||||
def close(self):
|
def close(self):
|
||||||
self.driver.close()
|
self.driver.close()
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self):
|
||||||
|
await build_indices_and_constraints(self.driver)
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
self,
|
self,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
sources: list[str] | None = "messages",
|
sources: list[str] | None = "messages",
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""Retrieve the last n episodic nodes from the graph"""
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
@@ -103,9 +105,7 @@ class Graphiti:
|
|||||||
embedder = self.llm_client.client.embeddings
|
embedder = self.llm_client.client.embeddings
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
previous_episodes = await self.retrieve_episodes(
|
previous_episodes = await self.retrieve_episodes(reference_time)
|
||||||
reference_time, last_n=EPISODE_WINDOW_LEN
|
|
||||||
)
|
|
||||||
episode = EpisodicNode(
|
episode = EpisodicNode(
|
||||||
name=name,
|
name=name,
|
||||||
labels=[],
|
labels=[],
|
||||||
@@ -220,80 +220,6 @@ class Graphiti:
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def build_indices(self):
|
|
||||||
index_queries: list[LiteralString] = [
|
|
||||||
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
|
||||||
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
|
||||||
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
|
|
||||||
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
|
|
||||||
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
|
||||||
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
|
||||||
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
|
||||||
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
|
||||||
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
|
|
||||||
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
|
|
||||||
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
|
|
||||||
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
|
|
||||||
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
|
|
||||||
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
|
|
||||||
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
|
|
||||||
"""
|
|
||||||
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
|
||||||
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
|
||||||
OPTIONS {indexConfig: {
|
|
||||||
`vector.dimensions`: 1024,
|
|
||||||
`vector.similarity_function`: 'cosine'
|
|
||||||
}}
|
|
||||||
""",
|
|
||||||
"""
|
|
||||||
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
|
||||||
FOR (n:Entity) ON (n.name_embedding)
|
|
||||||
OPTIONS {indexConfig: {
|
|
||||||
`vector.dimensions`: 1024,
|
|
||||||
`vector.similarity_function`: 'cosine'
|
|
||||||
}}
|
|
||||||
""",
|
|
||||||
"""
|
|
||||||
CREATE CONSTRAINT entity_name IF NOT EXISTS
|
|
||||||
FOR (n:Entity) REQUIRE n.name IS UNIQUE
|
|
||||||
""",
|
|
||||||
"""
|
|
||||||
CREATE CONSTRAINT edge_facts IF NOT EXISTS
|
|
||||||
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
|
|
||||||
""",
|
|
||||||
]
|
|
||||||
|
|
||||||
await asyncio.gather(
|
|
||||||
*[self.driver.execute_query(query) for query in index_queries]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
|
|
||||||
text = query.replace("\n", " ")
|
|
||||||
search_vector = (
|
|
||||||
(
|
|
||||||
await self.llm_client.client.embeddings.create(
|
|
||||||
input=[text], model="text-embedding-3-small"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.data[0]
|
|
||||||
.embedding[:EMBEDDING_DIM]
|
|
||||||
)
|
|
||||||
|
|
||||||
edges = await edge_similarity_search(search_vector, self.driver)
|
|
||||||
nodes = await entity_fulltext_search(query, self.driver)
|
|
||||||
|
|
||||||
node_ids = [node.uuid for node in nodes]
|
|
||||||
|
|
||||||
for edge in edges:
|
|
||||||
node_ids.append(edge.source_node_uuid)
|
|
||||||
node_ids.append(edge.target_node_uuid)
|
|
||||||
|
|
||||||
node_ids = list(dict.fromkeys(node_ids))
|
|
||||||
|
|
||||||
context = await bfs(node_ids, self.driver)
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
||||||
async def add_episode_bulk(
|
async def add_episode_bulk(
|
||||||
self,
|
self,
|
||||||
bulk_episodes: list[BulkEpisode],
|
bulk_episodes: list[BulkEpisode],
|
||||||
@@ -368,3 +294,24 @@ class Graphiti:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def search(self, query: str, num_results=10):
|
||||||
|
search_config = SearchConfig(num_episodes=0, num_results=num_results)
|
||||||
|
edges = (
|
||||||
|
await hybrid_search(
|
||||||
|
self.driver,
|
||||||
|
self.llm_client.client.embeddings,
|
||||||
|
query,
|
||||||
|
datetime.now(),
|
||||||
|
search_config,
|
||||||
|
)
|
||||||
|
)["edges"]
|
||||||
|
|
||||||
|
facts = [edge.fact for edge in edges]
|
||||||
|
|
||||||
|
return facts
|
||||||
|
|
||||||
|
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
|
||||||
|
return await hybrid_search(
|
||||||
|
self.driver, self.llm_client.client.embeddings, query, timestamp, config
|
||||||
|
)
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ def node_list(context: dict[str, any]) -> list[Message]:
|
|||||||
|
|
||||||
Task:
|
Task:
|
||||||
1. Group nodes together such that all duplicate nodes are in the same list of names
|
1. Group nodes together such that all duplicate nodes are in the same list of names
|
||||||
2. All dupolicate names should be grouped together in the same list
|
2. All duplicate names should be grouped together in the same list
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Each name from the list of nodes should appear EXACTLY once in your response
|
1. Each name from the list of nodes should appear EXACTLY once in your response
|
||||||
|
|||||||
104
core/search/search.py
Normal file
104
core/search/search.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
from neo4j import AsyncDriver
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.edges import EntityEdge, Edge
|
||||||
|
from core.llm_client.config import EMBEDDING_DIM
|
||||||
|
from core.nodes import Node
|
||||||
|
from core.search.search_utils import (
|
||||||
|
edge_similarity_search,
|
||||||
|
edge_fulltext_search,
|
||||||
|
get_mentioned_nodes,
|
||||||
|
rrf,
|
||||||
|
)
|
||||||
|
from core.utils import retrieve_episodes
|
||||||
|
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchConfig(BaseModel):
|
||||||
|
num_results: int = 10
|
||||||
|
num_episodes: int = EPISODE_WINDOW_LEN
|
||||||
|
similarity_search: str = "cosine"
|
||||||
|
text_search: str = "BM25"
|
||||||
|
reranker: str = "rrf"
|
||||||
|
|
||||||
|
|
||||||
|
async def hybrid_search(
|
||||||
|
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
||||||
|
) -> dict[str, [Node | Edge]]:
|
||||||
|
start = time()
|
||||||
|
|
||||||
|
episodes = []
|
||||||
|
nodes = []
|
||||||
|
edges = []
|
||||||
|
|
||||||
|
search_results = []
|
||||||
|
|
||||||
|
if config.num_episodes > 0:
|
||||||
|
episodes.extend(await retrieve_episodes(driver, timestamp))
|
||||||
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||||
|
|
||||||
|
if config.text_search == "BM25":
|
||||||
|
text_search = await edge_fulltext_search(query, driver)
|
||||||
|
search_results.append(text_search)
|
||||||
|
|
||||||
|
if config.similarity_search == "cosine":
|
||||||
|
query_text = query.replace("\n", " ")
|
||||||
|
search_vector = (
|
||||||
|
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
|
||||||
|
.data[0]
|
||||||
|
.embedding[:EMBEDDING_DIM]
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_search = await edge_similarity_search(search_vector, driver)
|
||||||
|
search_results.append(similarity_search)
|
||||||
|
|
||||||
|
if len(search_results) == 1:
|
||||||
|
edges = search_results[0]
|
||||||
|
|
||||||
|
elif len(search_results) > 1 and not config.reranker == "rrf":
|
||||||
|
logger.exception("Multiple searches enabled without a reranker")
|
||||||
|
raise Exception("Multiple searches enabled without a reranker")
|
||||||
|
|
||||||
|
elif config.reranker == "rrf":
|
||||||
|
edge_uuid_map = {}
|
||||||
|
search_result_uuids = []
|
||||||
|
|
||||||
|
logger.info([[edge.fact for edge in result] for result in search_results])
|
||||||
|
|
||||||
|
for result in search_results:
|
||||||
|
result_uuids = []
|
||||||
|
for edge in result:
|
||||||
|
result_uuids.append(edge.uuid)
|
||||||
|
edge_uuid_map[edge.uuid] = edge
|
||||||
|
|
||||||
|
search_result_uuids.append(result_uuids)
|
||||||
|
|
||||||
|
search_result_uuids = [
|
||||||
|
[edge.uuid for edge in result] for result in search_results
|
||||||
|
]
|
||||||
|
|
||||||
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
|
||||||
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
edges.extend(reranked_edges)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"episodes": episodes,
|
||||||
|
"nodes": nodes,
|
||||||
|
"edges": edges,
|
||||||
|
}
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"search returned context for query {query} in {(end - start) * 1000} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return context
|
||||||
@@ -1,23 +1,54 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
||||||
from core.edges import EntityEdge
|
from core.edges import EntityEdge
|
||||||
from core.nodes import EntityNode
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
|
|
||||||
|
|
||||||
|
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
||||||
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary
|
||||||
|
""",
|
||||||
|
uuids=episode_uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes: list[EntityNode] = []
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
nodes.append(
|
||||||
|
EntityNode(
|
||||||
|
uuid=record["uuid"],
|
||||||
|
name=record["name"],
|
||||||
|
labels=["Entity"],
|
||||||
|
created_at=datetime.now(),
|
||||||
|
summary=record["summary"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
||||||
RETURN
|
RETURN DISTINCT
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
n.name AS source_name,
|
n.name AS source_name,
|
||||||
n.summary AS source_summary,
|
n.summary AS source_summary,
|
||||||
@@ -138,7 +169,7 @@ async def entity_similarity_search(
|
|||||||
EntityNode(
|
EntityNode(
|
||||||
uuid=record["uuid"],
|
uuid=record["uuid"],
|
||||||
name=record["name"],
|
name=record["name"],
|
||||||
labels=[],
|
labels=["Entity"],
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
summary=record["summary"],
|
summary=record["summary"],
|
||||||
)
|
)
|
||||||
@@ -155,7 +186,7 @@ async def entity_fulltext_search(
|
|||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||||
RETURN
|
RETURN
|
||||||
node.uuid As uuid,
|
node.uuid As uuid,
|
||||||
node.name AS name,
|
node.name AS name,
|
||||||
node.created_at AS created_at,
|
node.created_at AS created_at,
|
||||||
@@ -173,7 +204,7 @@ async def entity_fulltext_search(
|
|||||||
EntityNode(
|
EntityNode(
|
||||||
uuid=record["uuid"],
|
uuid=record["uuid"],
|
||||||
name=record["name"],
|
name=record["name"],
|
||||||
labels=[],
|
labels=["Entity"],
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
summary=record["summary"],
|
summary=record["summary"],
|
||||||
)
|
)
|
||||||
@@ -193,7 +224,7 @@ async def edge_fulltext_search(
|
|||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
YIELD relationship AS r, score
|
YIELD relationship AS r, score
|
||||||
MATCH (n:Entity)-[r]->(m:Entity)
|
MATCH (n:Entity)-[r]->(m:Entity)
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
m.uuid AS target_node_uuid,
|
m.uuid AS target_node_uuid,
|
||||||
@@ -291,3 +322,18 @@ async def get_relevant_edges(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return relevant_edges
|
return relevant_edges
|
||||||
|
|
||||||
|
|
||||||
|
# takes in a list of rankings of uuids
|
||||||
|
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
|
scores: dict[str, int] = defaultdict(int)
|
||||||
|
for result in results:
|
||||||
|
for i, uuid in enumerate(result):
|
||||||
|
scores[uuid] += 1 / (i + rank_const)
|
||||||
|
|
||||||
|
scored_uuids = [term for term in scores.items()]
|
||||||
|
scored_uuids.sort(reverse=True, key=lambda term: term[1])
|
||||||
|
|
||||||
|
sorted_uuids = [term[0] for term in scored_uuids]
|
||||||
|
|
||||||
|
return sorted_uuids
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
@@ -21,7 +20,7 @@ from core.utils.maintenance.node_operations import (
|
|||||||
dedupe_node_list,
|
dedupe_node_list,
|
||||||
dedupe_extracted_nodes,
|
dedupe_extracted_nodes,
|
||||||
)
|
)
|
||||||
from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges
|
from core.search.search_utils import get_relevant_nodes, get_relevant_edges
|
||||||
|
|
||||||
CHUNK_SIZE = 10
|
CHUNK_SIZE = 10
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import LiteralString
|
||||||
|
|
||||||
from core.nodes import EpisodicNode
|
from core.nodes import EpisodicNode
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
@@ -9,6 +11,64 @@ EPISODE_WINDOW_LEN = 3
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(driver: AsyncDriver):
|
||||||
|
constraints: list[LiteralString] = [
|
||||||
|
"""
|
||||||
|
CREATE CONSTRAINT entity_name IF NOT EXISTS
|
||||||
|
FOR (n:Entity) REQUIRE n.name IS UNIQUE
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE CONSTRAINT edge_facts IF NOT EXISTS
|
||||||
|
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
range_indices: list[LiteralString] = [
|
||||||
|
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
||||||
|
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
||||||
|
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
|
||||||
|
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
|
||||||
|
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||||
|
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||||
|
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
||||||
|
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
||||||
|
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
|
||||||
|
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
|
||||||
|
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
|
||||||
|
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
|
||||||
|
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
|
||||||
|
]
|
||||||
|
|
||||||
|
fulltext_indices: list[LiteralString] = [
|
||||||
|
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
|
||||||
|
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
|
||||||
|
]
|
||||||
|
|
||||||
|
vector_indices: list[LiteralString] = [
|
||||||
|
"""
|
||||||
|
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||||
|
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
||||||
|
FOR (n:Entity) ON (n.name_embedding)
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
index_queries: list[LiteralString] = (
|
||||||
|
constraints + range_indices + fulltext_indices + vector_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
||||||
|
|
||||||
|
|
||||||
async def clear_data(driver: AsyncDriver):
|
async def clear_data(driver: AsyncDriver):
|
||||||
async with driver.session() as session:
|
async with driver.session() as session:
|
||||||
|
|
||||||
@@ -21,7 +81,7 @@ async def clear_data(driver: AsyncDriver):
|
|||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
sources: list[str] | None = "messages",
|
sources: list[str] | None = "messages",
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""Retrieve the last n episodic nodes from the graph"""
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ async def main(use_bulk: bool = True):
|
|||||||
episode_type="string",
|
episode_type="string",
|
||||||
reference_time=message.actual_timestamp,
|
reference_time=message.actual_timestamp,
|
||||||
)
|
)
|
||||||
for i, message in enumerate(messages[3:7])
|
for i, message in enumerate(messages[3:14])
|
||||||
]
|
]
|
||||||
|
|
||||||
await client.add_episode_bulk(episodes)
|
await client.add_episode_bulk(episodes)
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from core.search.search import SearchConfig
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -51,16 +53,13 @@ def setup_logging():
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def format_context(context):
|
def format_context(facts):
|
||||||
formatted_string = ""
|
formatted_string = ""
|
||||||
for uuid, data in context.items():
|
formatted_string += "FACTS:\n"
|
||||||
formatted_string += f"UUID: {uuid}\n"
|
for fact in facts:
|
||||||
formatted_string += f" Name: {data['name']}\n"
|
formatted_string += f" - {fact}\n"
|
||||||
formatted_string += f" Summary: {data['summary']}\n"
|
formatted_string += "\n"
|
||||||
formatted_string += " Facts:\n"
|
|
||||||
for fact in data["facts"]:
|
|
||||||
formatted_string += f" - {fact}\n"
|
|
||||||
formatted_string += "\n"
|
|
||||||
return formatted_string.strip()
|
return formatted_string.strip()
|
||||||
|
|
||||||
|
|
||||||
@@ -68,19 +67,18 @@ def format_context(context):
|
|||||||
async def test_graphiti_init():
|
async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
|
||||||
await graphiti.build_indices()
|
|
||||||
|
|
||||||
context = await graphiti.search("Freakenomics guest")
|
facts = await graphiti.search("Freakenomics guest")
|
||||||
|
|
||||||
logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context))
|
logger.info("\nQUERY: Freakenomics guest\n" + format_context(facts))
|
||||||
|
|
||||||
context = await graphiti.search("tania tetlow")
|
facts = await graphiti.search("tania tetlow\n")
|
||||||
|
|
||||||
logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context))
|
logger.info("\nQUERY: Tania Tetlow\n" + format_context(facts))
|
||||||
|
|
||||||
context = await graphiti.search("issues with higher ed")
|
facts = await graphiti.search("issues with higher ed")
|
||||||
|
|
||||||
logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context))
|
logger.info("\nQUERY: issues with higher ed\n" + format_context(facts))
|
||||||
graphiti.close()
|
graphiti.close()
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user