mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
Add episode refactor (#85)
* temp commit while moving * fix name embedding bug * invalidation * format * tests on runner examples * format * ellipsis * ruff * fix * format * minor prompt change
This commit is contained in:
committed by
GitHub
parent
1d31442751
commit
299021173b
@@ -94,7 +94,7 @@ async def main():
|
||||
|
||||
async def ingest_products_data(client: Graphiti):
|
||||
script_dir = Path(__file__).parent
|
||||
json_file_path = script_dir / 'allbirds_products.json'
|
||||
json_file_path = script_dir / '../data/manybirds_products.json'
|
||||
|
||||
with open(json_file_path) as file:
|
||||
products = json.load(file)['products']
|
||||
@@ -110,7 +110,14 @@ async def ingest_products_data(client: Graphiti):
|
||||
for i, product in enumerate(products)
|
||||
]
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
for episode in episodes:
|
||||
await client.add_episode(
|
||||
episode.name,
|
||||
episode.content,
|
||||
episode.source_description,
|
||||
episode.reference_time,
|
||||
episode.source,
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
|
||||
await client.add_episode_bulk(episodes)
|
||||
|
||||
|
||||
asyncio.run(main(True))
|
||||
asyncio.run(main(False))
|
||||
|
||||
@@ -59,11 +59,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
invalidate_edges,
|
||||
prepare_edges_for_invalidation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -293,7 +288,7 @@ class Graphiti:
|
||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
||||
)
|
||||
|
||||
# Resolve extracted nodes with nodes already in the graph
|
||||
# Resolve extracted nodes with nodes already in the graph and extract facts
|
||||
existing_nodes_lists: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
|
||||
@@ -302,22 +297,27 @@ class Graphiti:
|
||||
|
||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
|
||||
mentioned_nodes, _ = await resolve_extracted_nodes(
|
||||
self.llm_client, extracted_nodes, existing_nodes_lists
|
||||
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
||||
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
||||
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
|
||||
)
|
||||
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
||||
nodes.extend(mentioned_nodes)
|
||||
|
||||
# Extract facts as edges given entity nodes
|
||||
extracted_edges = await extract_edges(
|
||||
self.llm_client, episode, mentioned_nodes, previous_episodes
|
||||
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
||||
extracted_edges, uuid_map
|
||||
)
|
||||
|
||||
# calculate embeddings
|
||||
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
|
||||
await asyncio.gather(
|
||||
*[
|
||||
edge.generate_embedding(embedder)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
|
||||
# Resolve extracted edges with edges already in the graph
|
||||
existing_edges_list: list[list[EntityEdge]] = list(
|
||||
# Resolve extracted edges with related edges already in the graph
|
||||
related_edges_list: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
@@ -327,74 +327,66 @@ class Graphiti:
|
||||
edge.target_node_uuid,
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
)
|
||||
for edge in extracted_edges
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
|
||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
||||
)
|
||||
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
|
||||
|
||||
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
|
||||
self.llm_client, extracted_edges, existing_edges_list
|
||||
logger.info(
|
||||
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
||||
)
|
||||
|
||||
# Extract dates for the newly extracted edges
|
||||
edge_dates = await asyncio.gather(
|
||||
*[
|
||||
extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
for edge in deduped_edges
|
||||
]
|
||||
existing_source_edges_list: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
[edge],
|
||||
edge.source_node_uuid,
|
||||
None,
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
for i, edge in enumerate(deduped_edges):
|
||||
valid_at = edge_dates[i][0]
|
||||
invalid_at = edge_dates[i][1]
|
||||
existing_target_edges_list: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
[edge],
|
||||
None,
|
||||
edge.target_node_uuid,
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at is not None:
|
||||
edge.expired_at = now
|
||||
|
||||
entity_edges.extend(deduped_edges)
|
||||
|
||||
existing_edges: list[EntityEdge] = [
|
||||
e for edge_lst in existing_edges_list for e in edge_lst
|
||||
existing_edges_list: list[list[EntityEdge]] = [
|
||||
source_lst + target_lst
|
||||
for source_lst, target_lst in zip(
|
||||
existing_source_edges_list, existing_target_edges_list
|
||||
)
|
||||
]
|
||||
|
||||
(
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
) = prepare_edges_for_invalidation(
|
||||
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||
self.llm_client,
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
extracted_edges_with_resolved_pointers,
|
||||
related_edges_list,
|
||||
existing_edges_list,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
|
||||
for edge in invalidated_edges:
|
||||
for existing_edge in existing_edges:
|
||||
if existing_edge.uuid == edge.uuid:
|
||||
existing_edge.expired_at = edge.expired_at
|
||||
for deduped_edge in deduped_edges:
|
||||
if deduped_edge.uuid == edge.uuid:
|
||||
deduped_edge.expired_at = edge.expired_at
|
||||
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
||||
entity_edges.extend(resolved_edges + invalidated_edges)
|
||||
|
||||
entity_edges.extend(existing_edges)
|
||||
|
||||
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
|
||||
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
||||
|
||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
||||
mentioned_nodes,
|
||||
|
||||
@@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
|
||||
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
||||
|
||||
Existing Edges:
|
||||
{json.dumps(context['existing_edges'], indent=2)}
|
||||
{json.dumps(context['related_edges'], indent=2)}
|
||||
|
||||
New Edge:
|
||||
{json.dumps(context['extracted_edges'], indent=2)}
|
||||
|
||||
@@ -21,10 +21,12 @@ from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
@@ -71,4 +73,38 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {'v1': v1}
|
||||
def v2(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
|
||||
|
||||
Existing Edges:
|
||||
{context['existing_edges']}
|
||||
|
||||
New Edge:
|
||||
{context['new_edge']}
|
||||
|
||||
|
||||
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
||||
{{
|
||||
"invalidated_edges": [
|
||||
{{
|
||||
"uuid": "The UUID of the edge to be invalidated",
|
||||
"fact": "Updated fact of the edge"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {'v1': v1, 'v2': v2}
|
||||
|
||||
@@ -96,11 +96,11 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
|
||||
|
||||
async def edge_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
@@ -211,7 +211,7 @@ async def edge_similarity_search(
|
||||
|
||||
|
||||
async def entity_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
@@ -247,7 +247,7 @@ async def entity_similarity_search(
|
||||
|
||||
|
||||
async def entity_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
@@ -284,11 +284,11 @@ async def entity_fulltext_search(
|
||||
|
||||
|
||||
async def edge_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
cypher_query = Query("""
|
||||
@@ -401,10 +401,10 @@ async def edge_fulltext_search(
|
||||
|
||||
|
||||
async def hybrid_node_search(
|
||||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
driver: AsyncDriver,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
driver: AsyncDriver,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Perform a hybrid search for nodes using both text queries and embeddings.
|
||||
@@ -466,8 +466,8 @@ async def hybrid_node_search(
|
||||
|
||||
|
||||
async def get_relevant_nodes(
|
||||
nodes: list[EntityNode],
|
||||
driver: AsyncDriver,
|
||||
nodes: list[EntityNode],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Retrieve relevant nodes based on the provided list of EntityNodes.
|
||||
@@ -503,11 +503,11 @@ async def get_relevant_nodes(
|
||||
|
||||
|
||||
async def get_relevant_edges(
|
||||
driver: AsyncDriver,
|
||||
edges: list[EntityEdge],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
driver: AsyncDriver,
|
||||
edges: list[EntityEdge],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
relevant_edges: list[EntityEdge] = []
|
||||
@@ -557,7 +557,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||
|
||||
|
||||
async def node_distance_reranker(
|
||||
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
||||
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
||||
) -> list[str]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(results)
|
||||
@@ -579,8 +579,8 @@ async def node_distance_reranker(
|
||||
|
||||
for record in records:
|
||||
if (
|
||||
record['source_uuid'] == center_node_uuid
|
||||
or record['target_uuid'] == center_node_uuid
|
||||
record['source_uuid'] == center_node_uuid
|
||||
or record['target_uuid'] == center_node_uuid
|
||||
):
|
||||
continue
|
||||
distance = record['score']
|
||||
|
||||
@@ -24,6 +24,10 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
get_edge_contradictions,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -149,28 +153,110 @@ async def dedupe_extracted_edges(
|
||||
async def resolve_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
related_edges_lists: list[list[EntityEdge]],
|
||||
existing_edges_lists: list[list[EntityEdge]],
|
||||
) -> list[EntityEdge]:
|
||||
resolved_edges: list[EntityEdge] = list(
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
|
||||
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
|
||||
resolve_extracted_edge(
|
||||
llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
existing_edges,
|
||||
current_episode,
|
||||
previous_episodes,
|
||||
)
|
||||
for extracted_edge, related_edges, existing_edges in zip(
|
||||
extracted_edges, related_edges_lists, existing_edges_lists
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return resolved_edges
|
||||
resolved_edges: list[EntityEdge] = []
|
||||
invalidated_edges: list[EntityEdge] = []
|
||||
for result in results:
|
||||
resolved_edge = result[0]
|
||||
invalidated_edge_chunk = result[1]
|
||||
|
||||
resolved_edges.append(resolved_edge)
|
||||
invalidated_edges.extend(invalidated_edge_chunk)
|
||||
|
||||
return resolved_edges, invalidated_edges
|
||||
|
||||
|
||||
async def resolve_extracted_edge(
|
||||
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
|
||||
llm_client: LLMClient,
|
||||
extracted_edge: EntityEdge,
|
||||
related_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
||||
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
||||
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
||||
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
||||
if invalid_at is not None and resolved_edge.expired_at is None:
|
||||
resolved_edge.expired_at = now
|
||||
|
||||
# Determine if the new_edge needs to be expired
|
||||
if resolved_edge.expired_at is None:
|
||||
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
||||
for candidate in invalidation_candidates:
|
||||
if (
|
||||
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
||||
) and candidate.valid_at > resolved_edge.valid_at:
|
||||
# Expire new edge since we have information about more recent events
|
||||
resolved_edge.invalid_at = candidate.valid_at
|
||||
resolved_edge.expired_at = now
|
||||
break
|
||||
|
||||
# Determine which contradictory edges need to be expired
|
||||
invalidated_edges: list[EntityEdge] = []
|
||||
for edge in invalidation_candidates:
|
||||
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
|
||||
if (
|
||||
edge.invalid_at is not None
|
||||
and resolved_edge.valid_at is not None
|
||||
and edge.invalid_at < resolved_edge.valid_at
|
||||
) or (
|
||||
edge.valid_at is not None
|
||||
and resolved_edge.invalid_at is not None
|
||||
and resolved_edge.invalid_at < edge.valid_at
|
||||
):
|
||||
continue
|
||||
# New edge invalidates edge
|
||||
elif (
|
||||
edge.valid_at is not None
|
||||
and resolved_edge.valid_at is not None
|
||||
and edge.valid_at < resolved_edge.valid_at
|
||||
):
|
||||
edge.invalid_at = resolved_edge.valid_at
|
||||
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
|
||||
invalidated_edges.append(edge)
|
||||
|
||||
return resolved_edge, invalidated_edges
|
||||
|
||||
|
||||
async def dedupe_extracted_edge(
|
||||
llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
|
||||
) -> EntityEdge:
|
||||
start = time()
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_edges_context = [
|
||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
||||
related_edges_context = [
|
||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
|
||||
]
|
||||
|
||||
extracted_edge_context = {
|
||||
@@ -180,7 +266,7 @@ async def resolve_extracted_edge(
|
||||
}
|
||||
|
||||
context = {
|
||||
'existing_edges': existing_edges_context,
|
||||
'related_edges': related_edges_context,
|
||||
'extracted_edges': extracted_edge_context,
|
||||
}
|
||||
|
||||
@@ -191,14 +277,14 @@ async def resolve_extracted_edge(
|
||||
|
||||
edge = extracted_edge
|
||||
if is_duplicate:
|
||||
for existing_edge in existing_edges:
|
||||
for existing_edge in related_edges:
|
||||
if existing_edge.uuid != uuid:
|
||||
continue
|
||||
edge = existing_edge
|
||||
|
||||
end = time()
|
||||
logger.info(
|
||||
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
||||
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
||||
)
|
||||
|
||||
return edge
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
@@ -181,3 +182,36 @@ async def extract_edge_dates(
|
||||
logger.info(f'Edge date extraction explanation: {explanation}')
|
||||
|
||||
return valid_at_datetime, invalid_at_datetime
|
||||
|
||||
|
||||
async def get_edge_contradictions(
|
||||
llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge]
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
existing_edge_map = {edge.uuid: edge for edge in existing_edges}
|
||||
|
||||
new_edge_context = {'uuid': new_edge.uuid, 'name': new_edge.name, 'fact': new_edge.fact}
|
||||
existing_edge_context = [
|
||||
{'uuid': existing_edge.uuid, 'name': existing_edge.name, 'fact': existing_edge.fact}
|
||||
for existing_edge in existing_edges
|
||||
]
|
||||
|
||||
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v2(context))
|
||||
|
||||
contradicted_edge_data = llm_response.get('invalidated_edges', [])
|
||||
|
||||
contradicted_edges: list[EntityEdge] = []
|
||||
for edge_data in contradicted_edge_data:
|
||||
if edge_data['uuid'] in existing_edge_map:
|
||||
contradicted_edge = existing_edge_map[edge_data['uuid']]
|
||||
contradicted_edge.fact = edge_data['fact']
|
||||
contradicted_edges.append(contradicted_edge)
|
||||
|
||||
end = time()
|
||||
logger.info(
|
||||
f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms'
|
||||
)
|
||||
|
||||
return contradicted_edges
|
||||
|
||||
Reference in New Issue
Block a user