mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
reasoning-gym-server & cli tool (#154)
* feat: Add initial server structure with configuration, registry, and middleware * feat: Add chain_sum dataset to experiment registry test * fix: Update test_registry to use DatasetSpec for composite config validation * refactor: Update Pydantic config to use json_schema_extra and ConfigDict * feat: Add Pydantic models for API request/response data * feat: Implement basic experiment management endpoints with tests * feat: Implement composite configuration endpoints for experiments * fix: Add missing DatasetConfigUpdate import in server.py * refactor: Update dataset config update method to properly merge config updates * fix: Correctly retrieve current dataset config in composite endpoint * feat: Add basic CLI structure with experiments and config commands * feat: Add initial CLI tool with basic experiment management commands * refactor: Reorganize CLI package structure and fix import paths * refactor: Implement initial CLI commands for experiment management * feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool * fix: Move print statements inside try block to resolve SyntaxError * fix: Resolve SyntaxError in edit_config function by adding missing except block * feat: Add default app instance in server module for easier uvicorn startup * docs: Add README.md with server and RGC tool documentation * remove unused files * refactor: Remove unsupported type annotation in registry.py * refactor: Move ExperimentRegistry to coaching module and add Experiment class * fix: Add missing CompositeDataset import in test_registry.py * refactor: Implement lazy ASGI app creation for server initialization * feat: Add health check command to RGC CLI for server connection * feat: Add version tracking support to CompositeDataset * feat: Add DatasetVersionManager for tracking dataset versions * feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset * feat: Add entry_id metadata combining version and index * fix: Resolve undefined variable by storing version_id before use * test: Add comprehensive unit tests for score_answer_with_id() function * test: Add comprehensive version tracking test for dataset config updates * feat: Validate dataset weights are positive in CompositeDataset initialization * feat: Add weight update and normalization methods to CompositeDataset * refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets * feat: Add negative weight validation to CompositeDataset constructor * feat: Add duplicate dataset name check in CompositeDataset and update test * refactor: Move duplicate dataset name check inside dataset iteration loop * refactor: Update CompositeDataset weight management to use config as source of truth * refactor: Move duplicate dataset name check to CompositeConfig.validate() * test: Update composite dataset weight test assertions and validation * feat: Add methods to add and remove datasets in CompositeDataset * refactor: Remove weight normalization and use unnormalized weights directly * refactor: Remove redundant total weight check in update_dataset_weights * feat: Add batch generation and scoring endpoints to server * fix: Import BatchEntry in server.py to resolve undefined name error * refactor: Update ReasoningGymDataset to use server for batch generation and scoring * fix: Add missing List and Dict type imports * feat: Add get_batch() and score_outputs() methods to RGClient * test: Add unit tests for generate_batch and score_outputs endpoints * refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor * feat: Add validation for base_index and batch_size in generate_batch endpoint * refactor: Remove unused BatchRequest type from imports * refactor: Convert models to use Pydantic exclusively * test: Update scoring endpoint tests to use correct request model format * refactor: Rename ScoreItem to AnswerItem and update related code * feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids * fix: Add missing ScoringResponse import in server.py * move verl ppo sample with server into own file * refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient * refactor: Update client methods to use Pydantic models for type safety * refactor: Use Pydantic models for experiment and dataset config operations * refactor: Clean up duplicate methods and improve error handling in main.py * first bits of rg server use for verl * refactor: Optimize scoring with single HTTP request in _score_output * fix: Correct experiment creation with ExperimentCreate object * grpo tests with server
This commit is contained in:
9
examples/veRL/launch_on_2gpu_server.sh
Executable file
9
examples/veRL/launch_on_2gpu_server.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
export N_GPUS=2
|
||||
export BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
|
||||
export ROLLOUT_TP_SIZE=2
|
||||
export EXPERIMENT_NAME=chain_sum_llama
|
||||
export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
|
||||
bash ./train_grpo_server.sh
|
||||
344
examples/veRL/main_ppo_custom_reward_server.py
Normal file
344
examples/veRL/main_ppo_custom_reward_server.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# This example is an adapted version of Bytedance's code:
|
||||
# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import hydra
|
||||
import ray
|
||||
import torch
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
||||
from verl.utils.dataset.rl_dataset import collate_fn
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
|
||||
import reasoning_gym
|
||||
import reasoning_gym.utils
|
||||
from reasoning_gym.utils import extract_answer
|
||||
from tools.server.models import AnswerItem, BatchEntry, ExperimentCreate
|
||||
|
||||
|
||||
class ReasoningGymDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
dataset_name: str,
|
||||
seed: int,
|
||||
size: int,
|
||||
developer_prompt: Optional[str] = None,
|
||||
developer_role: str = "system",
|
||||
max_prompt_length: int = 2048,
|
||||
truncation: str = "error", ## ['left', 'right', 'error']
|
||||
return_raw_chat: bool = False,
|
||||
server_url: str = "http://localhost:8000",
|
||||
api_key: Optional[str] = None,
|
||||
batch_size: int = 32,
|
||||
):
|
||||
from tools.cli.rgc.client import RGClient
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset_name = dataset_name
|
||||
self.developer_prompt = developer_prompt
|
||||
self.developer_role = developer_role
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.truncation = truncation
|
||||
self.return_raw_chat = return_raw_chat
|
||||
self.size = size
|
||||
self.batch_size = batch_size
|
||||
|
||||
# Initialize client and create experiment if needed
|
||||
self.client = RGClient(base_url=server_url, api_key=api_key)
|
||||
|
||||
# Check if experiment exists, create if not
|
||||
experiments = self.client.list_experiments()
|
||||
if dataset_name not in experiments.experiments:
|
||||
config = ExperimentCreate(
|
||||
name=dataset_name,
|
||||
size=size,
|
||||
seed=seed,
|
||||
datasets={dataset_name: {"weight": 1.0, "config": {"seed": seed, "size": size}}},
|
||||
)
|
||||
self.client.create_experiment(dataset_name, config)
|
||||
|
||||
# Cache for batches
|
||||
self._batch_cache: dict[int, List[BatchEntry]] = {}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.size
|
||||
|
||||
def _get_batch(self, batch_idx: int) -> List[BatchEntry]:
|
||||
"""Fetch or retrieve cached batch"""
|
||||
if batch_idx not in self._batch_cache:
|
||||
base_index = batch_idx * self.batch_size
|
||||
response = self.client.get_batch(self.dataset_name, base_index=base_index, batch_size=self.batch_size)
|
||||
self._batch_cache[batch_idx] = response.entries
|
||||
|
||||
# # Basic cache management - keep only last N batches
|
||||
# if len(self._batch_cache) > 10:
|
||||
# oldest_batch = min(self._batch_cache.keys())
|
||||
# del self._batch_cache[oldest_batch]
|
||||
|
||||
return self._batch_cache[batch_idx]
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Get batch containing this index
|
||||
batch_idx = index // self.batch_size
|
||||
|
||||
batch = self._get_batch(batch_idx)
|
||||
entry = batch[index % self.batch_size]
|
||||
|
||||
# Format chat/prompt
|
||||
chat = []
|
||||
if self.developer_prompt is not None:
|
||||
chat.append({"role": self.developer_role, "content": self.developer_prompt})
|
||||
chat.append({"role": "user", "content": entry.question})
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
# Tokenize
|
||||
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
|
||||
prompt=prompt,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.max_prompt_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
left_pad=True,
|
||||
truncation=self.truncation,
|
||||
)
|
||||
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
row_dict = {
|
||||
"data_source": "reasoning_gym/" + self.dataset_name,
|
||||
"input_ids": input_ids[0],
|
||||
"attention_mask": attention_mask[0],
|
||||
"position_ids": position_ids[0],
|
||||
"entry_id": entry.entry_id,
|
||||
"metadata": entry.metadata,
|
||||
"index": index,
|
||||
}
|
||||
|
||||
# Add raw chat if requested
|
||||
if self.return_raw_chat:
|
||||
row_dict["raw_prompt"] = chat
|
||||
|
||||
return row_dict
|
||||
|
||||
|
||||
class RayPPOTrainerCustom(RayPPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
tokenizer,
|
||||
role_worker_mapping: dict,
|
||||
resource_pool_manager,
|
||||
ray_worker_group_cls,
|
||||
dataset_name: str = "chain_sum",
|
||||
dataset_size: int = 10000,
|
||||
):
|
||||
self.dataset_name = dataset_name
|
||||
self.dataset_size = dataset_size
|
||||
|
||||
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
rg_api_key = os.getenv("REASONING_GYM_API_KEY", "your-secret-key")
|
||||
self.train_dataset = ReasoningGymDataset(
|
||||
tokenizer=tokenizer,
|
||||
dataset_name=self.dataset_name,
|
||||
seed=1,
|
||||
size=self.dataset_size,
|
||||
developer_prompt=developer_prompt,
|
||||
api_key=rg_api_key,
|
||||
)
|
||||
|
||||
self.val_dataset = ReasoningGymDataset(
|
||||
tokenizer=tokenizer,
|
||||
dataset_name=self.dataset_name,
|
||||
seed=2,
|
||||
size=self.dataset_size,
|
||||
developer_prompt=developer_prompt,
|
||||
api_key=rg_api_key,
|
||||
)
|
||||
|
||||
train_reward_fn = lambda data: self._score_output(data, num_examine=0)
|
||||
val_reward_fn = lambda data: self._score_output(data, num_examine=1)
|
||||
|
||||
super().__init__(
|
||||
config,
|
||||
tokenizer,
|
||||
role_worker_mapping,
|
||||
resource_pool_manager,
|
||||
ray_worker_group_cls,
|
||||
train_reward_fn,
|
||||
val_reward_fn,
|
||||
)
|
||||
|
||||
def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor:
|
||||
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
|
||||
|
||||
# Prepare batch of answers to score
|
||||
answer_items = []
|
||||
valid_response_lengths = []
|
||||
sequences_strs = []
|
||||
|
||||
for i in range(len(data)):
|
||||
data_item = data[i]
|
||||
|
||||
# Get prompt and response
|
||||
prompt_ids = data_item.batch["prompts"]
|
||||
prompt_length = prompt_ids.shape[-1]
|
||||
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
|
||||
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
|
||||
|
||||
response_ids = data_item.batch["responses"]
|
||||
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
|
||||
valid_response_ids = response_ids[:valid_response_length]
|
||||
valid_response_lengths.append(valid_response_length)
|
||||
|
||||
# Decode full sequence
|
||||
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
|
||||
sequences_str = self.tokenizer.decode(sequences)
|
||||
sequences_strs.append(sequences_str)
|
||||
|
||||
# Extract answer and prepare scoring item
|
||||
found_answer = extract_answer(sequences_str, tag_name="answer")
|
||||
|
||||
index = data_item.non_tensor_batch["index"]
|
||||
entry_id = self.train_dataset[index]["entry_id"]
|
||||
# print(
|
||||
# "found_answer",
|
||||
# entry_id,
|
||||
# found_answer,
|
||||
# )
|
||||
|
||||
answer_items.append(AnswerItem(entry_id=entry_id, answer=found_answer))
|
||||
|
||||
# Score all answers in one request
|
||||
response = self.train_dataset.client.score_outputs(self.train_dataset.dataset_name, answer_items)
|
||||
# print("response", response)
|
||||
|
||||
# Fill reward tensor
|
||||
for i, (score, valid_response_length) in enumerate(zip(response.scores, valid_response_lengths)):
|
||||
reward_tensor[i, valid_response_length - 1] = score
|
||||
|
||||
if i < num_examine:
|
||||
print(f"reward={score}, seq={sequences_strs[i]}")
|
||||
|
||||
return reward_tensor
|
||||
|
||||
def _create_dataloader(self):
|
||||
self.train_dataloader = DataLoader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=self.config.data.train_batch_size,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
self.val_dataloader = DataLoader(
|
||||
dataset=self.val_dataset,
|
||||
batch_size=len(self.val_dataset),
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
assert len(self.train_dataloader) >= 1
|
||||
assert len(self.val_dataloader) >= 1
|
||||
|
||||
print(f"Size of train dataloader: {len(self.train_dataloader)}")
|
||||
print(f"Size of val dataloader: {len(self.val_dataloader)}")
|
||||
|
||||
# inject total_training_steps to actor/critic optim_config. This is hacky.
|
||||
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
||||
|
||||
if self.config.trainer.total_training_steps is not None:
|
||||
total_training_steps = self.config.trainer.total_training_steps
|
||||
|
||||
self.total_training_steps = total_training_steps
|
||||
print(f"Total training steps: {self.total_training_steps}")
|
||||
|
||||
OmegaConf.set_struct(self.config, True)
|
||||
with open_dict(self.config):
|
||||
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
|
||||
self.config.critic.optim.total_training_steps = total_training_steps
|
||||
|
||||
|
||||
@ray.remote
|
||||
def main_task(config):
|
||||
# print initial config
|
||||
from pprint import pprint
|
||||
|
||||
from verl.utils import hf_tokenizer
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
# define worker classes
|
||||
if config.actor_rollout_ref.actor.strategy == "fsdp":
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
|
||||
|
||||
ray_worker_group_cls = RayWorkerGroup
|
||||
|
||||
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
|
||||
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
||||
|
||||
ray_worker_group_cls = NVMegatronRayWorkerGroup
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
||||
|
||||
role_worker_mapping = {
|
||||
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
||||
Role.Critic: ray.remote(CriticWorker),
|
||||
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
|
||||
}
|
||||
|
||||
global_pool_id = "global_pool"
|
||||
resource_pool_spec = {
|
||||
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
||||
}
|
||||
mapping = {
|
||||
Role.ActorRollout: global_pool_id,
|
||||
Role.Critic: global_pool_id,
|
||||
Role.RefPolicy: global_pool_id,
|
||||
}
|
||||
|
||||
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
||||
|
||||
trainer = RayPPOTrainerCustom(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
role_worker_mapping=role_worker_mapping,
|
||||
resource_pool_manager=resource_pool_manager,
|
||||
ray_worker_group_cls=ray_worker_group_cls,
|
||||
)
|
||||
trainer.init_workers()
|
||||
trainer.fit()
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
|
||||
def main(config):
|
||||
if not ray.is_initialized():
|
||||
# this is for local ray cluster
|
||||
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
|
||||
|
||||
ray.get(main_task.remote(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
39
examples/veRL/train_grpo_server.sh
Normal file
39
examples/veRL/train_grpo_server.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
python3 -u main_ppo_custom_reward_server.py \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files=$DATA_DIR/train.parquet \
|
||||
data.val_files=$DATA_DIR/test.parquet \
|
||||
data.train_batch_size=32 \
|
||||
data.val_batch_size=32 \
|
||||
data.max_prompt_length=512 \
|
||||
data.max_response_length=1024 \
|
||||
actor_rollout_ref.model.path=$BASE_MODEL \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.n=8 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=['console'] \
|
||||
trainer.project_name='verl_chain_sum_grpo' \
|
||||
trainer.experiment_name=$EXPERIMENT_NAME \
|
||||
trainer.n_gpus_per_node=$N_GPUS \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=100 \
|
||||
trainer.test_freq=100 \
|
||||
trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log
|
||||
@@ -32,7 +32,22 @@ license = "Apache-2.0"
|
||||
license-files = ["LICENSE*"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"]
|
||||
test = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"httpx>=0.27.0"
|
||||
]
|
||||
server = [
|
||||
"fastapi>=0.109.0",
|
||||
"uvicorn>=0.27.0",
|
||||
"pydantic-settings>=2.1.0",
|
||||
]
|
||||
cli = [
|
||||
"typer>=0.9.0",
|
||||
"rich>=13.7.0",
|
||||
"pyyaml>=6.0.1",
|
||||
"httpx>=0.27.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/open-thought/reasoning-gym"
|
||||
@@ -40,12 +55,19 @@ test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"]
|
||||
|
||||
|
||||
[tool.hatch.build]
|
||||
packages = ["reasoning_gym"]
|
||||
include = [
|
||||
"reasoning_gym/**/*.py",
|
||||
"reasoning_gym/**/*.txt",
|
||||
"reasoning_gym/**/levels/*",
|
||||
packages = [
|
||||
"reasoning_gym",
|
||||
"tools.cli.rgc"
|
||||
]
|
||||
include = [
|
||||
"reasoning_gym/**/*.py",
|
||||
"reasoning_gym/**/*.txt",
|
||||
"reasoning_gym/**/levels/*",
|
||||
"tools/cli/rgc/**/*.py"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
rgc = "tools.cli.rgc.main:main"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
36
reasoning_gym/coaching/experiment.py
Normal file
36
reasoning_gym/coaching/experiment.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Experiment class combining dataset, scoreboard and curriculum."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ..composite import CompositeConfig, CompositeDataset
|
||||
from ..version_manager import DatasetVersionManager
|
||||
from .coach import ScoreBoard
|
||||
|
||||
|
||||
@dataclass
|
||||
class Experiment:
|
||||
"""
|
||||
An experiment combines a dataset with scoring and curriculum management.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the experiment
|
||||
dataset: The composite dataset for generating examples
|
||||
score_board: Tracks performance metrics
|
||||
config: The configuration used to create the dataset
|
||||
version_manager: Manages dataset versions for scoring
|
||||
"""
|
||||
|
||||
name: str
|
||||
dataset: CompositeDataset
|
||||
score_board: ScoreBoard
|
||||
config: CompositeConfig
|
||||
version_manager: DatasetVersionManager
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, config: CompositeConfig) -> "Experiment":
|
||||
"""Create a new experiment from a configuration."""
|
||||
version_manager = DatasetVersionManager()
|
||||
dataset = CompositeDataset(config, version_manager=version_manager)
|
||||
score_board = ScoreBoard()
|
||||
return cls(name=name, dataset=dataset, score_board=score_board, config=config, version_manager=version_manager)
|
||||
34
reasoning_gym/coaching/registry.py
Normal file
34
reasoning_gym/coaching/registry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Registry for managing active experiments."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ..composite import CompositeConfig
|
||||
from .experiment import Experiment
|
||||
|
||||
|
||||
class ExperimentRegistry:
|
||||
"""Singleton registry for managing active experiments."""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._experiments = {}
|
||||
return cls._instance
|
||||
|
||||
def register_experiment(self, name: str, config: CompositeConfig) -> None:
|
||||
"""Register a new experiment with the given name and configuration."""
|
||||
self._experiments[name] = Experiment.create(name, config)
|
||||
|
||||
def get_experiment(self, name: str) -> Optional[Experiment]:
|
||||
"""Get an experiment by name."""
|
||||
return self._experiments.get(name)
|
||||
|
||||
def list_experiments(self) -> List[str]:
|
||||
"""List all registered experiment names."""
|
||||
return list(self._experiments.keys())
|
||||
|
||||
def remove_experiment(self, name: str) -> bool:
|
||||
"""Remove an experiment by name. Returns True if removed, False if not found."""
|
||||
return bool(self._experiments.pop(name, None))
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from random import Random
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -6,6 +6,7 @@ import yaml
|
||||
|
||||
from .dataset import ProceduralDataset
|
||||
from .factory import create_dataset, register_dataset
|
||||
from .version_manager import DatasetVersionManager
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -37,6 +38,11 @@ class CompositeConfig:
|
||||
assert self.datasets, "Must specify at least one dataset"
|
||||
assert len(self.datasets) > 0, "Must specify at least one dataset"
|
||||
|
||||
# Check for duplicate dataset names
|
||||
dataset_names = [ds.name for ds in self.datasets]
|
||||
if len(dataset_names) != len(set(dataset_names)):
|
||||
raise ValueError("Duplicate dataset names are not allowed in CompositeDataset")
|
||||
|
||||
# Validate each dataset spec
|
||||
for ds in self.datasets:
|
||||
ds.validate()
|
||||
@@ -57,13 +63,14 @@ class CompositeConfig:
|
||||
class CompositeDataset(ProceduralDataset):
|
||||
"""A dataset that combines multiple datasets with weighted sampling"""
|
||||
|
||||
def __init__(self, config: CompositeConfig):
|
||||
def __init__(self, config: CompositeConfig, version_manager: Optional[DatasetVersionManager] = None):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.version_manager = version_manager
|
||||
self.dataset_versions = {} # dataset_name -> version_id
|
||||
|
||||
# Initialize sub-datasets with incremented seeds
|
||||
self.datasets = {}
|
||||
self.weights = []
|
||||
total_weight = 0.0
|
||||
|
||||
for i, ds_spec in enumerate(config.datasets):
|
||||
# Create dataset with derived seed
|
||||
@@ -73,12 +80,18 @@ class CompositeDataset(ProceduralDataset):
|
||||
if "size" not in ds_config:
|
||||
ds_config["size"] = self.size
|
||||
|
||||
self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
|
||||
total_weight += ds_spec.weight
|
||||
self.weights.append(ds_spec.weight)
|
||||
if ds_spec.weight < 0:
|
||||
raise ValueError(f"Dataset '{ds_spec.name}' has invalid weight {ds_spec.weight}, must be non-negative")
|
||||
|
||||
# Normalize weights
|
||||
self.weights = [w / total_weight for w in self.weights]
|
||||
dataset = create_dataset(ds_spec.name, **ds_config)
|
||||
self.datasets[ds_spec.name] = dataset
|
||||
|
||||
# Register version if tracking enabled
|
||||
if version_manager is not None:
|
||||
version_id = version_manager.register_dataset(ds_spec.name, dataset)
|
||||
self.dataset_versions[ds_spec.name] = version_id
|
||||
|
||||
self.weights.append(ds_spec.weight) # Store unnormalized weights directly
|
||||
self.dataset_names = [ds.name for ds in config.datasets]
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
@@ -98,6 +111,13 @@ class CompositeDataset(ProceduralDataset):
|
||||
item["metadata"]["source_dataset"] = dataset_name
|
||||
item["metadata"]["source_index"] = idx
|
||||
|
||||
# Add version info if tracking enabled
|
||||
if self.version_manager is not None:
|
||||
version_id = self.dataset_versions[dataset_name]
|
||||
item["metadata"]["version_id"] = version_id
|
||||
# Add entry_id combining version and index
|
||||
item["metadata"]["entry_id"] = f"{version_id}.{idx}"
|
||||
|
||||
return item
|
||||
|
||||
def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None:
|
||||
@@ -116,23 +136,151 @@ class CompositeDataset(ProceduralDataset):
|
||||
|
||||
dataset = self.datasets[dataset_name]
|
||||
|
||||
# Create new config with updates
|
||||
new_config = dataset.config.__class__(**vars(dataset.config))
|
||||
for key, value in config_updates.items():
|
||||
setattr(new_config, key, value)
|
||||
# Update the current config
|
||||
new_config = replace(dataset.config, **config_updates)
|
||||
|
||||
# Validate new config
|
||||
new_config.validate()
|
||||
|
||||
# Create new dataset instance with updated config
|
||||
dataset_cls = dataset.__class__
|
||||
self.datasets[dataset_name] = dataset_cls(new_config)
|
||||
new_dataset = dataset_cls(new_config)
|
||||
self.datasets[dataset_name] = new_dataset
|
||||
|
||||
# Register new version if tracking enabled
|
||||
if self.version_manager is not None:
|
||||
version_id = self.version_manager.register_dataset(dataset_name, new_dataset)
|
||||
self.dataset_versions[dataset_name] = version_id
|
||||
|
||||
def update_dataset_weight(self, dataset_name: str, weight: float) -> None:
|
||||
"""Update weight for a specific dataset in the configuration
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to update
|
||||
weight: New weight value
|
||||
|
||||
Raises:
|
||||
KeyError: If dataset_name not found
|
||||
ValueError: If weight is negative
|
||||
"""
|
||||
if dataset_name not in self.datasets:
|
||||
raise KeyError(f"Dataset '{dataset_name}' not found")
|
||||
if weight < 0:
|
||||
raise ValueError(f"Weight must be non-negative, got {weight}")
|
||||
|
||||
# Update weight in both config and weights list
|
||||
for i, ds_spec in enumerate(self.config.datasets):
|
||||
if ds_spec.name == dataset_name:
|
||||
ds_spec.weight = weight
|
||||
self.weights[i] = weight
|
||||
break
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
"""Forward scoring to appropriate dataset"""
|
||||
dataset_name = entry["metadata"]["source_dataset"]
|
||||
return self.datasets[dataset_name].score_answer(answer, entry)
|
||||
|
||||
def add_dataset(self, dataset_spec: DatasetSpec) -> None:
|
||||
"""Add a new dataset to the composite
|
||||
|
||||
Args:
|
||||
dataset_spec: Specification for the dataset to add
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset name already exists
|
||||
"""
|
||||
# Validate spec
|
||||
dataset_spec.validate()
|
||||
|
||||
# Check for duplicate name
|
||||
if dataset_spec.name in self.datasets:
|
||||
raise ValueError(f"Dataset '{dataset_spec.name}' already exists in composite")
|
||||
|
||||
# Create dataset with derived seed
|
||||
ds_config = dataset_spec.config.copy()
|
||||
if "seed" not in ds_config:
|
||||
ds_config["seed"] = self.seed + len(self.datasets) + 1
|
||||
if "size" not in ds_config:
|
||||
ds_config["size"] = self.size
|
||||
|
||||
# Create and add dataset
|
||||
dataset = create_dataset(dataset_spec.name, **ds_config)
|
||||
self.datasets[dataset_spec.name] = dataset
|
||||
|
||||
# Register version if tracking enabled
|
||||
if self.version_manager is not None:
|
||||
version_id = self.version_manager.register_dataset(dataset_spec.name, dataset)
|
||||
self.dataset_versions[dataset_spec.name] = version_id
|
||||
|
||||
# Add to config and update internal state
|
||||
self.config.datasets.append(dataset_spec)
|
||||
self.dataset_names.append(dataset_spec.name)
|
||||
self.weights.append(dataset_spec.weight) # Use weight directly from spec
|
||||
|
||||
def remove_dataset(self, dataset_name: str) -> None:
|
||||
"""Remove a dataset from the composite
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to remove
|
||||
|
||||
Raises:
|
||||
KeyError: If dataset not found
|
||||
ValueError: If trying to remove last dataset
|
||||
"""
|
||||
if dataset_name not in self.datasets:
|
||||
raise KeyError(f"Dataset '{dataset_name}' not found")
|
||||
|
||||
if len(self.datasets) <= 1:
|
||||
raise ValueError("Cannot remove last dataset from composite")
|
||||
|
||||
# Remove from all internal structures
|
||||
del self.datasets[dataset_name]
|
||||
if self.version_manager is not None:
|
||||
del self.dataset_versions[dataset_name]
|
||||
|
||||
# Remove from config
|
||||
self.config.datasets = [ds for ds in self.config.datasets if ds.name != dataset_name]
|
||||
|
||||
# Update internal state
|
||||
idx = self.dataset_names.index(dataset_name)
|
||||
self.dataset_names.pop(idx)
|
||||
self.weights.pop(idx)
|
||||
|
||||
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
|
||||
"""Score an answer using an entry_id to lookup the original entry
|
||||
|
||||
Args:
|
||||
answer: The answer to score
|
||||
entry_id: String in format "version_id.index"
|
||||
|
||||
Returns:
|
||||
Score between 0 and 1
|
||||
|
||||
Raises:
|
||||
ValueError: If entry_id format is invalid
|
||||
KeyError: If version not found in version manager
|
||||
"""
|
||||
if self.version_manager is None:
|
||||
raise RuntimeError("Version manager required for scoring with entry_id")
|
||||
|
||||
try:
|
||||
version_id, index = map(int, entry_id.split("."))
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid entry_id format: {entry_id}, expected 'version_id.index'")
|
||||
|
||||
# Get dataset from version manager
|
||||
dataset_info = self.version_manager.get_dataset(version_id)
|
||||
if dataset_info is None:
|
||||
raise KeyError(f"Version {version_id} not found in version manager")
|
||||
|
||||
dataset_name, dataset = dataset_info
|
||||
|
||||
# Get entry from dataset
|
||||
entry = dataset[index]
|
||||
|
||||
# Score answer using dataset's scoring function
|
||||
return dataset.score_answer(answer, entry)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("composite", CompositeDataset, CompositeConfig)
|
||||
|
||||
76
reasoning_gym/version_manager.py
Normal file
76
reasoning_gym/version_manager.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Version manager for tracking dataset versions."""
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .dataset import ProceduralDataset
|
||||
|
||||
|
||||
class DatasetVersionManager:
|
||||
"""Manages versioned ProceduralDataset instances and their configurations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the version manager."""
|
||||
self.current_version = 0
|
||||
# version_id -> (dataset_name, dataset_instance)
|
||||
self.datasets: Dict[int, Tuple[str, ProceduralDataset]] = {}
|
||||
|
||||
def register_dataset(self, name: str, dataset: ProceduralDataset) -> int:
|
||||
"""
|
||||
Register a new dataset version.
|
||||
|
||||
Args:
|
||||
name: Name/identifier of the dataset type
|
||||
dataset: Instance of ProceduralDataset
|
||||
|
||||
Returns:
|
||||
version_id: Unique identifier for this dataset version
|
||||
"""
|
||||
self.current_version += 1
|
||||
self.datasets[self.current_version] = (name, dataset)
|
||||
return self.current_version
|
||||
|
||||
def get_dataset(self, version_id: int) -> Optional[Tuple[str, ProceduralDataset]]:
|
||||
"""
|
||||
Retrieve a dataset by its version ID.
|
||||
|
||||
Args:
|
||||
version_id: The version identifier
|
||||
|
||||
Returns:
|
||||
Tuple of (dataset_name, dataset_instance) if found, None otherwise
|
||||
"""
|
||||
return self.datasets.get(version_id)
|
||||
|
||||
def get_entry(self, version_id: int, index: int) -> Dict[str, any]:
|
||||
"""
|
||||
Get a specific entry from a versioned dataset.
|
||||
|
||||
Args:
|
||||
version_id: The version identifier
|
||||
index: Index of the entry to retrieve
|
||||
|
||||
Returns:
|
||||
The dataset entry
|
||||
|
||||
Raises:
|
||||
KeyError: If version_id is not found
|
||||
"""
|
||||
if version_id not in self.datasets:
|
||||
raise KeyError(f"Dataset version {version_id} not found")
|
||||
|
||||
_, dataset = self.datasets[version_id]
|
||||
return dataset[index]
|
||||
|
||||
def cleanup_old_versions(self, keep_latest: int = 10):
|
||||
"""
|
||||
Remove old dataset versions to free memory.
|
||||
|
||||
Args:
|
||||
keep_latest: Number of most recent versions to keep
|
||||
"""
|
||||
if len(self.datasets) <= keep_latest:
|
||||
return
|
||||
|
||||
versions_to_remove = sorted(self.datasets.keys())[:-keep_latest]
|
||||
for version in versions_to_remove:
|
||||
del self.datasets[version]
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import yaml
|
||||
|
||||
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
||||
from reasoning_gym.version_manager import DatasetVersionManager
|
||||
|
||||
|
||||
def create_test_config(tmp_path):
|
||||
@@ -85,13 +86,165 @@ def test_composite_dataset_weights():
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
|
||||
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
|
||||
DatasetSpec("products", 3.0, {"min_terms": 2}),
|
||||
],
|
||||
)
|
||||
|
||||
dataset = CompositeDataset(config)
|
||||
assert abs(dataset.weights[0] - 0.4) < 1e-6
|
||||
assert abs(dataset.weights[1] - 0.6) < 1e-6
|
||||
assert abs(dataset.weights[0] - 2.0) < 1e-6
|
||||
assert abs(dataset.weights[1] - 3.0) < 1e-6
|
||||
|
||||
# Test weight updates
|
||||
dataset.update_dataset_weight("chain_sum", 1.0)
|
||||
print(dataset.weights)
|
||||
assert abs(dataset.weights[0] - 1.0) < 1e-6
|
||||
assert abs(dataset.weights[1] - 3.0) < 1e-6
|
||||
|
||||
# Test invalid weight
|
||||
with pytest.raises(ValueError, match="Weight must be non-negative"):
|
||||
dataset.update_dataset_weight("chain_sum", -1.0)
|
||||
|
||||
# Test invalid dataset name
|
||||
with pytest.raises(KeyError):
|
||||
dataset.update_dataset_weight("invalid_dataset", 1.0)
|
||||
|
||||
# Test zero total weight
|
||||
dataset.update_dataset_weight("chain_sum", 0.0)
|
||||
with pytest.raises(ValueError, match="Total of weights must be greater than zero"):
|
||||
dataset.update_dataset_weight("products", 0.0)
|
||||
_ = dataset[0] # access item with all weights 0
|
||||
|
||||
# Test duplicate dataset names
|
||||
with pytest.raises(ValueError, match="Duplicate dataset names"):
|
||||
CompositeConfig(
|
||||
size=1000,
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
|
||||
DatasetSpec("chain_sum", 1.0, {"min_terms": 3}),
|
||||
],
|
||||
).validate()
|
||||
|
||||
|
||||
def test_version_tracking_with_config_updates():
|
||||
"""Test that version tracking works correctly when updating dataset configs"""
|
||||
# Create composite dataset with version manager
|
||||
version_manager = DatasetVersionManager()
|
||||
config = CompositeConfig(
|
||||
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
||||
)
|
||||
dataset = CompositeDataset(config, version_manager=version_manager)
|
||||
|
||||
# Get an entry and its id from initial version
|
||||
entry_1 = dataset[0]
|
||||
entry_id_1 = entry_1["metadata"]["entry_id"]
|
||||
answer_1 = entry_1["answer"]
|
||||
|
||||
# Update dataset config
|
||||
dataset.update_dataset_config("chain_sum", {"min_terms": 3, "max_terms": 5})
|
||||
|
||||
# Get new entry after config update
|
||||
entry_2 = dataset[0]
|
||||
entry_id_2 = entry_2["metadata"]["entry_id"]
|
||||
answer_2 = entry_2["answer"]
|
||||
|
||||
# Verify entries have different version IDs
|
||||
version_1 = int(entry_id_1.split(".")[0])
|
||||
version_2 = int(entry_id_2.split(".")[0])
|
||||
assert version_1 != version_2, "New config should create new version"
|
||||
|
||||
# Verify original answer still works with original version
|
||||
score_1 = dataset.score_answer_with_id(answer_1, entry_id_1)
|
||||
assert score_1 == 1.0, "Original answer should still work with original version"
|
||||
|
||||
# Verify new answer works with new version
|
||||
score_2 = dataset.score_answer_with_id(answer_2, entry_id_2)
|
||||
assert score_2 == 1.0, "New answer should work with new version"
|
||||
|
||||
# Verify original answer fails with new version
|
||||
score_3 = dataset.score_answer_with_id(answer_1, entry_id_2)
|
||||
assert score_3 < 1.0, "Original answer should not work with new version"
|
||||
|
||||
|
||||
def test_score_answer_with_id():
|
||||
"""Test scoring answers using entry_id"""
|
||||
# Create composite dataset with version manager
|
||||
version_manager = DatasetVersionManager()
|
||||
config = CompositeConfig(
|
||||
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
||||
)
|
||||
dataset = CompositeDataset(config, version_manager=version_manager)
|
||||
|
||||
# Get an entry and its id
|
||||
entry = dataset[0]
|
||||
entry_id = entry["metadata"]["entry_id"]
|
||||
|
||||
# Test successful scoring
|
||||
answer = entry["answer"]
|
||||
score = dataset.score_answer_with_id(answer, entry_id)
|
||||
assert score == 1.0 # Correct answer should get full score
|
||||
|
||||
# Test wrong answer
|
||||
wrong_answer = "wrong"
|
||||
score = dataset.score_answer_with_id(wrong_answer, entry_id)
|
||||
assert score < 1.0 # Wrong answer should get lower score
|
||||
|
||||
# Test invalid entry_id format
|
||||
with pytest.raises(ValueError, match="Invalid entry_id format"):
|
||||
dataset.score_answer_with_id(answer, "invalid")
|
||||
|
||||
# Test non-existent version
|
||||
with pytest.raises(KeyError, match="Version .* not found"):
|
||||
dataset.score_answer_with_id(answer, "999.0")
|
||||
|
||||
# Test without version manager
|
||||
dataset_no_vm = CompositeDataset(config)
|
||||
with pytest.raises(RuntimeError, match="Version manager required"):
|
||||
dataset_no_vm.score_answer_with_id(answer, entry_id)
|
||||
|
||||
|
||||
def test_add_remove_dataset():
|
||||
"""Test adding and removing datasets from composite"""
|
||||
config = CompositeConfig(
|
||||
size=1000,
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
|
||||
],
|
||||
)
|
||||
|
||||
dataset = CompositeDataset(config)
|
||||
|
||||
# Test adding new dataset
|
||||
new_spec = DatasetSpec("products", 2.0, {"min_terms": 2})
|
||||
dataset.add_dataset(new_spec)
|
||||
|
||||
assert len(dataset.datasets) == 2
|
||||
assert "products" in dataset.datasets
|
||||
assert len(dataset.config.datasets) == 2
|
||||
|
||||
assert dataset.dataset_names[0] == "chain_sum"
|
||||
assert dataset.dataset_names[1] == "products"
|
||||
assert abs(dataset.weights[0] - 1.0) < 1e-6 # chain_sum weight
|
||||
assert abs(dataset.weights[1] - 2.0) < 1e-6 # products weight
|
||||
|
||||
# Test duplicate name
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
dataset.add_dataset(new_spec)
|
||||
|
||||
# Test removing dataset
|
||||
dataset.remove_dataset("products")
|
||||
assert len(dataset.datasets) == 1
|
||||
assert "products" not in dataset.datasets
|
||||
assert len(dataset.config.datasets) == 1
|
||||
|
||||
# Test removing non-existent dataset
|
||||
with pytest.raises(KeyError):
|
||||
dataset.remove_dataset("nonexistent")
|
||||
|
||||
# Test removing last dataset
|
||||
with pytest.raises(ValueError, match="Cannot remove last dataset"):
|
||||
dataset.remove_dataset("chain_sum")
|
||||
|
||||
|
||||
def test_yaml_loading(tmp_path):
|
||||
|
||||
83
tools/README.md
Normal file
83
tools/README.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Reasoning Gym Tools
|
||||
|
||||
This directory contains additional tools for working with Reasoning Gym:
|
||||
|
||||
## Server
|
||||
|
||||
A FastAPI server that manages reasoning gym experiments, allowing runtime configuration and monitoring.
|
||||
|
||||
### Starting the Server
|
||||
|
||||
1. Install server dependencies:
|
||||
```bash
|
||||
pip install -e ".[server]"
|
||||
```
|
||||
|
||||
2. Set the API key environment variable:
|
||||
```bash
|
||||
export REASONING_GYM_API_KEY=your-secret-key
|
||||
```
|
||||
|
||||
3. Start the server:
|
||||
```bash
|
||||
uvicorn tools.server.server:app
|
||||
```
|
||||
|
||||
The server will be available at http://localhost:8000. You can access the API documentation at http://localhost:8000/docs.
|
||||
|
||||
## RGC (Reasoning Gym Client)
|
||||
|
||||
A command-line interface for interacting with the Reasoning Gym server.
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install -e ".[cli]"
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
First, set the API key to match your server:
|
||||
```bash
|
||||
export REASONING_GYM_API_KEY=your-secret-key
|
||||
```
|
||||
|
||||
Then you can use the CLI:
|
||||
|
||||
```bash
|
||||
# List all commands
|
||||
rgc --help
|
||||
|
||||
# List experiments
|
||||
rgc experiments list
|
||||
|
||||
# Create a new experiment interactively
|
||||
rgc experiments create my-experiment
|
||||
|
||||
# Create from config file
|
||||
rgc experiments create my-experiment -f config.yaml
|
||||
|
||||
# Show experiment details
|
||||
rgc experiments show my-experiment
|
||||
|
||||
# Edit dataset configuration
|
||||
rgc config edit my-experiment chain_sum
|
||||
```
|
||||
|
||||
### Example Configuration File
|
||||
|
||||
Here's an example `config.yaml` for creating an experiment:
|
||||
|
||||
```yaml
|
||||
size: 500
|
||||
seed: 42
|
||||
datasets:
|
||||
chain_sum:
|
||||
weight: 1.0
|
||||
config:
|
||||
min_terms: 2
|
||||
max_terms: 4
|
||||
min_digits: 1
|
||||
max_digits: 2
|
||||
allow_negation: false
|
||||
```
|
||||
0
tools/cli/__init__.py
Normal file
0
tools/cli/__init__.py
Normal file
5
tools/cli/rgc/__init__.py
Normal file
5
tools/cli/rgc/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Reasoning Gym CLI tool."""
|
||||
|
||||
from .main import main
|
||||
|
||||
__all__ = ["main"]
|
||||
125
tools/cli/rgc/client.py
Normal file
125
tools/cli/rgc/client.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""HTTP client for interacting with the Reasoning Gym server."""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from tools.server.models import (
|
||||
AnswerItem,
|
||||
BatchResponse,
|
||||
DatasetConfigUpdate,
|
||||
ExperimentCreate,
|
||||
ExperimentList,
|
||||
ExperimentResponse,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
DEFAULT_SERVER = "http://localhost:8000"
|
||||
API_KEY = os.getenv("REASONING_GYM_API_KEY", "default-key")
|
||||
|
||||
|
||||
class RGClient:
|
||||
"""Client for interacting with Reasoning Gym server."""
|
||||
|
||||
def __init__(self, base_url: str = DEFAULT_SERVER, api_key: str = API_KEY):
|
||||
"""Initialize client with server URL and API key."""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.headers = {"X-API-Key": api_key}
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
"""Construct full URL for given path."""
|
||||
return f"{self.base_url}/{path.lstrip('/')}"
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""Check server health status."""
|
||||
try:
|
||||
response = httpx.get(self._url("/health"), headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()["status"] == "healthy"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_experiments(self) -> ExperimentList:
|
||||
"""List all registered experiments."""
|
||||
response = httpx.get(self._url("/experiments"), headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return ExperimentList.model_validate(response.json())
|
||||
|
||||
def create_experiment(self, name: str, config: ExperimentCreate) -> ExperimentResponse:
|
||||
"""Create a new experiment."""
|
||||
response = httpx.post(
|
||||
self._url("/experiments"),
|
||||
headers=self.headers,
|
||||
json=config.model_dump(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ExperimentResponse.model_validate(response.json())
|
||||
|
||||
def delete_experiment(self, name: str) -> None:
|
||||
"""Delete an experiment."""
|
||||
response = httpx.delete(
|
||||
self._url(f"/experiments/{name}"),
|
||||
headers=self.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def get_experiment_config(self, name: str) -> ExperimentResponse:
|
||||
"""Get experiment configuration."""
|
||||
response = httpx.get(
|
||||
self._url(f"/experiments/{name}/composite"),
|
||||
headers=self.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ExperimentResponse.model_validate(response.json())
|
||||
|
||||
def update_dataset_config(self, experiment: str, dataset: str, config: DatasetConfigUpdate) -> None:
|
||||
"""Update dataset configuration."""
|
||||
response = httpx.post(
|
||||
self._url(f"/experiments/{experiment}/composite/{dataset}"),
|
||||
headers=self.headers,
|
||||
json=config.model_dump(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def get_batch(self, experiment: str, base_index: int, batch_size: int) -> BatchResponse:
|
||||
"""Get a batch of entries from an experiment.
|
||||
|
||||
Args:
|
||||
experiment: Name of the experiment
|
||||
base_index: Starting index for the batch
|
||||
batch_size: Number of entries to retrieve
|
||||
|
||||
Returns:
|
||||
BatchResponse containing entries with questions and metadata
|
||||
"""
|
||||
response = httpx.get(
|
||||
self._url(f"/experiments/{experiment}/batch"),
|
||||
headers=self.headers,
|
||||
params={"base_index": base_index, "batch_size": batch_size},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return BatchResponse.model_validate(response.json())
|
||||
|
||||
def score_outputs(self, experiment: str, entry_answers: List[AnswerItem]) -> ScoringResponse:
|
||||
"""Score a batch of answers.
|
||||
|
||||
Args:
|
||||
experiment: Name of the experiment
|
||||
entry_answers: List of AnswerItems with entry_ids and answers to score
|
||||
|
||||
Returns:
|
||||
ScoringResponse containing scores and entry_ids
|
||||
"""
|
||||
request = ScoringRequest(answers=entry_answers)
|
||||
response = httpx.post(
|
||||
self._url(f"/experiments/{experiment}/score"),
|
||||
headers=self.headers,
|
||||
json=request.model_dump(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ScoringResponse.model_validate(response.json())
|
||||
231
tools/cli/rgc/main.py
Normal file
231
tools/cli/rgc/main.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Main entry point for the Reasoning Gym CLI."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.prompt import Confirm, Prompt
|
||||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
|
||||
from tools.server.models import DatasetConfigUpdate, ExperimentCreate
|
||||
|
||||
# Initialize Typer apps
|
||||
app = typer.Typer(
|
||||
name="rgc",
|
||||
help="Reasoning Gym CLI - Manage and monitor reasoning gym experiments",
|
||||
add_completion=True,
|
||||
)
|
||||
experiments_app = typer.Typer(help="Manage experiments")
|
||||
config_app = typer.Typer(help="Manage configurations")
|
||||
|
||||
app.add_typer(experiments_app, name="experiments")
|
||||
app.add_typer(config_app, name="config")
|
||||
|
||||
|
||||
@app.command("health")
|
||||
def check_health():
|
||||
"""Check server connection and health status."""
|
||||
try:
|
||||
if client.check_health():
|
||||
console.print("[green]Server is healthy[/]")
|
||||
else:
|
||||
console.print("[red]Server is not responding correctly[/]")
|
||||
raise typer.Exit(1)
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error connecting to server: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
# Initialize client and console
|
||||
from .client import RGClient
|
||||
|
||||
client = RGClient()
|
||||
console = Console()
|
||||
|
||||
|
||||
@experiments_app.command("list")
|
||||
def list_experiments():
|
||||
"""List all registered experiments with their status."""
|
||||
table = Table(title="Registered Experiments")
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("Datasets", style="magenta")
|
||||
table.add_column("Size", style="blue")
|
||||
table.add_column("Seed", style="green")
|
||||
|
||||
try:
|
||||
experiments = client.list_experiments()
|
||||
for exp_name in experiments.experiments:
|
||||
try:
|
||||
config = client.get_experiment_config(exp_name)
|
||||
datasets = ", ".join(config.datasets.keys())
|
||||
table.add_row(exp_name, datasets, str(config.size), str(config.seed or ""))
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Could not get config for {exp_name}: {e}[/]")
|
||||
table.add_row(exp_name, "?", "?", "?")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error listing experiments: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@experiments_app.command("create")
|
||||
def create_experiment(
|
||||
name: str = typer.Argument(..., help="Name of the experiment"),
|
||||
config_file: Optional[str] = typer.Option(None, "--file", "-f", help="YAML configuration file"),
|
||||
):
|
||||
"""Create a new experiment."""
|
||||
if config_file:
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
exp_config = yaml.safe_load(f)
|
||||
config = ExperimentCreate(**exp_config)
|
||||
response = client.create_experiment(name, config)
|
||||
console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error creating experiment: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Interactive creation
|
||||
size = Prompt.ask("Dataset size", default="500")
|
||||
seed = Prompt.ask("Random seed (optional)", default="")
|
||||
|
||||
datasets = {}
|
||||
while Confirm.ask("Add dataset?"):
|
||||
ds_name = Prompt.ask("Dataset name")
|
||||
weight = float(Prompt.ask("Weight", default="1.0"))
|
||||
|
||||
# Get dataset-specific config
|
||||
console.print("\nEnter dataset configuration:")
|
||||
config = {}
|
||||
while Confirm.ask("Add config parameter?"):
|
||||
key = Prompt.ask("Parameter name")
|
||||
value = Prompt.ask("Parameter value")
|
||||
try:
|
||||
# Try to convert to appropriate type
|
||||
if value.isdigit():
|
||||
value = int(value)
|
||||
elif value.lower() in ("true", "false"):
|
||||
value = value.lower() == "true"
|
||||
elif "." in value and value.replace(".", "").isdigit():
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
config[key] = value
|
||||
|
||||
datasets[ds_name] = {"weight": weight, "config": config}
|
||||
|
||||
# Create experiment config
|
||||
exp_config = {"name": name, "size": int(size), "seed": int(seed) if seed else None, "datasets": datasets}
|
||||
|
||||
# Show final config
|
||||
console.print("\nFinal configuration:")
|
||||
console.print(Syntax(yaml.dump(exp_config), "yaml"))
|
||||
|
||||
if Confirm.ask("Create experiment with this configuration?"):
|
||||
try:
|
||||
config = ExperimentCreate(**exp_config)
|
||||
response = client.create_experiment(name, config)
|
||||
console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error creating experiment: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print("[yellow]Experiment creation cancelled[/]")
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@experiments_app.command("delete")
|
||||
def delete_experiment(
|
||||
name: str = typer.Argument(..., help="Name of the experiment to delete"),
|
||||
force: bool = typer.Option(False, "--force", "-f", help="Force deletion without confirmation"),
|
||||
):
|
||||
"""Delete an experiment."""
|
||||
if not force and not Confirm.ask(f"Delete experiment [cyan]{name}[/]?"):
|
||||
raise typer.Exit()
|
||||
|
||||
try:
|
||||
client.delete_experiment(name)
|
||||
console.print(f"[green]Deleted experiment[/] [cyan]{name}[/]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error deleting experiment: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@experiments_app.command("show")
|
||||
def show_experiment(
|
||||
name: str = typer.Argument(..., help="Name of the experiment"),
|
||||
):
|
||||
"""Show experiment details."""
|
||||
try:
|
||||
config = client.get_experiment_config(name)
|
||||
console.print(Syntax(yaml.dump(config.model_dump()), "yaml"))
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error getting experiment config: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@config_app.command("edit")
|
||||
def edit_config(
|
||||
experiment: str = typer.Argument(..., help="Name of the experiment"),
|
||||
dataset: str = typer.Argument(..., help="Name of the dataset to edit"),
|
||||
):
|
||||
"""Interactive configuration editor."""
|
||||
try:
|
||||
exp_config = client.get_experiment_config(experiment)
|
||||
if dataset not in exp_config.datasets:
|
||||
console.print(f"[red]Dataset {dataset} not found in experiment[/]")
|
||||
raise typer.Exit(1)
|
||||
current_config = exp_config.datasets[dataset]["config"]
|
||||
|
||||
console.print(f"\nCurrent configuration for [cyan]{dataset}[/]:")
|
||||
console.print(Syntax(yaml.dump(current_config), "yaml"))
|
||||
|
||||
# Interactive editing
|
||||
new_config = {}
|
||||
for key, value in current_config.items():
|
||||
new_value = Prompt.ask(f"{key}", default=str(value), show_default=True)
|
||||
|
||||
# Try to convert to appropriate type
|
||||
try:
|
||||
if isinstance(value, bool):
|
||||
new_value = new_value.lower() == "true"
|
||||
elif isinstance(value, int):
|
||||
new_value = int(new_value)
|
||||
elif isinstance(value, float):
|
||||
new_value = float(new_value)
|
||||
except ValueError:
|
||||
console.print(f"[yellow]Warning: Could not convert {new_value} to {type(value)}[/]")
|
||||
|
||||
new_config[key] = new_value
|
||||
|
||||
# Show changes
|
||||
console.print("\nNew configuration:")
|
||||
console.print(Syntax(yaml.dump(new_config), "yaml"))
|
||||
|
||||
if Confirm.ask("Apply these changes?"):
|
||||
try:
|
||||
config_update = DatasetConfigUpdate(config=new_config)
|
||||
client.update_dataset_config(experiment, dataset, config_update)
|
||||
console.print("[green]Configuration updated successfully[/]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error updating configuration: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print("[yellow]Update cancelled[/]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error getting experiment configuration: {e}[/]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for the CLI."""
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
tools/server/__init__.py
Normal file
8
tools/server/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Reasoning Gym Server - A FastAPI server for managing reasoning gym experiments.
|
||||
"""
|
||||
|
||||
from .config import ServerConfig
|
||||
from .server import create_app
|
||||
|
||||
__all__ = ["create_app", "ServerConfig"]
|
||||
17
tools/server/config.py
Normal file
17
tools/server/config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Server configuration using Pydantic settings management."""
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ServerConfig(BaseSettings):
|
||||
"""Configuration settings for the Reasoning Gym server."""
|
||||
|
||||
host: str = Field(default="localhost", description="Server host address")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
api_key: str = Field(
|
||||
default=..., description="API key for authentication", json_schema_extra={"env": "REASONING_GYM_API_KEY"}
|
||||
)
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
|
||||
model_config = ConfigDict(env_prefix="REASONING_GYM_")
|
||||
23
tools/server/middleware.py
Normal file
23
tools/server/middleware.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""API key middleware for FastAPI."""
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to check for valid API key in request headers."""
|
||||
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if not api_key or api_key != self.api_key:
|
||||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key")
|
||||
|
||||
return await call_next(request)
|
||||
75
tools/server/models.py
Normal file
75
tools/server/models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Pydantic models for API request/response data."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ExperimentCreate(BaseModel):
|
||||
"""Request model for creating a new experiment."""
|
||||
|
||||
name: str = Field(..., description="Unique name for the experiment")
|
||||
size: int = Field(500, description="Size of the dataset")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
datasets: Dict[str, Dict[str, Any]] = Field(..., description="Dictionary of datasets configurations")
|
||||
|
||||
|
||||
class ExperimentResponse(BaseModel):
|
||||
"""Response model for experiment operations."""
|
||||
|
||||
name: str = Field(..., description="Name of the experiment")
|
||||
size: int = Field(..., description="Size of the dataset")
|
||||
seed: Optional[int] = Field(None, description="Random seed used")
|
||||
datasets: Dict[str, Dict[str, Any]] = Field(..., description="Current dataset configurations")
|
||||
|
||||
|
||||
class ExperimentList(BaseModel):
|
||||
"""Response model for listing experiments."""
|
||||
|
||||
experiments: List[str] = Field(default_factory=list, description="List of registered experiment names")
|
||||
|
||||
|
||||
class DatasetConfigUpdate(BaseModel):
|
||||
"""Request model for updating dataset configuration."""
|
||||
|
||||
config: Dict[str, Any] = Field(..., description="Configuration parameters to update")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Response model for error conditions."""
|
||||
|
||||
detail: str = Field(..., description="Error message")
|
||||
|
||||
|
||||
class BatchEntry(BaseModel):
|
||||
"""Single entry in a batch"""
|
||||
|
||||
question: str = Field(..., description="The question text")
|
||||
entry_id: str = Field(..., description="Unique identifier in format '{version}.{index}'")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata about the entry")
|
||||
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
"""Response containing a batch of entries"""
|
||||
|
||||
entries: List[BatchEntry] = Field(..., description="List of batch entries")
|
||||
|
||||
|
||||
class AnswerItem(BaseModel):
|
||||
"""Single score item containing entry_id and answer"""
|
||||
|
||||
entry_id: str = Field(..., description="Entry identifier to score")
|
||||
answer: str = Field(..., description="Answer to evaluate")
|
||||
|
||||
|
||||
class ScoringRequest(BaseModel):
|
||||
"""Request for scoring model outputs"""
|
||||
|
||||
answers: List[AnswerItem] = Field(..., description="List of entries to score")
|
||||
|
||||
|
||||
class ScoringResponse(BaseModel):
|
||||
"""Response containing scores for answers"""
|
||||
|
||||
scores: List[float] = Field(..., description="List of scores in same order as request")
|
||||
entry_ids: List[str] = Field(..., description="List of entry_ids in same order as request")
|
||||
169
tools/server/server.py
Normal file
169
tools/server/server.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""FastAPI server implementation for Reasoning Gym."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from reasoning_gym.coaching.registry import ExperimentRegistry
|
||||
from reasoning_gym.composite import CompositeConfig, DatasetSpec
|
||||
|
||||
from .config import ServerConfig
|
||||
from .middleware import APIKeyMiddleware
|
||||
from .models import (
|
||||
BatchEntry,
|
||||
BatchResponse,
|
||||
DatasetConfigUpdate,
|
||||
ExperimentCreate,
|
||||
ExperimentList,
|
||||
ExperimentResponse,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_app(config: ServerConfig) -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=config.log_level)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(title="Reasoning Gym Server")
|
||||
|
||||
# Add middleware
|
||||
app.add_middleware(APIKeyMiddleware, api_key=config.api_key)
|
||||
|
||||
# Initialize registry
|
||||
registry = ExperimentRegistry()
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.post("/experiments", response_model=ExperimentResponse)
|
||||
async def create_experiment(experiment: ExperimentCreate):
|
||||
"""Create a new experiment."""
|
||||
# Convert dict format to DatasetSpec list
|
||||
dataset_specs = []
|
||||
for name, spec in experiment.datasets.items():
|
||||
dataset_specs.append(DatasetSpec(name=name, weight=spec.get("weight", 1.0), config=spec.get("config", {})))
|
||||
|
||||
config = CompositeConfig(size=experiment.size, seed=experiment.seed, datasets=dataset_specs)
|
||||
|
||||
try:
|
||||
registry.register_experiment(experiment.name, config)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return ExperimentResponse(
|
||||
name=experiment.name, size=experiment.size, seed=experiment.seed, datasets=experiment.datasets
|
||||
)
|
||||
|
||||
@app.get("/experiments", response_model=ExperimentList)
|
||||
async def list_experiments():
|
||||
"""List all registered experiments."""
|
||||
return ExperimentList(experiments=registry.list_experiments())
|
||||
|
||||
@app.delete("/experiments/{name}")
|
||||
async def delete_experiment(name: str):
|
||||
"""Delete an experiment."""
|
||||
if not registry.remove_experiment(name):
|
||||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
return {"status": "deleted"}
|
||||
|
||||
@app.get("/experiments/{name}/batch", response_model=BatchResponse)
|
||||
async def generate_batch(name: str, base_index: int, batch_size: int):
|
||||
"""Generate a batch of raw entries"""
|
||||
# Validate parameters
|
||||
if base_index < 0:
|
||||
raise HTTPException(status_code=400, detail="base_index must be non-negative")
|
||||
if batch_size <= 0:
|
||||
raise HTTPException(status_code=400, detail="batch_size must be positive")
|
||||
|
||||
experiment = registry.get_experiment(name)
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
|
||||
try:
|
||||
entries = []
|
||||
for i in range(base_index, base_index + batch_size):
|
||||
entry = experiment.dataset[i]
|
||||
|
||||
# Create BatchEntry with minimal required data
|
||||
batch_entry = BatchEntry(
|
||||
question=entry["question"],
|
||||
entry_id=f"{entry['metadata']['version_id']}.{i}",
|
||||
metadata=entry["metadata"],
|
||||
)
|
||||
entries.append(batch_entry)
|
||||
|
||||
return BatchResponse(entries=entries)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@app.post("/experiments/{name}/score", response_model=ScoringResponse)
|
||||
async def score_outputs(name: str, request: ScoringRequest):
|
||||
"""Score extracted answers"""
|
||||
experiment = registry.get_experiment(name)
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
|
||||
try:
|
||||
scores = []
|
||||
entry_ids = []
|
||||
for item in request.answers:
|
||||
score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id)
|
||||
scores.append(score)
|
||||
entry_ids.append(item.entry_id)
|
||||
|
||||
return ScoringResponse(scores=scores, entry_ids=entry_ids)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@app.get("/experiments/{name}/composite", response_model=ExperimentResponse)
|
||||
async def get_composite_config(name: str):
|
||||
"""Get composite configuration for an experiment."""
|
||||
experiment = registry.get_experiment(name)
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
|
||||
# Convert internal config to API response format
|
||||
datasets = {}
|
||||
for ds_spec in experiment.config.datasets:
|
||||
dataset = experiment.dataset.datasets[ds_spec.name]
|
||||
datasets[ds_spec.name] = {
|
||||
"weight": ds_spec.weight,
|
||||
"config": vars(dataset.config), # Get current config from dataset instance
|
||||
}
|
||||
|
||||
return ExperimentResponse(
|
||||
name=name, size=experiment.config.size, seed=experiment.config.seed, datasets=datasets
|
||||
)
|
||||
|
||||
@app.post("/experiments/{name}/composite/{dataset_name}")
|
||||
async def update_dataset_config(name: str, dataset_name: str, config_update: DatasetConfigUpdate):
|
||||
"""Update configuration for a specific dataset in the composite."""
|
||||
experiment = registry.get_experiment(name)
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
|
||||
try:
|
||||
experiment.dataset.update_dataset_config(dataset_name, config_update.config)
|
||||
return {"status": "updated"}
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def app(scope, receive, send):
|
||||
"""ASGI application that lazily creates the FastAPI app."""
|
||||
if not hasattr(app, "server_app"):
|
||||
app.server_app = create_app(ServerConfig())
|
||||
await app.server_app(scope, receive, send)
|
||||
1
tools/server/tests/__init__.py
Normal file
1
tools/server/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the Reasoning Gym server."""
|
||||
27
tools/server/tests/test_config.py
Normal file
27
tools/server/tests/test_config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Tests for server configuration."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from ..config import ServerConfig
|
||||
|
||||
|
||||
def test_default_config():
|
||||
"""Test default configuration values."""
|
||||
os.environ["REASONING_GYM_API_KEY"] = "test-key"
|
||||
config = ServerConfig()
|
||||
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 8000
|
||||
assert config.api_key == "test-key"
|
||||
assert config.log_level == "INFO"
|
||||
|
||||
|
||||
def test_missing_api_key():
|
||||
"""Test that missing API key raises an error."""
|
||||
if "REASONING_GYM_API_KEY" in os.environ:
|
||||
del os.environ["REASONING_GYM_API_KEY"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ServerConfig()
|
||||
277
tools/server/tests/test_endpoints.py
Normal file
277
tools/server/tests/test_endpoints.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Tests for API endpoints."""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..config import ServerConfig
|
||||
from ..server import create_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a test client."""
|
||||
config = ServerConfig(host="localhost", port=8000, api_key="test-key", log_level="INFO")
|
||||
app = create_app(config)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_health_check(client):
|
||||
"""Test health check endpoint."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "healthy"}
|
||||
|
||||
|
||||
def test_experiment_endpoints(client):
|
||||
"""Test experiment management endpoints."""
|
||||
# Set API key
|
||||
headers = {"X-API-Key": "test-key"}
|
||||
|
||||
# Create experiment
|
||||
create_data = {
|
||||
"name": "test_exp",
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
"datasets": {
|
||||
"chain_sum": {
|
||||
"weight": 1.0,
|
||||
"config": {
|
||||
"min_terms": 2,
|
||||
"max_terms": 4,
|
||||
"min_digits": 1,
|
||||
"max_digits": 2,
|
||||
"allow_negation": False,
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response = client.post("/experiments", json=create_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test_exp"
|
||||
|
||||
# List experiments
|
||||
response = client.get("/experiments", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert "test_exp" in response.json()["experiments"]
|
||||
|
||||
# Delete experiment
|
||||
response = client.delete("/experiments/test_exp", headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify deletion
|
||||
response = client.get("/experiments", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert "test_exp" not in response.json()["experiments"]
|
||||
|
||||
# Try to delete non-existent experiment
|
||||
response = client.delete("/experiments/nonexistent", headers=headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_batch_generation_endpoint(client):
|
||||
"""Test batch generation endpoint."""
|
||||
headers = {"X-API-Key": "test-key"}
|
||||
|
||||
# Create test experiment
|
||||
create_data = {
|
||||
"name": "test_exp",
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
"datasets": {
|
||||
"chain_sum": {
|
||||
"weight": 1.0,
|
||||
"config": {
|
||||
"min_terms": 2,
|
||||
"max_terms": 4,
|
||||
"min_digits": 1,
|
||||
"max_digits": 2,
|
||||
"allow_negation": False,
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response = client.post("/experiments", json=create_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test batch generation
|
||||
response = client.get(
|
||||
"/experiments/test_exp/batch",
|
||||
params={"base_index": 0, "batch_size": 2},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
print(data)
|
||||
|
||||
# Verify batch structure
|
||||
assert "entries" in data
|
||||
assert len(data["entries"]) == 2
|
||||
|
||||
# Verify entry structure
|
||||
entry = data["entries"][0]
|
||||
assert "question" in entry
|
||||
assert "entry_id" in entry
|
||||
assert "metadata" in entry
|
||||
|
||||
# Test error cases
|
||||
# Non-existent experiment
|
||||
response = client.get(
|
||||
"/experiments/nonexistent/batch",
|
||||
params={"base_index": 0, "batch_size": 2},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Invalid parameters
|
||||
response = client.get(
|
||||
"/experiments/test_exp/batch",
|
||||
params={"base_index": -1, "batch_size": 2},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_scoring_endpoint(client):
|
||||
"""Test answer scoring endpoint."""
|
||||
headers = {"X-API-Key": "test-key"}
|
||||
|
||||
# Create test experiment
|
||||
create_data = {
|
||||
"name": "test_exp",
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
"datasets": {
|
||||
"chain_sum": {
|
||||
"weight": 1.0,
|
||||
"config": {
|
||||
"min_terms": 2,
|
||||
"max_terms": 4,
|
||||
"min_digits": 1,
|
||||
"max_digits": 2,
|
||||
"allow_negation": False,
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response = client.post("/experiments", json=create_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get a batch to get valid entry_ids
|
||||
response = client.get(
|
||||
"/experiments/test_exp/batch",
|
||||
params={"base_index": 0, "batch_size": 2},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
batch = response.json()
|
||||
entry_id = batch["entries"][0]["entry_id"]
|
||||
|
||||
# Test scoring with correct answer
|
||||
response = client.post(
|
||||
"/experiments/test_exp/score",
|
||||
json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, # Assuming 2+2=4 is the first question
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "scores" in result
|
||||
assert "entry_ids" in result
|
||||
assert len(result["scores"]) == 1
|
||||
assert len(result["entry_ids"]) == 1
|
||||
assert result["entry_ids"][0] == entry_id
|
||||
assert isinstance(result["scores"][0], float)
|
||||
assert 0 <= result["scores"][0] <= 1
|
||||
|
||||
# Test scoring with wrong answer
|
||||
response = client.post(
|
||||
"/experiments/test_exp/score",
|
||||
json={"answers": [{"entry_id": entry_id, "answer": "wrong"}]},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["scores"][0] < 1.0
|
||||
assert result["entry_ids"][0] == entry_id
|
||||
|
||||
# Test error cases
|
||||
# Invalid entry_id format
|
||||
response = client.post(
|
||||
"/experiments/test_exp/score",
|
||||
json={"answers": [{"entry_id": "invalid_id", "answer": "4"}]},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
# Non-existent experiment
|
||||
response = client.post(
|
||||
"/experiments/nonexistent/score",
|
||||
json={"answers": [{"entry_id": entry_id, "answer": "4"}]},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_composite_config_endpoints(client):
|
||||
"""Test composite configuration endpoints."""
|
||||
headers = {"X-API-Key": "test-key"}
|
||||
|
||||
# Create an experiment first
|
||||
create_data = {
|
||||
"name": "test_exp",
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
"datasets": {
|
||||
"chain_sum": {
|
||||
"weight": 1.0,
|
||||
"config": {
|
||||
"min_terms": 2,
|
||||
"max_terms": 4,
|
||||
"min_digits": 1,
|
||||
"max_digits": 2,
|
||||
"allow_negation": False,
|
||||
"size": 10,
|
||||
"seed": 42,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response = client.post("/experiments", json=create_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get composite config
|
||||
response = client.get("/experiments/test_exp/composite", headers=headers)
|
||||
assert response.status_code == 200
|
||||
config = response.json()
|
||||
assert config["name"] == "test_exp"
|
||||
assert "chain_sum" in config["datasets"]
|
||||
|
||||
# Update dataset config
|
||||
update_data = {"config": {"min_terms": 3, "max_terms": 5}}
|
||||
response = client.post("/experiments/test_exp/composite/chain_sum", json=update_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify update
|
||||
response = client.get("/experiments/test_exp/composite", headers=headers)
|
||||
assert response.status_code == 200
|
||||
config = response.json()
|
||||
assert config["datasets"]["chain_sum"]["config"]["min_terms"] == 3
|
||||
assert config["datasets"]["chain_sum"]["config"]["max_terms"] == 5
|
||||
|
||||
# Test error cases
|
||||
# Non-existent experiment
|
||||
response = client.get("/experiments/nonexistent/composite", headers=headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Non-existent dataset
|
||||
response = client.post("/experiments/test_exp/composite/nonexistent", json=update_data, headers=headers)
|
||||
assert response.status_code == 404
|
||||
44
tools/server/tests/test_registry.py
Normal file
44
tools/server/tests/test_registry.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for experiment registry."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.chain_sum import ChainSumConfig
|
||||
from reasoning_gym.coaching.registry import ExperimentRegistry
|
||||
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
||||
|
||||
|
||||
def test_singleton():
|
||||
"""Test that ExperimentRegistry is a singleton."""
|
||||
registry1 = ExperimentRegistry()
|
||||
registry2 = ExperimentRegistry()
|
||||
assert registry1 is registry2
|
||||
|
||||
|
||||
def test_experiment_management():
|
||||
"""Test basic experiment management operations."""
|
||||
registry = ExperimentRegistry()
|
||||
|
||||
# Clear any existing experiments
|
||||
for name in registry.list_experiments():
|
||||
registry.remove_experiment(name)
|
||||
|
||||
# Test registration with chain_sum dataset
|
||||
chain_sum_spec = DatasetSpec(name="chain_sum", weight=1.0, config=vars(ChainSumConfig(size=10, seed=42)))
|
||||
|
||||
config = CompositeConfig(size=10, seed=42, datasets=[chain_sum_spec])
|
||||
registry.register_experiment("test_exp", config)
|
||||
|
||||
# Test listing
|
||||
assert "test_exp" in registry.list_experiments()
|
||||
|
||||
# Test retrieval
|
||||
exp = registry.get_experiment("test_exp")
|
||||
assert exp is not None
|
||||
assert exp.name == "test_exp"
|
||||
assert isinstance(exp.dataset, CompositeDataset)
|
||||
assert exp.config == config
|
||||
|
||||
# Test removal
|
||||
assert registry.remove_experiment("test_exp")
|
||||
assert "test_exp" not in registry.list_experiments()
|
||||
assert not registry.remove_experiment("nonexistent")
|
||||
Reference in New Issue
Block a user