Files
graphiti/tests/utils/maintenance/test_temporal_operations.py
Preston Rasmussen 42fb590606 Add group ids (#89)
* set and retrieve group ids

* update add episode with group id support

* add episode and search functional

* update bulk

* mypy updates

* remove unused imports

* update unit tests

* unit tests

* add optional uuid field

* format

* mypy

* ellipsis
2024-09-06 12:33:42 -04:00

369 lines
12 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from datetime import datetime, timedelta
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.temporal_operations import (
extract_date_strings_from_edge,
prepare_edges_for_invalidation,
prepare_invalidation_context,
)
# Helper function to create test data
def create_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges
existing_edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
existing_edge2 = EntityEdge(
uuid='e2',
source_node_uuid='2',
target_node_uuid='3',
name='LIKES',
fact='Node2 likes Node3',
created_at=now,
group_id='1',
)
new_edge1 = EntityEdge(
uuid='e3',
source_node_uuid='1',
target_node_uuid='3',
name='WORKS_WITH',
fact='Node1 works with Node3',
created_at=now,
group_id='1',
)
new_edge2 = EntityEdge(
uuid='e4',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Node1 dislikes Node2',
created_at=now,
group_id='1',
)
return {
'nodes': [node1, node2, node3],
'existing_edges': [existing_edge1, existing_edge2],
'new_edges': [new_edge1, new_edge2],
}
def test_prepare_edges_for_invalidation_basic():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
test_data['existing_edges'], test_data['new_edges'], test_data['nodes']
)
assert len(existing_edges_pending_invalidation) == 2
assert len(new_edges_with_nodes) == 2
# Check if the edges are correctly associated with nodes
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
assert isinstance(edge_with_nodes[0], EntityNode)
assert isinstance(edge_with_nodes[1], EntityEdge)
assert isinstance(edge_with_nodes[2], EntityNode)
def test_prepare_edges_for_invalidation_no_existing_edges():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
[], test_data['new_edges'], test_data['nodes']
)
assert len(existing_edges_pending_invalidation) == 0
assert len(new_edges_with_nodes) == 2
def test_prepare_edges_for_invalidation_no_new_edges():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
test_data['existing_edges'], [], test_data['nodes']
)
assert len(existing_edges_pending_invalidation) == 2
assert len(new_edges_with_nodes) == 0
def test_prepare_edges_for_invalidation_missing_nodes():
test_data = create_test_data()
# Remove one node to simulate a missing node scenario
nodes = test_data['nodes'][:-1]
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
test_data['existing_edges'], test_data['new_edges'], nodes
)
assert len(existing_edges_pending_invalidation) == 1
assert len(new_edges_with_nodes) == 1
def test_prepare_invalidation_context():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges
edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
edge2 = EntityEdge(
uuid='e2',
source_node_uuid='2',
target_node_uuid='3',
name='LIKES',
fact='Node2 likes Node3',
created_at=now,
group_id='1',
)
# Create NodeEdgeNodeTriplet objects
existing_edge = (node1, edge1, node2)
new_edge = (node2, edge2, node3)
# Prepare test input
existing_edges = [existing_edge]
new_edges = [new_edge]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode 1',
content='This is the content of previous episode 1.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode 1 for unit testing',
group_id='1',
),
EpisodicNode(
name='Previous Episode 2',
content='This is the content of previous episode 2.',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source=EpisodeType.message,
source_description='Test previous episode 2 for unit testing',
group_id='1',
),
]
# Call the function
result = prepare_invalidation_context(
existing_edges, new_edges, current_episode, previous_episodes
)
# Assert the result
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 1
assert len(result['new_edges']) == 1
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 2
# Check the format of the existing edge
existing_edge_str = result['existing_edges'][0]
assert edge1.uuid in existing_edge_str
assert node1.name in existing_edge_str
assert edge1.name in existing_edge_str
assert node2.name in existing_edge_str
assert edge1.fact in existing_edge_str
# Check the format of the new edge
new_edge_str = result['new_edges'][0]
assert edge2.uuid in new_edge_str
assert node2.name in new_edge_str
assert edge2.name in new_edge_str
assert node3.name in new_edge_str
assert edge2.fact in new_edge_str
def test_prepare_invalidation_context_empty_input():
now = datetime.now()
current_episode = EpisodicNode(
name='Current Episode',
content='Empty episode',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test empty episode for unit testing',
group_id='1',
)
result = prepare_invalidation_context([], [], current_episode, [])
assert isinstance(result, dict)
assert 'existing_edges' in result
assert 'new_edges' in result
assert 'current_episode' in result
assert 'previous_episodes' in result
assert len(result['existing_edges']) == 0
assert len(result['new_edges']) == 0
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 0
def test_prepare_invalidation_context_sorting():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
# Create edges with different timestamps
edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=now,
group_id='1',
)
edge2 = EntityEdge(
uuid='e2',
source_node_uuid='2',
target_node_uuid='1',
name='LIKES',
fact='Node2 likes Node1',
created_at=now + timedelta(hours=1),
group_id='1',
)
edge_with_nodes1 = (node1, edge1, node2)
edge_with_nodes2 = (node2, edge2, node1)
# Prepare test input
existing_edges = [edge_with_nodes1, edge_with_nodes2]
# Create a current episode and previous episodes
current_episode = EpisodicNode(
name='Current Episode',
content='This is the current episode content.',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='This is the content of a previous episode.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
),
]
# Call the function
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
# Assert the result
assert len(result['existing_edges']) == 2
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
assert result['current_episode'] == current_episode.content
assert len(result['previous_episodes']) == 1
assert result['previous_episodes'][0] == previous_episodes[0].content
class TestExtractDateStringsFromEdge(unittest.TestCase):
def generate_entity_edge(self, valid_at, invalid_at):
return EntityEdge(
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=datetime.now(),
valid_at=valid_at,
invalid_at=invalid_at,
group_id='1',
)
def test_both_dates_present(self):
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0))
result = extract_date_strings_from_edge(edge)
expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)'
self.assertEqual(result, expected)
def test_only_valid_at_present(self):
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None)
result = extract_date_strings_from_edge(edge)
expected = 'Start Date: 2024-01-01T12:00:00'
self.assertEqual(result, expected)
def test_only_invalid_at_present(self):
edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0))
result = extract_date_strings_from_edge(edge)
expected = ' (End Date: 2024-01-02T12:00:00)'
self.assertEqual(result, expected)
def test_no_dates_present(self):
edge = self.generate_entity_edge(None, None)
result = extract_date_strings_from_edge(edge)
expected = ''
self.assertEqual(result, expected)
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])