mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
from fastapi import APIRouter, status
|
|
|
|
from graph_service.dto import (
|
|
GetMemoryRequest,
|
|
GetMemoryResponse,
|
|
Message,
|
|
SearchQuery,
|
|
SearchResults,
|
|
)
|
|
from graph_service.zep_graphiti import ZepGraphitiDep, get_fact_result_from_edge
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post('/search', status_code=status.HTTP_200_OK)
|
|
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
|
|
center_node_uuid: str | None = None
|
|
if query.search_type == 'user_centered_facts':
|
|
user_node = await graphiti.get_user_node(query.group_id)
|
|
if user_node:
|
|
center_node_uuid = user_node.uuid
|
|
relevant_edges = await graphiti.search(
|
|
group_ids=[query.group_id],
|
|
query=query.query,
|
|
num_results=query.max_facts,
|
|
center_node_uuid=center_node_uuid,
|
|
)
|
|
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
|
|
return SearchResults(
|
|
facts=facts,
|
|
)
|
|
|
|
|
|
@router.post('/get-memory', status_code=status.HTTP_200_OK)
|
|
async def get_memory(
|
|
request: GetMemoryRequest,
|
|
graphiti: ZepGraphitiDep,
|
|
):
|
|
combined_query = compose_query_from_messages(request.messages)
|
|
result = await graphiti.search(
|
|
group_ids=[request.group_id],
|
|
query=combined_query,
|
|
num_results=request.max_facts,
|
|
)
|
|
facts = [get_fact_result_from_edge(edge) for edge in result]
|
|
return GetMemoryResponse(facts=facts)
|
|
|
|
|
|
def compose_query_from_messages(messages: list[Message]):
|
|
combined_query = ''
|
|
for message in messages:
|
|
combined_query += f"{message.role_type or ''}({message.role or ''}): {message.content}\n"
|
|
return combined_query
|