Basic curriculum (#198)

* feat: Add optional curriculum support to dataset registration and creation
* docs: Add docstrings to create_curriculum() and register_dataset()
* feat: Add curriculum configuration classes for CurriculumExperiment
* feat: Add weight parameter to CurriculumAttributeConfig and use in DatasetSpec
* refactor: Simplify CurriculumAttributeConfig with "*" attribute level support
* test: Add unit tests for CurriculumExperiment class
* feat: Add from_yaml() method to CurriculumExperimentConfig with unit test
This commit is contained in:
Andreas Köpf
2025-03-07 11:22:12 +01:00
committed by GitHub
parent cbfdf097a0
commit c69bc5d4e6
29 changed files with 943 additions and 63 deletions

View File

@@ -0,0 +1,9 @@
#!/bin/bash
export N_GPUS=4
export BASE_MODEL=meta-llama/Llama-3.2-3B-Instruct # meta-llama/Llama-3.2-1B-Instruct
export ROLLOUT_TP_SIZE=2
export EXPERIMENT_NAME=basic_curriculum
export VLLM_ATTENTION_BACKEND=XFORMERS
bash ./train_grpo.sh

View File

@@ -0,0 +1,298 @@
# This example is an adapted version of Bytedance's code:
# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py
from io import StringIO
from typing import Optional
import hydra
import ray
import torch
import verl.utils.torch_functional as verl_F
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.dataset.rl_dataset import collate_fn
from verl.utils.model import compute_position_id_with_mask
import reasoning_gym
import reasoning_gym.utils
from reasoning_gym.coaching.curriculum_config import CurriculumExperimentConfig
from reasoning_gym.coaching.experiment import CurriculumExperiment
from reasoning_gym.utils import extract_answer
curriculum_config_yaml = """
curricula:
leg_counting:
attribute_levels:
num_animals: 2
products:
attribute_levels:
num_terms: 4
num_digits: 4
chain_sum:
attribute_levels:
num_terms: 4
num_digits: 4
weight: 1.0
"""
class ReasoningGymDataset(Dataset):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
experiment_name: str,
seed: int,
size: int,
developer_prompt: Optional[str] = None,
developer_role: str = "system",
max_prompt_length: int = 2048,
truncation: str = "error", ## ['left', 'right', 'error']
return_raw_chat: bool = False,
):
self.tokenizer = tokenizer
curriculum_config = CurriculumExperimentConfig.from_yaml_stream(StringIO(curriculum_config_yaml))
self.experiment = CurriculumExperiment(experiment_name, curriculum_config, size=size, seed=seed)
self.developer_prompt = developer_prompt
self.developer_role = developer_role
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.return_raw_chat = return_raw_chat
def __len__(self) -> int:
return len(self.experiment.composite)
def __getitem__(self, index: int):
row_dict = self.experiment.get_dataset_entry(index).copy()
q = row_dict["question"]
chat = []
if self.developer_prompt is not None:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": q})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
position_ids = compute_position_id_with_mask(attention_mask)
row_dict["data_source"] = "reasoning_gym/" + self.dataset_name
row_dict["input_ids"] = input_ids[0]
row_dict["attention_mask"] = attention_mask[0]
row_dict["position_ids"] = position_ids[0]
# encode prompts without chat template
if self.return_raw_chat:
row_dict["raw_prompt"] = chat.tolist()
return row_dict
class RayPPOTrainerCustom(RayPPOTrainer):
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict,
resource_pool_manager,
ray_worker_group_cls,
experiment_name: str = "basic_curriculum",
dataset_size: int = 10000,
):
self.dataset_size = dataset_size
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
self.train_dataset = ReasoningGymDataset(
tokenizer=tokenizer,
experiment_name=experiment_name,
seed=1,
size=self.dataset_size,
developer_prompt=developer_prompt,
)
self.val_dataset = ReasoningGymDataset(
tokenizer=tokenizer,
experiment_name=experiment_name,
seed=2,
size=self.dataset_size,
developer_prompt=developer_prompt,
)
train_reward_fn = lambda data: self._score_output(data, num_examine=0)
val_reward_fn = lambda data: self._score_output(data, num_examine=1)
super().__init__(
config,
tokenizer,
role_worker_mapping,
resource_pool_manager,
ray_worker_group_cls,
train_reward_fn,
val_reward_fn,
)
def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
num_printed = 0
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch["prompts"] # tokenized prompts
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)
entry_id = data_item.non_tensor_batch["metadata"]["entry_id"]
score = self._compute_score(
solution_str=sequences_str,
entry_id=entry_id,
)
reward_tensor[i, valid_response_length - 1] = score
if num_printed < num_examine:
print(f"reward={score}, seq={sequences_str}")
num_printed += 1
return reward_tensor
def _compute_score(self, solution_str: str, entry_id: str) -> float:
found_answer = extract_answer(solution_str, tag_name="answer")
reward = self.train_dataset.experiment.score_answer_with_id(found_answer, entry_id=entry_id)
print(f"entry_id: {entry_id}; found answer={found_answer}; reward: {reward};")
return reward
def _create_dataloader(self):
self.train_dataloader = DataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.train_batch_size,
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f"Total training steps: {self.total_training_steps}")
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
@ray.remote
def main_task(config):
# print initial config
from pprint import pprint
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainerCustom(
config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
)
trainer.init_workers()
trainer.fit()
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(config))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,39 @@
#!/bin/bash
set -x
python3 -u ppo_curriculum.py \
algorithm.adv_estimator=grpo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=512 \
data.val_batch_size=512 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['wandb'] \
trainer.project_name='verl_chain_sum_grpo' \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=100 \
trainer.test_freq=100 \
trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log

View File

@@ -0,0 +1,171 @@
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 1 # > 1 for grpo
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
fsdp_size: -1
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: null # set a number
max_length: null
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}

View File

@@ -123,7 +123,7 @@ class ChainSumCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
levels=list(range(2, 13)),
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
@@ -133,7 +133,7 @@ class ChainSumCurriculum(BaseCurriculum):
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 4, 10],
levels=list(range(1, 11)),
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
@@ -145,4 +145,4 @@ class ChainSumCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig)
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig, ChainSumCurriculum)

View File

@@ -65,4 +65,4 @@ class CountBitsCurriculum(BaseCurriculum):
)
register_dataset("count_bits", CountBitsDataset, CountBitsConfig)
register_dataset("count_bits", CountBitsDataset, CountBitsConfig, CountBitsCurriculum)

View File

@@ -4,6 +4,9 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from reasoning_gym.coaching.attributes import AttributeType, RangeAttributeDefinition
from reasoning_gym.coaching.base_curriculum import BaseCurriculum
from ..factory import ProceduralDataset, register_dataset
ANIMALS = {
@@ -124,4 +127,23 @@ class LegCountingDataset(ProceduralDataset):
}
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)
class LegCountingCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(LegCountingCurriculum.__name__, LegCountingConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="num_animals",
levels=list(range(1, 20)),
default_level=0, # Start with 2 terms
description="Number of animals in question",
attr_type=AttributeType.APPEND,
min_value=1, # Ensure at least 1 animal
lower_field_name="min_animals",
upper_field_name="max_animals",
),
)
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig, LegCountingCurriculum)

View File

@@ -115,7 +115,7 @@ class ProductsCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
levels=list(range(2, 13)),
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
@@ -125,7 +125,7 @@ class ProductsCurriculum(BaseCurriculum):
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 3, 4],
levels=list(range(1, 11)),
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
@@ -137,4 +137,4 @@ class ProductsCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("products", ProductsDataset, ProductsConfig)
register_dataset("products", ProductsDataset, ProductsConfig, ProductsCurriculum)

View File

@@ -1,8 +1,9 @@
from typing import Any, Iterable, Optional
from typing import Any, Optional, TypeVar
from ..factory import ConfigT
from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition
ConfigT = TypeVar("ConfigT")
class BaseCurriculum:
def __init__(self, name: str, config_cls: ConfigT):
@@ -21,7 +22,6 @@ class BaseCurriculum:
elif isinstance(attr, ScalarAttributeDefinition):
val = self.get_attr_value(attr.name)
config_args[attr.field_name] = val
print(config_args)
return self._config_cls(**config_args)
@property

View File

@@ -0,0 +1,91 @@
from dataclasses import dataclass
from typing import Dict, Optional
import yaml
@dataclass
class CurriculumAttributeConfig:
"""Configuration for curriculum attribute levels"""
# Dictionary mapping attribute names to levels
# Special key "*" means apply that level to all attributes
attribute_levels: Dict[str, int]
# Weight for sampling this dataset
weight: float = 1.0
def validate(self):
"""Validate the configuration"""
if not self.attribute_levels:
raise ValueError("Must specify at least one attribute level")
@dataclass
class CurriculumExperimentConfig:
"""Configuration for curriculum experiments"""
# Dictionary mapping dataset names to their curriculum configurations
curricula: Dict[str, CurriculumAttributeConfig]
def validate(self):
"""Validate the configuration"""
if not self.curricula:
raise ValueError("Must specify at least one curriculum")
for dataset_name, attr_config in self.curricula.items():
if not isinstance(attr_config, CurriculumAttributeConfig):
raise ValueError(f"Invalid attribute config for dataset {dataset_name}")
attr_config.validate()
@classmethod
def from_yaml_stream(cls, stream) -> "CurriculumExperimentConfig":
"""Load configuration from a YAML stream
Args:
stream: A file-like object containing YAML data
Returns:
CurriculumExperimentConfig instance
Raises:
ValueError: If YAML data has invalid format
"""
data = yaml.safe_load(stream)
if not isinstance(data, dict):
raise ValueError("YAML data must contain a dictionary")
if "curricula" not in data:
raise ValueError("YAML data must contain a 'curricula' key")
# Convert curriculum configs
curricula = {}
for dataset_name, config in data["curricula"].items():
if not isinstance(config, dict):
raise ValueError(f"Curriculum config for {dataset_name} must be a dictionary")
if "attribute_levels" not in config:
raise ValueError(f"Curriculum config for {dataset_name} must contain 'attribute_levels'")
weight = config.get("weight", 1.0)
curricula[dataset_name] = CurriculumAttributeConfig(
attribute_levels=config["attribute_levels"], weight=weight
)
return cls(curricula=curricula)
@classmethod
def from_yaml(cls, yaml_path: str) -> "CurriculumExperimentConfig":
"""Load configuration from YAML file
Args:
yaml_path: Path to YAML configuration file
Returns:
CurriculumExperimentConfig instance
Raises:
ValueError: If YAML file has invalid format
"""
with open(yaml_path, "r") as f:
return cls.from_yaml_stream(f)

View File

@@ -1,36 +1,93 @@
"""Experiment class combining dataset, scoreboard and curriculum."""
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional
from ..composite import CompositeConfig, CompositeDataset
from ..composite import CompositeConfig, CompositeDataset, DatasetSpec
from ..factory import create_curriculum
from ..version_manager import DatasetVersionManager
from .coach import ScoreBoard
from .curriculum_config import CurriculumExperimentConfig
@dataclass
class Experiment:
"""
An experiment combines a dataset with scoring and curriculum management.
def __init__(self, name: str, composite: CompositeDataset):
self.name = name
self.composite = composite
self.score_board = ScoreBoard()
Attributes:
name: Unique identifier for the experiment
dataset: The composite dataset for generating examples
score_board: Tracks performance metrics
config: The configuration used to create the dataset
version_manager: Manages dataset versions for scoring
"""
def get_dataset_entry(self, index: int) -> dict:
return self.composite[index]
name: str
dataset: CompositeDataset
score_board: ScoreBoard
config: CompositeConfig
version_manager: DatasetVersionManager
def score_answer_with_id(
self, answer: Optional[str], entry_id: str, conversation: Optional[list[dict]] = None
) -> float:
dataset, index, dataset_name = self.composite.resolve_entry_id(entry_id)
entry = dataset[index]
score = dataset.score_answer(answer, entry)
metadata = entry["metadata"]
self.score_board.add_score(score, metadata, conversation)
return score
@classmethod
def create(cls, name: str, config: CompositeConfig) -> "Experiment":
"""Create a new experiment from a configuration."""
version_manager = DatasetVersionManager()
dataset = CompositeDataset(config, version_manager=version_manager)
score_board = ScoreBoard()
return cls(name=name, dataset=dataset, score_board=score_board, config=config, version_manager=version_manager)
return cls(name=name, dataset=dataset)
class CurriculumExperiment(Experiment):
def __init__(self, name: str, config: CurriculumExperimentConfig, size: int, seed: Optional[int] = None):
"""Initialize curriculum experiment with configured datasets and their curricula.
Args:
name: Name of the experiment
config: Configuration specifying datasets and their attribute levels
size: Number of examples to generate
seed: Random seed for reproducibility
"""
# Initialize curricula and build dataset specs
self.curricula = {}
dataset_specs = []
# Process each dataset in the curriculum config
for dataset_name, attr_config in config.curricula.items():
# Create and store curriculum
curriculum = create_curriculum(dataset_name)
self.curricula[dataset_name] = curriculum
# Handle special "*" attribute that sets all levels
if "*" in attr_config.attribute_levels:
level = attr_config.attribute_levels["*"]
for attr_name in curriculum.attributes:
curriculum.set_attr_level(attr_name, level)
# Set individual attribute levels (overriding "*" if specified)
for attr_name, level in attr_config.attribute_levels.items():
if attr_name != "*":
curriculum.set_attr_level(attr_name, level)
# Generate dataset config from curriculum
dataset_config = curriculum.generate_configuration()
# Create dataset spec
spec = DatasetSpec(name=dataset_name, weight=attr_config.weight, config=dataset_config.__dict__)
dataset_specs.append(spec)
# Create composite config with all datasets
composite_config = CompositeConfig(size=size, seed=seed, datasets=dataset_specs)
# Create composite dataset
version_manager = DatasetVersionManager()
composite = CompositeDataset(config=composite_config, version_manager=version_manager)
# Initialize base experiment
super().__init__(name=name, composite=composite)
# Store curriculum config
self.curriculum_config = config
def update_difficulty(self):
"""Update difficulty levels based on performance metrics"""
# TODO: Implement difficulty adjustment logic
pass

View File

@@ -48,10 +48,16 @@ class CompositeConfig:
ds.validate()
@classmethod
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
"""Load configuration from YAML file"""
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
def from_yaml_stream(cls, stream) -> "CompositeConfig":
"""Load configuration from a YAML stream
Args:
stream: A file-like object containing YAML data
Returns:
CompositeConfig instance
"""
data = yaml.safe_load(stream)
# Convert dataset specs to DatasetSpec objects
if "datasets" in data:
@@ -59,6 +65,19 @@ class CompositeConfig:
return cls(**data)
@classmethod
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
"""Load configuration from YAML file
Args:
yaml_path: Path to YAML configuration file
Returns:
CompositeConfig instance
"""
with open(yaml_path, "r") as f:
return cls.from_yaml_stream(f)
class CompositeDataset(ProceduralDataset):
"""A dataset that combines multiple datasets with weighted sampling"""
@@ -246,20 +265,7 @@ class CompositeDataset(ProceduralDataset):
self.dataset_names.pop(idx)
self.weights.pop(idx)
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
"""Score an answer using an entry_id to lookup the original entry
Args:
answer: The answer to score
entry_id: String in format "version_id.index"
Returns:
Score between 0 and 1
Raises:
ValueError: If entry_id format is invalid
KeyError: If version not found in version manager
"""
def resolve_entry_id(self, entry_id: str) -> tuple[ProceduralDataset, int, str]:
if self.version_manager is None:
raise RuntimeError("Version manager required for scoring with entry_id")
@@ -274,6 +280,24 @@ class CompositeDataset(ProceduralDataset):
raise KeyError(f"Version {version_id} not found in version manager")
dataset_name, dataset = dataset_info
return dataset, index, dataset_name
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
"""Score an answer using an entry_id to lookup the original entry
Args:
answer: The answer to score
entry_id: String in format "version_id.index"
Returns:
Score between 0 and 1
Raises:
ValueError: If entry_id format is invalid
KeyError: If version not found in version manager
"""
dataset, index, _ = self.resolve_entry_id(entry_id)
# Get entry from dataset
entry = dataset[index]

View File

@@ -1,24 +1,34 @@
from dataclasses import is_dataclass
from typing import Type, TypeVar
from typing import Optional, Type, TypeVar
from reasoning_gym.coaching.base_curriculum import BaseCurriculum, ConfigT
from .dataset import ProceduralDataset
# Type variables for generic type hints
ConfigT = TypeVar("ConfigT")
DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
CurriculumT = TypeVar("CurriculumT", bound=BaseCurriculum)
# Global registry of datasets
DATASETS: dict[str, tuple[Type[ProceduralDataset], Type]] = {}
CURRICULA: dict[str, BaseCurriculum] = {}
def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None:
def register_dataset(
name: str,
dataset_cls: Type[DatasetT],
config_cls: Type[ConfigT],
curriculum_cls: Optional[CurriculumT] = None,
) -> None:
"""
Register a dataset class with its configuration class.
Register a dataset class with its configuration class and optional curriculum.
Args:
name: Unique identifier for the dataset
dataset_cls: Class derived from ProceduralDataset
config_cls: Configuration dataclass for the dataset
curriculum_cls: Optional curriculum class for progressive difficulty
Raises:
ValueError: If name is already registered or invalid types provided
@@ -34,6 +44,9 @@ def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[Co
DATASETS[name] = (dataset_cls, config_cls)
if curriculum_cls:
CURRICULA[name] = curriculum_cls
def create_dataset(name: str, **kwargs) -> ProceduralDataset:
"""
@@ -56,3 +69,28 @@ def create_dataset(name: str, **kwargs) -> ProceduralDataset:
config = config_cls(**kwargs)
return dataset_cls(config=config)
def create_curriculum(name: str) -> BaseCurriculum:
"""
Create a curriculum instance for the named dataset.
Args:
name: Registered dataset name
Returns:
Configured curriculum instance
Raises:
ValueError: If dataset not found or has no curriculum registered
"""
if name not in CURRICULA:
raise ValueError(f"No curriculum registered for dataset '{name}'")
curriculum_cls = CURRICULA[name]
return curriculum_cls()
def has_curriculum(name: str) -> bool:
return name in CURRICULA

View File

@@ -179,4 +179,4 @@ class NQueensCurriculum(BaseCurriculum):
)
register_dataset("n_queens", NQueensDataset, NQueensConfig)
register_dataset("n_queens", NQueensDataset, NQueensConfig, NQueensCurriculum)

View File

@@ -174,4 +174,4 @@ class CourseScheduleCurriculum(BaseCurriculum):
)
register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig)
register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig, CourseScheduleCurriculum)

View File

@@ -191,4 +191,4 @@ class LargestIslandCurriculum(BaseCurriculum):
)
register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig)
register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig, LargestIslandCurriculum)

View File

@@ -193,4 +193,4 @@ class ShortestPathCurriculum(BaseCurriculum):
)
register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig)
register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig, ShortestPathCurriculum)

View File

@@ -0,0 +1,131 @@
import io
import pytest
import yaml
from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig
from reasoning_gym.coaching.experiment import CurriculumExperiment
def test_curriculum_experiment_initialization():
"""Test basic initialization of CurriculumExperiment"""
# Create config with leg_counting dataset
config = CurriculumExperimentConfig(
curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"num_animals": 2}, weight=1.0)}
)
# Create experiment
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
# Check experiment was created correctly
assert experiment.name == "test_experiment"
assert "leg_counting" in experiment.curricula
assert "leg_counting" in experiment.composite.datasets
# Check curriculum was configured correctly
curriculum = experiment.curricula["leg_counting"]
assert curriculum.get_attr_level("num_animals") == 2
# Check dataset was created with correct config
dataset = experiment.composite.datasets["leg_counting"]
assert dataset.config.min_animals == 1
assert dataset.config.max_animals == 3
# Check we can get entries from the dataset
entry = experiment.get_dataset_entry(0)
assert "question" in entry
assert "answer" in entry
assert "metadata" in entry
assert entry["metadata"]["source_dataset"] == "leg_counting"
def test_curriculum_experiment_wildcard_level():
"""Test using "*" to set all attribute levels"""
config = CurriculumExperimentConfig(
curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"*": 3}, weight=1.0)}
)
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
# Check all attributes were set to level 3
curriculum = experiment.curricula["leg_counting"]
for attr_name in curriculum.attributes:
assert curriculum.get_attr_level(attr_name) == 3
def test_curriculum_experiment_mixed_levels():
"""Test mixing "*" with specific attribute levels"""
config = CurriculumExperimentConfig(
curricula={
"leg_counting": CurriculumAttributeConfig(
attribute_levels={"*": 2, "num_animals": 4}, weight=1.0 # Should override the "*" level
)
}
)
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
curriculum = experiment.curricula["leg_counting"]
assert curriculum.get_attr_level("num_animals") == 4 # Specific override
def test_curriculum_experiment_from_yaml():
"""Test loading curriculum experiment config from YAML using a string stream"""
# Create a YAML string
yaml_content = """
curricula:
leg_counting:
attribute_levels:
"*": 2
num_animals: 4
weight: 1.5
chain_sum:
attribute_levels:
num_terms: 1
num_digits: 2
weight: 0.8
"""
# Use StringIO to create a file-like object from the string
from io import StringIO
yaml_stream = StringIO(yaml_content)
# Load config from YAML stream
config = CurriculumExperimentConfig.from_yaml_stream(yaml_stream)
# Verify config was loaded correctly
assert len(config.curricula) == 2
assert "leg_counting" in config.curricula
assert "chain_sum" in config.curricula
# Check leg_counting curriculum
leg_counting = config.curricula["leg_counting"]
assert leg_counting.attribute_levels["*"] == 2
assert leg_counting.attribute_levels["num_animals"] == 4
assert leg_counting.weight == 1.5
# Check chain_sum curriculum
chain_sum = config.curricula["chain_sum"]
assert chain_sum.attribute_levels["num_terms"] == 1
assert chain_sum.attribute_levels["num_digits"] == 2
assert chain_sum.weight == 0.8
# Create experiment from the loaded config
experiment = CurriculumExperiment(name="yaml_test", config=config, size=10, seed=42)
# Verify experiment was created correctly
assert "leg_counting" in experiment.curricula
assert "chain_sum" in experiment.curricula
# Check attribute levels were applied
leg_curriculum = experiment.curricula["leg_counting"]
assert leg_curriculum.get_attr_level("num_animals") == 4
chain_sum_curriculum = experiment.curricula["chain_sum"]
assert chain_sum_curriculum.get_attr_level("num_terms") == 1
assert chain_sum_curriculum.get_attr_level("num_digits") == 2

View File

@@ -89,12 +89,12 @@ def create_app(config: ServerConfig) -> FastAPI:
try:
entries = []
for i in range(base_index, base_index + batch_size):
entry = experiment.dataset[i]
entry = experiment.get_dataset_entry(i)
# Create BatchEntry with minimal required data
batch_entry = BatchEntry(
question=entry["question"],
entry_id=f"{entry['metadata']['version_id']}.{i}",
entry_id=f"{entry['metadata']['entry_id']}",
metadata=entry["metadata"],
)
entries.append(batch_entry)
@@ -115,7 +115,7 @@ def create_app(config: ServerConfig) -> FastAPI:
scores = []
entry_ids = []
for item in request.answers:
score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id)
score = experiment.score_answer_with_id(item.answer, item.entry_id)
scores.append(score)
entry_ids.append(item.entry_id)
@@ -134,7 +134,7 @@ def create_app(config: ServerConfig) -> FastAPI:
# Convert internal config to API response format
datasets = {}
for ds_spec in experiment.config.datasets:
dataset = experiment.dataset.datasets[ds_spec.name]
dataset = experiment.composite.datasets[ds_spec.name]
datasets[ds_spec.name] = {
"weight": ds_spec.weight,
"config": vars(dataset.config), # Get current config from dataset instance
@@ -152,7 +152,7 @@ def create_app(config: ServerConfig) -> FastAPI:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
experiment.dataset.update_dataset_config(dataset_name, config_update.config)
experiment.composite.update_dataset_config(dataset_name, config_update.config)
return {"status": "updated"}
except KeyError:
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment")

View File

@@ -35,7 +35,7 @@ def test_experiment_management():
exp = registry.get_experiment("test_exp")
assert exp is not None
assert exp.name == "test_exp"
assert isinstance(exp.dataset, CompositeDataset)
assert isinstance(exp.composite, CompositeDataset)
assert exp.config == config
# Test removal