Files
graphiti/tests/utils/maintenance/test_temporal_operations_int.py
Preston Rasmussen 42fb590606 Add group ids (#89)
* set and retrieve group ids

* update add episode with group id support

* add episode and search functional

* update bulk

* mypy updates

* remove unused imports

* update unit tests

* unit tests

* add optional uuid field

* format

* mypy

* ellipsis
2024-09-06 12:33:42 -04:00

348 lines
9.8 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime, timedelta
import pytest
from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.temporal_operations import (
invalidate_edges,
)
load_dotenv()
def setup_llm_client():
return OpenAIClient(
LLMConfig(
api_key=os.getenv('TEST_OPENAI_API_KEY'),
model=os.getenv('TEST_OPENAI_MODEL'),
base_url='https://api.openai.com/v1',
)
)
def create_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now)
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now)
# Create edges
edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='LIKES',
fact='Alice likes Bob',
created_at=now - timedelta(days=1),
)
edge2 = EntityEdge(
uuid='e2',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
created_at=now,
)
existing_edge = (node1, edge1, node2)
new_edge = (node1, edge2, node2)
# Create current episode
current_episode = EpisodicNode(
name='Current Episode',
content='Alice now dislikes Bob',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
# Create previous episodes
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='Alice liked Bob',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
)
]
return existing_edge, new_edge, current_episode, previous_episodes
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges():
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge[1].uuid
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_invalidation():
existing_edge, _, current_episode, previous_episodes = create_test_data()
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [], current_episode, previous_episodes
)
assert len(invalidated_edges) == 0
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_multiple_existing():
existing_edge1, new_edge = create_test_data()
existing_edge2, _ = create_test_data()
existing_edge2[1].uuid = 'e3'
existing_edge2[1].name = 'KNOWS'
existing_edge2[1].fact = 'Alice knows Bob'
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge1, existing_edge2], [new_edge]
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge1[1].uuid
assert invalidated_edges[0].expired_at is not None
# Helper function to create more complex test data
def create_complex_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1')
node4 = EntityNode(
uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1'
)
# Create edges
edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='LIKES',
fact='Alice likes Bob',
group_id='1',
created_at=now - timedelta(days=5),
)
edge2 = EntityEdge(
uuid='e2',
source_node_uuid='1',
target_node_uuid='3',
name='FRIENDS_WITH',
fact='Alice is friends with Charlie',
group_id='1',
created_at=now - timedelta(days=3),
)
edge3 = EntityEdge(
uuid='e3',
source_node_uuid='2',
target_node_uuid='4',
name='WORKS_FOR',
fact='Bob works for Company XYZ',
group_id='1',
created_at=now - timedelta(days=2),
)
existing_edge1 = (node1, edge1, node2)
existing_edge2 = (node1, edge2, node3)
existing_edge3 = (node2, edge3, node4)
return [existing_edge1, existing_edge2, existing_edge3], [
node1,
node2,
node3,
node4,
]
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_complex():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that contradicts an existing one
new_edge = (
nodes[0],
EntityEdge(
uuid='e4',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
group_id='1',
created_at=datetime.now(),
),
nodes[1],
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == 'e1'
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_temporal_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that updates an existing one with new information
new_edge = (
nodes[1],
EntityEdge(
uuid='e5',
source_node_uuid='2',
target_node_uuid='4',
name='LEFT_JOB',
fact='Bob left his job at Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == 'e3'
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_multiple_invalidations():
existing_edges, nodes = create_complex_test_data()
# Create new edges that invalidate multiple existing edges
new_edge1 = (
nodes[0],
EntityEdge(
uuid='e6',
source_node_uuid='1',
target_node_uuid='2',
name='ENEMIES_WITH',
fact='Alice and Bob are now enemies',
group_id='1',
created_at=datetime.now(),
),
nodes[1],
)
new_edge2 = (
nodes[0],
EntityEdge(
uuid='e7',
source_node_uuid='1',
target_node_uuid='3',
name='ENDED_FRIENDSHIP',
fact='Alice ended her friendship with Charlie',
group_id='1',
created_at=datetime.now(),
),
nodes[2],
)
invalidated_edges = await invalidate_edges(
setup_llm_client(), existing_edges, [new_edge1, new_edge2]
)
assert len(invalidated_edges) == 2
assert set(edge.uuid for edge in invalidated_edges) == {'e1', 'e2'}
for edge in invalidated_edges:
assert edge.expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_effect():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that doesn't invalidate any existing edges
new_edge = (
nodes[2],
EntityEdge(
uuid='e8',
source_node_uuid='3',
target_node_uuid='4',
name='APPLIED_TO',
fact='Charlie applied to Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
assert len(invalidated_edges) == 0
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_partial_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that partially updates an existing one
new_edge = (
nodes[1],
EntityEdge(
uuid='e9',
source_node_uuid='2',
target_node_uuid='4',
name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=datetime.now(),
),
nodes[3],
)
invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge])
assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_empty_inputs():
invalidated_edges = await invalidate_edges(setup_llm_client(), [], [])
assert len(invalidated_edges) == 0