mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
add bulk temporal extraction and improve bulk quality and performance (#67)
* parallelize edge deduping more * parallelize node insertion more * improve bulk behavior performance * dedupe nodes actually works * add a reranker to search * bulk dedupe episodes only across the same nodes * add temporal extraction bulk function * cleaned up bulk * default to 4o * format * mypy * mympy * mypy ignore
This commit is contained in:
committed by
GitHub
parent
aac06d9d24
commit
35a4e5172b
@@ -29,6 +29,7 @@ from graphiti_core.llm_client.utils import generate_embedding
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
|
||||
from graphiti_core.search.search_utils import (
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
get_relevant_edges,
|
||||
get_relevant_nodes,
|
||||
hybrid_node_search,
|
||||
@@ -41,6 +42,7 @@ from graphiti_core.utils.bulk_utils import (
|
||||
RawEpisode,
|
||||
dedupe_edges_bulk,
|
||||
dedupe_nodes_bulk,
|
||||
extract_edge_dates_bulk,
|
||||
extract_nodes_and_edges_bulk,
|
||||
resolve_edge_pointers,
|
||||
retrieve_previous_episodes_bulk,
|
||||
@@ -319,26 +321,24 @@ class Graphiti:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode.valid_at,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = datetime.now()
|
||||
edge.expired_at = now
|
||||
for edge in existing_edges:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode.valid_at,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = datetime.now()
|
||||
edge.expired_at = now
|
||||
(
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
@@ -481,15 +481,18 @@ class Graphiti:
|
||||
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
||||
)
|
||||
|
||||
# Dedupe extracted nodes
|
||||
nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes)
|
||||
# Dedupe extracted nodes, compress extracted edges
|
||||
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
|
||||
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
||||
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
||||
)
|
||||
|
||||
# save nodes to KG
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
|
||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||
extracted_edges, uuid_map
|
||||
extracted_edges_timestamped, uuid_map
|
||||
)
|
||||
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
||||
episodic_edges, uuid_map
|
||||
@@ -579,7 +582,9 @@ class Graphiti:
|
||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||
)
|
||||
|
||||
async def get_nodes_by_query(self, query: str, limit: int | None = None) -> list[EntityNode]:
|
||||
async def get_nodes_by_query(
|
||||
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Retrieve nodes from the graph database based on a text query.
|
||||
|
||||
|
||||
@@ -53,7 +53,9 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||
1. start with the list of nodes from New Nodes
|
||||
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
|
||||
node in the list
|
||||
3. Respond with the resulting list of nodes
|
||||
3. when deduplicating nodes, synthesize their summaries into a short new summary that contains the relevant information
|
||||
of the summaries of the new and existing nodes
|
||||
4. Respond with the resulting list of nodes
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||
@@ -64,6 +66,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||
"new_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node",
|
||||
"summary": "Brief summary of the node's role or significance"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
@@ -92,6 +95,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
|
||||
Task:
|
||||
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
||||
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
|
||||
relevant information of the summaries of the new and existing nodes.
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||
@@ -104,7 +109,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||
"duplicates": [
|
||||
{{
|
||||
"name": "name of the new node",
|
||||
"duplicate_of": "name of the existing node"
|
||||
"duplicate_of": "name of the existing node",
|
||||
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
@@ -130,6 +136,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||
Task:
|
||||
1. Group nodes together such that all duplicate nodes are in the same list of names
|
||||
2. All duplicate names should be grouped together in the same list
|
||||
3. Also return a new summary that synthesizes the summary into a new short summary
|
||||
|
||||
Guidelines:
|
||||
1. Each name from the list of nodes should appear EXACTLY once in your response
|
||||
@@ -140,6 +147,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||
"nodes": [
|
||||
{{
|
||||
"names": ["myNode", "node that is a duplicate of myNode"],
|
||||
"summary": "Brief summary of the node summaries that appear in the list of names."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
@@ -110,10 +110,11 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||
|
||||
Guidelines:
|
||||
1. Create edges only between the provided nodes.
|
||||
2. Each edge should represent a clear relationship between two nodes.
|
||||
2. Each edge should represent a clear relationship between two DISTINCT nodes.
|
||||
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
||||
4. Provide a more detailed fact describing the relationship.
|
||||
5. Consider temporal aspects of relationships when relevant.
|
||||
6. Avoid using the same node as the source and target of a relationship
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
|
||||
@@ -63,12 +63,12 @@ class SearchResults(BaseModel):
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
|
||||
@@ -268,13 +268,13 @@ async def hybrid_node_search(
|
||||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
driver: AsyncDriver,
|
||||
limit: int | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Perform a hybrid search for nodes using both text queries and embeddings.
|
||||
|
||||
This method combines fulltext search and vector similarity search to find
|
||||
relevant nodes in the graph database.
|
||||
relevant nodes in the graph database. It uses an rrf reranker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -307,27 +307,25 @@ async def hybrid_node_search(
|
||||
"""
|
||||
|
||||
start = time()
|
||||
relevant_nodes: list[EntityNode] = []
|
||||
relevant_node_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
|
||||
*[
|
||||
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
|
||||
for e in embeddings
|
||||
],
|
||||
results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
|
||||
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
|
||||
)
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for node in result:
|
||||
if node.uuid in relevant_node_uuids:
|
||||
continue
|
||||
node_uuid_map: dict[str, EntityNode] = {
|
||||
node.uuid: node for result in results for node in result
|
||||
}
|
||||
result_uuids = [[node.uuid for node in result] for result in results]
|
||||
|
||||
relevant_node_uuids.add(node.uuid)
|
||||
relevant_nodes.append(node)
|
||||
ranked_uuids = rrf(result_uuids)
|
||||
|
||||
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
||||
|
||||
end = time()
|
||||
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
|
||||
logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
|
||||
return relevant_nodes
|
||||
|
||||
|
||||
|
||||
@@ -15,11 +15,14 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from numpy import dot
|
||||
from numpy import dot, sqrt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
@@ -39,8 +42,11 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||
dedupe_node_list,
|
||||
extract_nodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
||||
|
||||
CHUNK_SIZE = 15
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHUNK_SIZE = 10
|
||||
|
||||
|
||||
class RawEpisode(BaseModel):
|
||||
@@ -52,7 +58,7 @@ class RawEpisode(BaseModel):
|
||||
|
||||
|
||||
async def retrieve_previous_episodes_bulk(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||
previous_episodes_list = await asyncio.gather(
|
||||
*[
|
||||
@@ -68,7 +74,7 @@ async def retrieve_previous_episodes_bulk(
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||
extracted_nodes_bulk = await asyncio.gather(
|
||||
*[
|
||||
@@ -105,36 +111,67 @@ async def extract_nodes_and_edges_bulk(
|
||||
|
||||
|
||||
async def dedupe_nodes_bulk(
|
||||
driver: AsyncDriver,
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
driver: AsyncDriver,
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
# Compress nodes
|
||||
nodes, uuid_map = node_name_match(extracted_nodes)
|
||||
|
||||
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
||||
|
||||
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
|
||||
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||
|
||||
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
|
||||
llm_client, compressed_nodes, existing_nodes
|
||||
existing_nodes_chunks: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks]
|
||||
)
|
||||
)
|
||||
|
||||
compressed_map.update(partial_uuid_map)
|
||||
results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
||||
for i, node_chunk in enumerate(node_chunks)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return nodes, compressed_map
|
||||
final_nodes: list[EntityNode] = []
|
||||
for result in results:
|
||||
final_nodes.extend(result[0])
|
||||
partial_uuid_map = result[1]
|
||||
compressed_map.update(partial_uuid_map)
|
||||
|
||||
return final_nodes, compressed_map
|
||||
|
||||
|
||||
async def dedupe_edges_bulk(
|
||||
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||
) -> list[EntityEdge]:
|
||||
# Compress edges
|
||||
# First compress edges
|
||||
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
||||
|
||||
existing_edges = await get_relevant_edges(compressed_edges, driver)
|
||||
edge_chunks = [
|
||||
compressed_edges[i: i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
|
||||
]
|
||||
|
||||
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
|
||||
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks]
|
||||
)
|
||||
)
|
||||
|
||||
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
||||
for i, edge_chunk in enumerate(edge_chunks)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
|
||||
return edges
|
||||
|
||||
|
||||
@@ -152,15 +189,60 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
|
||||
|
||||
|
||||
async def compress_nodes(
|
||||
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
||||
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
|
||||
if len(nodes) == 0:
|
||||
return nodes, uuid_map
|
||||
|
||||
anchor = nodes[0]
|
||||
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
|
||||
# Our approach involves us deduplicating chunks of nodes in parallel.
|
||||
# We want n chunks of size n so that n ** 2 == len(nodes).
|
||||
# We want chunk sizes to be at least 10 for optimizing LLM processing time
|
||||
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
|
||||
|
||||
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
||||
# First calculate similarity scores between nodes
|
||||
similarity_scores: list[tuple[int, int, float]] = [
|
||||
(i, j, dot(n.name_embedding or [], m.name_embedding or []))
|
||||
for i, n in enumerate(nodes)
|
||||
for j, m in enumerate(nodes[:i])
|
||||
]
|
||||
|
||||
# We now sort by semantic similarity
|
||||
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
|
||||
|
||||
# initialize our chunks based on chunk size
|
||||
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
|
||||
|
||||
# Draft the most similar nodes into the same chunk
|
||||
while len(similarity_scores) > 0:
|
||||
i, j, _ = similarity_scores.pop()
|
||||
# determine if any of the nodes have already been drafted into a chunk
|
||||
n = nodes[i]
|
||||
m = nodes[j]
|
||||
# make sure the shortest chunks get preference
|
||||
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
|
||||
|
||||
n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
||||
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
||||
|
||||
# both nodes already in a chunk
|
||||
if n_chunk > -1 and m_chunk > -1:
|
||||
continue
|
||||
|
||||
# n has a chunk and that chunk is not full
|
||||
elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
|
||||
# put m in the same chunk as n
|
||||
node_chunks[n_chunk].append(m)
|
||||
|
||||
# m has a chunk and that chunk is not full
|
||||
elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
|
||||
# put n in the same chunk as m
|
||||
node_chunks[m_chunk].append(n)
|
||||
|
||||
# neither node has a chunk or the chunk is full
|
||||
else:
|
||||
# add both nodes to the shortest chunk
|
||||
node_chunks[-1].extend([n, m])
|
||||
|
||||
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
||||
|
||||
@@ -181,13 +263,21 @@ async def compress_nodes(
|
||||
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
||||
if len(edges) == 0:
|
||||
return edges
|
||||
# We only want to dedupe edges that are between the same pair of nodes
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
# We drop loop edges
|
||||
if edge.source_node_uuid == edge.target_node_uuid:
|
||||
continue
|
||||
|
||||
anchor = edges[0]
|
||||
edges.sort(
|
||||
key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])
|
||||
)
|
||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||
pointers.sort()
|
||||
|
||||
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||
|
||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||
|
||||
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
||||
|
||||
@@ -225,3 +315,43 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
||||
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def extract_edge_dates_bulk(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
) -> list[EntityEdge]:
|
||||
edges: list[EntityEdge] = []
|
||||
# confirm that all of our edges have at least one episode
|
||||
for edge in extracted_edges:
|
||||
if edge.episodes is not None and len(edge.episodes) > 0:
|
||||
edges.append(edge)
|
||||
|
||||
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
|
||||
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
||||
}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
extract_edge_dates(
|
||||
llm_client,
|
||||
edge,
|
||||
episode_uuid_map[edge.episodes[0]][0], # type: ignore
|
||||
episode_uuid_map[edge.episodes[0]][1], # type: ignore
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
valid_at = result[0]
|
||||
invalid_at = result[1]
|
||||
edge = edges[i]
|
||||
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = datetime.now()
|
||||
|
||||
return edges
|
||||
|
||||
@@ -189,6 +189,7 @@ async def dedupe_node_list(
|
||||
uuid_map: dict[str, str] = {}
|
||||
for node_data in nodes_data:
|
||||
node = node_map[node_data['names'][0]]
|
||||
node.summary = node_data['summary']
|
||||
unique_nodes.append(node)
|
||||
|
||||
for name in node_data['names'][1:]:
|
||||
|
||||
@@ -147,7 +147,6 @@ def process_edge_invalidation_llm_response(
|
||||
async def extract_edge_dates(
|
||||
llm_client: LLMClient,
|
||||
edge: EntityEdge,
|
||||
reference_time: datetime,
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: List[EpisodicNode],
|
||||
) -> tuple[datetime | None, datetime | None, str]:
|
||||
@@ -156,7 +155,7 @@ async def extract_edge_dates(
|
||||
'edge_fact': edge.fact,
|
||||
'current_episode': current_episode.content,
|
||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||
'reference_timestamp': reference_time.isoformat(),
|
||||
'reference_timestamp': current_episode.valid_at.isoformat(),
|
||||
}
|
||||
llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user