mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
* feat: Add initial server structure with configuration, registry, and middleware * feat: Add chain_sum dataset to experiment registry test * fix: Update test_registry to use DatasetSpec for composite config validation * refactor: Update Pydantic config to use json_schema_extra and ConfigDict * feat: Add Pydantic models for API request/response data * feat: Implement basic experiment management endpoints with tests * feat: Implement composite configuration endpoints for experiments * fix: Add missing DatasetConfigUpdate import in server.py * refactor: Update dataset config update method to properly merge config updates * fix: Correctly retrieve current dataset config in composite endpoint * feat: Add basic CLI structure with experiments and config commands * feat: Add initial CLI tool with basic experiment management commands * refactor: Reorganize CLI package structure and fix import paths * refactor: Implement initial CLI commands for experiment management * feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool * fix: Move print statements inside try block to resolve SyntaxError * fix: Resolve SyntaxError in edit_config function by adding missing except block * feat: Add default app instance in server module for easier uvicorn startup * docs: Add README.md with server and RGC tool documentation * remove unused files * refactor: Remove unsupported type annotation in registry.py * refactor: Move ExperimentRegistry to coaching module and add Experiment class * fix: Add missing CompositeDataset import in test_registry.py * refactor: Implement lazy ASGI app creation for server initialization * feat: Add health check command to RGC CLI for server connection * feat: Add version tracking support to CompositeDataset * feat: Add DatasetVersionManager for tracking dataset versions * feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset * feat: Add entry_id metadata combining version and index * fix: Resolve undefined variable by storing version_id before use * test: Add comprehensive unit tests for score_answer_with_id() function * test: Add comprehensive version tracking test for dataset config updates * feat: Validate dataset weights are positive in CompositeDataset initialization * feat: Add weight update and normalization methods to CompositeDataset * refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets * feat: Add negative weight validation to CompositeDataset constructor * feat: Add duplicate dataset name check in CompositeDataset and update test * refactor: Move duplicate dataset name check inside dataset iteration loop * refactor: Update CompositeDataset weight management to use config as source of truth * refactor: Move duplicate dataset name check to CompositeConfig.validate() * test: Update composite dataset weight test assertions and validation * feat: Add methods to add and remove datasets in CompositeDataset * refactor: Remove weight normalization and use unnormalized weights directly * refactor: Remove redundant total weight check in update_dataset_weights * feat: Add batch generation and scoring endpoints to server * fix: Import BatchEntry in server.py to resolve undefined name error * refactor: Update ReasoningGymDataset to use server for batch generation and scoring * fix: Add missing List and Dict type imports * feat: Add get_batch() and score_outputs() methods to RGClient * test: Add unit tests for generate_batch and score_outputs endpoints * refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor * feat: Add validation for base_index and batch_size in generate_batch endpoint * refactor: Remove unused BatchRequest type from imports * refactor: Convert models to use Pydantic exclusively * test: Update scoring endpoint tests to use correct request model format * refactor: Rename ScoreItem to AnswerItem and update related code * feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids * fix: Add missing ScoringResponse import in server.py * move verl ppo sample with server into own file * refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient * refactor: Update client methods to use Pydantic models for type safety * refactor: Use Pydantic models for experiment and dataset config operations * refactor: Clean up duplicate methods and improve error handling in main.py * first bits of rg server use for verl * refactor: Optimize scoring with single HTTP request in _score_output * fix: Correct experiment creation with ExperimentCreate object * grpo tests with server
278 lines
8.1 KiB
Python
278 lines
8.1 KiB
Python
"""Tests for API endpoints."""
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from ..config import ServerConfig
|
|
from ..server import create_app
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create a test client."""
|
|
config = ServerConfig(host="localhost", port=8000, api_key="test-key", log_level="INFO")
|
|
app = create_app(config)
|
|
return TestClient(app)
|
|
|
|
|
|
def test_health_check(client):
|
|
"""Test health check endpoint."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "healthy"}
|
|
|
|
|
|
def test_experiment_endpoints(client):
|
|
"""Test experiment management endpoints."""
|
|
# Set API key
|
|
headers = {"X-API-Key": "test-key"}
|
|
|
|
# Create experiment
|
|
create_data = {
|
|
"name": "test_exp",
|
|
"size": 10,
|
|
"seed": 42,
|
|
"datasets": {
|
|
"chain_sum": {
|
|
"weight": 1.0,
|
|
"config": {
|
|
"min_terms": 2,
|
|
"max_terms": 4,
|
|
"min_digits": 1,
|
|
"max_digits": 2,
|
|
"allow_negation": False,
|
|
"size": 10,
|
|
"seed": 42,
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
response = client.post("/experiments", json=create_data, headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.json()["name"] == "test_exp"
|
|
|
|
# List experiments
|
|
response = client.get("/experiments", headers=headers)
|
|
assert response.status_code == 200
|
|
assert "test_exp" in response.json()["experiments"]
|
|
|
|
# Delete experiment
|
|
response = client.delete("/experiments/test_exp", headers=headers)
|
|
assert response.status_code == 200
|
|
|
|
# Verify deletion
|
|
response = client.get("/experiments", headers=headers)
|
|
assert response.status_code == 200
|
|
assert "test_exp" not in response.json()["experiments"]
|
|
|
|
# Try to delete non-existent experiment
|
|
response = client.delete("/experiments/nonexistent", headers=headers)
|
|
assert response.status_code == 404
|
|
|
|
|
|
def test_batch_generation_endpoint(client):
|
|
"""Test batch generation endpoint."""
|
|
headers = {"X-API-Key": "test-key"}
|
|
|
|
# Create test experiment
|
|
create_data = {
|
|
"name": "test_exp",
|
|
"size": 10,
|
|
"seed": 42,
|
|
"datasets": {
|
|
"chain_sum": {
|
|
"weight": 1.0,
|
|
"config": {
|
|
"min_terms": 2,
|
|
"max_terms": 4,
|
|
"min_digits": 1,
|
|
"max_digits": 2,
|
|
"allow_negation": False,
|
|
"size": 10,
|
|
"seed": 42,
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
response = client.post("/experiments", json=create_data, headers=headers)
|
|
assert response.status_code == 200
|
|
|
|
# Test batch generation
|
|
response = client.get(
|
|
"/experiments/test_exp/batch",
|
|
params={"base_index": 0, "batch_size": 2},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
print(data)
|
|
|
|
# Verify batch structure
|
|
assert "entries" in data
|
|
assert len(data["entries"]) == 2
|
|
|
|
# Verify entry structure
|
|
entry = data["entries"][0]
|
|
assert "question" in entry
|
|
assert "entry_id" in entry
|
|
assert "metadata" in entry
|
|
|
|
# Test error cases
|
|
# Non-existent experiment
|
|
response = client.get(
|
|
"/experiments/nonexistent/batch",
|
|
params={"base_index": 0, "batch_size": 2},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
# Invalid parameters
|
|
response = client.get(
|
|
"/experiments/test_exp/batch",
|
|
params={"base_index": -1, "batch_size": 2},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 400
|
|
|
|
|
|
def test_scoring_endpoint(client):
|
|
"""Test answer scoring endpoint."""
|
|
headers = {"X-API-Key": "test-key"}
|
|
|
|
# Create test experiment
|
|
create_data = {
|
|
"name": "test_exp",
|
|
"size": 10,
|
|
"seed": 42,
|
|
"datasets": {
|
|
"chain_sum": {
|
|
"weight": 1.0,
|
|
"config": {
|
|
"min_terms": 2,
|
|
"max_terms": 4,
|
|
"min_digits": 1,
|
|
"max_digits": 2,
|
|
"allow_negation": False,
|
|
"size": 10,
|
|
"seed": 42,
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
response = client.post("/experiments", json=create_data, headers=headers)
|
|
assert response.status_code == 200
|
|
|
|
# Get a batch to get valid entry_ids
|
|
response = client.get(
|
|
"/experiments/test_exp/batch",
|
|
params={"base_index": 0, "batch_size": 2},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 200
|
|
batch = response.json()
|
|
entry_id = batch["entries"][0]["entry_id"]
|
|
|
|
# Test scoring with correct answer
|
|
response = client.post(
|
|
"/experiments/test_exp/score",
|
|
json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, # Assuming 2+2=4 is the first question
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert "scores" in result
|
|
assert "entry_ids" in result
|
|
assert len(result["scores"]) == 1
|
|
assert len(result["entry_ids"]) == 1
|
|
assert result["entry_ids"][0] == entry_id
|
|
assert isinstance(result["scores"][0], float)
|
|
assert 0 <= result["scores"][0] <= 1
|
|
|
|
# Test scoring with wrong answer
|
|
response = client.post(
|
|
"/experiments/test_exp/score",
|
|
json={"answers": [{"entry_id": entry_id, "answer": "wrong"}]},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert result["scores"][0] < 1.0
|
|
assert result["entry_ids"][0] == entry_id
|
|
|
|
# Test error cases
|
|
# Invalid entry_id format
|
|
response = client.post(
|
|
"/experiments/test_exp/score",
|
|
json={"answers": [{"entry_id": "invalid_id", "answer": "4"}]},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 400
|
|
|
|
# Non-existent experiment
|
|
response = client.post(
|
|
"/experiments/nonexistent/score",
|
|
json={"answers": [{"entry_id": entry_id, "answer": "4"}]},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
|
|
def test_composite_config_endpoints(client):
|
|
"""Test composite configuration endpoints."""
|
|
headers = {"X-API-Key": "test-key"}
|
|
|
|
# Create an experiment first
|
|
create_data = {
|
|
"name": "test_exp",
|
|
"size": 10,
|
|
"seed": 42,
|
|
"datasets": {
|
|
"chain_sum": {
|
|
"weight": 1.0,
|
|
"config": {
|
|
"min_terms": 2,
|
|
"max_terms": 4,
|
|
"min_digits": 1,
|
|
"max_digits": 2,
|
|
"allow_negation": False,
|
|
"size": 10,
|
|
"seed": 42,
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
response = client.post("/experiments", json=create_data, headers=headers)
|
|
assert response.status_code == 200
|
|
|
|
# Get composite config
|
|
response = client.get("/experiments/test_exp/composite", headers=headers)
|
|
assert response.status_code == 200
|
|
config = response.json()
|
|
assert config["name"] == "test_exp"
|
|
assert "chain_sum" in config["datasets"]
|
|
|
|
# Update dataset config
|
|
update_data = {"config": {"min_terms": 3, "max_terms": 5}}
|
|
response = client.post("/experiments/test_exp/composite/chain_sum", json=update_data, headers=headers)
|
|
assert response.status_code == 200
|
|
|
|
# Verify update
|
|
response = client.get("/experiments/test_exp/composite", headers=headers)
|
|
assert response.status_code == 200
|
|
config = response.json()
|
|
assert config["datasets"]["chain_sum"]["config"]["min_terms"] == 3
|
|
assert config["datasets"]["chain_sum"]["config"]["max_terms"] == 5
|
|
|
|
# Test error cases
|
|
# Non-existent experiment
|
|
response = client.get("/experiments/nonexistent/composite", headers=headers)
|
|
assert response.status_code == 404
|
|
|
|
# Non-existent dataset
|
|
response = client.post("/experiments/test_exp/composite/nonexistent", json=update_data, headers=headers)
|
|
assert response.status_code == 404
|