Files
graphiti/core/utils/maintenance/edge_operations.py
Daniel Chalef 50da9d0f31 format and linting (#18)
* Makefile and format

* fix podcast stuff

* refactor: update import statement for transcript_parser in podcast_runner.py

* format and linting

* chore: Update import statements and remove unused code in maintenance module
2024-08-22 12:26:13 -07:00

266 lines
8.1 KiB
Python

import json
import logging
from datetime import datetime
from time import time
from typing import List
from core.edges import EntityEdge, EpisodicEdge
from core.llm_client import LLMClient
from core.nodes import EntityNode, EpisodicNode
from core.prompts import prompt_library
logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: List[EntityNode],
episode: EpisodicNode,
created_at: datetime,
) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = []
for node in entity_nodes:
edge = EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=created_at,
)
edges.append(edge)
return edges
async def extract_new_edges(
llm_client: LLMClient,
episode: EpisodicNode,
new_nodes: list[EntityNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> tuple[list[EntityEdge], list[EntityNode]]:
# Prepare context for LLM
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"relevant_schema": json.dumps(relevant_schema, indent=2),
"new_nodes": [
{"name": node.name, "summary": node.summary} for node in new_nodes
],
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.v1(context)
)
new_edges_data = llm_response.get("new_edges", [])
logger.info(f"Extracted new edges: {new_edges_data}")
# Convert the extracted data into EntityEdge objects
new_edges = []
for edge_data in new_edges_data:
source_node = next(
(node for node in new_nodes if node.name == edge_data["source_node"]),
None,
)
target_node = next(
(node for node in new_nodes if node.name == edge_data["target_node"]),
None,
)
# If source or target is not in new_nodes, check if it's an existing node
if source_node is None and edge_data["source_node"] in relevant_schema["nodes"]:
existing_node_data = relevant_schema["nodes"][edge_data["source_node"]]
source_node = EntityNode(
uuid=existing_node_data["uuid"],
name=edge_data["source_node"],
labels=[existing_node_data["label"]],
summary="",
created_at=datetime.now(),
)
if target_node is None and edge_data["target_node"] in relevant_schema["nodes"]:
existing_node_data = relevant_schema["nodes"][edge_data["target_node"]]
target_node = EntityNode(
uuid=existing_node_data["uuid"],
name=edge_data["target_node"],
labels=[existing_node_data["label"]],
summary="",
created_at=datetime.now(),
)
if (
source_node
and target_node
and not (
source_node.name.startswith("Message")
or target_node.name.startswith("Message")
)
):
valid_at = (
datetime.fromisoformat(edge_data["valid_at"])
if edge_data["valid_at"]
else episode.valid_at or datetime.now()
)
invalid_at = (
datetime.fromisoformat(edge_data["invalid_at"])
if edge_data["invalid_at"]
else None
)
new_edge = EntityEdge(
source_node=source_node,
target_node=target_node,
name=edge_data["relation_type"],
fact=edge_data["fact"],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=valid_at,
invalid_at=invalid_at,
)
new_edges.append(new_edge)
logger.info(
f"Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})"
)
affected_nodes = set()
for edge in new_edges:
affected_nodes.add(edge.source_node)
affected_nodes.add(edge.target_node)
return new_edges, list(affected_nodes)
async def extract_edges(
llm_client: LLMClient,
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
start = time()
# Prepare context for LLM
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"nodes": [
{"uuid": node.uuid, "name": node.name, "summary": node.summary}
for node in nodes
],
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.v2(context)
)
edges_data = llm_response.get("edges", [])
end = time()
logger.info(f"Extracted new edges: {edges_data} in {(end - start) * 1000} ms")
# Convert the extracted data into EntityEdge objects
edges = []
for edge_data in edges_data:
if edge_data["target_node_uuid"] and edge_data["source_node_uuid"]:
edge = EntityEdge(
source_node_uuid=edge_data["source_node_uuid"],
target_node_uuid=edge_data["target_node_uuid"],
name=edge_data["relation_type"],
fact=edge_data["fact"],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=None,
invalid_at=None,
)
edges.append(edge)
logger.info(
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
)
return edges
async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
) -> list[EntityEdge]:
# Create edge map
edge_map = {}
for edge in existing_edges:
edge_map[edge.fact] = edge
for edge in extracted_edges:
if edge.fact in edge_map:
continue
edge_map[edge.fact] = edge
# Prepare context for LLM
context = {
"extracted_edges": [
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
],
"existing_edges": [
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
],
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.v1(context)
)
new_edges_data = llm_response.get("new_edges", [])
logger.info(f"Extracted new edges: {new_edges_data}")
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data["fact"]]
edges.append(edge)
return edges
async def dedupe_edge_list(
llm_client: LLMClient,
edges: list[EntityEdge],
) -> list[EntityEdge]:
start = time()
# Create edge map
edge_map = {}
for edge in edges:
edge_map[edge.fact] = edge
# Prepare context for LLM
context = {"edges": [{"name": edge.name, "fact": edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context)
)
unique_edges_data = llm_response.get("unique_edges", [])
end = time()
logger.info(
f"Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms "
)
# Get full edge data
unique_edges = []
for edge_data in unique_edges_data:
fact = edge_data["fact"]
unique_edges.append(edge_map[fact])
return unique_edges