mirror of
https://github.com/getzep/graphiti.git
synced 2024-09-08 19:13:11 +03:00
* set and retrieve group ids * update add episode with group id support * add episode and search functional * update bulk * mypy updates * remove unused imports * update unit tests * unit tests * add optional uuid field * format * mypy * ellipsis
369 lines
12 KiB
Python
369 lines
12 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import unittest
|
|
from datetime import datetime, timedelta
|
|
|
|
import pytest
|
|
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
extract_date_strings_from_edge,
|
|
prepare_edges_for_invalidation,
|
|
prepare_invalidation_context,
|
|
)
|
|
|
|
|
|
# Helper function to create test data
|
|
def create_test_data():
|
|
now = datetime.now()
|
|
|
|
# Create nodes
|
|
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
|
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
|
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
|
|
|
# Create edges
|
|
existing_edge1 = EntityEdge(
|
|
uuid='e1',
|
|
source_node_uuid='1',
|
|
target_node_uuid='2',
|
|
name='KNOWS',
|
|
fact='Node1 knows Node2',
|
|
created_at=now,
|
|
group_id='1',
|
|
)
|
|
existing_edge2 = EntityEdge(
|
|
uuid='e2',
|
|
source_node_uuid='2',
|
|
target_node_uuid='3',
|
|
name='LIKES',
|
|
fact='Node2 likes Node3',
|
|
created_at=now,
|
|
group_id='1',
|
|
)
|
|
new_edge1 = EntityEdge(
|
|
uuid='e3',
|
|
source_node_uuid='1',
|
|
target_node_uuid='3',
|
|
name='WORKS_WITH',
|
|
fact='Node1 works with Node3',
|
|
created_at=now,
|
|
group_id='1',
|
|
)
|
|
new_edge2 = EntityEdge(
|
|
uuid='e4',
|
|
source_node_uuid='1',
|
|
target_node_uuid='2',
|
|
name='DISLIKES',
|
|
fact='Node1 dislikes Node2',
|
|
created_at=now,
|
|
group_id='1',
|
|
)
|
|
|
|
return {
|
|
'nodes': [node1, node2, node3],
|
|
'existing_edges': [existing_edge1, existing_edge2],
|
|
'new_edges': [new_edge1, new_edge2],
|
|
}
|
|
|
|
|
|
def test_prepare_edges_for_invalidation_basic():
|
|
test_data = create_test_data()
|
|
|
|
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
test_data['existing_edges'], test_data['new_edges'], test_data['nodes']
|
|
)
|
|
|
|
assert len(existing_edges_pending_invalidation) == 2
|
|
assert len(new_edges_with_nodes) == 2
|
|
|
|
# Check if the edges are correctly associated with nodes
|
|
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
|
|
assert isinstance(edge_with_nodes[0], EntityNode)
|
|
assert isinstance(edge_with_nodes[1], EntityEdge)
|
|
assert isinstance(edge_with_nodes[2], EntityNode)
|
|
|
|
|
|
def test_prepare_edges_for_invalidation_no_existing_edges():
|
|
test_data = create_test_data()
|
|
|
|
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
[], test_data['new_edges'], test_data['nodes']
|
|
)
|
|
|
|
assert len(existing_edges_pending_invalidation) == 0
|
|
assert len(new_edges_with_nodes) == 2
|
|
|
|
|
|
def test_prepare_edges_for_invalidation_no_new_edges():
|
|
test_data = create_test_data()
|
|
|
|
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
test_data['existing_edges'], [], test_data['nodes']
|
|
)
|
|
|
|
assert len(existing_edges_pending_invalidation) == 2
|
|
assert len(new_edges_with_nodes) == 0
|
|
|
|
|
|
def test_prepare_edges_for_invalidation_missing_nodes():
|
|
test_data = create_test_data()
|
|
|
|
# Remove one node to simulate a missing node scenario
|
|
nodes = test_data['nodes'][:-1]
|
|
|
|
existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation(
|
|
test_data['existing_edges'], test_data['new_edges'], nodes
|
|
)
|
|
|
|
assert len(existing_edges_pending_invalidation) == 1
|
|
assert len(new_edges_with_nodes) == 1
|
|
|
|
|
|
def test_prepare_invalidation_context():
|
|
now = datetime.now()
|
|
|
|
# Create nodes
|
|
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
|
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
|
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
|
|
|
|
# Create edges
|
|
edge1 = EntityEdge(
|
|
uuid='e1',
|
|
source_node_uuid='1',
|
|
target_node_uuid='2',
|
|
name='KNOWS',
|
|
fact='Node1 knows Node2',
|
|
created_at=now,
|
|
group_id='1',
|
|
)
|
|
edge2 = EntityEdge(
|
|
uuid='e2',
|
|
source_node_uuid='2',
|
|
target_node_uuid='3',
|
|
name='LIKES',
|
|
fact='Node2 likes Node3',
|
|
created_at=now,
|
|
group_id='1',
|
|
)
|
|
|
|
# Create NodeEdgeNodeTriplet objects
|
|
existing_edge = (node1, edge1, node2)
|
|
new_edge = (node2, edge2, node3)
|
|
|
|
# Prepare test input
|
|
existing_edges = [existing_edge]
|
|
new_edges = [new_edge]
|
|
|
|
# Create a current episode and previous episodes
|
|
current_episode = EpisodicNode(
|
|
name='Current Episode',
|
|
content='This is the current episode content.',
|
|
created_at=now,
|
|
valid_at=now,
|
|
source=EpisodeType.message,
|
|
source_description='Test episode for unit testing',
|
|
group_id='1',
|
|
)
|
|
previous_episodes = [
|
|
EpisodicNode(
|
|
name='Previous Episode 1',
|
|
content='This is the content of previous episode 1.',
|
|
created_at=now - timedelta(days=1),
|
|
valid_at=now - timedelta(days=1),
|
|
source=EpisodeType.message,
|
|
source_description='Test previous episode 1 for unit testing',
|
|
group_id='1',
|
|
),
|
|
EpisodicNode(
|
|
name='Previous Episode 2',
|
|
content='This is the content of previous episode 2.',
|
|
created_at=now - timedelta(days=2),
|
|
valid_at=now - timedelta(days=2),
|
|
source=EpisodeType.message,
|
|
source_description='Test previous episode 2 for unit testing',
|
|
group_id='1',
|
|
),
|
|
]
|
|
|
|
# Call the function
|
|
result = prepare_invalidation_context(
|
|
existing_edges, new_edges, current_episode, previous_episodes
|
|
)
|
|
|
|
# Assert the result
|
|
assert isinstance(result, dict)
|
|
assert 'existing_edges' in result
|
|
assert 'new_edges' in result
|
|
assert 'current_episode' in result
|
|
assert 'previous_episodes' in result
|
|
assert len(result['existing_edges']) == 1
|
|
assert len(result['new_edges']) == 1
|
|
assert result['current_episode'] == current_episode.content
|
|
assert len(result['previous_episodes']) == 2
|
|
|
|
# Check the format of the existing edge
|
|
existing_edge_str = result['existing_edges'][0]
|
|
assert edge1.uuid in existing_edge_str
|
|
assert node1.name in existing_edge_str
|
|
assert edge1.name in existing_edge_str
|
|
assert node2.name in existing_edge_str
|
|
assert edge1.fact in existing_edge_str
|
|
|
|
# Check the format of the new edge
|
|
new_edge_str = result['new_edges'][0]
|
|
assert edge2.uuid in new_edge_str
|
|
assert node2.name in new_edge_str
|
|
assert edge2.name in new_edge_str
|
|
assert node3.name in new_edge_str
|
|
assert edge2.fact in new_edge_str
|
|
|
|
|
|
def test_prepare_invalidation_context_empty_input():
|
|
now = datetime.now()
|
|
current_episode = EpisodicNode(
|
|
name='Current Episode',
|
|
content='Empty episode',
|
|
created_at=now,
|
|
valid_at=now,
|
|
source=EpisodeType.message,
|
|
source_description='Test empty episode for unit testing',
|
|
group_id='1',
|
|
)
|
|
result = prepare_invalidation_context([], [], current_episode, [])
|
|
assert isinstance(result, dict)
|
|
assert 'existing_edges' in result
|
|
assert 'new_edges' in result
|
|
assert 'current_episode' in result
|
|
assert 'previous_episodes' in result
|
|
assert len(result['existing_edges']) == 0
|
|
assert len(result['new_edges']) == 0
|
|
assert result['current_episode'] == current_episode.content
|
|
assert len(result['previous_episodes']) == 0
|
|
|
|
|
|
def test_prepare_invalidation_context_sorting():
|
|
now = datetime.now()
|
|
|
|
# Create nodes
|
|
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
|
|
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
|
|
|
|
# Create edges with different timestamps
|
|
edge1 = EntityEdge(
|
|
uuid='e1',
|
|
source_node_uuid='1',
|
|
target_node_uuid='2',
|
|
name='KNOWS',
|
|
fact='Node1 knows Node2',
|
|
created_at=now,
|
|
group_id='1',
|
|
)
|
|
edge2 = EntityEdge(
|
|
uuid='e2',
|
|
source_node_uuid='2',
|
|
target_node_uuid='1',
|
|
name='LIKES',
|
|
fact='Node2 likes Node1',
|
|
created_at=now + timedelta(hours=1),
|
|
group_id='1',
|
|
)
|
|
|
|
edge_with_nodes1 = (node1, edge1, node2)
|
|
edge_with_nodes2 = (node2, edge2, node1)
|
|
|
|
# Prepare test input
|
|
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
|
|
|
# Create a current episode and previous episodes
|
|
current_episode = EpisodicNode(
|
|
name='Current Episode',
|
|
content='This is the current episode content.',
|
|
created_at=now,
|
|
valid_at=now,
|
|
source=EpisodeType.message,
|
|
source_description='Test episode for unit testing',
|
|
group_id='1',
|
|
)
|
|
previous_episodes = [
|
|
EpisodicNode(
|
|
name='Previous Episode',
|
|
content='This is the content of a previous episode.',
|
|
created_at=now - timedelta(days=1),
|
|
valid_at=now - timedelta(days=1),
|
|
source=EpisodeType.message,
|
|
source_description='Test previous episode for unit testing',
|
|
group_id='1',
|
|
),
|
|
]
|
|
|
|
# Call the function
|
|
result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes)
|
|
|
|
# Assert the result
|
|
assert len(result['existing_edges']) == 2
|
|
assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first
|
|
assert edge1.uuid in result['existing_edges'][1] # The older edge should be second
|
|
assert result['current_episode'] == current_episode.content
|
|
assert len(result['previous_episodes']) == 1
|
|
assert result['previous_episodes'][0] == previous_episodes[0].content
|
|
|
|
|
|
class TestExtractDateStringsFromEdge(unittest.TestCase):
|
|
def generate_entity_edge(self, valid_at, invalid_at):
|
|
return EntityEdge(
|
|
source_node_uuid='1',
|
|
target_node_uuid='2',
|
|
name='KNOWS',
|
|
fact='Node1 knows Node2',
|
|
created_at=datetime.now(),
|
|
valid_at=valid_at,
|
|
invalid_at=invalid_at,
|
|
group_id='1',
|
|
)
|
|
|
|
def test_both_dates_present(self):
|
|
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0))
|
|
result = extract_date_strings_from_edge(edge)
|
|
expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)'
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_only_valid_at_present(self):
|
|
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None)
|
|
result = extract_date_strings_from_edge(edge)
|
|
expected = 'Start Date: 2024-01-01T12:00:00'
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_only_invalid_at_present(self):
|
|
edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0))
|
|
result = extract_date_strings_from_edge(edge)
|
|
expected = ' (End Date: 2024-01-02T12:00:00)'
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_no_dates_present(self):
|
|
edge = self.generate_entity_edge(None, None)
|
|
result = extract_date_strings_from_edge(edge)
|
|
expected = ''
|
|
self.assertEqual(result, expected)
|
|
|
|
|
|
# Run the tests
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__])
|