Files
graphiti/core/search/search_utils.py
Daniel Chalef 50da9d0f31 format and linting (#18)
* Makefile and format

* fix podcast stuff

* refactor: update import statement for transcript_parser in podcast_runner.py

* format and linting

* chore: Update import statements and remove unused code in maintenance module
2024-08-22 12:26:13 -07:00

340 lines
9.7 KiB
Python

import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from core.edges import EntityEdge
from core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
""",
uuids=episode_uuids,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
)
return nodes
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN DISTINCT
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context = {}
for record in records:
n_uuid = record["source_node_uuid"]
if n_uuid in context:
context[n_uuid]["facts"].append(record["fact"])
else:
context[n_uuid] = {
"name": record["source_name"],
"summary": record["source_summary"],
"facts": [record["fact"]],
}
m_uuid = record["target_node_uuid"]
if m_uuid not in context:
context[m_uuid] = {
"name": record["target_name"],
"summary": record["target_summary"],
"facts": [],
}
logger.info(f"bfs search returned context: {context}")
return context
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
search_vector=search_vector,
limit=limit,
)
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record["uuid"],
source_node_uuid=record["source_node_uuid"],
target_node_uuid=record["target_node_uuid"],
fact=record["fact"],
name=record["name"],
episodes=record["episodes"],
fact_embedding=record["fact_embedding"],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
)
edges.append(edge)
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
)
return nodes
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + "~"
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
)
return nodes
async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + "~"
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record["uuid"],
source_node_uuid=record["source_node_uuid"],
target_node_uuid=record["target_node_uuid"],
fact=record["fact"],
name=record["name"],
episodes=record["episodes"],
fact_embedding=record["fact_embedding"],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
)
edges.append(edge)
return edges
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
)
for result in results:
for node in result:
if node.uuid in relevant_node_uuids:
continue
relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
end = time()
logger.info(
f"Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms"
)
return relevant_nodes
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
)
for result in results:
for edge in result:
if edge.uuid in relevant_edge_uuids:
continue
relevant_edge_uuids.add(edge.uuid)
relevant_edges.append(edge)
end = time()
logger.info(
f"Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms"
)
return relevant_edges
# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const)
scored_uuids = [term for term in scores.items()]
scored_uuids.sort(reverse=True, key=lambda term: term[1])
sorted_uuids = [term[0] for term in scored_uuids]
return sorted_uuids