mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
* Makefile and format * fix podcast stuff * refactor: update import statement for transcript_parser in podcast_runner.py * format and linting * chore: Update import statements and remove unused code in maintenance module
104 lines
2.8 KiB
Python
104 lines
2.8 KiB
Python
import logging
|
|
from datetime import datetime
|
|
from time import time
|
|
|
|
from neo4j import AsyncDriver
|
|
from pydantic import BaseModel
|
|
|
|
from core.edges import Edge
|
|
from core.llm_client.config import EMBEDDING_DIM
|
|
from core.nodes import Node
|
|
from core.search.search_utils import (
|
|
edge_fulltext_search,
|
|
edge_similarity_search,
|
|
get_mentioned_nodes,
|
|
rrf,
|
|
)
|
|
from core.utils import retrieve_episodes
|
|
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SearchConfig(BaseModel):
|
|
num_results: int = 10
|
|
num_episodes: int = EPISODE_WINDOW_LEN
|
|
similarity_search: str = "cosine"
|
|
text_search: str = "BM25"
|
|
reranker: str = "rrf"
|
|
|
|
|
|
async def hybrid_search(
|
|
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
|
) -> dict[str, [Node | Edge]]:
|
|
start = time()
|
|
|
|
episodes = []
|
|
nodes = []
|
|
edges = []
|
|
|
|
search_results = []
|
|
|
|
if config.num_episodes > 0:
|
|
episodes.extend(await retrieve_episodes(driver, timestamp))
|
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
|
|
|
if config.text_search == "BM25":
|
|
text_search = await edge_fulltext_search(query, driver)
|
|
search_results.append(text_search)
|
|
|
|
if config.similarity_search == "cosine":
|
|
query_text = query.replace("\n", " ")
|
|
search_vector = (
|
|
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
|
|
.data[0]
|
|
.embedding[:EMBEDDING_DIM]
|
|
)
|
|
|
|
similarity_search = await edge_similarity_search(search_vector, driver)
|
|
search_results.append(similarity_search)
|
|
|
|
if len(search_results) == 1:
|
|
edges = search_results[0]
|
|
|
|
elif len(search_results) > 1 and not config.reranker == "rrf":
|
|
logger.exception("Multiple searches enabled without a reranker")
|
|
raise Exception("Multiple searches enabled without a reranker")
|
|
|
|
elif config.reranker == "rrf":
|
|
edge_uuid_map = {}
|
|
search_result_uuids = []
|
|
|
|
logger.info([[edge.fact for edge in result] for result in search_results])
|
|
|
|
for result in search_results:
|
|
result_uuids = []
|
|
for edge in result:
|
|
result_uuids.append(edge.uuid)
|
|
edge_uuid_map[edge.uuid] = edge
|
|
|
|
search_result_uuids.append(result_uuids)
|
|
|
|
search_result_uuids = [
|
|
[edge.uuid for edge in result] for result in search_results
|
|
]
|
|
|
|
reranked_uuids = rrf(search_result_uuids)
|
|
|
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
|
edges.extend(reranked_edges)
|
|
|
|
context = {
|
|
"episodes": episodes,
|
|
"nodes": nodes,
|
|
"edges": edges,
|
|
}
|
|
|
|
end = time()
|
|
|
|
logger.info(
|
|
f"search returned context for query {query} in {(end - start) * 1000} ms"
|
|
)
|
|
|
|
return context
|