reasoning-gym-server & cli tool (#154)

* feat: Add initial server structure with configuration, registry, and middleware

* feat: Add chain_sum dataset to experiment registry test

* fix: Update test_registry to use DatasetSpec for composite config validation

* refactor: Update Pydantic config to use json_schema_extra and ConfigDict

* feat: Add Pydantic models for API request/response data

* feat: Implement basic experiment management endpoints with tests

* feat: Implement composite configuration endpoints for experiments

* fix: Add missing DatasetConfigUpdate import in server.py

* refactor: Update dataset config update method to properly merge config updates

* fix: Correctly retrieve current dataset config in composite endpoint

* feat: Add basic CLI structure with experiments and config commands

* feat: Add initial CLI tool with basic experiment management commands

* refactor: Reorganize CLI package structure and fix import paths

* refactor: Implement initial CLI commands for experiment management

* feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool

* fix: Move print statements inside try block to resolve SyntaxError

* fix: Resolve SyntaxError in edit_config function by adding missing except block

* feat: Add default app instance in server module for easier uvicorn startup

* docs: Add README.md with server and RGC tool documentation

* remove unused files

* refactor: Remove unsupported type annotation in registry.py

* refactor: Move ExperimentRegistry to coaching module and add Experiment class

* fix: Add missing CompositeDataset import in test_registry.py

* refactor: Implement lazy ASGI app creation for server initialization

* feat: Add health check command to RGC CLI for server connection

* feat: Add version tracking support to CompositeDataset

* feat: Add DatasetVersionManager for tracking dataset versions

* feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset

* feat: Add entry_id metadata combining version and index

* fix: Resolve undefined variable by storing version_id before use

* test: Add comprehensive unit tests for score_answer_with_id() function

* test: Add comprehensive version tracking test for dataset config updates

* feat: Validate dataset weights are positive in CompositeDataset initialization

* feat: Add weight update and normalization methods to CompositeDataset

* refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets

* feat: Add negative weight validation to CompositeDataset constructor

* feat: Add duplicate dataset name check in CompositeDataset and update test

* refactor: Move duplicate dataset name check inside dataset iteration loop

* refactor: Update CompositeDataset weight management to use config as source of truth

* refactor: Move duplicate dataset name check to CompositeConfig.validate()

* test: Update composite dataset weight test assertions and validation

* feat: Add methods to add and remove datasets in CompositeDataset

* refactor: Remove weight normalization and use unnormalized weights directly

* refactor: Remove redundant total weight check in update_dataset_weights

* feat: Add batch generation and scoring endpoints to server

* fix: Import BatchEntry in server.py to resolve undefined name error

* refactor: Update ReasoningGymDataset to use server for batch generation and scoring

* fix: Add missing List and Dict type imports

* feat: Add get_batch() and score_outputs() methods to RGClient

* test: Add unit tests for generate_batch and score_outputs endpoints

* refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor

* feat: Add validation for base_index and batch_size in generate_batch endpoint

* refactor: Remove unused BatchRequest type from imports

* refactor: Convert models to use Pydantic exclusively

* test: Update scoring endpoint tests to use correct request model format

* refactor: Rename ScoreItem to AnswerItem and update related code

* feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids

* fix: Add missing ScoringResponse import in server.py

* move verl ppo sample with server into own file

* refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient

* refactor: Update client methods to use Pydantic models for type safety

* refactor: Use Pydantic models for experiment and dataset config operations

* refactor: Clean up duplicate methods and improve error handling in main.py

* first bits of rg server use for verl

* refactor: Optimize scoring with single HTTP request in _score_output

* fix: Correct experiment creation with ExperimentCreate object

* grpo tests with server
This commit is contained in:
Andreas Köpf
2025-02-19 22:41:33 +01:00
committed by GitHub
parent bec6aefd11
commit e2702092f4
23 changed files with 1968 additions and 22 deletions

View File

@@ -0,0 +1,9 @@
#!/bin/bash
export N_GPUS=2
export BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
export ROLLOUT_TP_SIZE=2
export EXPERIMENT_NAME=chain_sum_llama
export VLLM_ATTENTION_BACKEND=XFORMERS
bash ./train_grpo_server.sh

View File

@@ -0,0 +1,344 @@
# This example is an adapted version of Bytedance's code:
# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py
import os
from typing import Dict, List, Optional
import hydra
import ray
import torch
import verl.utils.torch_functional as verl_F
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.dataset.rl_dataset import collate_fn
from verl.utils.model import compute_position_id_with_mask
import reasoning_gym
import reasoning_gym.utils
from reasoning_gym.utils import extract_answer
from tools.server.models import AnswerItem, BatchEntry, ExperimentCreate
class ReasoningGymDataset(Dataset):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
dataset_name: str,
seed: int,
size: int,
developer_prompt: Optional[str] = None,
developer_role: str = "system",
max_prompt_length: int = 2048,
truncation: str = "error", ## ['left', 'right', 'error']
return_raw_chat: bool = False,
server_url: str = "http://localhost:8000",
api_key: Optional[str] = None,
batch_size: int = 32,
):
from tools.cli.rgc.client import RGClient
self.tokenizer = tokenizer
self.dataset_name = dataset_name
self.developer_prompt = developer_prompt
self.developer_role = developer_role
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.return_raw_chat = return_raw_chat
self.size = size
self.batch_size = batch_size
# Initialize client and create experiment if needed
self.client = RGClient(base_url=server_url, api_key=api_key)
# Check if experiment exists, create if not
experiments = self.client.list_experiments()
if dataset_name not in experiments.experiments:
config = ExperimentCreate(
name=dataset_name,
size=size,
seed=seed,
datasets={dataset_name: {"weight": 1.0, "config": {"seed": seed, "size": size}}},
)
self.client.create_experiment(dataset_name, config)
# Cache for batches
self._batch_cache: dict[int, List[BatchEntry]] = {}
def __len__(self) -> int:
return self.size
def _get_batch(self, batch_idx: int) -> List[BatchEntry]:
"""Fetch or retrieve cached batch"""
if batch_idx not in self._batch_cache:
base_index = batch_idx * self.batch_size
response = self.client.get_batch(self.dataset_name, base_index=base_index, batch_size=self.batch_size)
self._batch_cache[batch_idx] = response.entries
# # Basic cache management - keep only last N batches
# if len(self._batch_cache) > 10:
# oldest_batch = min(self._batch_cache.keys())
# del self._batch_cache[oldest_batch]
return self._batch_cache[batch_idx]
def __getitem__(self, index):
# Get batch containing this index
batch_idx = index // self.batch_size
batch = self._get_batch(batch_idx)
entry = batch[index % self.batch_size]
# Format chat/prompt
chat = []
if self.developer_prompt is not None:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": entry.question})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# Tokenize
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
position_ids = compute_position_id_with_mask(attention_mask)
row_dict = {
"data_source": "reasoning_gym/" + self.dataset_name,
"input_ids": input_ids[0],
"attention_mask": attention_mask[0],
"position_ids": position_ids[0],
"entry_id": entry.entry_id,
"metadata": entry.metadata,
"index": index,
}
# Add raw chat if requested
if self.return_raw_chat:
row_dict["raw_prompt"] = chat
return row_dict
class RayPPOTrainerCustom(RayPPOTrainer):
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict,
resource_pool_manager,
ray_worker_group_cls,
dataset_name: str = "chain_sum",
dataset_size: int = 10000,
):
self.dataset_name = dataset_name
self.dataset_size = dataset_size
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
rg_api_key = os.getenv("REASONING_GYM_API_KEY", "your-secret-key")
self.train_dataset = ReasoningGymDataset(
tokenizer=tokenizer,
dataset_name=self.dataset_name,
seed=1,
size=self.dataset_size,
developer_prompt=developer_prompt,
api_key=rg_api_key,
)
self.val_dataset = ReasoningGymDataset(
tokenizer=tokenizer,
dataset_name=self.dataset_name,
seed=2,
size=self.dataset_size,
developer_prompt=developer_prompt,
api_key=rg_api_key,
)
train_reward_fn = lambda data: self._score_output(data, num_examine=0)
val_reward_fn = lambda data: self._score_output(data, num_examine=1)
super().__init__(
config,
tokenizer,
role_worker_mapping,
resource_pool_manager,
ray_worker_group_cls,
train_reward_fn,
val_reward_fn,
)
def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
# Prepare batch of answers to score
answer_items = []
valid_response_lengths = []
sequences_strs = []
for i in range(len(data)):
data_item = data[i]
# Get prompt and response
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
valid_response_lengths.append(valid_response_length)
# Decode full sequence
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)
sequences_strs.append(sequences_str)
# Extract answer and prepare scoring item
found_answer = extract_answer(sequences_str, tag_name="answer")
index = data_item.non_tensor_batch["index"]
entry_id = self.train_dataset[index]["entry_id"]
# print(
# "found_answer",
# entry_id,
# found_answer,
# )
answer_items.append(AnswerItem(entry_id=entry_id, answer=found_answer))
# Score all answers in one request
response = self.train_dataset.client.score_outputs(self.train_dataset.dataset_name, answer_items)
# print("response", response)
# Fill reward tensor
for i, (score, valid_response_length) in enumerate(zip(response.scores, valid_response_lengths)):
reward_tensor[i, valid_response_length - 1] = score
if i < num_examine:
print(f"reward={score}, seq={sequences_strs[i]}")
return reward_tensor
def _create_dataloader(self):
self.train_dataloader = DataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.train_batch_size,
shuffle=False,
drop_last=True,
collate_fn=collate_fn,
)
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=False,
drop_last=True,
collate_fn=collate_fn,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f"Total training steps: {self.total_training_steps}")
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
@ray.remote
def main_task(config):
# print initial config
from pprint import pprint
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainerCustom(
config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
)
trainer.init_workers()
trainer.fit()
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(config))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,39 @@
#!/bin/bash
set -x
python3 -u main_ppo_custom_reward_server.py \
algorithm.adv_estimator=grpo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=32 \
data.val_batch_size=32 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_chain_sum_grpo' \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=100 \
trainer.test_freq=100 \
trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log

View File

@@ -32,7 +32,22 @@ license = "Apache-2.0"
license-files = ["LICENSE*"]
[project.optional-dependencies]
test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"]
test = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"httpx>=0.27.0"
]
server = [
"fastapi>=0.109.0",
"uvicorn>=0.27.0",
"pydantic-settings>=2.1.0",
]
cli = [
"typer>=0.9.0",
"rich>=13.7.0",
"pyyaml>=6.0.1",
"httpx>=0.27.0",
]
[project.urls]
"Homepage" = "https://github.com/open-thought/reasoning-gym"
@@ -40,12 +55,19 @@ test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"]
[tool.hatch.build]
packages = ["reasoning_gym"]
include = [
"reasoning_gym/**/*.py",
"reasoning_gym/**/*.txt",
"reasoning_gym/**/levels/*",
packages = [
"reasoning_gym",
"tools.cli.rgc"
]
include = [
"reasoning_gym/**/*.py",
"reasoning_gym/**/*.txt",
"reasoning_gym/**/levels/*",
"tools/cli/rgc/**/*.py"
]
[project.scripts]
rgc = "tools.cli.rgc.main:main"
[tool.black]
line-length = 120

View File

@@ -0,0 +1,36 @@
"""Experiment class combining dataset, scoreboard and curriculum."""
from dataclasses import dataclass
from typing import Optional
from ..composite import CompositeConfig, CompositeDataset
from ..version_manager import DatasetVersionManager
from .coach import ScoreBoard
@dataclass
class Experiment:
"""
An experiment combines a dataset with scoring and curriculum management.
Attributes:
name: Unique identifier for the experiment
dataset: The composite dataset for generating examples
score_board: Tracks performance metrics
config: The configuration used to create the dataset
version_manager: Manages dataset versions for scoring
"""
name: str
dataset: CompositeDataset
score_board: ScoreBoard
config: CompositeConfig
version_manager: DatasetVersionManager
@classmethod
def create(cls, name: str, config: CompositeConfig) -> "Experiment":
"""Create a new experiment from a configuration."""
version_manager = DatasetVersionManager()
dataset = CompositeDataset(config, version_manager=version_manager)
score_board = ScoreBoard()
return cls(name=name, dataset=dataset, score_board=score_board, config=config, version_manager=version_manager)

View File

@@ -0,0 +1,34 @@
"""Registry for managing active experiments."""
from typing import Dict, List, Optional
from ..composite import CompositeConfig
from .experiment import Experiment
class ExperimentRegistry:
"""Singleton registry for managing active experiments."""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._experiments = {}
return cls._instance
def register_experiment(self, name: str, config: CompositeConfig) -> None:
"""Register a new experiment with the given name and configuration."""
self._experiments[name] = Experiment.create(name, config)
def get_experiment(self, name: str) -> Optional[Experiment]:
"""Get an experiment by name."""
return self._experiments.get(name)
def list_experiments(self) -> List[str]:
"""List all registered experiment names."""
return list(self._experiments.keys())
def remove_experiment(self, name: str) -> bool:
"""Remove an experiment by name. Returns True if removed, False if not found."""
return bool(self._experiments.pop(name, None))

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, replace
from random import Random
from typing import Any, Dict, List, Optional
@@ -6,6 +6,7 @@ import yaml
from .dataset import ProceduralDataset
from .factory import create_dataset, register_dataset
from .version_manager import DatasetVersionManager
@dataclass
@@ -37,6 +38,11 @@ class CompositeConfig:
assert self.datasets, "Must specify at least one dataset"
assert len(self.datasets) > 0, "Must specify at least one dataset"
# Check for duplicate dataset names
dataset_names = [ds.name for ds in self.datasets]
if len(dataset_names) != len(set(dataset_names)):
raise ValueError("Duplicate dataset names are not allowed in CompositeDataset")
# Validate each dataset spec
for ds in self.datasets:
ds.validate()
@@ -57,13 +63,14 @@ class CompositeConfig:
class CompositeDataset(ProceduralDataset):
"""A dataset that combines multiple datasets with weighted sampling"""
def __init__(self, config: CompositeConfig):
def __init__(self, config: CompositeConfig, version_manager: Optional[DatasetVersionManager] = None):
super().__init__(config=config, seed=config.seed, size=config.size)
self.version_manager = version_manager
self.dataset_versions = {} # dataset_name -> version_id
# Initialize sub-datasets with incremented seeds
self.datasets = {}
self.weights = []
total_weight = 0.0
for i, ds_spec in enumerate(config.datasets):
# Create dataset with derived seed
@@ -73,12 +80,18 @@ class CompositeDataset(ProceduralDataset):
if "size" not in ds_config:
ds_config["size"] = self.size
self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
total_weight += ds_spec.weight
self.weights.append(ds_spec.weight)
if ds_spec.weight < 0:
raise ValueError(f"Dataset '{ds_spec.name}' has invalid weight {ds_spec.weight}, must be non-negative")
# Normalize weights
self.weights = [w / total_weight for w in self.weights]
dataset = create_dataset(ds_spec.name, **ds_config)
self.datasets[ds_spec.name] = dataset
# Register version if tracking enabled
if version_manager is not None:
version_id = version_manager.register_dataset(ds_spec.name, dataset)
self.dataset_versions[ds_spec.name] = version_id
self.weights.append(ds_spec.weight) # Store unnormalized weights directly
self.dataset_names = [ds.name for ds in config.datasets]
def __getitem__(self, idx: int) -> dict:
@@ -98,6 +111,13 @@ class CompositeDataset(ProceduralDataset):
item["metadata"]["source_dataset"] = dataset_name
item["metadata"]["source_index"] = idx
# Add version info if tracking enabled
if self.version_manager is not None:
version_id = self.dataset_versions[dataset_name]
item["metadata"]["version_id"] = version_id
# Add entry_id combining version and index
item["metadata"]["entry_id"] = f"{version_id}.{idx}"
return item
def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None:
@@ -116,23 +136,151 @@ class CompositeDataset(ProceduralDataset):
dataset = self.datasets[dataset_name]
# Create new config with updates
new_config = dataset.config.__class__(**vars(dataset.config))
for key, value in config_updates.items():
setattr(new_config, key, value)
# Update the current config
new_config = replace(dataset.config, **config_updates)
# Validate new config
new_config.validate()
# Create new dataset instance with updated config
dataset_cls = dataset.__class__
self.datasets[dataset_name] = dataset_cls(new_config)
new_dataset = dataset_cls(new_config)
self.datasets[dataset_name] = new_dataset
# Register new version if tracking enabled
if self.version_manager is not None:
version_id = self.version_manager.register_dataset(dataset_name, new_dataset)
self.dataset_versions[dataset_name] = version_id
def update_dataset_weight(self, dataset_name: str, weight: float) -> None:
"""Update weight for a specific dataset in the configuration
Args:
dataset_name: Name of the dataset to update
weight: New weight value
Raises:
KeyError: If dataset_name not found
ValueError: If weight is negative
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
if weight < 0:
raise ValueError(f"Weight must be non-negative, got {weight}")
# Update weight in both config and weights list
for i, ds_spec in enumerate(self.config.datasets):
if ds_spec.name == dataset_name:
ds_spec.weight = weight
self.weights[i] = weight
break
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Forward scoring to appropriate dataset"""
dataset_name = entry["metadata"]["source_dataset"]
return self.datasets[dataset_name].score_answer(answer, entry)
def add_dataset(self, dataset_spec: DatasetSpec) -> None:
"""Add a new dataset to the composite
Args:
dataset_spec: Specification for the dataset to add
Raises:
ValueError: If dataset name already exists
"""
# Validate spec
dataset_spec.validate()
# Check for duplicate name
if dataset_spec.name in self.datasets:
raise ValueError(f"Dataset '{dataset_spec.name}' already exists in composite")
# Create dataset with derived seed
ds_config = dataset_spec.config.copy()
if "seed" not in ds_config:
ds_config["seed"] = self.seed + len(self.datasets) + 1
if "size" not in ds_config:
ds_config["size"] = self.size
# Create and add dataset
dataset = create_dataset(dataset_spec.name, **ds_config)
self.datasets[dataset_spec.name] = dataset
# Register version if tracking enabled
if self.version_manager is not None:
version_id = self.version_manager.register_dataset(dataset_spec.name, dataset)
self.dataset_versions[dataset_spec.name] = version_id
# Add to config and update internal state
self.config.datasets.append(dataset_spec)
self.dataset_names.append(dataset_spec.name)
self.weights.append(dataset_spec.weight) # Use weight directly from spec
def remove_dataset(self, dataset_name: str) -> None:
"""Remove a dataset from the composite
Args:
dataset_name: Name of the dataset to remove
Raises:
KeyError: If dataset not found
ValueError: If trying to remove last dataset
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
if len(self.datasets) <= 1:
raise ValueError("Cannot remove last dataset from composite")
# Remove from all internal structures
del self.datasets[dataset_name]
if self.version_manager is not None:
del self.dataset_versions[dataset_name]
# Remove from config
self.config.datasets = [ds for ds in self.config.datasets if ds.name != dataset_name]
# Update internal state
idx = self.dataset_names.index(dataset_name)
self.dataset_names.pop(idx)
self.weights.pop(idx)
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
"""Score an answer using an entry_id to lookup the original entry
Args:
answer: The answer to score
entry_id: String in format "version_id.index"
Returns:
Score between 0 and 1
Raises:
ValueError: If entry_id format is invalid
KeyError: If version not found in version manager
"""
if self.version_manager is None:
raise RuntimeError("Version manager required for scoring with entry_id")
try:
version_id, index = map(int, entry_id.split("."))
except ValueError:
raise ValueError(f"Invalid entry_id format: {entry_id}, expected 'version_id.index'")
# Get dataset from version manager
dataset_info = self.version_manager.get_dataset(version_id)
if dataset_info is None:
raise KeyError(f"Version {version_id} not found in version manager")
dataset_name, dataset = dataset_info
# Get entry from dataset
entry = dataset[index]
# Score answer using dataset's scoring function
return dataset.score_answer(answer, entry)
# Register the dataset
register_dataset("composite", CompositeDataset, CompositeConfig)

View File

@@ -0,0 +1,76 @@
"""Version manager for tracking dataset versions."""
from typing import Dict, Optional, Tuple
from .dataset import ProceduralDataset
class DatasetVersionManager:
"""Manages versioned ProceduralDataset instances and their configurations."""
def __init__(self):
"""Initialize the version manager."""
self.current_version = 0
# version_id -> (dataset_name, dataset_instance)
self.datasets: Dict[int, Tuple[str, ProceduralDataset]] = {}
def register_dataset(self, name: str, dataset: ProceduralDataset) -> int:
"""
Register a new dataset version.
Args:
name: Name/identifier of the dataset type
dataset: Instance of ProceduralDataset
Returns:
version_id: Unique identifier for this dataset version
"""
self.current_version += 1
self.datasets[self.current_version] = (name, dataset)
return self.current_version
def get_dataset(self, version_id: int) -> Optional[Tuple[str, ProceduralDataset]]:
"""
Retrieve a dataset by its version ID.
Args:
version_id: The version identifier
Returns:
Tuple of (dataset_name, dataset_instance) if found, None otherwise
"""
return self.datasets.get(version_id)
def get_entry(self, version_id: int, index: int) -> Dict[str, any]:
"""
Get a specific entry from a versioned dataset.
Args:
version_id: The version identifier
index: Index of the entry to retrieve
Returns:
The dataset entry
Raises:
KeyError: If version_id is not found
"""
if version_id not in self.datasets:
raise KeyError(f"Dataset version {version_id} not found")
_, dataset = self.datasets[version_id]
return dataset[index]
def cleanup_old_versions(self, keep_latest: int = 10):
"""
Remove old dataset versions to free memory.
Args:
keep_latest: Number of most recent versions to keep
"""
if len(self.datasets) <= keep_latest:
return
versions_to_remove = sorted(self.datasets.keys())[:-keep_latest]
for version in versions_to_remove:
del self.datasets[version]

View File

@@ -4,6 +4,7 @@ import pytest
import yaml
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
from reasoning_gym.version_manager import DatasetVersionManager
def create_test_config(tmp_path):
@@ -85,13 +86,165 @@ def test_composite_dataset_weights():
seed=42,
datasets=[
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
DatasetSpec("products", 3.0, {"min_terms": 2}),
],
)
dataset = CompositeDataset(config)
assert abs(dataset.weights[0] - 0.4) < 1e-6
assert abs(dataset.weights[1] - 0.6) < 1e-6
assert abs(dataset.weights[0] - 2.0) < 1e-6
assert abs(dataset.weights[1] - 3.0) < 1e-6
# Test weight updates
dataset.update_dataset_weight("chain_sum", 1.0)
print(dataset.weights)
assert abs(dataset.weights[0] - 1.0) < 1e-6
assert abs(dataset.weights[1] - 3.0) < 1e-6
# Test invalid weight
with pytest.raises(ValueError, match="Weight must be non-negative"):
dataset.update_dataset_weight("chain_sum", -1.0)
# Test invalid dataset name
with pytest.raises(KeyError):
dataset.update_dataset_weight("invalid_dataset", 1.0)
# Test zero total weight
dataset.update_dataset_weight("chain_sum", 0.0)
with pytest.raises(ValueError, match="Total of weights must be greater than zero"):
dataset.update_dataset_weight("products", 0.0)
_ = dataset[0] # access item with all weights 0
# Test duplicate dataset names
with pytest.raises(ValueError, match="Duplicate dataset names"):
CompositeConfig(
size=1000,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 1.0, {"min_terms": 3}),
],
).validate()
def test_version_tracking_with_config_updates():
"""Test that version tracking works correctly when updating dataset configs"""
# Create composite dataset with version manager
version_manager = DatasetVersionManager()
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset = CompositeDataset(config, version_manager=version_manager)
# Get an entry and its id from initial version
entry_1 = dataset[0]
entry_id_1 = entry_1["metadata"]["entry_id"]
answer_1 = entry_1["answer"]
# Update dataset config
dataset.update_dataset_config("chain_sum", {"min_terms": 3, "max_terms": 5})
# Get new entry after config update
entry_2 = dataset[0]
entry_id_2 = entry_2["metadata"]["entry_id"]
answer_2 = entry_2["answer"]
# Verify entries have different version IDs
version_1 = int(entry_id_1.split(".")[0])
version_2 = int(entry_id_2.split(".")[0])
assert version_1 != version_2, "New config should create new version"
# Verify original answer still works with original version
score_1 = dataset.score_answer_with_id(answer_1, entry_id_1)
assert score_1 == 1.0, "Original answer should still work with original version"
# Verify new answer works with new version
score_2 = dataset.score_answer_with_id(answer_2, entry_id_2)
assert score_2 == 1.0, "New answer should work with new version"
# Verify original answer fails with new version
score_3 = dataset.score_answer_with_id(answer_1, entry_id_2)
assert score_3 < 1.0, "Original answer should not work with new version"
def test_score_answer_with_id():
"""Test scoring answers using entry_id"""
# Create composite dataset with version manager
version_manager = DatasetVersionManager()
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset = CompositeDataset(config, version_manager=version_manager)
# Get an entry and its id
entry = dataset[0]
entry_id = entry["metadata"]["entry_id"]
# Test successful scoring
answer = entry["answer"]
score = dataset.score_answer_with_id(answer, entry_id)
assert score == 1.0 # Correct answer should get full score
# Test wrong answer
wrong_answer = "wrong"
score = dataset.score_answer_with_id(wrong_answer, entry_id)
assert score < 1.0 # Wrong answer should get lower score
# Test invalid entry_id format
with pytest.raises(ValueError, match="Invalid entry_id format"):
dataset.score_answer_with_id(answer, "invalid")
# Test non-existent version
with pytest.raises(KeyError, match="Version .* not found"):
dataset.score_answer_with_id(answer, "999.0")
# Test without version manager
dataset_no_vm = CompositeDataset(config)
with pytest.raises(RuntimeError, match="Version manager required"):
dataset_no_vm.score_answer_with_id(answer, entry_id)
def test_add_remove_dataset():
"""Test adding and removing datasets from composite"""
config = CompositeConfig(
size=1000,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
],
)
dataset = CompositeDataset(config)
# Test adding new dataset
new_spec = DatasetSpec("products", 2.0, {"min_terms": 2})
dataset.add_dataset(new_spec)
assert len(dataset.datasets) == 2
assert "products" in dataset.datasets
assert len(dataset.config.datasets) == 2
assert dataset.dataset_names[0] == "chain_sum"
assert dataset.dataset_names[1] == "products"
assert abs(dataset.weights[0] - 1.0) < 1e-6 # chain_sum weight
assert abs(dataset.weights[1] - 2.0) < 1e-6 # products weight
# Test duplicate name
with pytest.raises(ValueError, match="already exists"):
dataset.add_dataset(new_spec)
# Test removing dataset
dataset.remove_dataset("products")
assert len(dataset.datasets) == 1
assert "products" not in dataset.datasets
assert len(dataset.config.datasets) == 1
# Test removing non-existent dataset
with pytest.raises(KeyError):
dataset.remove_dataset("nonexistent")
# Test removing last dataset
with pytest.raises(ValueError, match="Cannot remove last dataset"):
dataset.remove_dataset("chain_sum")
def test_yaml_loading(tmp_path):

83
tools/README.md Normal file
View File

@@ -0,0 +1,83 @@
# Reasoning Gym Tools
This directory contains additional tools for working with Reasoning Gym:
## Server
A FastAPI server that manages reasoning gym experiments, allowing runtime configuration and monitoring.
### Starting the Server
1. Install server dependencies:
```bash
pip install -e ".[server]"
```
2. Set the API key environment variable:
```bash
export REASONING_GYM_API_KEY=your-secret-key
```
3. Start the server:
```bash
uvicorn tools.server.server:app
```
The server will be available at http://localhost:8000. You can access the API documentation at http://localhost:8000/docs.
## RGC (Reasoning Gym Client)
A command-line interface for interacting with the Reasoning Gym server.
### Installation
```bash
pip install -e ".[cli]"
```
### Usage
First, set the API key to match your server:
```bash
export REASONING_GYM_API_KEY=your-secret-key
```
Then you can use the CLI:
```bash
# List all commands
rgc --help
# List experiments
rgc experiments list
# Create a new experiment interactively
rgc experiments create my-experiment
# Create from config file
rgc experiments create my-experiment -f config.yaml
# Show experiment details
rgc experiments show my-experiment
# Edit dataset configuration
rgc config edit my-experiment chain_sum
```
### Example Configuration File
Here's an example `config.yaml` for creating an experiment:
```yaml
size: 500
seed: 42
datasets:
chain_sum:
weight: 1.0
config:
min_terms: 2
max_terms: 4
min_digits: 1
max_digits: 2
allow_negation: false
```

0
tools/cli/__init__.py Normal file
View File

View File

@@ -0,0 +1,5 @@
"""Reasoning Gym CLI tool."""
from .main import main
__all__ = ["main"]

125
tools/cli/rgc/client.py Normal file
View File

@@ -0,0 +1,125 @@
"""HTTP client for interacting with the Reasoning Gym server."""
import os
from typing import List, Optional
import httpx
from rich.console import Console
from tools.server.models import (
AnswerItem,
BatchResponse,
DatasetConfigUpdate,
ExperimentCreate,
ExperimentList,
ExperimentResponse,
ScoringRequest,
ScoringResponse,
)
console = Console()
DEFAULT_SERVER = "http://localhost:8000"
API_KEY = os.getenv("REASONING_GYM_API_KEY", "default-key")
class RGClient:
"""Client for interacting with Reasoning Gym server."""
def __init__(self, base_url: str = DEFAULT_SERVER, api_key: str = API_KEY):
"""Initialize client with server URL and API key."""
self.base_url = base_url.rstrip("/")
self.headers = {"X-API-Key": api_key}
def _url(self, path: str) -> str:
"""Construct full URL for given path."""
return f"{self.base_url}/{path.lstrip('/')}"
def check_health(self) -> bool:
"""Check server health status."""
try:
response = httpx.get(self._url("/health"), headers=self.headers)
response.raise_for_status()
return response.json()["status"] == "healthy"
except Exception:
return False
def list_experiments(self) -> ExperimentList:
"""List all registered experiments."""
response = httpx.get(self._url("/experiments"), headers=self.headers)
response.raise_for_status()
return ExperimentList.model_validate(response.json())
def create_experiment(self, name: str, config: ExperimentCreate) -> ExperimentResponse:
"""Create a new experiment."""
response = httpx.post(
self._url("/experiments"),
headers=self.headers,
json=config.model_dump(),
)
response.raise_for_status()
return ExperimentResponse.model_validate(response.json())
def delete_experiment(self, name: str) -> None:
"""Delete an experiment."""
response = httpx.delete(
self._url(f"/experiments/{name}"),
headers=self.headers,
)
response.raise_for_status()
def get_experiment_config(self, name: str) -> ExperimentResponse:
"""Get experiment configuration."""
response = httpx.get(
self._url(f"/experiments/{name}/composite"),
headers=self.headers,
)
response.raise_for_status()
return ExperimentResponse.model_validate(response.json())
def update_dataset_config(self, experiment: str, dataset: str, config: DatasetConfigUpdate) -> None:
"""Update dataset configuration."""
response = httpx.post(
self._url(f"/experiments/{experiment}/composite/{dataset}"),
headers=self.headers,
json=config.model_dump(),
)
response.raise_for_status()
def get_batch(self, experiment: str, base_index: int, batch_size: int) -> BatchResponse:
"""Get a batch of entries from an experiment.
Args:
experiment: Name of the experiment
base_index: Starting index for the batch
batch_size: Number of entries to retrieve
Returns:
BatchResponse containing entries with questions and metadata
"""
response = httpx.get(
self._url(f"/experiments/{experiment}/batch"),
headers=self.headers,
params={"base_index": base_index, "batch_size": batch_size},
)
response.raise_for_status()
return BatchResponse.model_validate(response.json())
def score_outputs(self, experiment: str, entry_answers: List[AnswerItem]) -> ScoringResponse:
"""Score a batch of answers.
Args:
experiment: Name of the experiment
entry_answers: List of AnswerItems with entry_ids and answers to score
Returns:
ScoringResponse containing scores and entry_ids
"""
request = ScoringRequest(answers=entry_answers)
response = httpx.post(
self._url(f"/experiments/{experiment}/score"),
headers=self.headers,
json=request.model_dump(),
)
response.raise_for_status()
return ScoringResponse.model_validate(response.json())

231
tools/cli/rgc/main.py Normal file
View File

@@ -0,0 +1,231 @@
"""Main entry point for the Reasoning Gym CLI."""
import os
from typing import Optional
import typer
import yaml
from rich.console import Console
from rich.prompt import Confirm, Prompt
from rich.syntax import Syntax
from rich.table import Table
from tools.server.models import DatasetConfigUpdate, ExperimentCreate
# Initialize Typer apps
app = typer.Typer(
name="rgc",
help="Reasoning Gym CLI - Manage and monitor reasoning gym experiments",
add_completion=True,
)
experiments_app = typer.Typer(help="Manage experiments")
config_app = typer.Typer(help="Manage configurations")
app.add_typer(experiments_app, name="experiments")
app.add_typer(config_app, name="config")
@app.command("health")
def check_health():
"""Check server connection and health status."""
try:
if client.check_health():
console.print("[green]Server is healthy[/]")
else:
console.print("[red]Server is not responding correctly[/]")
raise typer.Exit(1)
except Exception as e:
console.print(f"[red]Error connecting to server: {e}[/]")
raise typer.Exit(1)
# Initialize client and console
from .client import RGClient
client = RGClient()
console = Console()
@experiments_app.command("list")
def list_experiments():
"""List all registered experiments with their status."""
table = Table(title="Registered Experiments")
table.add_column("Name", style="cyan")
table.add_column("Datasets", style="magenta")
table.add_column("Size", style="blue")
table.add_column("Seed", style="green")
try:
experiments = client.list_experiments()
for exp_name in experiments.experiments:
try:
config = client.get_experiment_config(exp_name)
datasets = ", ".join(config.datasets.keys())
table.add_row(exp_name, datasets, str(config.size), str(config.seed or ""))
except Exception as e:
console.print(f"[yellow]Warning: Could not get config for {exp_name}: {e}[/]")
table.add_row(exp_name, "?", "?", "?")
except Exception as e:
console.print(f"[red]Error listing experiments: {e}[/]")
raise typer.Exit(1)
console.print(table)
@experiments_app.command("create")
def create_experiment(
name: str = typer.Argument(..., help="Name of the experiment"),
config_file: Optional[str] = typer.Option(None, "--file", "-f", help="YAML configuration file"),
):
"""Create a new experiment."""
if config_file:
try:
with open(config_file, "r") as f:
exp_config = yaml.safe_load(f)
config = ExperimentCreate(**exp_config)
response = client.create_experiment(name, config)
console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]")
except Exception as e:
console.print(f"[red]Error creating experiment: {e}[/]")
raise typer.Exit(1)
else:
# Interactive creation
size = Prompt.ask("Dataset size", default="500")
seed = Prompt.ask("Random seed (optional)", default="")
datasets = {}
while Confirm.ask("Add dataset?"):
ds_name = Prompt.ask("Dataset name")
weight = float(Prompt.ask("Weight", default="1.0"))
# Get dataset-specific config
console.print("\nEnter dataset configuration:")
config = {}
while Confirm.ask("Add config parameter?"):
key = Prompt.ask("Parameter name")
value = Prompt.ask("Parameter value")
try:
# Try to convert to appropriate type
if value.isdigit():
value = int(value)
elif value.lower() in ("true", "false"):
value = value.lower() == "true"
elif "." in value and value.replace(".", "").isdigit():
value = float(value)
except ValueError:
pass
config[key] = value
datasets[ds_name] = {"weight": weight, "config": config}
# Create experiment config
exp_config = {"name": name, "size": int(size), "seed": int(seed) if seed else None, "datasets": datasets}
# Show final config
console.print("\nFinal configuration:")
console.print(Syntax(yaml.dump(exp_config), "yaml"))
if Confirm.ask("Create experiment with this configuration?"):
try:
config = ExperimentCreate(**exp_config)
response = client.create_experiment(name, config)
console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]")
except Exception as e:
console.print(f"[red]Error creating experiment: {e}[/]")
raise typer.Exit(1)
else:
console.print("[yellow]Experiment creation cancelled[/]")
raise typer.Exit()
@experiments_app.command("delete")
def delete_experiment(
name: str = typer.Argument(..., help="Name of the experiment to delete"),
force: bool = typer.Option(False, "--force", "-f", help="Force deletion without confirmation"),
):
"""Delete an experiment."""
if not force and not Confirm.ask(f"Delete experiment [cyan]{name}[/]?"):
raise typer.Exit()
try:
client.delete_experiment(name)
console.print(f"[green]Deleted experiment[/] [cyan]{name}[/]")
except Exception as e:
console.print(f"[red]Error deleting experiment: {e}[/]")
raise typer.Exit(1)
@experiments_app.command("show")
def show_experiment(
name: str = typer.Argument(..., help="Name of the experiment"),
):
"""Show experiment details."""
try:
config = client.get_experiment_config(name)
console.print(Syntax(yaml.dump(config.model_dump()), "yaml"))
except Exception as e:
console.print(f"[red]Error getting experiment config: {e}[/]")
raise typer.Exit(1)
@config_app.command("edit")
def edit_config(
experiment: str = typer.Argument(..., help="Name of the experiment"),
dataset: str = typer.Argument(..., help="Name of the dataset to edit"),
):
"""Interactive configuration editor."""
try:
exp_config = client.get_experiment_config(experiment)
if dataset not in exp_config.datasets:
console.print(f"[red]Dataset {dataset} not found in experiment[/]")
raise typer.Exit(1)
current_config = exp_config.datasets[dataset]["config"]
console.print(f"\nCurrent configuration for [cyan]{dataset}[/]:")
console.print(Syntax(yaml.dump(current_config), "yaml"))
# Interactive editing
new_config = {}
for key, value in current_config.items():
new_value = Prompt.ask(f"{key}", default=str(value), show_default=True)
# Try to convert to appropriate type
try:
if isinstance(value, bool):
new_value = new_value.lower() == "true"
elif isinstance(value, int):
new_value = int(new_value)
elif isinstance(value, float):
new_value = float(new_value)
except ValueError:
console.print(f"[yellow]Warning: Could not convert {new_value} to {type(value)}[/]")
new_config[key] = new_value
# Show changes
console.print("\nNew configuration:")
console.print(Syntax(yaml.dump(new_config), "yaml"))
if Confirm.ask("Apply these changes?"):
try:
config_update = DatasetConfigUpdate(config=new_config)
client.update_dataset_config(experiment, dataset, config_update)
console.print("[green]Configuration updated successfully[/]")
except Exception as e:
console.print(f"[red]Error updating configuration: {e}[/]")
raise typer.Exit(1)
else:
console.print("[yellow]Update cancelled[/]")
except Exception as e:
console.print(f"[red]Error getting experiment configuration: {e}[/]")
raise typer.Exit(1)
def main():
"""Entry point for the CLI."""
app()
if __name__ == "__main__":
main()

8
tools/server/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Reasoning Gym Server - A FastAPI server for managing reasoning gym experiments.
"""
from .config import ServerConfig
from .server import create_app
__all__ = ["create_app", "ServerConfig"]

17
tools/server/config.py Normal file
View File

@@ -0,0 +1,17 @@
"""Server configuration using Pydantic settings management."""
from pydantic import ConfigDict, Field
from pydantic_settings import BaseSettings
class ServerConfig(BaseSettings):
"""Configuration settings for the Reasoning Gym server."""
host: str = Field(default="localhost", description="Server host address")
port: int = Field(default=8000, description="Server port")
api_key: str = Field(
default=..., description="API key for authentication", json_schema_extra={"env": "REASONING_GYM_API_KEY"}
)
log_level: str = Field(default="INFO", description="Logging level")
model_config = ConfigDict(env_prefix="REASONING_GYM_")

View File

@@ -0,0 +1,23 @@
"""API key middleware for FastAPI."""
from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.status import HTTP_401_UNAUTHORIZED
class APIKeyMiddleware(BaseHTTPMiddleware):
"""Middleware to check for valid API key in request headers."""
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
if request.url.path == "/health":
return await call_next(request)
api_key = request.headers.get("X-API-Key")
if not api_key or api_key != self.api_key:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key")
return await call_next(request)

75
tools/server/models.py Normal file
View File

@@ -0,0 +1,75 @@
"""Pydantic models for API request/response data."""
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
class ExperimentCreate(BaseModel):
"""Request model for creating a new experiment."""
name: str = Field(..., description="Unique name for the experiment")
size: int = Field(500, description="Size of the dataset")
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
datasets: Dict[str, Dict[str, Any]] = Field(..., description="Dictionary of datasets configurations")
class ExperimentResponse(BaseModel):
"""Response model for experiment operations."""
name: str = Field(..., description="Name of the experiment")
size: int = Field(..., description="Size of the dataset")
seed: Optional[int] = Field(None, description="Random seed used")
datasets: Dict[str, Dict[str, Any]] = Field(..., description="Current dataset configurations")
class ExperimentList(BaseModel):
"""Response model for listing experiments."""
experiments: List[str] = Field(default_factory=list, description="List of registered experiment names")
class DatasetConfigUpdate(BaseModel):
"""Request model for updating dataset configuration."""
config: Dict[str, Any] = Field(..., description="Configuration parameters to update")
class ErrorResponse(BaseModel):
"""Response model for error conditions."""
detail: str = Field(..., description="Error message")
class BatchEntry(BaseModel):
"""Single entry in a batch"""
question: str = Field(..., description="The question text")
entry_id: str = Field(..., description="Unique identifier in format '{version}.{index}'")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata about the entry")
class BatchResponse(BaseModel):
"""Response containing a batch of entries"""
entries: List[BatchEntry] = Field(..., description="List of batch entries")
class AnswerItem(BaseModel):
"""Single score item containing entry_id and answer"""
entry_id: str = Field(..., description="Entry identifier to score")
answer: str = Field(..., description="Answer to evaluate")
class ScoringRequest(BaseModel):
"""Request for scoring model outputs"""
answers: List[AnswerItem] = Field(..., description="List of entries to score")
class ScoringResponse(BaseModel):
"""Response containing scores for answers"""
scores: List[float] = Field(..., description="List of scores in same order as request")
entry_ids: List[str] = Field(..., description="List of entry_ids in same order as request")

169
tools/server/server.py Normal file
View File

@@ -0,0 +1,169 @@
"""FastAPI server implementation for Reasoning Gym."""
import logging
from fastapi import FastAPI, HTTPException
from reasoning_gym.coaching.registry import ExperimentRegistry
from reasoning_gym.composite import CompositeConfig, DatasetSpec
from .config import ServerConfig
from .middleware import APIKeyMiddleware
from .models import (
BatchEntry,
BatchResponse,
DatasetConfigUpdate,
ExperimentCreate,
ExperimentList,
ExperimentResponse,
ScoringRequest,
ScoringResponse,
)
def create_app(config: ServerConfig) -> FastAPI:
"""Create and configure the FastAPI application."""
# Configure logging
logging.basicConfig(level=config.log_level)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(title="Reasoning Gym Server")
# Add middleware
app.add_middleware(APIKeyMiddleware, api_key=config.api_key)
# Initialize registry
registry = ExperimentRegistry()
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@app.post("/experiments", response_model=ExperimentResponse)
async def create_experiment(experiment: ExperimentCreate):
"""Create a new experiment."""
# Convert dict format to DatasetSpec list
dataset_specs = []
for name, spec in experiment.datasets.items():
dataset_specs.append(DatasetSpec(name=name, weight=spec.get("weight", 1.0), config=spec.get("config", {})))
config = CompositeConfig(size=experiment.size, seed=experiment.seed, datasets=dataset_specs)
try:
registry.register_experiment(experiment.name, config)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return ExperimentResponse(
name=experiment.name, size=experiment.size, seed=experiment.seed, datasets=experiment.datasets
)
@app.get("/experiments", response_model=ExperimentList)
async def list_experiments():
"""List all registered experiments."""
return ExperimentList(experiments=registry.list_experiments())
@app.delete("/experiments/{name}")
async def delete_experiment(name: str):
"""Delete an experiment."""
if not registry.remove_experiment(name):
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
return {"status": "deleted"}
@app.get("/experiments/{name}/batch", response_model=BatchResponse)
async def generate_batch(name: str, base_index: int, batch_size: int):
"""Generate a batch of raw entries"""
# Validate parameters
if base_index < 0:
raise HTTPException(status_code=400, detail="base_index must be non-negative")
if batch_size <= 0:
raise HTTPException(status_code=400, detail="batch_size must be positive")
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
entries = []
for i in range(base_index, base_index + batch_size):
entry = experiment.dataset[i]
# Create BatchEntry with minimal required data
batch_entry = BatchEntry(
question=entry["question"],
entry_id=f"{entry['metadata']['version_id']}.{i}",
metadata=entry["metadata"],
)
entries.append(batch_entry)
return BatchResponse(entries=entries)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/experiments/{name}/score", response_model=ScoringResponse)
async def score_outputs(name: str, request: ScoringRequest):
"""Score extracted answers"""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
scores = []
entry_ids = []
for item in request.answers:
score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id)
scores.append(score)
entry_ids.append(item.entry_id)
return ScoringResponse(scores=scores, entry_ids=entry_ids)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/experiments/{name}/composite", response_model=ExperimentResponse)
async def get_composite_config(name: str):
"""Get composite configuration for an experiment."""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
# Convert internal config to API response format
datasets = {}
for ds_spec in experiment.config.datasets:
dataset = experiment.dataset.datasets[ds_spec.name]
datasets[ds_spec.name] = {
"weight": ds_spec.weight,
"config": vars(dataset.config), # Get current config from dataset instance
}
return ExperimentResponse(
name=name, size=experiment.config.size, seed=experiment.config.seed, datasets=datasets
)
@app.post("/experiments/{name}/composite/{dataset_name}")
async def update_dataset_config(name: str, dataset_name: str, config_update: DatasetConfigUpdate):
"""Update configuration for a specific dataset in the composite."""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
experiment.dataset.update_dataset_config(dataset_name, config_update.config)
return {"status": "updated"}
except KeyError:
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment")
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return app
async def app(scope, receive, send):
"""ASGI application that lazily creates the FastAPI app."""
if not hasattr(app, "server_app"):
app.server_app = create_app(ServerConfig())
await app.server_app(scope, receive, send)

View File

@@ -0,0 +1 @@
"""Tests for the Reasoning Gym server."""

View File

@@ -0,0 +1,27 @@
"""Tests for server configuration."""
import os
import pytest
from ..config import ServerConfig
def test_default_config():
"""Test default configuration values."""
os.environ["REASONING_GYM_API_KEY"] = "test-key"
config = ServerConfig()
assert config.host == "localhost"
assert config.port == 8000
assert config.api_key == "test-key"
assert config.log_level == "INFO"
def test_missing_api_key():
"""Test that missing API key raises an error."""
if "REASONING_GYM_API_KEY" in os.environ:
del os.environ["REASONING_GYM_API_KEY"]
with pytest.raises(ValueError):
ServerConfig()

View File

@@ -0,0 +1,277 @@
"""Tests for API endpoints."""
import pytest
from fastapi.testclient import TestClient
from ..config import ServerConfig
from ..server import create_app
@pytest.fixture
def client():
"""Create a test client."""
config = ServerConfig(host="localhost", port=8000, api_key="test-key", log_level="INFO")
app = create_app(config)
return TestClient(app)
def test_health_check(client):
"""Test health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_experiment_endpoints(client):
"""Test experiment management endpoints."""
# Set API key
headers = {"X-API-Key": "test-key"}
# Create experiment
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
assert response.json()["name"] == "test_exp"
# List experiments
response = client.get("/experiments", headers=headers)
assert response.status_code == 200
assert "test_exp" in response.json()["experiments"]
# Delete experiment
response = client.delete("/experiments/test_exp", headers=headers)
assert response.status_code == 200
# Verify deletion
response = client.get("/experiments", headers=headers)
assert response.status_code == 200
assert "test_exp" not in response.json()["experiments"]
# Try to delete non-existent experiment
response = client.delete("/experiments/nonexistent", headers=headers)
assert response.status_code == 404
def test_batch_generation_endpoint(client):
"""Test batch generation endpoint."""
headers = {"X-API-Key": "test-key"}
# Create test experiment
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
# Test batch generation
response = client.get(
"/experiments/test_exp/batch",
params={"base_index": 0, "batch_size": 2},
headers=headers,
)
assert response.status_code == 200
data = response.json()
print(data)
# Verify batch structure
assert "entries" in data
assert len(data["entries"]) == 2
# Verify entry structure
entry = data["entries"][0]
assert "question" in entry
assert "entry_id" in entry
assert "metadata" in entry
# Test error cases
# Non-existent experiment
response = client.get(
"/experiments/nonexistent/batch",
params={"base_index": 0, "batch_size": 2},
headers=headers,
)
assert response.status_code == 404
# Invalid parameters
response = client.get(
"/experiments/test_exp/batch",
params={"base_index": -1, "batch_size": 2},
headers=headers,
)
assert response.status_code == 400
def test_scoring_endpoint(client):
"""Test answer scoring endpoint."""
headers = {"X-API-Key": "test-key"}
# Create test experiment
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
# Get a batch to get valid entry_ids
response = client.get(
"/experiments/test_exp/batch",
params={"base_index": 0, "batch_size": 2},
headers=headers,
)
assert response.status_code == 200
batch = response.json()
entry_id = batch["entries"][0]["entry_id"]
# Test scoring with correct answer
response = client.post(
"/experiments/test_exp/score",
json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, # Assuming 2+2=4 is the first question
headers=headers,
)
assert response.status_code == 200
result = response.json()
assert "scores" in result
assert "entry_ids" in result
assert len(result["scores"]) == 1
assert len(result["entry_ids"]) == 1
assert result["entry_ids"][0] == entry_id
assert isinstance(result["scores"][0], float)
assert 0 <= result["scores"][0] <= 1
# Test scoring with wrong answer
response = client.post(
"/experiments/test_exp/score",
json={"answers": [{"entry_id": entry_id, "answer": "wrong"}]},
headers=headers,
)
assert response.status_code == 200
result = response.json()
assert result["scores"][0] < 1.0
assert result["entry_ids"][0] == entry_id
# Test error cases
# Invalid entry_id format
response = client.post(
"/experiments/test_exp/score",
json={"answers": [{"entry_id": "invalid_id", "answer": "4"}]},
headers=headers,
)
assert response.status_code == 400
# Non-existent experiment
response = client.post(
"/experiments/nonexistent/score",
json={"answers": [{"entry_id": entry_id, "answer": "4"}]},
headers=headers,
)
assert response.status_code == 404
def test_composite_config_endpoints(client):
"""Test composite configuration endpoints."""
headers = {"X-API-Key": "test-key"}
# Create an experiment first
create_data = {
"name": "test_exp",
"size": 10,
"seed": 42,
"datasets": {
"chain_sum": {
"weight": 1.0,
"config": {
"min_terms": 2,
"max_terms": 4,
"min_digits": 1,
"max_digits": 2,
"allow_negation": False,
"size": 10,
"seed": 42,
},
}
},
}
response = client.post("/experiments", json=create_data, headers=headers)
assert response.status_code == 200
# Get composite config
response = client.get("/experiments/test_exp/composite", headers=headers)
assert response.status_code == 200
config = response.json()
assert config["name"] == "test_exp"
assert "chain_sum" in config["datasets"]
# Update dataset config
update_data = {"config": {"min_terms": 3, "max_terms": 5}}
response = client.post("/experiments/test_exp/composite/chain_sum", json=update_data, headers=headers)
assert response.status_code == 200
# Verify update
response = client.get("/experiments/test_exp/composite", headers=headers)
assert response.status_code == 200
config = response.json()
assert config["datasets"]["chain_sum"]["config"]["min_terms"] == 3
assert config["datasets"]["chain_sum"]["config"]["max_terms"] == 5
# Test error cases
# Non-existent experiment
response = client.get("/experiments/nonexistent/composite", headers=headers)
assert response.status_code == 404
# Non-existent dataset
response = client.post("/experiments/test_exp/composite/nonexistent", json=update_data, headers=headers)
assert response.status_code == 404

View File

@@ -0,0 +1,44 @@
"""Tests for experiment registry."""
import pytest
from reasoning_gym.arithmetic.chain_sum import ChainSumConfig
from reasoning_gym.coaching.registry import ExperimentRegistry
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
def test_singleton():
"""Test that ExperimentRegistry is a singleton."""
registry1 = ExperimentRegistry()
registry2 = ExperimentRegistry()
assert registry1 is registry2
def test_experiment_management():
"""Test basic experiment management operations."""
registry = ExperimentRegistry()
# Clear any existing experiments
for name in registry.list_experiments():
registry.remove_experiment(name)
# Test registration with chain_sum dataset
chain_sum_spec = DatasetSpec(name="chain_sum", weight=1.0, config=vars(ChainSumConfig(size=10, seed=42)))
config = CompositeConfig(size=10, seed=42, datasets=[chain_sum_spec])
registry.register_experiment("test_exp", config)
# Test listing
assert "test_exp" in registry.list_experiments()
# Test retrieval
exp = registry.get_experiment("test_exp")
assert exp is not None
assert exp.name == "test_exp"
assert isinstance(exp.dataset, CompositeDataset)
assert exp.config == config
# Test removal
assert registry.remove_experiment("test_exp")
assert "test_exp" not in registry.list_experiments()
assert not registry.remove_experiment("nonexistent")