mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
Basic curriculum (#198)
* feat: Add optional curriculum support to dataset registration and creation * docs: Add docstrings to create_curriculum() and register_dataset() * feat: Add curriculum configuration classes for CurriculumExperiment * feat: Add weight parameter to CurriculumAttributeConfig and use in DatasetSpec * refactor: Simplify CurriculumAttributeConfig with "*" attribute level support * test: Add unit tests for CurriculumExperiment class * feat: Add from_yaml() method to CurriculumExperimentConfig with unit test
This commit is contained in:
9
examples/veRL/basic_curriculum/launch.sh
Executable file
9
examples/veRL/basic_curriculum/launch.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
export N_GPUS=4
|
||||
export BASE_MODEL=meta-llama/Llama-3.2-3B-Instruct # meta-llama/Llama-3.2-1B-Instruct
|
||||
export ROLLOUT_TP_SIZE=2
|
||||
export EXPERIMENT_NAME=basic_curriculum
|
||||
export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
|
||||
bash ./train_grpo.sh
|
||||
298
examples/veRL/basic_curriculum/ppo_curriculum.py
Normal file
298
examples/veRL/basic_curriculum/ppo_curriculum.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# This example is an adapted version of Bytedance's code:
|
||||
# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py
|
||||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
import hydra
|
||||
import ray
|
||||
import torch
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
||||
from verl.utils.dataset.rl_dataset import collate_fn
|
||||
from verl.utils.model import compute_position_id_with_mask
|
||||
|
||||
import reasoning_gym
|
||||
import reasoning_gym.utils
|
||||
from reasoning_gym.coaching.curriculum_config import CurriculumExperimentConfig
|
||||
from reasoning_gym.coaching.experiment import CurriculumExperiment
|
||||
from reasoning_gym.utils import extract_answer
|
||||
|
||||
curriculum_config_yaml = """
|
||||
curricula:
|
||||
leg_counting:
|
||||
attribute_levels:
|
||||
num_animals: 2
|
||||
products:
|
||||
attribute_levels:
|
||||
num_terms: 4
|
||||
num_digits: 4
|
||||
chain_sum:
|
||||
attribute_levels:
|
||||
num_terms: 4
|
||||
num_digits: 4
|
||||
weight: 1.0
|
||||
"""
|
||||
|
||||
|
||||
class ReasoningGymDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
experiment_name: str,
|
||||
seed: int,
|
||||
size: int,
|
||||
developer_prompt: Optional[str] = None,
|
||||
developer_role: str = "system",
|
||||
max_prompt_length: int = 2048,
|
||||
truncation: str = "error", ## ['left', 'right', 'error']
|
||||
return_raw_chat: bool = False,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
curriculum_config = CurriculumExperimentConfig.from_yaml_stream(StringIO(curriculum_config_yaml))
|
||||
self.experiment = CurriculumExperiment(experiment_name, curriculum_config, size=size, seed=seed)
|
||||
self.developer_prompt = developer_prompt
|
||||
self.developer_role = developer_role
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.truncation = truncation
|
||||
self.return_raw_chat = return_raw_chat
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.experiment.composite)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
row_dict = self.experiment.get_dataset_entry(index).copy()
|
||||
q = row_dict["question"]
|
||||
|
||||
chat = []
|
||||
if self.developer_prompt is not None:
|
||||
chat.append({"role": self.developer_role, "content": self.developer_prompt})
|
||||
chat.append({"role": "user", "content": q})
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
|
||||
prompt=prompt,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.max_prompt_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
left_pad=True,
|
||||
truncation=self.truncation,
|
||||
)
|
||||
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
row_dict["data_source"] = "reasoning_gym/" + self.dataset_name
|
||||
row_dict["input_ids"] = input_ids[0]
|
||||
row_dict["attention_mask"] = attention_mask[0]
|
||||
row_dict["position_ids"] = position_ids[0]
|
||||
|
||||
# encode prompts without chat template
|
||||
if self.return_raw_chat:
|
||||
row_dict["raw_prompt"] = chat.tolist()
|
||||
|
||||
return row_dict
|
||||
|
||||
|
||||
class RayPPOTrainerCustom(RayPPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
tokenizer,
|
||||
role_worker_mapping: dict,
|
||||
resource_pool_manager,
|
||||
ray_worker_group_cls,
|
||||
experiment_name: str = "basic_curriculum",
|
||||
dataset_size: int = 10000,
|
||||
):
|
||||
self.dataset_size = dataset_size
|
||||
|
||||
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
self.train_dataset = ReasoningGymDataset(
|
||||
tokenizer=tokenizer,
|
||||
experiment_name=experiment_name,
|
||||
seed=1,
|
||||
size=self.dataset_size,
|
||||
developer_prompt=developer_prompt,
|
||||
)
|
||||
|
||||
self.val_dataset = ReasoningGymDataset(
|
||||
tokenizer=tokenizer,
|
||||
experiment_name=experiment_name,
|
||||
seed=2,
|
||||
size=self.dataset_size,
|
||||
developer_prompt=developer_prompt,
|
||||
)
|
||||
|
||||
train_reward_fn = lambda data: self._score_output(data, num_examine=0)
|
||||
val_reward_fn = lambda data: self._score_output(data, num_examine=1)
|
||||
|
||||
super().__init__(
|
||||
config,
|
||||
tokenizer,
|
||||
role_worker_mapping,
|
||||
resource_pool_manager,
|
||||
ray_worker_group_cls,
|
||||
train_reward_fn,
|
||||
val_reward_fn,
|
||||
)
|
||||
|
||||
def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor:
|
||||
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
|
||||
|
||||
num_printed = 0
|
||||
for i in range(len(data)):
|
||||
data_item = data[i] # DataProtoItem
|
||||
|
||||
prompt_ids = data_item.batch["prompts"] # tokenized prompts
|
||||
prompt_length = prompt_ids.shape[-1]
|
||||
|
||||
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
|
||||
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
|
||||
|
||||
response_ids = data_item.batch["responses"]
|
||||
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
|
||||
valid_response_ids = response_ids[:valid_response_length]
|
||||
|
||||
# decode
|
||||
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
|
||||
sequences_str = self.tokenizer.decode(sequences)
|
||||
|
||||
entry_id = data_item.non_tensor_batch["metadata"]["entry_id"]
|
||||
|
||||
score = self._compute_score(
|
||||
solution_str=sequences_str,
|
||||
entry_id=entry_id,
|
||||
)
|
||||
reward_tensor[i, valid_response_length - 1] = score
|
||||
|
||||
if num_printed < num_examine:
|
||||
print(f"reward={score}, seq={sequences_str}")
|
||||
num_printed += 1
|
||||
|
||||
return reward_tensor
|
||||
|
||||
def _compute_score(self, solution_str: str, entry_id: str) -> float:
|
||||
found_answer = extract_answer(solution_str, tag_name="answer")
|
||||
reward = self.train_dataset.experiment.score_answer_with_id(found_answer, entry_id=entry_id)
|
||||
print(f"entry_id: {entry_id}; found answer={found_answer}; reward: {reward};")
|
||||
return reward
|
||||
|
||||
def _create_dataloader(self):
|
||||
self.train_dataloader = DataLoader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=self.config.data.train_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
self.val_dataloader = DataLoader(
|
||||
dataset=self.val_dataset,
|
||||
batch_size=len(self.val_dataset),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
assert len(self.train_dataloader) >= 1
|
||||
assert len(self.val_dataloader) >= 1
|
||||
|
||||
print(f"Size of train dataloader: {len(self.train_dataloader)}")
|
||||
print(f"Size of val dataloader: {len(self.val_dataloader)}")
|
||||
|
||||
# inject total_training_steps to actor/critic optim_config. This is hacky.
|
||||
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
||||
|
||||
if self.config.trainer.total_training_steps is not None:
|
||||
total_training_steps = self.config.trainer.total_training_steps
|
||||
|
||||
self.total_training_steps = total_training_steps
|
||||
print(f"Total training steps: {self.total_training_steps}")
|
||||
|
||||
OmegaConf.set_struct(self.config, True)
|
||||
with open_dict(self.config):
|
||||
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
|
||||
self.config.critic.optim.total_training_steps = total_training_steps
|
||||
|
||||
|
||||
@ray.remote
|
||||
def main_task(config):
|
||||
# print initial config
|
||||
from pprint import pprint
|
||||
|
||||
from verl.utils import hf_tokenizer
|
||||
from verl.utils.fs import copy_local_path_from_hdfs
|
||||
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
||||
|
||||
# instantiate tokenizer
|
||||
tokenizer = hf_tokenizer(local_path)
|
||||
|
||||
# define worker classes
|
||||
if config.actor_rollout_ref.actor.strategy == "fsdp":
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
|
||||
|
||||
ray_worker_group_cls = RayWorkerGroup
|
||||
|
||||
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
|
||||
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
||||
|
||||
ray_worker_group_cls = NVMegatronRayWorkerGroup
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
||||
|
||||
role_worker_mapping = {
|
||||
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
||||
Role.Critic: ray.remote(CriticWorker),
|
||||
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
|
||||
}
|
||||
|
||||
global_pool_id = "global_pool"
|
||||
resource_pool_spec = {
|
||||
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
||||
}
|
||||
mapping = {
|
||||
Role.ActorRollout: global_pool_id,
|
||||
Role.Critic: global_pool_id,
|
||||
Role.RefPolicy: global_pool_id,
|
||||
}
|
||||
|
||||
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
||||
|
||||
trainer = RayPPOTrainerCustom(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
role_worker_mapping=role_worker_mapping,
|
||||
resource_pool_manager=resource_pool_manager,
|
||||
ray_worker_group_cls=ray_worker_group_cls,
|
||||
)
|
||||
trainer.init_workers()
|
||||
trainer.fit()
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
|
||||
def main(config):
|
||||
if not ray.is_initialized():
|
||||
# this is for local ray cluster
|
||||
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
|
||||
|
||||
ray.get(main_task.remote(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
39
examples/veRL/basic_curriculum/train_grpo.sh
Normal file
39
examples/veRL/basic_curriculum/train_grpo.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
python3 -u ppo_curriculum.py \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files=$DATA_DIR/train.parquet \
|
||||
data.val_files=$DATA_DIR/test.parquet \
|
||||
data.train_batch_size=512 \
|
||||
data.val_batch_size=512 \
|
||||
data.max_prompt_length=512 \
|
||||
data.max_response_length=1024 \
|
||||
actor_rollout_ref.model.path=$BASE_MODEL \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.n=8 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger=['wandb'] \
|
||||
trainer.project_name='verl_chain_sum_grpo' \
|
||||
trainer.experiment_name=$EXPERIMENT_NAME \
|
||||
trainer.n_gpus_per_node=$N_GPUS \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=100 \
|
||||
trainer.test_freq=100 \
|
||||
trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log
|
||||
171
examples/veRL/chain_sum/config/ppo_trainer.yaml
Normal file
171
examples/veRL/chain_sum/config/ppo_trainer.yaml
Normal file
@@ -0,0 +1,171 @@
|
||||
data:
|
||||
tokenizer: null
|
||||
train_files: ~/data/rlhf/gsm8k/train.parquet
|
||||
val_files: ~/data/rlhf/gsm8k/test.parquet
|
||||
prompt_key: prompt
|
||||
max_prompt_length: 512
|
||||
max_response_length: 512
|
||||
train_batch_size: 1024
|
||||
val_batch_size: 1312
|
||||
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
|
||||
return_raw_chat: False
|
||||
|
||||
actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
external_lib: null
|
||||
override_config: { }
|
||||
enable_gradient_checkpointing: True
|
||||
use_remove_padding: False
|
||||
actor:
|
||||
strategy: fsdp # This is for backward-compatibility
|
||||
ppo_mini_batch_size: 256
|
||||
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
|
||||
ppo_micro_batch_size_per_gpu: null
|
||||
use_dynamic_bsz: False
|
||||
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
|
||||
grad_clip: 1.0
|
||||
clip_ratio: 0.2
|
||||
entropy_coeff: 0.001
|
||||
use_kl_loss: False # True for GRPO
|
||||
kl_loss_coef: 0.001 # for grpo
|
||||
kl_loss_type: low_var_kl # for grpo
|
||||
ppo_epochs: 1
|
||||
shuffle: False
|
||||
ulysses_sequence_parallel_size: 1 # sp size
|
||||
optim:
|
||||
lr: 1e-6
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
fsdp_config:
|
||||
wrap_policy:
|
||||
# transformer_layer_cls_to_wrap: None
|
||||
min_num_params: 0
|
||||
param_offload: False
|
||||
optimizer_offload: False
|
||||
fsdp_size: -1
|
||||
ref:
|
||||
fsdp_config:
|
||||
param_offload: False
|
||||
wrap_policy:
|
||||
# transformer_layer_cls_to_wrap: None
|
||||
min_num_params: 0
|
||||
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
||||
log_prob_micro_batch_size_per_gpu: null
|
||||
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
||||
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
|
||||
rollout:
|
||||
name: vllm
|
||||
temperature: 1.0
|
||||
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
|
||||
top_p: 1
|
||||
prompt_length: ${data.max_prompt_length} # not use for opensource
|
||||
response_length: ${data.max_response_length}
|
||||
# for vllm rollout
|
||||
dtype: bfloat16 # should align with FSDP
|
||||
gpu_memory_utilization: 0.5
|
||||
ignore_eos: False
|
||||
enforce_eager: True
|
||||
free_cache_engine: True
|
||||
load_format: dummy_dtensor
|
||||
tensor_model_parallel_size: 2
|
||||
max_num_batched_tokens: 8192
|
||||
max_num_seqs: 1024
|
||||
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
||||
log_prob_micro_batch_size_per_gpu: null
|
||||
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
||||
disable_log_stats: True
|
||||
enable_chunked_prefill: True # could get higher throughput
|
||||
# for hf rollout
|
||||
do_sample: True
|
||||
# number of responses (i.e. num sample times)
|
||||
n: 1 # > 1 for grpo
|
||||
|
||||
critic:
|
||||
strategy: fsdp
|
||||
optim:
|
||||
lr: 1e-5
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
tokenizer_path: ${actor_rollout_ref.model.path}
|
||||
override_config: { }
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
enable_gradient_checkpointing: True
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
param_offload: False
|
||||
optimizer_offload: False
|
||||
wrap_policy:
|
||||
# transformer_layer_cls_to_wrap: None
|
||||
min_num_params: 0
|
||||
fsdp_size: -1
|
||||
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
|
||||
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
|
||||
ppo_micro_batch_size_per_gpu: null
|
||||
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
|
||||
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
|
||||
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
|
||||
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
|
||||
ulysses_sequence_parallel_size: 1 # sp size
|
||||
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
|
||||
shuffle: ${actor_rollout_ref.actor.shuffle}
|
||||
grad_clip: 1.0
|
||||
cliprange_value: 0.5
|
||||
|
||||
reward_model:
|
||||
enable: False
|
||||
strategy: fsdp
|
||||
model:
|
||||
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
|
||||
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
min_num_params: 0
|
||||
param_offload: False
|
||||
fsdp_size: -1
|
||||
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
|
||||
micro_batch_size_per_gpu: null # set a number
|
||||
max_length: null
|
||||
ulysses_sequence_parallel_size: 1 # sp size
|
||||
use_dynamic_bsz: ${critic.use_dynamic_bsz}
|
||||
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
|
||||
|
||||
algorithm:
|
||||
gamma: 1.0
|
||||
lam: 1.0
|
||||
adv_estimator: gae
|
||||
kl_penalty: kl # how to estimate kl divergence
|
||||
kl_ctrl:
|
||||
type: fixed
|
||||
kl_coef: 0.001
|
||||
|
||||
trainer:
|
||||
total_epochs: 30
|
||||
total_training_steps: null
|
||||
project_name: verl_examples
|
||||
experiment_name: gsm8k
|
||||
logger: [ 'console', 'wandb' ]
|
||||
val_generations_to_log_to_wandb: 0
|
||||
nnodes: 1
|
||||
n_gpus_per_node: 8
|
||||
save_freq: -1
|
||||
# auto: find the last ckpt to resume. If can't find, start from scratch
|
||||
resume_mode: auto # or auto or resume_path if
|
||||
resume_from_path: False
|
||||
test_freq: -1
|
||||
critic_warmup: 0
|
||||
default_hdfs_dir: null
|
||||
remove_previous_ckpt_in_save: False
|
||||
del_local_ckpt_after_load: False
|
||||
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
||||
@@ -123,7 +123,7 @@ class ChainSumCurriculum(BaseCurriculum):
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
levels=list(range(2, 13)),
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
@@ -133,7 +133,7 @@ class ChainSumCurriculum(BaseCurriculum):
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 4, 10],
|
||||
levels=list(range(1, 11)),
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
@@ -145,4 +145,4 @@ class ChainSumCurriculum(BaseCurriculum):
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig)
|
||||
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig, ChainSumCurriculum)
|
||||
|
||||
@@ -65,4 +65,4 @@ class CountBitsCurriculum(BaseCurriculum):
|
||||
)
|
||||
|
||||
|
||||
register_dataset("count_bits", CountBitsDataset, CountBitsConfig)
|
||||
register_dataset("count_bits", CountBitsDataset, CountBitsConfig, CountBitsCurriculum)
|
||||
|
||||
@@ -4,6 +4,9 @@ from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from reasoning_gym.coaching.attributes import AttributeType, RangeAttributeDefinition
|
||||
from reasoning_gym.coaching.base_curriculum import BaseCurriculum
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
ANIMALS = {
|
||||
@@ -124,4 +127,23 @@ class LegCountingDataset(ProceduralDataset):
|
||||
}
|
||||
|
||||
|
||||
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)
|
||||
class LegCountingCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(LegCountingCurriculum.__name__, LegCountingConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="num_animals",
|
||||
levels=list(range(1, 20)),
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Number of animals in question",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1, # Ensure at least 1 animal
|
||||
lower_field_name="min_animals",
|
||||
upper_field_name="max_animals",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig, LegCountingCurriculum)
|
||||
|
||||
@@ -115,7 +115,7 @@ class ProductsCurriculum(BaseCurriculum):
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
levels=list(range(2, 13)),
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
@@ -125,7 +125,7 @@ class ProductsCurriculum(BaseCurriculum):
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 3, 4],
|
||||
levels=list(range(1, 11)),
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
@@ -137,4 +137,4 @@ class ProductsCurriculum(BaseCurriculum):
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("products", ProductsDataset, ProductsConfig)
|
||||
register_dataset("products", ProductsDataset, ProductsConfig, ProductsCurriculum)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Any, Iterable, Optional
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from ..factory import ConfigT
|
||||
from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
|
||||
ConfigT = TypeVar("ConfigT")
|
||||
|
||||
|
||||
class BaseCurriculum:
|
||||
def __init__(self, name: str, config_cls: ConfigT):
|
||||
@@ -21,7 +22,6 @@ class BaseCurriculum:
|
||||
elif isinstance(attr, ScalarAttributeDefinition):
|
||||
val = self.get_attr_value(attr.name)
|
||||
config_args[attr.field_name] = val
|
||||
print(config_args)
|
||||
return self._config_cls(**config_args)
|
||||
|
||||
@property
|
||||
|
||||
91
reasoning_gym/coaching/curriculum_config.py
Normal file
91
reasoning_gym/coaching/curriculum_config.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurriculumAttributeConfig:
|
||||
"""Configuration for curriculum attribute levels"""
|
||||
|
||||
# Dictionary mapping attribute names to levels
|
||||
# Special key "*" means apply that level to all attributes
|
||||
attribute_levels: Dict[str, int]
|
||||
# Weight for sampling this dataset
|
||||
weight: float = 1.0
|
||||
|
||||
def validate(self):
|
||||
"""Validate the configuration"""
|
||||
if not self.attribute_levels:
|
||||
raise ValueError("Must specify at least one attribute level")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurriculumExperimentConfig:
|
||||
"""Configuration for curriculum experiments"""
|
||||
|
||||
# Dictionary mapping dataset names to their curriculum configurations
|
||||
curricula: Dict[str, CurriculumAttributeConfig]
|
||||
|
||||
def validate(self):
|
||||
"""Validate the configuration"""
|
||||
if not self.curricula:
|
||||
raise ValueError("Must specify at least one curriculum")
|
||||
|
||||
for dataset_name, attr_config in self.curricula.items():
|
||||
if not isinstance(attr_config, CurriculumAttributeConfig):
|
||||
raise ValueError(f"Invalid attribute config for dataset {dataset_name}")
|
||||
attr_config.validate()
|
||||
|
||||
@classmethod
|
||||
def from_yaml_stream(cls, stream) -> "CurriculumExperimentConfig":
|
||||
"""Load configuration from a YAML stream
|
||||
|
||||
Args:
|
||||
stream: A file-like object containing YAML data
|
||||
|
||||
Returns:
|
||||
CurriculumExperimentConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If YAML data has invalid format
|
||||
"""
|
||||
data = yaml.safe_load(stream)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("YAML data must contain a dictionary")
|
||||
|
||||
if "curricula" not in data:
|
||||
raise ValueError("YAML data must contain a 'curricula' key")
|
||||
|
||||
# Convert curriculum configs
|
||||
curricula = {}
|
||||
for dataset_name, config in data["curricula"].items():
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f"Curriculum config for {dataset_name} must be a dictionary")
|
||||
|
||||
if "attribute_levels" not in config:
|
||||
raise ValueError(f"Curriculum config for {dataset_name} must contain 'attribute_levels'")
|
||||
|
||||
weight = config.get("weight", 1.0)
|
||||
curricula[dataset_name] = CurriculumAttributeConfig(
|
||||
attribute_levels=config["attribute_levels"], weight=weight
|
||||
)
|
||||
|
||||
return cls(curricula=curricula)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str) -> "CurriculumExperimentConfig":
|
||||
"""Load configuration from YAML file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to YAML configuration file
|
||||
|
||||
Returns:
|
||||
CurriculumExperimentConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If YAML file has invalid format
|
||||
"""
|
||||
with open(yaml_path, "r") as f:
|
||||
return cls.from_yaml_stream(f)
|
||||
@@ -1,36 +1,93 @@
|
||||
"""Experiment class combining dataset, scoreboard and curriculum."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..composite import CompositeConfig, CompositeDataset
|
||||
from ..composite import CompositeConfig, CompositeDataset, DatasetSpec
|
||||
from ..factory import create_curriculum
|
||||
from ..version_manager import DatasetVersionManager
|
||||
from .coach import ScoreBoard
|
||||
from .curriculum_config import CurriculumExperimentConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Experiment:
|
||||
"""
|
||||
An experiment combines a dataset with scoring and curriculum management.
|
||||
def __init__(self, name: str, composite: CompositeDataset):
|
||||
self.name = name
|
||||
self.composite = composite
|
||||
self.score_board = ScoreBoard()
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the experiment
|
||||
dataset: The composite dataset for generating examples
|
||||
score_board: Tracks performance metrics
|
||||
config: The configuration used to create the dataset
|
||||
version_manager: Manages dataset versions for scoring
|
||||
"""
|
||||
def get_dataset_entry(self, index: int) -> dict:
|
||||
return self.composite[index]
|
||||
|
||||
name: str
|
||||
dataset: CompositeDataset
|
||||
score_board: ScoreBoard
|
||||
config: CompositeConfig
|
||||
version_manager: DatasetVersionManager
|
||||
def score_answer_with_id(
|
||||
self, answer: Optional[str], entry_id: str, conversation: Optional[list[dict]] = None
|
||||
) -> float:
|
||||
dataset, index, dataset_name = self.composite.resolve_entry_id(entry_id)
|
||||
entry = dataset[index]
|
||||
score = dataset.score_answer(answer, entry)
|
||||
metadata = entry["metadata"]
|
||||
self.score_board.add_score(score, metadata, conversation)
|
||||
return score
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, config: CompositeConfig) -> "Experiment":
|
||||
"""Create a new experiment from a configuration."""
|
||||
version_manager = DatasetVersionManager()
|
||||
dataset = CompositeDataset(config, version_manager=version_manager)
|
||||
score_board = ScoreBoard()
|
||||
return cls(name=name, dataset=dataset, score_board=score_board, config=config, version_manager=version_manager)
|
||||
return cls(name=name, dataset=dataset)
|
||||
|
||||
|
||||
class CurriculumExperiment(Experiment):
|
||||
def __init__(self, name: str, config: CurriculumExperimentConfig, size: int, seed: Optional[int] = None):
|
||||
"""Initialize curriculum experiment with configured datasets and their curricula.
|
||||
|
||||
Args:
|
||||
name: Name of the experiment
|
||||
config: Configuration specifying datasets and their attribute levels
|
||||
size: Number of examples to generate
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
# Initialize curricula and build dataset specs
|
||||
self.curricula = {}
|
||||
dataset_specs = []
|
||||
|
||||
# Process each dataset in the curriculum config
|
||||
for dataset_name, attr_config in config.curricula.items():
|
||||
# Create and store curriculum
|
||||
curriculum = create_curriculum(dataset_name)
|
||||
self.curricula[dataset_name] = curriculum
|
||||
|
||||
# Handle special "*" attribute that sets all levels
|
||||
if "*" in attr_config.attribute_levels:
|
||||
level = attr_config.attribute_levels["*"]
|
||||
for attr_name in curriculum.attributes:
|
||||
curriculum.set_attr_level(attr_name, level)
|
||||
|
||||
# Set individual attribute levels (overriding "*" if specified)
|
||||
for attr_name, level in attr_config.attribute_levels.items():
|
||||
if attr_name != "*":
|
||||
curriculum.set_attr_level(attr_name, level)
|
||||
|
||||
# Generate dataset config from curriculum
|
||||
dataset_config = curriculum.generate_configuration()
|
||||
|
||||
# Create dataset spec
|
||||
spec = DatasetSpec(name=dataset_name, weight=attr_config.weight, config=dataset_config.__dict__)
|
||||
dataset_specs.append(spec)
|
||||
|
||||
# Create composite config with all datasets
|
||||
composite_config = CompositeConfig(size=size, seed=seed, datasets=dataset_specs)
|
||||
|
||||
# Create composite dataset
|
||||
version_manager = DatasetVersionManager()
|
||||
composite = CompositeDataset(config=composite_config, version_manager=version_manager)
|
||||
|
||||
# Initialize base experiment
|
||||
super().__init__(name=name, composite=composite)
|
||||
|
||||
# Store curriculum config
|
||||
self.curriculum_config = config
|
||||
|
||||
def update_difficulty(self):
|
||||
"""Update difficulty levels based on performance metrics"""
|
||||
# TODO: Implement difficulty adjustment logic
|
||||
pass
|
||||
|
||||
@@ -48,10 +48,16 @@ class CompositeConfig:
|
||||
ds.validate()
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
|
||||
"""Load configuration from YAML file"""
|
||||
with open(yaml_path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
def from_yaml_stream(cls, stream) -> "CompositeConfig":
|
||||
"""Load configuration from a YAML stream
|
||||
|
||||
Args:
|
||||
stream: A file-like object containing YAML data
|
||||
|
||||
Returns:
|
||||
CompositeConfig instance
|
||||
"""
|
||||
data = yaml.safe_load(stream)
|
||||
|
||||
# Convert dataset specs to DatasetSpec objects
|
||||
if "datasets" in data:
|
||||
@@ -59,6 +65,19 @@ class CompositeConfig:
|
||||
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
|
||||
"""Load configuration from YAML file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to YAML configuration file
|
||||
|
||||
Returns:
|
||||
CompositeConfig instance
|
||||
"""
|
||||
with open(yaml_path, "r") as f:
|
||||
return cls.from_yaml_stream(f)
|
||||
|
||||
|
||||
class CompositeDataset(ProceduralDataset):
|
||||
"""A dataset that combines multiple datasets with weighted sampling"""
|
||||
@@ -246,20 +265,7 @@ class CompositeDataset(ProceduralDataset):
|
||||
self.dataset_names.pop(idx)
|
||||
self.weights.pop(idx)
|
||||
|
||||
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
|
||||
"""Score an answer using an entry_id to lookup the original entry
|
||||
|
||||
Args:
|
||||
answer: The answer to score
|
||||
entry_id: String in format "version_id.index"
|
||||
|
||||
Returns:
|
||||
Score between 0 and 1
|
||||
|
||||
Raises:
|
||||
ValueError: If entry_id format is invalid
|
||||
KeyError: If version not found in version manager
|
||||
"""
|
||||
def resolve_entry_id(self, entry_id: str) -> tuple[ProceduralDataset, int, str]:
|
||||
if self.version_manager is None:
|
||||
raise RuntimeError("Version manager required for scoring with entry_id")
|
||||
|
||||
@@ -274,6 +280,24 @@ class CompositeDataset(ProceduralDataset):
|
||||
raise KeyError(f"Version {version_id} not found in version manager")
|
||||
|
||||
dataset_name, dataset = dataset_info
|
||||
return dataset, index, dataset_name
|
||||
|
||||
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
|
||||
"""Score an answer using an entry_id to lookup the original entry
|
||||
|
||||
Args:
|
||||
answer: The answer to score
|
||||
entry_id: String in format "version_id.index"
|
||||
|
||||
Returns:
|
||||
Score between 0 and 1
|
||||
|
||||
Raises:
|
||||
ValueError: If entry_id format is invalid
|
||||
KeyError: If version not found in version manager
|
||||
"""
|
||||
|
||||
dataset, index, _ = self.resolve_entry_id(entry_id)
|
||||
|
||||
# Get entry from dataset
|
||||
entry = dataset[index]
|
||||
|
||||
@@ -1,24 +1,34 @@
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Type, TypeVar
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
from reasoning_gym.coaching.base_curriculum import BaseCurriculum, ConfigT
|
||||
|
||||
from .dataset import ProceduralDataset
|
||||
|
||||
# Type variables for generic type hints
|
||||
ConfigT = TypeVar("ConfigT")
|
||||
|
||||
DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
|
||||
CurriculumT = TypeVar("CurriculumT", bound=BaseCurriculum)
|
||||
|
||||
# Global registry of datasets
|
||||
DATASETS: dict[str, tuple[Type[ProceduralDataset], Type]] = {}
|
||||
CURRICULA: dict[str, BaseCurriculum] = {}
|
||||
|
||||
|
||||
def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None:
|
||||
def register_dataset(
|
||||
name: str,
|
||||
dataset_cls: Type[DatasetT],
|
||||
config_cls: Type[ConfigT],
|
||||
curriculum_cls: Optional[CurriculumT] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register a dataset class with its configuration class.
|
||||
Register a dataset class with its configuration class and optional curriculum.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the dataset
|
||||
dataset_cls: Class derived from ProceduralDataset
|
||||
config_cls: Configuration dataclass for the dataset
|
||||
curriculum_cls: Optional curriculum class for progressive difficulty
|
||||
|
||||
Raises:
|
||||
ValueError: If name is already registered or invalid types provided
|
||||
@@ -34,6 +44,9 @@ def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[Co
|
||||
|
||||
DATASETS[name] = (dataset_cls, config_cls)
|
||||
|
||||
if curriculum_cls:
|
||||
CURRICULA[name] = curriculum_cls
|
||||
|
||||
|
||||
def create_dataset(name: str, **kwargs) -> ProceduralDataset:
|
||||
"""
|
||||
@@ -56,3 +69,28 @@ def create_dataset(name: str, **kwargs) -> ProceduralDataset:
|
||||
config = config_cls(**kwargs)
|
||||
|
||||
return dataset_cls(config=config)
|
||||
|
||||
|
||||
def create_curriculum(name: str) -> BaseCurriculum:
|
||||
"""
|
||||
Create a curriculum instance for the named dataset.
|
||||
|
||||
Args:
|
||||
name: Registered dataset name
|
||||
|
||||
Returns:
|
||||
Configured curriculum instance
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset not found or has no curriculum registered
|
||||
"""
|
||||
if name not in CURRICULA:
|
||||
raise ValueError(f"No curriculum registered for dataset '{name}'")
|
||||
|
||||
curriculum_cls = CURRICULA[name]
|
||||
|
||||
return curriculum_cls()
|
||||
|
||||
|
||||
def has_curriculum(name: str) -> bool:
|
||||
return name in CURRICULA
|
||||
|
||||
@@ -179,4 +179,4 @@ class NQueensCurriculum(BaseCurriculum):
|
||||
)
|
||||
|
||||
|
||||
register_dataset("n_queens", NQueensDataset, NQueensConfig)
|
||||
register_dataset("n_queens", NQueensDataset, NQueensConfig, NQueensCurriculum)
|
||||
|
||||
@@ -174,4 +174,4 @@ class CourseScheduleCurriculum(BaseCurriculum):
|
||||
)
|
||||
|
||||
|
||||
register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig)
|
||||
register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig, CourseScheduleCurriculum)
|
||||
|
||||
@@ -191,4 +191,4 @@ class LargestIslandCurriculum(BaseCurriculum):
|
||||
)
|
||||
|
||||
|
||||
register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig)
|
||||
register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig, LargestIslandCurriculum)
|
||||
|
||||
@@ -193,4 +193,4 @@ class ShortestPathCurriculum(BaseCurriculum):
|
||||
)
|
||||
|
||||
|
||||
register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig)
|
||||
register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig, ShortestPathCurriculum)
|
||||
|
||||
131
tests/coaching/test_curriculum_experiment.py
Normal file
131
tests/coaching/test_curriculum_experiment.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import io
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig
|
||||
from reasoning_gym.coaching.experiment import CurriculumExperiment
|
||||
|
||||
|
||||
def test_curriculum_experiment_initialization():
|
||||
"""Test basic initialization of CurriculumExperiment"""
|
||||
|
||||
# Create config with leg_counting dataset
|
||||
config = CurriculumExperimentConfig(
|
||||
curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"num_animals": 2}, weight=1.0)}
|
||||
)
|
||||
|
||||
# Create experiment
|
||||
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
|
||||
|
||||
# Check experiment was created correctly
|
||||
assert experiment.name == "test_experiment"
|
||||
assert "leg_counting" in experiment.curricula
|
||||
assert "leg_counting" in experiment.composite.datasets
|
||||
|
||||
# Check curriculum was configured correctly
|
||||
curriculum = experiment.curricula["leg_counting"]
|
||||
assert curriculum.get_attr_level("num_animals") == 2
|
||||
|
||||
# Check dataset was created with correct config
|
||||
dataset = experiment.composite.datasets["leg_counting"]
|
||||
assert dataset.config.min_animals == 1
|
||||
assert dataset.config.max_animals == 3
|
||||
|
||||
# Check we can get entries from the dataset
|
||||
entry = experiment.get_dataset_entry(0)
|
||||
assert "question" in entry
|
||||
assert "answer" in entry
|
||||
assert "metadata" in entry
|
||||
assert entry["metadata"]["source_dataset"] == "leg_counting"
|
||||
|
||||
|
||||
def test_curriculum_experiment_wildcard_level():
|
||||
"""Test using "*" to set all attribute levels"""
|
||||
|
||||
config = CurriculumExperimentConfig(
|
||||
curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"*": 3}, weight=1.0)}
|
||||
)
|
||||
|
||||
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
|
||||
|
||||
# Check all attributes were set to level 3
|
||||
curriculum = experiment.curricula["leg_counting"]
|
||||
for attr_name in curriculum.attributes:
|
||||
assert curriculum.get_attr_level(attr_name) == 3
|
||||
|
||||
|
||||
def test_curriculum_experiment_mixed_levels():
|
||||
"""Test mixing "*" with specific attribute levels"""
|
||||
|
||||
config = CurriculumExperimentConfig(
|
||||
curricula={
|
||||
"leg_counting": CurriculumAttributeConfig(
|
||||
attribute_levels={"*": 2, "num_animals": 4}, weight=1.0 # Should override the "*" level
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
experiment = CurriculumExperiment(name="test_experiment", config=config, size=10, seed=42)
|
||||
|
||||
curriculum = experiment.curricula["leg_counting"]
|
||||
assert curriculum.get_attr_level("num_animals") == 4 # Specific override
|
||||
|
||||
|
||||
def test_curriculum_experiment_from_yaml():
|
||||
"""Test loading curriculum experiment config from YAML using a string stream"""
|
||||
|
||||
# Create a YAML string
|
||||
yaml_content = """
|
||||
curricula:
|
||||
leg_counting:
|
||||
attribute_levels:
|
||||
"*": 2
|
||||
num_animals: 4
|
||||
weight: 1.5
|
||||
chain_sum:
|
||||
attribute_levels:
|
||||
num_terms: 1
|
||||
num_digits: 2
|
||||
weight: 0.8
|
||||
"""
|
||||
|
||||
# Use StringIO to create a file-like object from the string
|
||||
from io import StringIO
|
||||
|
||||
yaml_stream = StringIO(yaml_content)
|
||||
|
||||
# Load config from YAML stream
|
||||
config = CurriculumExperimentConfig.from_yaml_stream(yaml_stream)
|
||||
|
||||
# Verify config was loaded correctly
|
||||
assert len(config.curricula) == 2
|
||||
assert "leg_counting" in config.curricula
|
||||
assert "chain_sum" in config.curricula
|
||||
|
||||
# Check leg_counting curriculum
|
||||
leg_counting = config.curricula["leg_counting"]
|
||||
assert leg_counting.attribute_levels["*"] == 2
|
||||
assert leg_counting.attribute_levels["num_animals"] == 4
|
||||
assert leg_counting.weight == 1.5
|
||||
|
||||
# Check chain_sum curriculum
|
||||
chain_sum = config.curricula["chain_sum"]
|
||||
assert chain_sum.attribute_levels["num_terms"] == 1
|
||||
assert chain_sum.attribute_levels["num_digits"] == 2
|
||||
assert chain_sum.weight == 0.8
|
||||
|
||||
# Create experiment from the loaded config
|
||||
experiment = CurriculumExperiment(name="yaml_test", config=config, size=10, seed=42)
|
||||
|
||||
# Verify experiment was created correctly
|
||||
assert "leg_counting" in experiment.curricula
|
||||
assert "chain_sum" in experiment.curricula
|
||||
|
||||
# Check attribute levels were applied
|
||||
leg_curriculum = experiment.curricula["leg_counting"]
|
||||
assert leg_curriculum.get_attr_level("num_animals") == 4
|
||||
|
||||
chain_sum_curriculum = experiment.curricula["chain_sum"]
|
||||
assert chain_sum_curriculum.get_attr_level("num_terms") == 1
|
||||
assert chain_sum_curriculum.get_attr_level("num_digits") == 2
|
||||
@@ -89,12 +89,12 @@ def create_app(config: ServerConfig) -> FastAPI:
|
||||
try:
|
||||
entries = []
|
||||
for i in range(base_index, base_index + batch_size):
|
||||
entry = experiment.dataset[i]
|
||||
entry = experiment.get_dataset_entry(i)
|
||||
|
||||
# Create BatchEntry with minimal required data
|
||||
batch_entry = BatchEntry(
|
||||
question=entry["question"],
|
||||
entry_id=f"{entry['metadata']['version_id']}.{i}",
|
||||
entry_id=f"{entry['metadata']['entry_id']}",
|
||||
metadata=entry["metadata"],
|
||||
)
|
||||
entries.append(batch_entry)
|
||||
@@ -115,7 +115,7 @@ def create_app(config: ServerConfig) -> FastAPI:
|
||||
scores = []
|
||||
entry_ids = []
|
||||
for item in request.answers:
|
||||
score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id)
|
||||
score = experiment.score_answer_with_id(item.answer, item.entry_id)
|
||||
scores.append(score)
|
||||
entry_ids.append(item.entry_id)
|
||||
|
||||
@@ -134,7 +134,7 @@ def create_app(config: ServerConfig) -> FastAPI:
|
||||
# Convert internal config to API response format
|
||||
datasets = {}
|
||||
for ds_spec in experiment.config.datasets:
|
||||
dataset = experiment.dataset.datasets[ds_spec.name]
|
||||
dataset = experiment.composite.datasets[ds_spec.name]
|
||||
datasets[ds_spec.name] = {
|
||||
"weight": ds_spec.weight,
|
||||
"config": vars(dataset.config), # Get current config from dataset instance
|
||||
@@ -152,7 +152,7 @@ def create_app(config: ServerConfig) -> FastAPI:
|
||||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
|
||||
try:
|
||||
experiment.dataset.update_dataset_config(dataset_name, config_update.config)
|
||||
experiment.composite.update_dataset_config(dataset_name, config_update.config)
|
||||
return {"status": "updated"}
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment")
|
||||
|
||||
@@ -35,7 +35,7 @@ def test_experiment_management():
|
||||
exp = registry.get_experiment("test_exp")
|
||||
assert exp is not None
|
||||
assert exp.name == "test_exp"
|
||||
assert isinstance(exp.dataset, CompositeDataset)
|
||||
assert isinstance(exp.composite, CompositeDataset)
|
||||
assert exp.config == config
|
||||
|
||||
# Test removal
|
||||
|
||||
Reference in New Issue
Block a user