diff --git a/examples/veRL/config/ppo_trainer.yaml b/examples/veRL/basic_curriculum/config/ppo_trainer.yaml similarity index 100% rename from examples/veRL/config/ppo_trainer.yaml rename to examples/veRL/basic_curriculum/config/ppo_trainer.yaml diff --git a/examples/veRL/basic_curriculum/launch.sh b/examples/veRL/basic_curriculum/launch.sh new file mode 100755 index 00000000..6681a2c8 --- /dev/null +++ b/examples/veRL/basic_curriculum/launch.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +export N_GPUS=4 +export BASE_MODEL=meta-llama/Llama-3.2-3B-Instruct # meta-llama/Llama-3.2-1B-Instruct +export ROLLOUT_TP_SIZE=2 +export EXPERIMENT_NAME=basic_curriculum +export VLLM_ATTENTION_BACKEND=XFORMERS + +bash ./train_grpo.sh diff --git a/examples/veRL/basic_curriculum/ppo_curriculum.py b/examples/veRL/basic_curriculum/ppo_curriculum.py new file mode 100644 index 00000000..4c583e13 --- /dev/null +++ b/examples/veRL/basic_curriculum/ppo_curriculum.py @@ -0,0 +1,298 @@ +# This example is an adapted version of Bytedance's code: +# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py +from io import StringIO +from typing import Optional + +import hydra +import ray +import torch +import verl.utils.torch_functional as verl_F +from omegaconf import OmegaConf, open_dict +from torch.utils.data import DataLoader, Dataset +from transformers import PreTrainedTokenizer +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.dataset.rl_dataset import collate_fn +from verl.utils.model import compute_position_id_with_mask + +import reasoning_gym +import reasoning_gym.utils +from reasoning_gym.coaching.curriculum_config import CurriculumExperimentConfig +from reasoning_gym.coaching.experiment import CurriculumExperiment +from reasoning_gym.utils import extract_answer + +curriculum_config_yaml = """ + curricula: + leg_counting: + attribute_levels: + num_animals: 2 + products: + attribute_levels: + num_terms: 4 + num_digits: 4 + chain_sum: + attribute_levels: + num_terms: 4 + num_digits: 4 + weight: 1.0 + """ + + +class ReasoningGymDataset(Dataset): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + experiment_name: str, + seed: int, + size: int, + developer_prompt: Optional[str] = None, + developer_role: str = "system", + max_prompt_length: int = 2048, + truncation: str = "error", ## ['left', 'right', 'error'] + return_raw_chat: bool = False, + ): + self.tokenizer = tokenizer + curriculum_config = CurriculumExperimentConfig.from_yaml_stream(StringIO(curriculum_config_yaml)) + self.experiment = CurriculumExperiment(experiment_name, curriculum_config, size=size, seed=seed) + self.developer_prompt = developer_prompt + self.developer_role = developer_role + self.max_prompt_length = max_prompt_length + self.truncation = truncation + self.return_raw_chat = return_raw_chat + + def __len__(self) -> int: + return len(self.experiment.composite) + + def __getitem__(self, index: int): + row_dict = self.experiment.get_dataset_entry(index).copy() + q = row_dict["question"] + + chat = [] + if self.developer_prompt is not None: + chat.append({"role": self.developer_role, "content": self.developer_prompt}) + chat.append({"role": "user", "content": q}) + + prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt, + tokenizer=self.tokenizer, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict["data_source"] = "reasoning_gym/" + self.dataset_name + row_dict["input_ids"] = input_ids[0] + row_dict["attention_mask"] = attention_mask[0] + row_dict["position_ids"] = position_ids[0] + + # encode prompts without chat template + if self.return_raw_chat: + row_dict["raw_prompt"] = chat.tolist() + + return row_dict + + +class RayPPOTrainerCustom(RayPPOTrainer): + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict, + resource_pool_manager, + ray_worker_group_cls, + experiment_name: str = "basic_curriculum", + dataset_size: int = 10000, + ): + self.dataset_size = dataset_size + + developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] + self.train_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + experiment_name=experiment_name, + seed=1, + size=self.dataset_size, + developer_prompt=developer_prompt, + ) + + self.val_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + experiment_name=experiment_name, + seed=2, + size=self.dataset_size, + developer_prompt=developer_prompt, + ) + + train_reward_fn = lambda data: self._score_output(data, num_examine=0) + val_reward_fn = lambda data: self._score_output(data, num_examine=1) + + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + train_reward_fn, + val_reward_fn, + ) + + def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor: + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + + num_printed = 0 + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch["prompts"] # tokenized prompts + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + + entry_id = data_item.non_tensor_batch["metadata"]["entry_id"] + + score = self._compute_score( + solution_str=sequences_str, + entry_id=entry_id, + ) + reward_tensor[i, valid_response_length - 1] = score + + if num_printed < num_examine: + print(f"reward={score}, seq={sequences_str}") + num_printed += 1 + + return reward_tensor + + def _compute_score(self, solution_str: str, entry_id: str) -> float: + found_answer = extract_answer(solution_str, tag_name="answer") + reward = self.train_dataset.experiment.score_answer_with_id(found_answer, entry_id=entry_id) + print(f"entry_id: {entry_id}; found answer={found_answer}; reward: {reward};") + return reward + + def _create_dataloader(self): + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + +@ray.remote +def main_task(config): + # print initial config + from pprint import pprint + + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_local_path_from_hdfs + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == "fsdp": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPPOTrainerCustom( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + ) + trainer.init_workers() + trainer.fit() + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) + + ray.get(main_task.remote(config)) + + +if __name__ == "__main__": + main() diff --git a/examples/veRL/basic_curriculum/train_grpo.sh b/examples/veRL/basic_curriculum/train_grpo.sh new file mode 100644 index 00000000..6bfa35be --- /dev/null +++ b/examples/veRL/basic_curriculum/train_grpo.sh @@ -0,0 +1,39 @@ +#!/bin/bash +set -x + +python3 -u ppo_curriculum.py \ + algorithm.adv_estimator=grpo \ + data.train_files=$DATA_DIR/train.parquet \ + data.val_files=$DATA_DIR/test.parquet \ + data.train_batch_size=512 \ + data.val_batch_size=512 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['wandb'] \ + trainer.project_name='verl_chain_sum_grpo' \ + trainer.experiment_name=$EXPERIMENT_NAME \ + trainer.n_gpus_per_node=$N_GPUS \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=100 \ + trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log diff --git a/examples/veRL/config/grpo_trainer.yaml b/examples/veRL/chain_sum/config/grpo_trainer.yaml similarity index 100% rename from examples/veRL/config/grpo_trainer.yaml rename to examples/veRL/chain_sum/config/grpo_trainer.yaml diff --git a/examples/veRL/chain_sum/config/ppo_trainer.yaml b/examples/veRL/chain_sum/config/ppo_trainer.yaml new file mode 100644 index 00000000..a3d167ea --- /dev/null +++ b/examples/veRL/chain_sum/config/ppo_trainer.yaml @@ -0,0 +1,171 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: 1312 + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number + max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: [ 'console', 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/examples/veRL/launch_on_2gpu_server.sh b/examples/veRL/chain_sum/launch_on_2gpu_server.sh similarity index 100% rename from examples/veRL/launch_on_2gpu_server.sh rename to examples/veRL/chain_sum/launch_on_2gpu_server.sh diff --git a/examples/veRL/launch_on_4gpu.sh b/examples/veRL/chain_sum/launch_on_4gpu.sh similarity index 100% rename from examples/veRL/launch_on_4gpu.sh rename to examples/veRL/chain_sum/launch_on_4gpu.sh diff --git a/examples/veRL/main_ppo_custom_reward.py b/examples/veRL/chain_sum/main_ppo_custom_reward.py similarity index 100% rename from examples/veRL/main_ppo_custom_reward.py rename to examples/veRL/chain_sum/main_ppo_custom_reward.py diff --git a/examples/veRL/main_ppo_custom_reward_server.py b/examples/veRL/chain_sum/main_ppo_custom_reward_server.py similarity index 100% rename from examples/veRL/main_ppo_custom_reward_server.py rename to examples/veRL/chain_sum/main_ppo_custom_reward_server.py diff --git a/examples/veRL/train_grpo.sh b/examples/veRL/chain_sum/train_grpo.sh similarity index 100% rename from examples/veRL/train_grpo.sh rename to examples/veRL/chain_sum/train_grpo.sh diff --git a/examples/veRL/train_grpo_server.sh b/examples/veRL/chain_sum/train_grpo_server.sh similarity index 100% rename from examples/veRL/train_grpo_server.sh rename to examples/veRL/chain_sum/train_grpo_server.sh diff --git a/examples/veRL/train_ppo.sh b/examples/veRL/chain_sum/train_ppo.sh similarity index 100% rename from examples/veRL/train_ppo.sh rename to examples/veRL/chain_sum/train_ppo.sh diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 351b258f..1daffbf9 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -123,7 +123,7 @@ class ChainSumCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="num_terms", - levels=[2, 3, 4, 5], + levels=list(range(2, 13)), default_level=0, # Start with 2 terms description="Maximum number of terms in the expression", attr_type=AttributeType.APPEND, @@ -133,7 +133,7 @@ class ChainSumCurriculum(BaseCurriculum): ), RangeAttributeDefinition( name="num_digits", - levels=[1, 2, 4, 10], + levels=list(range(1, 11)), default_level=0, # Start with 1-digit numbers description="Number of digits in each operand", attr_type=AttributeType.APPEND, @@ -145,4 +145,4 @@ class ChainSumCurriculum(BaseCurriculum): # Register the dataset -register_dataset("chain_sum", ChainSumDataset, ChainSumConfig) +register_dataset("chain_sum", ChainSumDataset, ChainSumConfig, ChainSumCurriculum) diff --git a/reasoning_gym/arithmetic/count_bits.py b/reasoning_gym/arithmetic/count_bits.py index 059df677..cd55bdc7 100644 --- a/reasoning_gym/arithmetic/count_bits.py +++ b/reasoning_gym/arithmetic/count_bits.py @@ -65,4 +65,4 @@ class CountBitsCurriculum(BaseCurriculum): ) -register_dataset("count_bits", CountBitsDataset, CountBitsConfig) +register_dataset("count_bits", CountBitsDataset, CountBitsConfig, CountBitsCurriculum) diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index 3acc2f32..b68e133d 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -4,6 +4,9 @@ from dataclasses import dataclass from random import Random from typing import Optional +from reasoning_gym.coaching.attributes import AttributeType, RangeAttributeDefinition +from reasoning_gym.coaching.base_curriculum import BaseCurriculum + from ..factory import ProceduralDataset, register_dataset ANIMALS = { @@ -124,4 +127,23 @@ class LegCountingDataset(ProceduralDataset): } -register_dataset("leg_counting", LegCountingDataset, LegCountingConfig) +class LegCountingCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(LegCountingCurriculum.__name__, LegCountingConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_animals", + levels=list(range(1, 20)), + default_level=0, # Start with 2 terms + description="Number of animals in question", + attr_type=AttributeType.APPEND, + min_value=1, # Ensure at least 1 animal + lower_field_name="min_animals", + upper_field_name="max_animals", + ), + ) + + +register_dataset("leg_counting", LegCountingDataset, LegCountingConfig, LegCountingCurriculum) diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index dbe89cd0..c39c4377 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -115,7 +115,7 @@ class ProductsCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="num_terms", - levels=[2, 3, 4, 5], + levels=list(range(2, 13)), default_level=0, # Start with 2 terms description="Maximum number of terms in the expression", attr_type=AttributeType.APPEND, @@ -125,7 +125,7 @@ class ProductsCurriculum(BaseCurriculum): ), RangeAttributeDefinition( name="num_digits", - levels=[1, 2, 3, 4], + levels=list(range(1, 11)), default_level=0, # Start with 1-digit numbers description="Number of digits in each operand", attr_type=AttributeType.APPEND, @@ -137,4 +137,4 @@ class ProductsCurriculum(BaseCurriculum): # Register the dataset -register_dataset("products", ProductsDataset, ProductsConfig) +register_dataset("products", ProductsDataset, ProductsConfig, ProductsCurriculum) diff --git a/reasoning_gym/coaching/base_curriculum.py b/reasoning_gym/coaching/base_curriculum.py index 20684f75..22c031a1 100644 --- a/reasoning_gym/coaching/base_curriculum.py +++ b/reasoning_gym/coaching/base_curriculum.py @@ -1,8 +1,9 @@ -from typing import Any, Iterable, Optional +from typing import Any, Optional, TypeVar -from ..factory import ConfigT from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition +ConfigT = TypeVar("ConfigT") + class BaseCurriculum: def __init__(self, name: str, config_cls: ConfigT): @@ -21,7 +22,6 @@ class BaseCurriculum: elif isinstance(attr, ScalarAttributeDefinition): val = self.get_attr_value(attr.name) config_args[attr.field_name] = val - print(config_args) return self._config_cls(**config_args) @property diff --git a/reasoning_gym/coaching/curriculum_config.py b/reasoning_gym/coaching/curriculum_config.py new file mode 100644 index 00000000..7b431204 --- /dev/null +++ b/reasoning_gym/coaching/curriculum_config.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from typing import Dict, Optional + +import yaml + + +@dataclass +class CurriculumAttributeConfig: + """Configuration for curriculum attribute levels""" + + # Dictionary mapping attribute names to levels + # Special key "*" means apply that level to all attributes + attribute_levels: Dict[str, int] + # Weight for sampling this dataset + weight: float = 1.0 + + def validate(self): + """Validate the configuration""" + if not self.attribute_levels: + raise ValueError("Must specify at least one attribute level") + + +@dataclass +class CurriculumExperimentConfig: + """Configuration for curriculum experiments""" + + # Dictionary mapping dataset names to their curriculum configurations + curricula: Dict[str, CurriculumAttributeConfig] + + def validate(self): + """Validate the configuration""" + if not self.curricula: + raise ValueError("Must specify at least one curriculum") + + for dataset_name, attr_config in self.curricula.items(): + if not isinstance(attr_config, CurriculumAttributeConfig): + raise ValueError(f"Invalid attribute config for dataset {dataset_name}") + attr_config.validate() + + @classmethod + def from_yaml_stream(cls, stream) -> "CurriculumExperimentConfig": + """Load configuration from a YAML stream + + Args: + stream: A file-like object containing YAML data + + Returns: + CurriculumExperimentConfig instance + + Raises: + ValueError: If YAML data has invalid format + """ + data = yaml.safe_load(stream) + + if not isinstance(data, dict): + raise ValueError("YAML data must contain a dictionary") + + if "curricula" not in data: + raise ValueError("YAML data must contain a 'curricula' key") + + # Convert curriculum configs + curricula = {} + for dataset_name, config in data["curricula"].items(): + if not isinstance(config, dict): + raise ValueError(f"Curriculum config for {dataset_name} must be a dictionary") + + if "attribute_levels" not in config: + raise ValueError(f"Curriculum config for {dataset_name} must contain 'attribute_levels'") + + weight = config.get("weight", 1.0) + curricula[dataset_name] = CurriculumAttributeConfig( + attribute_levels=config["attribute_levels"], weight=weight + ) + + return cls(curricula=curricula) + + @classmethod + def from_yaml(cls, yaml_path: str) -> "CurriculumExperimentConfig": + """Load configuration from YAML file + + Args: + yaml_path: Path to YAML configuration file + + Returns: + CurriculumExperimentConfig instance + + Raises: + ValueError: If YAML file has invalid format + """ + with open(yaml_path, "r") as f: + return cls.from_yaml_stream(f) diff --git a/reasoning_gym/coaching/experiment.py b/reasoning_gym/coaching/experiment.py index d3a9e00f..bbb02b88 100644 --- a/reasoning_gym/coaching/experiment.py +++ b/reasoning_gym/coaching/experiment.py @@ -1,36 +1,93 @@ """Experiment class combining dataset, scoreboard and curriculum.""" -from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional -from ..composite import CompositeConfig, CompositeDataset +from ..composite import CompositeConfig, CompositeDataset, DatasetSpec +from ..factory import create_curriculum from ..version_manager import DatasetVersionManager from .coach import ScoreBoard +from .curriculum_config import CurriculumExperimentConfig -@dataclass class Experiment: - """ - An experiment combines a dataset with scoring and curriculum management. + def __init__(self, name: str, composite: CompositeDataset): + self.name = name + self.composite = composite + self.score_board = ScoreBoard() - Attributes: - name: Unique identifier for the experiment - dataset: The composite dataset for generating examples - score_board: Tracks performance metrics - config: The configuration used to create the dataset - version_manager: Manages dataset versions for scoring - """ + def get_dataset_entry(self, index: int) -> dict: + return self.composite[index] - name: str - dataset: CompositeDataset - score_board: ScoreBoard - config: CompositeConfig - version_manager: DatasetVersionManager + def score_answer_with_id( + self, answer: Optional[str], entry_id: str, conversation: Optional[list[dict]] = None + ) -> float: + dataset, index, dataset_name = self.composite.resolve_entry_id(entry_id) + entry = dataset[index] + score = dataset.score_answer(answer, entry) + metadata = entry["metadata"] + self.score_board.add_score(score, metadata, conversation) + return score @classmethod def create(cls, name: str, config: CompositeConfig) -> "Experiment": """Create a new experiment from a configuration.""" version_manager = DatasetVersionManager() dataset = CompositeDataset(config, version_manager=version_manager) - score_board = ScoreBoard() - return cls(name=name, dataset=dataset, score_board=score_board, config=config, version_manager=version_manager) + return cls(name=name, dataset=dataset) + + +class CurriculumExperiment(Experiment): + def __init__(self, name: str, config: CurriculumExperimentConfig, size: int, seed: Optional[int] = None): + """Initialize curriculum experiment with configured datasets and their curricula. + + Args: + name: Name of the experiment + config: Configuration specifying datasets and their attribute levels + size: Number of examples to generate + seed: Random seed for reproducibility + """ + # Initialize curricula and build dataset specs + self.curricula = {} + dataset_specs = [] + + # Process each dataset in the curriculum config + for dataset_name, attr_config in config.curricula.items(): + # Create and store curriculum + curriculum = create_curriculum(dataset_name) + self.curricula[dataset_name] = curriculum + + # Handle special "*" attribute that sets all levels + if "*" in attr_config.attribute_levels: + level = attr_config.attribute_levels["*"] + for attr_name in curriculum.attributes: + curriculum.set_attr_level(attr_name, level) + + # Set individual attribute levels (overriding "*" if specified) + for attr_name, level in attr_config.attribute_levels.items(): + if attr_name != "*": + curriculum.set_attr_level(attr_name, level) + + # Generate dataset config from curriculum + dataset_config = curriculum.generate_configuration() + + # Create dataset spec + spec = DatasetSpec(name=dataset_name, weight=attr_config.weight, config=dataset_config.__dict__) + dataset_specs.append(spec) + + # Create composite config with all datasets + composite_config = CompositeConfig(size=size, seed=seed, datasets=dataset_specs) + + # Create composite dataset + version_manager = DatasetVersionManager() + composite = CompositeDataset(config=composite_config, version_manager=version_manager) + + # Initialize base experiment + super().__init__(name=name, composite=composite) + + # Store curriculum config + self.curriculum_config = config + + def update_difficulty(self): + """Update difficulty levels based on performance metrics""" + # TODO: Implement difficulty adjustment logic + pass diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py index 0e15bb0e..05700151 100644 --- a/reasoning_gym/composite.py +++ b/reasoning_gym/composite.py @@ -48,10 +48,16 @@ class CompositeConfig: ds.validate() @classmethod - def from_yaml(cls, yaml_path: str) -> "CompositeConfig": - """Load configuration from YAML file""" - with open(yaml_path, "r") as f: - data = yaml.safe_load(f) + def from_yaml_stream(cls, stream) -> "CompositeConfig": + """Load configuration from a YAML stream + + Args: + stream: A file-like object containing YAML data + + Returns: + CompositeConfig instance + """ + data = yaml.safe_load(stream) # Convert dataset specs to DatasetSpec objects if "datasets" in data: @@ -59,6 +65,19 @@ class CompositeConfig: return cls(**data) + @classmethod + def from_yaml(cls, yaml_path: str) -> "CompositeConfig": + """Load configuration from YAML file + + Args: + yaml_path: Path to YAML configuration file + + Returns: + CompositeConfig instance + """ + with open(yaml_path, "r") as f: + return cls.from_yaml_stream(f) + class CompositeDataset(ProceduralDataset): """A dataset that combines multiple datasets with weighted sampling""" @@ -246,20 +265,7 @@ class CompositeDataset(ProceduralDataset): self.dataset_names.pop(idx) self.weights.pop(idx) - def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float: - """Score an answer using an entry_id to lookup the original entry - - Args: - answer: The answer to score - entry_id: String in format "version_id.index" - - Returns: - Score between 0 and 1 - - Raises: - ValueError: If entry_id format is invalid - KeyError: If version not found in version manager - """ + def resolve_entry_id(self, entry_id: str) -> tuple[ProceduralDataset, int, str]: if self.version_manager is None: raise RuntimeError("Version manager required for scoring with entry_id") @@ -274,6 +280,24 @@ class CompositeDataset(ProceduralDataset): raise KeyError(f"Version {version_id} not found in version manager") dataset_name, dataset = dataset_info + return dataset, index, dataset_name + + def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float: + """Score an answer using an entry_id to lookup the original entry + + Args: + answer: The answer to score + entry_id: String in format "version_id.index" + + Returns: + Score between 0 and 1 + + Raises: + ValueError: If entry_id format is invalid + KeyError: If version not found in version manager + """ + + dataset, index, _ = self.resolve_entry_id(entry_id) # Get entry from dataset entry = dataset[index] diff --git a/reasoning_gym/factory.py b/reasoning_gym/factory.py index 3f5a62b2..95567da1 100644 --- a/reasoning_gym/factory.py +++ b/reasoning_gym/factory.py @@ -1,24 +1,34 @@ from dataclasses import is_dataclass -from typing import Type, TypeVar +from typing import Optional, Type, TypeVar + +from reasoning_gym.coaching.base_curriculum import BaseCurriculum, ConfigT from .dataset import ProceduralDataset # Type variables for generic type hints -ConfigT = TypeVar("ConfigT") + DatasetT = TypeVar("DatasetT", bound=ProceduralDataset) +CurriculumT = TypeVar("CurriculumT", bound=BaseCurriculum) # Global registry of datasets DATASETS: dict[str, tuple[Type[ProceduralDataset], Type]] = {} +CURRICULA: dict[str, BaseCurriculum] = {} -def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None: +def register_dataset( + name: str, + dataset_cls: Type[DatasetT], + config_cls: Type[ConfigT], + curriculum_cls: Optional[CurriculumT] = None, +) -> None: """ - Register a dataset class with its configuration class. + Register a dataset class with its configuration class and optional curriculum. Args: name: Unique identifier for the dataset dataset_cls: Class derived from ProceduralDataset config_cls: Configuration dataclass for the dataset + curriculum_cls: Optional curriculum class for progressive difficulty Raises: ValueError: If name is already registered or invalid types provided @@ -34,6 +44,9 @@ def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[Co DATASETS[name] = (dataset_cls, config_cls) + if curriculum_cls: + CURRICULA[name] = curriculum_cls + def create_dataset(name: str, **kwargs) -> ProceduralDataset: """ @@ -56,3 +69,28 @@ def create_dataset(name: str, **kwargs) -> ProceduralDataset: config = config_cls(**kwargs) return dataset_cls(config=config) + + +def create_curriculum(name: str) -> BaseCurriculum: + """ + Create a curriculum instance for the named dataset. + + Args: + name: Registered dataset name + + Returns: + Configured curriculum instance + + Raises: + ValueError: If dataset not found or has no curriculum registered + """ + if name not in CURRICULA: + raise ValueError(f"No curriculum registered for dataset '{name}'") + + curriculum_cls = CURRICULA[name] + + return curriculum_cls() + + +def has_curriculum(name: str) -> bool: + return name in CURRICULA diff --git a/reasoning_gym/games/n_queens.py b/reasoning_gym/games/n_queens.py index 723a4ef8..fe43583e 100644 --- a/reasoning_gym/games/n_queens.py +++ b/reasoning_gym/games/n_queens.py @@ -179,4 +179,4 @@ class NQueensCurriculum(BaseCurriculum): ) -register_dataset("n_queens", NQueensDataset, NQueensConfig) +register_dataset("n_queens", NQueensDataset, NQueensConfig, NQueensCurriculum) diff --git a/reasoning_gym/graphs/course_schedule.py b/reasoning_gym/graphs/course_schedule.py index ad3b6ee3..cf25a786 100644 --- a/reasoning_gym/graphs/course_schedule.py +++ b/reasoning_gym/graphs/course_schedule.py @@ -174,4 +174,4 @@ class CourseScheduleCurriculum(BaseCurriculum): ) -register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig) +register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig, CourseScheduleCurriculum) diff --git a/reasoning_gym/graphs/largest_island.py b/reasoning_gym/graphs/largest_island.py index 42b08095..17826a67 100644 --- a/reasoning_gym/graphs/largest_island.py +++ b/reasoning_gym/graphs/largest_island.py @@ -191,4 +191,4 @@ class LargestIslandCurriculum(BaseCurriculum): ) -register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig) +register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig, LargestIslandCurriculum) diff --git a/reasoning_gym/graphs/shortest_path.py b/reasoning_gym/graphs/shortest_path.py index f7924242..bcf40a2b 100644 --- a/reasoning_gym/graphs/shortest_path.py +++ b/reasoning_gym/graphs/shortest_path.py @@ -193,4 +193,4 @@ class ShortestPathCurriculum(BaseCurriculum): ) -register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig) +register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig, ShortestPathCurriculum) diff --git a/tests/coaching/test_curriculum_experiment.py b/tests/coaching/test_curriculum_experiment.py new file mode 100644 index 00000000..5088d683 --- /dev/null +++ b/tests/coaching/test_curriculum_experiment.py @@ -0,0 +1,131 @@ +import io + +import pytest +import yaml + +from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig +from reasoning_gym.coaching.experiment import CurriculumExperiment + + +def test_curriculum_experiment_initialization(): + """Test basic initialization of CurriculumExperiment""" + + # Create config with leg_counting dataset + config = CurriculumExperimentConfig( + curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"num_animals": 2}, weight=1.0)} + ) + + # Create experiment + experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42) + + # Check experiment was created correctly + assert experiment.name == "test_experiment" + assert "leg_counting" in experiment.curricula + assert "leg_counting" in experiment.composite.datasets + + # Check curriculum was configured correctly + curriculum = experiment.curricula["leg_counting"] + assert curriculum.get_attr_level("num_animals") == 2 + + # Check dataset was created with correct config + dataset = experiment.composite.datasets["leg_counting"] + assert dataset.config.min_animals == 1 + assert dataset.config.max_animals == 3 + + # Check we can get entries from the dataset + entry = experiment.get_dataset_entry(0) + assert "question" in entry + assert "answer" in entry + assert "metadata" in entry + assert entry["metadata"]["source_dataset"] == "leg_counting" + + +def test_curriculum_experiment_wildcard_level(): + """Test using "*" to set all attribute levels""" + + config = CurriculumExperimentConfig( + curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"*": 3}, weight=1.0)} + ) + + experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42) + + # Check all attributes were set to level 3 + curriculum = experiment.curricula["leg_counting"] + for attr_name in curriculum.attributes: + assert curriculum.get_attr_level(attr_name) == 3 + + +def test_curriculum_experiment_mixed_levels(): + """Test mixing "*" with specific attribute levels""" + + config = CurriculumExperimentConfig( + curricula={ + "leg_counting": CurriculumAttributeConfig( + attribute_levels={"*": 2, "num_animals": 4}, weight=1.0 # Should override the "*" level + ) + } + ) + + experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42) + + curriculum = experiment.curricula["leg_counting"] + assert curriculum.get_attr_level("num_animals") == 4 # Specific override + + +def test_curriculum_experiment_from_yaml(): + """Test loading curriculum experiment config from YAML using a string stream""" + + # Create a YAML string + yaml_content = """ + curricula: + leg_counting: + attribute_levels: + "*": 2 + num_animals: 4 + weight: 1.5 + chain_sum: + attribute_levels: + num_terms: 1 + num_digits: 2 + weight: 0.8 + """ + + # Use StringIO to create a file-like object from the string + from io import StringIO + + yaml_stream = StringIO(yaml_content) + + # Load config from YAML stream + config = CurriculumExperimentConfig.from_yaml_stream(yaml_stream) + + # Verify config was loaded correctly + assert len(config.curricula) == 2 + assert "leg_counting" in config.curricula + assert "chain_sum" in config.curricula + + # Check leg_counting curriculum + leg_counting = config.curricula["leg_counting"] + assert leg_counting.attribute_levels["*"] == 2 + assert leg_counting.attribute_levels["num_animals"] == 4 + assert leg_counting.weight == 1.5 + + # Check chain_sum curriculum + chain_sum = config.curricula["chain_sum"] + assert chain_sum.attribute_levels["num_terms"] == 1 + assert chain_sum.attribute_levels["num_digits"] == 2 + assert chain_sum.weight == 0.8 + + # Create experiment from the loaded config + experiment = CurriculumExperiment(name="yaml_test", config=config, size=10, seed=42) + + # Verify experiment was created correctly + assert "leg_counting" in experiment.curricula + assert "chain_sum" in experiment.curricula + + # Check attribute levels were applied + leg_curriculum = experiment.curricula["leg_counting"] + assert leg_curriculum.get_attr_level("num_animals") == 4 + + chain_sum_curriculum = experiment.curricula["chain_sum"] + assert chain_sum_curriculum.get_attr_level("num_terms") == 1 + assert chain_sum_curriculum.get_attr_level("num_digits") == 2 diff --git a/tools/server/server.py b/tools/server/server.py index 09ded0d9..6f437bc3 100644 --- a/tools/server/server.py +++ b/tools/server/server.py @@ -89,12 +89,12 @@ def create_app(config: ServerConfig) -> FastAPI: try: entries = [] for i in range(base_index, base_index + batch_size): - entry = experiment.dataset[i] + entry = experiment.get_dataset_entry(i) # Create BatchEntry with minimal required data batch_entry = BatchEntry( question=entry["question"], - entry_id=f"{entry['metadata']['version_id']}.{i}", + entry_id=f"{entry['metadata']['entry_id']}", metadata=entry["metadata"], ) entries.append(batch_entry) @@ -115,7 +115,7 @@ def create_app(config: ServerConfig) -> FastAPI: scores = [] entry_ids = [] for item in request.answers: - score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id) + score = experiment.score_answer_with_id(item.answer, item.entry_id) scores.append(score) entry_ids.append(item.entry_id) @@ -134,7 +134,7 @@ def create_app(config: ServerConfig) -> FastAPI: # Convert internal config to API response format datasets = {} for ds_spec in experiment.config.datasets: - dataset = experiment.dataset.datasets[ds_spec.name] + dataset = experiment.composite.datasets[ds_spec.name] datasets[ds_spec.name] = { "weight": ds_spec.weight, "config": vars(dataset.config), # Get current config from dataset instance @@ -152,7 +152,7 @@ def create_app(config: ServerConfig) -> FastAPI: raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") try: - experiment.dataset.update_dataset_config(dataset_name, config_update.config) + experiment.composite.update_dataset_config(dataset_name, config_update.config) return {"status": "updated"} except KeyError: raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment") diff --git a/tools/server/tests/test_registry.py b/tools/server/tests/test_registry.py index 9e19df03..e84f2db2 100644 --- a/tools/server/tests/test_registry.py +++ b/tools/server/tests/test_registry.py @@ -35,7 +35,7 @@ def test_experiment_management(): exp = registry.get_experiment("test_exp") assert exp is not None assert exp.name == "test_exp" - assert isinstance(exp.dataset, CompositeDataset) + assert isinstance(exp.composite, CompositeDataset) assert exp.config == config # Test removal