mirror of
				https://github.com/getzep/graphiti.git
				synced 2024-09-08 19:13:11 +03:00 
			
		
		
		
	 42fb590606
			
		
	
	42fb590606
	
	
	
		
			
			* set and retrieve group ids * update add episode with group id support * add episode and search functional * update bulk * mypy updates * remove unused imports * update unit tests * unit tests * add optional uuid field * format * mypy * ellipsis
		
			
				
	
	
		
			160 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			160 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from unittest.mock import AsyncMock, patch
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from graphiti_core.nodes import EntityNode
 | |
| from graphiti_core.search.search_utils import hybrid_node_search
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_hybrid_node_search_deduplication():
 | |
|     # Mock the database driver
 | |
|     mock_driver = AsyncMock()
 | |
| 
 | |
|     # Mock the entity_fulltext_search and entity_similarity_search functions
 | |
|     with patch(
 | |
|         'graphiti_core.search.search_utils.entity_fulltext_search'
 | |
|     ) as mock_fulltext_search, patch(
 | |
|         'graphiti_core.search.search_utils.entity_similarity_search'
 | |
|     ) as mock_similarity_search:
 | |
|         # Set up mock return values
 | |
|         mock_fulltext_search.side_effect = [
 | |
|             [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
 | |
|             [EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')],
 | |
|         ]
 | |
|         mock_similarity_search.side_effect = [
 | |
|             [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
 | |
|             [EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')],
 | |
|         ]
 | |
| 
 | |
|         # Call the function with test data
 | |
|         queries = ['Alice', 'Bob']
 | |
|         embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
 | |
|         results = await hybrid_node_search(queries, embeddings, mock_driver)
 | |
| 
 | |
|         # Assertions
 | |
|         assert len(results) == 3
 | |
|         assert set(node.uuid for node in results) == {'1', '2', '3'}
 | |
|         assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
 | |
| 
 | |
|         # Verify that the mock functions were called correctly
 | |
|         assert mock_fulltext_search.call_count == 2
 | |
|         assert mock_similarity_search.call_count == 2
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_hybrid_node_search_empty_results():
 | |
|     mock_driver = AsyncMock()
 | |
| 
 | |
|     with patch(
 | |
|         'graphiti_core.search.search_utils.entity_fulltext_search'
 | |
|     ) as mock_fulltext_search, patch(
 | |
|         'graphiti_core.search.search_utils.entity_similarity_search'
 | |
|     ) as mock_similarity_search:
 | |
|         mock_fulltext_search.return_value = []
 | |
|         mock_similarity_search.return_value = []
 | |
| 
 | |
|         queries = ['NonExistent']
 | |
|         embeddings = [[0.1, 0.2, 0.3]]
 | |
|         results = await hybrid_node_search(queries, embeddings, mock_driver)
 | |
| 
 | |
|         assert len(results) == 0
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_hybrid_node_search_only_fulltext():
 | |
|     mock_driver = AsyncMock()
 | |
| 
 | |
|     with patch(
 | |
|         'graphiti_core.search.search_utils.entity_fulltext_search'
 | |
|     ) as mock_fulltext_search, patch(
 | |
|         'graphiti_core.search.search_utils.entity_similarity_search'
 | |
|     ) as mock_similarity_search:
 | |
|         mock_fulltext_search.return_value = [
 | |
|             EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
 | |
|         ]
 | |
|         mock_similarity_search.return_value = []
 | |
| 
 | |
|         queries = ['Alice']
 | |
|         embeddings = []
 | |
|         results = await hybrid_node_search(queries, embeddings, mock_driver)
 | |
| 
 | |
|         assert len(results) == 1
 | |
|         assert results[0].name == 'Alice'
 | |
|         assert mock_fulltext_search.call_count == 1
 | |
|         assert mock_similarity_search.call_count == 0
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_hybrid_node_search_with_limit():
 | |
|     mock_driver = AsyncMock()
 | |
| 
 | |
|     with patch(
 | |
|         'graphiti_core.search.search_utils.entity_fulltext_search'
 | |
|     ) as mock_fulltext_search, patch(
 | |
|         'graphiti_core.search.search_utils.entity_similarity_search'
 | |
|     ) as mock_similarity_search:
 | |
|         mock_fulltext_search.return_value = [
 | |
|             EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
 | |
|             EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
 | |
|         ]
 | |
|         mock_similarity_search.return_value = [
 | |
|             EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
 | |
|             EntityNode(
 | |
|                 uuid='4',
 | |
|                 name='David',
 | |
|                 labels=['Entity'],
 | |
|                 group_id='1',
 | |
|             ),
 | |
|         ]
 | |
| 
 | |
|         queries = ['Test']
 | |
|         embeddings = [[0.1, 0.2, 0.3]]
 | |
|         limit = 1
 | |
|         results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
 | |
| 
 | |
|         # We expect 4 results because the limit is applied per search method
 | |
|         # before deduplication, and we're not actually limiting the results
 | |
|         # in the hybrid_node_search function itself
 | |
|         assert len(results) == 4
 | |
|         assert mock_fulltext_search.call_count == 1
 | |
|         assert mock_similarity_search.call_count == 1
 | |
|         # Verify that the limit was passed to the search functions
 | |
|         mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2)
 | |
|         mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2)
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_hybrid_node_search_with_limit_and_duplicates():
 | |
|     mock_driver = AsyncMock()
 | |
| 
 | |
|     with patch(
 | |
|         'graphiti_core.search.search_utils.entity_fulltext_search'
 | |
|     ) as mock_fulltext_search, patch(
 | |
|         'graphiti_core.search.search_utils.entity_similarity_search'
 | |
|     ) as mock_similarity_search:
 | |
|         mock_fulltext_search.return_value = [
 | |
|             EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
 | |
|             EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
 | |
|         ]
 | |
|         mock_similarity_search.return_value = [
 | |
|             EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),  # Duplicate
 | |
|             EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
 | |
|         ]
 | |
| 
 | |
|         queries = ['Test']
 | |
|         embeddings = [[0.1, 0.2, 0.3]]
 | |
|         limit = 2
 | |
|         results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
 | |
| 
 | |
|         # We expect 3 results because:
 | |
|         # 1. The limit of 2 is applied to each search method
 | |
|         # 2. We get 2 results from fulltext and 2 from similarity
 | |
|         # 3. One result is a duplicate (Alice), so it's only included once
 | |
|         assert len(results) == 3
 | |
|         assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
 | |
|         assert mock_fulltext_search.call_count == 1
 | |
|         assert mock_similarity_search.call_count == 1
 | |
|         mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4)
 | |
|         mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)
 |