Files
reasoning-gym/tools/server/server.py
Andreas Köpf c69bc5d4e6 Basic curriculum (#198)
* feat: Add optional curriculum support to dataset registration and creation
* docs: Add docstrings to create_curriculum() and register_dataset()
* feat: Add curriculum configuration classes for CurriculumExperiment
* feat: Add weight parameter to CurriculumAttributeConfig and use in DatasetSpec
* refactor: Simplify CurriculumAttributeConfig with "*" attribute level support
* test: Add unit tests for CurriculumExperiment class
* feat: Add from_yaml() method to CurriculumExperimentConfig with unit test
2025-03-07 11:22:12 +01:00

170 lines
6.3 KiB
Python

"""FastAPI server implementation for Reasoning Gym."""
import logging
from fastapi import FastAPI, HTTPException
from reasoning_gym.coaching.registry import ExperimentRegistry
from reasoning_gym.composite import CompositeConfig, DatasetSpec
from .config import ServerConfig
from .middleware import APIKeyMiddleware
from .models import (
BatchEntry,
BatchResponse,
DatasetConfigUpdate,
ExperimentCreate,
ExperimentList,
ExperimentResponse,
ScoringRequest,
ScoringResponse,
)
def create_app(config: ServerConfig) -> FastAPI:
"""Create and configure the FastAPI application."""
# Configure logging
logging.basicConfig(level=config.log_level)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(title="Reasoning Gym Server")
# Add middleware
app.add_middleware(APIKeyMiddleware, api_key=config.api_key)
# Initialize registry
registry = ExperimentRegistry()
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@app.post("/experiments", response_model=ExperimentResponse)
async def create_experiment(experiment: ExperimentCreate):
"""Create a new experiment."""
# Convert dict format to DatasetSpec list
dataset_specs = []
for name, spec in experiment.datasets.items():
dataset_specs.append(DatasetSpec(name=name, weight=spec.get("weight", 1.0), config=spec.get("config", {})))
config = CompositeConfig(size=experiment.size, seed=experiment.seed, datasets=dataset_specs)
try:
registry.register_experiment(experiment.name, config)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return ExperimentResponse(
name=experiment.name, size=experiment.size, seed=experiment.seed, datasets=experiment.datasets
)
@app.get("/experiments", response_model=ExperimentList)
async def list_experiments():
"""List all registered experiments."""
return ExperimentList(experiments=registry.list_experiments())
@app.delete("/experiments/{name}")
async def delete_experiment(name: str):
"""Delete an experiment."""
if not registry.remove_experiment(name):
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
return {"status": "deleted"}
@app.get("/experiments/{name}/batch", response_model=BatchResponse)
async def generate_batch(name: str, base_index: int, batch_size: int):
"""Generate a batch of raw entries"""
# Validate parameters
if base_index < 0:
raise HTTPException(status_code=400, detail="base_index must be non-negative")
if batch_size <= 0:
raise HTTPException(status_code=400, detail="batch_size must be positive")
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
entries = []
for i in range(base_index, base_index + batch_size):
entry = experiment.get_dataset_entry(i)
# Create BatchEntry with minimal required data
batch_entry = BatchEntry(
question=entry["question"],
entry_id=f"{entry['metadata']['entry_id']}",
metadata=entry["metadata"],
)
entries.append(batch_entry)
return BatchResponse(entries=entries)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/experiments/{name}/score", response_model=ScoringResponse)
async def score_outputs(name: str, request: ScoringRequest):
"""Score extracted answers"""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
scores = []
entry_ids = []
for item in request.answers:
score = experiment.score_answer_with_id(item.answer, item.entry_id)
scores.append(score)
entry_ids.append(item.entry_id)
return ScoringResponse(scores=scores, entry_ids=entry_ids)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/experiments/{name}/composite", response_model=ExperimentResponse)
async def get_composite_config(name: str):
"""Get composite configuration for an experiment."""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
# Convert internal config to API response format
datasets = {}
for ds_spec in experiment.config.datasets:
dataset = experiment.composite.datasets[ds_spec.name]
datasets[ds_spec.name] = {
"weight": ds_spec.weight,
"config": vars(dataset.config), # Get current config from dataset instance
}
return ExperimentResponse(
name=name, size=experiment.config.size, seed=experiment.config.seed, datasets=datasets
)
@app.post("/experiments/{name}/composite/{dataset_name}")
async def update_dataset_config(name: str, dataset_name: str, config_update: DatasetConfigUpdate):
"""Update configuration for a specific dataset in the composite."""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
experiment.composite.update_dataset_config(dataset_name, config_update.config)
return {"status": "updated"}
except KeyError:
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment")
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return app
async def app(scope, receive, send):
"""ASGI application that lazily creates the FastAPI app."""
if not hasattr(app, "server_app"):
app.server_app = create_app(ServerConfig())
await app.server_app(scope, receive, send)