Files
graphiti/core/search/search.py
Daniel Chalef 50da9d0f31 format and linting (#18)
* Makefile and format

* fix podcast stuff

* refactor: update import statement for transcript_parser in podcast_runner.py

* format and linting

* chore: Update import statements and remove unused code in maintenance module
2024-08-22 12:26:13 -07:00

104 lines
2.8 KiB
Python

import logging
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from pydantic import BaseModel
from core.edges import Edge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.search.search_utils import (
edge_fulltext_search,
edge_similarity_search,
get_mentioned_nodes,
rrf,
)
from core.utils import retrieve_episodes
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
logger = logging.getLogger(__name__)
class SearchConfig(BaseModel):
num_results: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = "cosine"
text_search: str = "BM25"
reranker: str = "rrf"
async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
start = time()
episodes = []
nodes = []
edges = []
search_results = []
if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))
if config.text_search == "BM25":
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)
if config.similarity_search == "cosine":
query_text = query.replace("\n", " ")
search_vector = (
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)
if len(search_results) == 1:
edges = search_results[0]
elif len(search_results) > 1 and not config.reranker == "rrf":
logger.exception("Multiple searches enabled without a reranker")
raise Exception("Multiple searches enabled without a reranker")
elif config.reranker == "rrf":
edge_uuid_map = {}
search_result_uuids = []
logger.info([[edge.fact for edge in result] for result in search_results])
for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge
search_result_uuids.append(result_uuids)
search_result_uuids = [
[edge.uuid for edge in result] for result in search_results
]
reranked_uuids = rrf(search_result_uuids)
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)
context = {
"episodes": episodes,
"nodes": nodes,
"edges": edges,
}
end = time()
logger.info(
f"search returned context for query {query} in {(end - start) * 1000} ms"
)
return context