Add episode refactor (#85)

* temp commit while moving

* fix name embedding bug

* invalidation

* format

* tests on runner examples

* format

* ellipsis

* ruff

* fix

* format

* minor prompt change
This commit is contained in:
Preston Rasmussen
2024-09-05 12:05:44 -04:00
committed by GitHub
parent 1d31442751
commit 299021173b
8 changed files with 261 additions and 106 deletions

View File

@@ -94,7 +94,7 @@ async def main():
async def ingest_products_data(client: Graphiti):
script_dir = Path(__file__).parent
json_file_path = script_dir / 'allbirds_products.json'
json_file_path = script_dir / '../data/manybirds_products.json'
with open(json_file_path) as file:
products = json.load(file)['products']
@@ -110,7 +110,14 @@ async def ingest_products_data(client: Graphiti):
for i, product in enumerate(products)
]
await client.add_episode_bulk(episodes)
for episode in episodes:
await client.add_episode(
episode.name,
episode.content,
episode.source_description,
episode.reference_time,
episode.source,
)
asyncio.run(main())

View File

@@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
await client.add_episode_bulk(episodes)
asyncio.run(main(True))
asyncio.run(main(False))

View File

@@ -59,11 +59,6 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
invalidate_edges,
prepare_edges_for_invalidation,
)
logger = logging.getLogger(__name__)
@@ -293,7 +288,7 @@ class Graphiti:
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)
# Resolve extracted nodes with nodes already in the graph
# Resolve extracted nodes with nodes already in the graph and extract facts
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
@@ -302,22 +297,27 @@ class Graphiti:
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
mentioned_nodes, _ = await resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)
# Extract facts as edges given entity nodes
extracted_edges = await extract_edges(
self.llm_client, episode, mentioned_nodes, previous_episodes
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
extracted_edges, uuid_map
)
# calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
await asyncio.gather(
*[
edge.generate_embedding(embedder)
for edge in extracted_edges_with_resolved_pointers
]
)
# Resolve extracted edges with edges already in the graph
existing_edges_list: list[list[EntityEdge]] = list(
# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
@@ -327,74 +327,66 @@ class Graphiti:
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges
for edge in extracted_edges_with_resolved_pointers
]
)
)
logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
)
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, extracted_edges, existing_edges_list
logger.info(
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
)
# Extract dates for the newly extracted edges
edge_dates = await asyncio.gather(
*[
extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
for edge in deduped_edges
]
existing_source_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
None,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
for i, edge in enumerate(deduped_edges):
valid_at = edge_dates[i][0]
invalid_at = edge_dates[i][1]
existing_target_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
self.driver,
[edge],
None,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at is not None:
edge.expired_at = now
entity_edges.extend(deduped_edges)
existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
existing_edges_list: list[list[EntityEdge]] = [
source_lst + target_lst
for source_lst, target_lst in zip(
existing_source_edges_list, existing_target_edges_list
)
]
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
) = prepare_edges_for_invalidation(
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
)
invalidated_edges = await invalidate_edges(
resolved_edges, invalidated_edges = await resolve_extracted_edges(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
extracted_edges_with_resolved_pointers,
related_edges_list,
existing_edges_list,
episode,
previous_episodes,
)
for edge in invalidated_edges:
for existing_edge in existing_edges:
if existing_edge.uuid == edge.uuid:
existing_edge.expired_at = edge.expired_at
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
entity_edges.extend(resolved_edges + invalidated_edges)
entity_edges.extend(existing_edges)
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,

View File

@@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
{json.dumps(context['related_edges'], indent=2)}
New Edge:
{json.dumps(context['extracted_edges'], indent=2)}

View File

@@ -21,10 +21,12 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]:
@@ -71,4 +73,38 @@ def v1(context: dict[str, Any]) -> list[Message]:
]
versions: Versions = {'v1': v1}
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
),
Message(
role='user',
content=f"""
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
Existing Edges:
{context['existing_edges']}
New Edge:
{context['new_edge']}
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"uuid": "The UUID of the edge to be invalidated",
"fact": "Updated fact of the edge"
}}
]
}}
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]
versions: Versions = {'v1': v1, 'v2': v2}

View File

@@ -96,11 +96,11 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
@@ -211,7 +211,7 @@ async def edge_similarity_search(
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
@@ -247,7 +247,7 @@ async def entity_similarity_search(
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
@@ -284,11 +284,11 @@ async def entity_fulltext_search(
async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
limit=RELEVANT_SCHEMA_LIMIT,
driver: AsyncDriver,
query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
cypher_query = Query("""
@@ -401,10 +401,10 @@ async def edge_fulltext_search(
async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Perform a hybrid search for nodes using both text queries and embeddings.
@@ -466,8 +466,8 @@ async def hybrid_node_search(
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
"""
Retrieve relevant nodes based on the provided list of EntityNodes.
@@ -503,11 +503,11 @@ async def get_relevant_nodes(
async def get_relevant_edges(
driver: AsyncDriver,
edges: list[EntityEdge],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
driver: AsyncDriver,
edges: list[EntityEdge],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
@@ -557,7 +557,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
@@ -579,8 +579,8 @@ async def node_distance_reranker(
for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']

View File

@@ -24,6 +24,10 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
get_edge_contradictions,
)
logger = logging.getLogger(__name__)
@@ -149,28 +153,110 @@ async def dedupe_extracted_edges(
async def resolve_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
related_edges_lists: list[list[EntityEdge]],
existing_edges_lists: list[list[EntityEdge]],
) -> list[EntityEdge]:
resolved_edges: list[EntityEdge] = list(
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await asyncio.gather(
*[
resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
current_episode,
previous_episodes,
)
for extracted_edge, related_edges, existing_edges in zip(
extracted_edges, related_edges_lists, existing_edges_lists
)
]
)
)
return resolved_edges
resolved_edges: list[EntityEdge] = []
invalidated_edges: list[EntityEdge] = []
for result in results:
resolved_edge = result[0]
invalidated_edge_chunk = result[1]
resolved_edges.append(resolved_edge)
invalidated_edges.extend(invalidated_edge_chunk)
return resolved_edges, invalidated_edges
async def resolve_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = datetime.now()
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if (
candidate.valid_at is not None and resolved_edge.valid_at is not None
) and candidate.valid_at > resolved_edge.valid_at:
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = []
for edge in invalidation_candidates:
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
if (
edge.invalid_at is not None
and resolved_edge.valid_at is not None
and edge.invalid_at < resolved_edge.valid_at
) or (
edge.valid_at is not None
and resolved_edge.invalid_at is not None
and resolved_edge.invalid_at < edge.valid_at
):
continue
# New edge invalidates edge
elif (
edge.valid_at is not None
and resolved_edge.valid_at is not None
and edge.valid_at < resolved_edge.valid_at
):
edge.invalid_at = resolved_edge.valid_at
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
invalidated_edges.append(edge)
return resolved_edge, invalidated_edges
async def dedupe_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
) -> EntityEdge:
start = time()
# Prepare context for LLM
existing_edges_context = [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
related_edges_context = [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
]
extracted_edge_context = {
@@ -180,7 +266,7 @@ async def resolve_extracted_edge(
}
context = {
'existing_edges': existing_edges_context,
'related_edges': related_edges_context,
'extracted_edges': extracted_edge_context,
}
@@ -191,14 +277,14 @@ async def resolve_extracted_edge(
edge = extracted_edge
if is_duplicate:
for existing_edge in existing_edges:
for existing_edge in related_edges:
if existing_edge.uuid != uuid:
continue
edge = existing_edge
end = time()
logger.info(
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
)
return edge

View File

@@ -16,6 +16,7 @@ limitations under the License.
import logging
from datetime import datetime
from time import time
from typing import List
from graphiti_core.edges import EntityEdge
@@ -181,3 +182,36 @@ async def extract_edge_dates(
logger.info(f'Edge date extraction explanation: {explanation}')
return valid_at_datetime, invalid_at_datetime
async def get_edge_contradictions(
llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge]
) -> list[EntityEdge]:
start = time()
existing_edge_map = {edge.uuid: edge for edge in existing_edges}
new_edge_context = {'uuid': new_edge.uuid, 'name': new_edge.name, 'fact': new_edge.fact}
existing_edge_context = [
{'uuid': existing_edge.uuid, 'name': existing_edge.name, 'fact': existing_edge.fact}
for existing_edge in existing_edges
]
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v2(context))
contradicted_edge_data = llm_response.get('invalidated_edges', [])
contradicted_edges: list[EntityEdge] = []
for edge_data in contradicted_edge_data:
if edge_data['uuid'] in existing_edge_map:
contradicted_edge = existing_edge_map[edge_data['uuid']]
contradicted_edge.fact = edge_data['fact']
contradicted_edges.append(contradicted_edge)
end = time()
logger.info(
f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms'
)
return contradicted_edges