format update

This commit is contained in:
blazickjp
2025-04-22 04:08:35 -07:00
parent 26fa8f3b0b
commit 4510855fc5
2 changed files with 18 additions and 15 deletions

View File

@@ -3,39 +3,40 @@
import pytest import pytest
from arxiv_mcp_server.prompts.handlers import list_prompts, get_prompt from arxiv_mcp_server.prompts.handlers import list_prompts, get_prompt
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_list_prompts(): async def test_server_list_prompts():
"""Test server list_prompts endpoint.""" """Test server list_prompts endpoint."""
prompts = await list_prompts() prompts = await list_prompts()
assert len(prompts) == 1 assert len(prompts) == 1
# Check that all prompts have required fields # Check that all prompts have required fields
for prompt in prompts: for prompt in prompts:
assert prompt.name assert prompt.name
assert prompt.description assert prompt.description
assert prompt.arguments is not None assert prompt.arguments is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_get_analysis_prompt(): async def test_server_get_analysis_prompt():
"""Test server get_prompt endpoint with analysis prompt.""" """Test server get_prompt endpoint with analysis prompt."""
result = await get_prompt( result = await get_prompt("deep-paper-analysis", {"paper_id": "2401.00123"})
"deep-paper-analysis",
{"paper_id": "2401.00123"}
)
assert len(result.messages) == 1 assert len(result.messages) == 1
message = result.messages[0] message = result.messages[0]
assert message.role == "user" assert message.role == "user"
assert "2401.00123" in message.content.text assert "2401.00123" in message.content.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_get_prompt_invalid_name(): async def test_server_get_prompt_invalid_name():
"""Test server get_prompt endpoint with invalid prompt name.""" """Test server get_prompt endpoint with invalid prompt name."""
with pytest.raises(ValueError, match="Prompt not found"): with pytest.raises(ValueError, match="Prompt not found"):
await get_prompt("invalid-prompt", {}) await get_prompt("invalid-prompt", {})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_get_prompt_missing_args(): async def test_server_get_prompt_missing_args():
"""Test server get_prompt endpoint with missing required arguments.""" """Test server get_prompt endpoint with missing required arguments."""
with pytest.raises(ValueError, match="Missing required argument"): with pytest.raises(ValueError, match="Missing required argument"):
await get_prompt("deep-paper-analysis", {}) await get_prompt("deep-paper-analysis", {})

View File

@@ -5,47 +5,49 @@ from typing import Dict
from arxiv_mcp_server.prompts.handlers import list_prompts, get_prompt from arxiv_mcp_server.prompts.handlers import list_prompts, get_prompt
from mcp.types import GetPromptResult, PromptMessage, TextContent from mcp.types import GetPromptResult, PromptMessage, TextContent
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_prompts(): async def test_list_prompts():
"""Test listing available prompts.""" """Test listing available prompts."""
prompts = await list_prompts() prompts = await list_prompts()
assert len(prompts) == 1 assert len(prompts) == 1
prompt_names = {p.name for p in prompts} prompt_names = {p.name for p in prompts}
expected_names = {"deep-paper-analysis"} expected_names = {"deep-paper-analysis"}
assert prompt_names == expected_names assert prompt_names == expected_names
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_paper_analysis_prompt(): async def test_get_paper_analysis_prompt():
"""Test getting paper analysis prompt.""" """Test getting paper analysis prompt."""
result = await get_prompt( result = await get_prompt("deep-paper-analysis", {"paper_id": "2401.00123"})
"deep-paper-analysis",
{"paper_id": "2401.00123"}
)
assert isinstance(result, GetPromptResult) assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1 assert len(result.messages) == 1
message = result.messages[0] message = result.messages[0]
assert isinstance(message, PromptMessage) assert isinstance(message, PromptMessage)
assert message.role == "user" assert message.role == "user"
assert isinstance(message.content, TextContent) assert isinstance(message.content, TextContent)
assert "2401.00123" in message.content.text assert "2401.00123" in message.content.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_prompt_with_invalid_name(): async def test_get_prompt_with_invalid_name():
"""Test getting prompt with invalid name.""" """Test getting prompt with invalid name."""
with pytest.raises(ValueError, match="Prompt not found"): with pytest.raises(ValueError, match="Prompt not found"):
await get_prompt("invalid-prompt", {}) await get_prompt("invalid-prompt", {})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_prompt_with_no_arguments(): async def test_get_prompt_with_no_arguments():
"""Test getting prompt with no arguments.""" """Test getting prompt with no arguments."""
with pytest.raises(ValueError, match="No arguments provided"): with pytest.raises(ValueError, match="No arguments provided"):
await get_prompt("deep-paper-analysis", None) await get_prompt("deep-paper-analysis", None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_prompt_with_missing_required_argument(): async def test_get_prompt_with_missing_required_argument():
"""Test getting prompt with missing required argument.""" """Test getting prompt with missing required argument."""
with pytest.raises(ValueError, match="Missing required argument"): with pytest.raises(ValueError, match="Missing required argument"):
await get_prompt("deep-paper-analysis", {}) await get_prompt("deep-paper-analysis", {})