Files
graphiti/server/graph_service/routers/retrieve.py
2024-09-06 15:37:19 -04:00

54 lines
1.6 KiB
Python

from fastapi import APIRouter, status
from graph_service.dto import (
GetMemoryRequest,
GetMemoryResponse,
Message,
SearchQuery,
SearchResults,
)
from graph_service.zep_graphiti import ZepGraphitiDep, get_fact_result_from_edge
router = APIRouter()
@router.post('/search', status_code=status.HTTP_200_OK)
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
center_node_uuid: str | None = None
if query.search_type == 'user_centered_facts':
user_node = await graphiti.get_user_node(query.group_id)
if user_node:
center_node_uuid = user_node.uuid
relevant_edges = await graphiti.search(
group_ids=[query.group_id],
query=query.query,
num_results=query.max_facts,
center_node_uuid=center_node_uuid,
)
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
return SearchResults(
facts=facts,
)
@router.post('/get-memory', status_code=status.HTTP_200_OK)
async def get_memory(
request: GetMemoryRequest,
graphiti: ZepGraphitiDep,
):
combined_query = compose_query_from_messages(request.messages)
result = await graphiti.search(
group_ids=[request.group_id],
query=combined_query,
num_results=request.max_facts,
)
facts = [get_fact_result_from_edge(edge) for edge in result]
return GetMemoryResponse(facts=facts)
def compose_query_from_messages(messages: list[Message]):
combined_query = ''
for message in messages:
combined_query += f"{message.role_type or ''}({message.role or ''}): {message.content}\n"
return combined_query