From 89fe402d10a59879781a1eb0a64affdf4c278a4d Mon Sep 17 00:00:00 2001 From: Mihajlo Micic <44226809+micic-mihajlo@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:50:06 -0400 Subject: [PATCH] Add comprehensive test suite for Responses API (#20) The project had almost no test coverage - just a single test checking if the API returns 200. This adds proper testing infrastructure and 21 new tests covering the main API functionality. Tests now cover response creation, error handling, tools, sessions, performance, and usage tracking. All tests passing. --- tests/conftest.py | 118 ++++++++++++++++++ tests/test_api_endpoints.py | 230 ++++++++++++++++++++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_api_endpoints.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..4c008a3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,118 @@ +import os +import sys +import pytest +from typing import Generator, Any +from unittest.mock import Mock, MagicMock +from fastapi.testclient import TestClient + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from openai_harmony import ( + HarmonyEncodingName, + load_harmony_encoding, +) +from gpt_oss.responses_api.api_server import create_api_server + + +@pytest.fixture(scope="session") +def harmony_encoding(): + return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + +@pytest.fixture +def mock_infer_token(harmony_encoding): + fake_tokens = harmony_encoding.encode( + "<|channel|>final<|message|>Test response<|return|>", + allowed_special="all" + ) + token_queue = fake_tokens.copy() + + def _mock_infer(tokens: list[int], temperature: float = 0.0, new_request: bool = False) -> int: + nonlocal token_queue + if len(token_queue) == 0: + token_queue = fake_tokens.copy() + return token_queue.pop(0) + return _mock_infer + + +@pytest.fixture +def api_client(harmony_encoding, mock_infer_token) -> Generator[TestClient, None, None]: + app = create_api_server( + infer_next_token=mock_infer_token, + encoding=harmony_encoding + ) + with TestClient(app) as client: + yield client + + +@pytest.fixture +def sample_request_data(): + return { + "model": "gpt-oss-120b", + "input": "Hello, how can I help you today?", + "stream": False, + "reasoning_effort": "low", + "temperature": 0.7, + "tools": [] + } + + +@pytest.fixture +def mock_browser_tool(): + mock = MagicMock() + mock.search.return_value = ["Result 1", "Result 2"] + mock.open_page.return_value = "Page content" + mock.find_on_page.return_value = "Found text" + return mock + + +@pytest.fixture +def mock_python_tool(): + mock = MagicMock() + mock.execute.return_value = { + "output": "print('Hello')", + "error": None, + "exit_code": 0 + } + return mock + + +@pytest.fixture(autouse=True) +def reset_test_environment(): + test_env_vars = ['OPENAI_API_KEY', 'GPT_OSS_MODEL_PATH'] + original_values = {} + + for var in test_env_vars: + if var in os.environ: + original_values[var] = os.environ[var] + del os.environ[var] + + yield + + for var, value in original_values.items(): + os.environ[var] = value + + +@pytest.fixture +def performance_timer(): + import time + + class Timer: + def __init__(self): + self.start_time = None + self.end_time = None + + def start(self): + self.start_time = time.time() + + def stop(self): + self.end_time = time.time() + return self.elapsed + + @property + def elapsed(self): + if self.start_time and self.end_time: + return self.end_time - self.start_time + return None + + return Timer() \ No newline at end of file diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py new file mode 100644 index 0000000..7fd354b --- /dev/null +++ b/tests/test_api_endpoints.py @@ -0,0 +1,230 @@ +import pytest +import json +import asyncio +from fastapi import status +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestResponsesEndpoint: + + def test_basic_response_creation(self, api_client, sample_request_data): + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "id" in data + assert data["object"] == "response" + assert data["model"] == sample_request_data["model"] + + def test_response_with_high_reasoning(self, api_client, sample_request_data): + sample_request_data["reasoning_effort"] = "high" + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "id" in data + assert data["status"] == "completed" + + def test_response_with_medium_reasoning(self, api_client, sample_request_data): + sample_request_data["reasoning_effort"] = "medium" + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "id" in data + assert data["status"] == "completed" + + def test_response_with_invalid_model(self, api_client, sample_request_data): + sample_request_data["model"] = "invalid-model" + response = api_client.post("/v1/responses", json=sample_request_data) + # Should still accept but might handle differently + assert response.status_code == status.HTTP_200_OK + + def test_response_with_empty_input(self, api_client, sample_request_data): + sample_request_data["input"] = "" + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_response_with_tools(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "browser_search" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_response_with_custom_temperature(self, api_client, sample_request_data): + for temp in [0.0, 0.5, 1.0, 1.5, 2.0]: + sample_request_data["temperature"] = temp + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "usage" in data + + def test_streaming_response(self, api_client, sample_request_data): + sample_request_data["stream"] = True + with api_client.stream("POST", "/v1/responses", json=sample_request_data) as response: + assert response.status_code == status.HTTP_200_OK + # Verify we get SSE events + for line in response.iter_lines(): + if line and line.startswith("data: "): + event_data = line[6:] # Remove "data: " prefix + if event_data != "[DONE]": + json.loads(event_data) # Should be valid JSON + break + + +class TestResponsesWithSession: + + def test_response_with_session_id(self, api_client, sample_request_data): + session_id = "test-session-123" + sample_request_data["session_id"] = session_id + + # First request + response1 = api_client.post("/v1/responses", json=sample_request_data) + assert response1.status_code == status.HTTP_200_OK + data1 = response1.json() + + # Second request with same session + sample_request_data["input"] = "Follow up question" + response2 = api_client.post("/v1/responses", json=sample_request_data) + assert response2.status_code == status.HTTP_200_OK + data2 = response2.json() + + # Should have different response IDs + assert data1["id"] != data2["id"] + + def test_response_continuation(self, api_client, sample_request_data): + # Create initial response + response1 = api_client.post("/v1/responses", json=sample_request_data) + assert response1.status_code == status.HTTP_200_OK + data1 = response1.json() + response_id = data1["id"] + + # Continue the response + continuation_request = { + "model": sample_request_data["model"], + "response_id": response_id, + "input": "Continue the previous thought" + } + response2 = api_client.post("/v1/responses", json=continuation_request) + assert response2.status_code == status.HTTP_200_OK + + +class TestErrorHandling: + + def test_missing_required_fields(self, api_client): + # Model field has default, so test with empty JSON + response = api_client.post("/v1/responses", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_invalid_reasoning_effort(self, api_client, sample_request_data): + sample_request_data["reasoning_effort"] = "invalid" + response = api_client.post("/v1/responses", json=sample_request_data) + # May handle gracefully or return error + assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY] + + def test_malformed_json(self, api_client): + response = api_client.post( + "/v1/responses", + data="not json", + headers={"Content-Type": "application/json"} + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_extremely_long_input(self, api_client, sample_request_data): + # Test with very long input + sample_request_data["input"] = "x" * 100000 + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + +class TestToolIntegration: + + def test_browser_search_tool(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "browser_search" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_function_tool_integration(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "function", + "name": "test_function", + "parameters": {"type": "object", "properties": {}}, + "description": "Test function" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + def test_multiple_tools(self, api_client, sample_request_data): + sample_request_data["tools"] = [ + { + "type": "browser_search" + }, + { + "type": "function", + "name": "test_function", + "parameters": {"type": "object", "properties": {}}, + "description": "Test function" + } + ] + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + + +class TestPerformance: + + def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer): + performance_timer.start() + response = api_client.post("/v1/responses", json=sample_request_data) + elapsed = performance_timer.stop() + + assert response.status_code == status.HTTP_200_OK + # Response should be reasonably fast for mock inference + assert elapsed < 5.0 # 5 seconds threshold + + def test_multiple_sequential_requests(self, api_client, sample_request_data): + # Test multiple requests work correctly + for i in range(3): + data = sample_request_data.copy() + data["input"] = f"Request {i}" + response = api_client.post("/v1/responses", json=data) + assert response.status_code == status.HTTP_200_OK + + +class TestUsageTracking: + + def test_usage_object_structure(self, api_client, sample_request_data): + response = api_client.post("/v1/responses", json=sample_request_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + + assert "usage" in data + usage = data["usage"] + assert "input_tokens" in usage + assert "output_tokens" in usage + assert "total_tokens" in usage + # reasoning_tokens may not always be present + # assert "reasoning_tokens" in usage + + # Basic validation + assert usage["input_tokens"] >= 0 + assert usage["output_tokens"] >= 0 + assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] + + def test_usage_increases_with_longer_input(self, api_client, sample_request_data): + # Short input + response1 = api_client.post("/v1/responses", json=sample_request_data) + usage1 = response1.json()["usage"] + + # Longer input + sample_request_data["input"] = sample_request_data["input"] * 10 + response2 = api_client.post("/v1/responses", json=sample_request_data) + usage2 = response2.json()["usage"] + + # Longer input should use more tokens + assert usage2["input_tokens"] > usage1["input_tokens"] \ No newline at end of file