Files
reasoning-gym/tools/server/tests/test_endpoints.py
Andreas Köpf e2702092f4 reasoning-gym-server & cli tool (#154)
* feat: Add initial server structure with configuration, registry, and middleware

* feat: Add chain_sum dataset to experiment registry test

* fix: Update test_registry to use DatasetSpec for composite config validation

* refactor: Update Pydantic config to use json_schema_extra and ConfigDict

* feat: Add Pydantic models for API request/response data

* feat: Implement basic experiment management endpoints with tests

* feat: Implement composite configuration endpoints for experiments

* fix: Add missing DatasetConfigUpdate import in server.py

* refactor: Update dataset config update method to properly merge config updates

* fix: Correctly retrieve current dataset config in composite endpoint

* feat: Add basic CLI structure with experiments and config commands

* feat: Add initial CLI tool with basic experiment management commands

* refactor: Reorganize CLI package structure and fix import paths

* refactor: Implement initial CLI commands for experiment management

* feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool

* fix: Move print statements inside try block to resolve SyntaxError

* fix: Resolve SyntaxError in edit_config function by adding missing except block

* feat: Add default app instance in server module for easier uvicorn startup

* docs: Add README.md with server and RGC tool documentation

* remove unused files

* refactor: Remove unsupported type annotation in registry.py

* refactor: Move ExperimentRegistry to coaching module and add Experiment class

* fix: Add missing CompositeDataset import in test_registry.py

* refactor: Implement lazy ASGI app creation for server initialization

* feat: Add health check command to RGC CLI for server connection

* feat: Add version tracking support to CompositeDataset

* feat: Add DatasetVersionManager for tracking dataset versions

* feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset

* feat: Add entry_id metadata combining version and index

* fix: Resolve undefined variable by storing version_id before use

* test: Add comprehensive unit tests for score_answer_with_id() function

* test: Add comprehensive version tracking test for dataset config updates

* feat: Validate dataset weights are positive in CompositeDataset initialization

* feat: Add weight update and normalization methods to CompositeDataset

* refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets

* feat: Add negative weight validation to CompositeDataset constructor

* feat: Add duplicate dataset name check in CompositeDataset and update test

* refactor: Move duplicate dataset name check inside dataset iteration loop

* refactor: Update CompositeDataset weight management to use config as source of truth

* refactor: Move duplicate dataset name check to CompositeConfig.validate()

* test: Update composite dataset weight test assertions and validation

* feat: Add methods to add and remove datasets in CompositeDataset

* refactor: Remove weight normalization and use unnormalized weights directly

* refactor: Remove redundant total weight check in update_dataset_weights

* feat: Add batch generation and scoring endpoints to server

* fix: Import BatchEntry in server.py to resolve undefined name error

* refactor: Update ReasoningGymDataset to use server for batch generation and scoring

* fix: Add missing List and Dict type imports

* feat: Add get_batch() and score_outputs() methods to RGClient

* test: Add unit tests for generate_batch and score_outputs endpoints

* refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor

* feat: Add validation for base_index and batch_size in generate_batch endpoint

* refactor: Remove unused BatchRequest type from imports

* refactor: Convert models to use Pydantic exclusively

* test: Update scoring endpoint tests to use correct request model format

* refactor: Rename ScoreItem to AnswerItem and update related code

* feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids

* fix: Add missing ScoringResponse import in server.py

* move verl ppo sample with server into own file

* refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient

* refactor: Update client methods to use Pydantic models for type safety

* refactor: Use Pydantic models for experiment and dataset config operations

* refactor: Clean up duplicate methods and improve error handling in main.py

* first bits of rg server use for verl

* refactor: Optimize scoring with single HTTP request in _score_output

* fix: Correct experiment creation with ExperimentCreate object

* grpo tests with server
2025-02-19 22:41:33 +01:00

278 lines
8.1 KiB
Python

"""Tests for API endpoints."""
import pytest
from fastapi.testclient import TestClient
from ..config import ServerConfig
from ..server import create_app
@pytest.fixture
def client():
"""Create a test client."""
config = ServerConfig(host="localhost", port=8000, api_key="test-key", log_level="INFO")
app = create_app(config)
return TestClient(app)
def test_health_check(client):
"""Test health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_experiment_endpoints(client):
"""Test experiment management endpoints."""
# Set API key
headers = {"X-API-Key": "test-key"}
# Create experiment
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
assert response.json()["name"] == "test_exp"
# List experiments
response = client.get("/experiments", headers=headers)
assert response.status_code == 200
assert "test_exp" in response.json()["experiments"]
# Delete experiment
response = client.delete("/experiments/test_exp", headers=headers)
assert response.status_code == 200
# Verify deletion
response = client.get("/experiments", headers=headers)
assert response.status_code == 200
assert "test_exp" not in response.json()["experiments"]
# Try to delete non-existent experiment
response = client.delete("/experiments/nonexistent", headers=headers)
assert response.status_code == 404
def test_batch_generation_endpoint(client):
"""Test batch generation endpoint."""
headers = {"X-API-Key": "test-key"}
# Create test experiment
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
# Test batch generation
response = client.get(
"/experiments/test_exp/batch",
params={"base_index": 0, "batch_size": 2},
headers=headers,
)
assert response.status_code == 200
data = response.json()
print(data)
# Verify batch structure
assert "entries" in data
assert len(data["entries"]) == 2
# Verify entry structure
entry = data["entries"][0]
assert "question" in entry
assert "entry_id" in entry
assert "metadata" in entry
# Test error cases
# Non-existent experiment
response = client.get(
"/experiments/nonexistent/batch",
params={"base_index": 0, "batch_size": 2},
headers=headers,
)
assert response.status_code == 404
# Invalid parameters
response = client.get(
"/experiments/test_exp/batch",
params={"base_index": -1, "batch_size": 2},
headers=headers,
)
assert response.status_code == 400
def test_scoring_endpoint(client):
"""Test answer scoring endpoint."""
headers = {"X-API-Key": "test-key"}
# Create test experiment
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
# Get a batch to get valid entry_ids
response = client.get(
"/experiments/test_exp/batch",
params={"base_index": 0, "batch_size": 2},
headers=headers,
)
assert response.status_code == 200
batch = response.json()
entry_id = batch["entries"][0]["entry_id"]
# Test scoring with correct answer
response = client.post(
"/experiments/test_exp/score",
json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, # Assuming 2+2=4 is the first question
headers=headers,
)
assert response.status_code == 200
result = response.json()
assert "scores" in result
assert "entry_ids" in result
assert len(result["scores"]) == 1
assert len(result["entry_ids"]) == 1
assert result["entry_ids"][0] == entry_id
assert isinstance(result["scores"][0], float)
assert 0 <= result["scores"][0] <= 1
# Test scoring with wrong answer
response = client.post(
"/experiments/test_exp/score",
json={"answers": [{"entry_id": entry_id, "answer": "wrong"}]},
headers=headers,
)
assert response.status_code == 200
result = response.json()
assert result["scores"][0] < 1.0
assert result["entry_ids"][0] == entry_id
# Test error cases
# Invalid entry_id format
response = client.post(
"/experiments/test_exp/score",
json={"answers": [{"entry_id": "invalid_id", "answer": "4"}]},
headers=headers,
)
assert response.status_code == 400
# Non-existent experiment
response = client.post(
"/experiments/nonexistent/score",
json={"answers": [{"entry_id": entry_id, "answer": "4"}]},
headers=headers,
)
assert response.status_code == 404
def test_composite_config_endpoints(client):
"""Test composite configuration endpoints."""
headers = {"X-API-Key": "test-key"}
# Create an experiment first
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
# Get composite config
response = client.get("/experiments/test_exp/composite", headers=headers)
assert response.status_code == 200
config = response.json()
assert config["name"] == "test_exp"
assert "chain_sum" in config["datasets"]
# Update dataset config
update_data = {"config": {"min_terms": 3, "max_terms": 5}}
response = client.post("/experiments/test_exp/composite/chain_sum", json=update_data, headers=headers)
assert response.status_code == 200
# Verify update
response = client.get("/experiments/test_exp/composite", headers=headers)
assert response.status_code == 200
config = response.json()
assert config["datasets"]["chain_sum"]["config"]["min_terms"] == 3
assert config["datasets"]["chain_sum"]["config"]["max_terms"] == 5
# Test error cases
# Non-existent experiment
response = client.get("/experiments/nonexistent/composite", headers=headers)
assert response.status_code == 404
# Non-existent dataset
response = client.post("/experiments/test_exp/composite/nonexistent", json=update_data, headers=headers)
assert response.status_code == 404