Files
reasoning-gym/reasoning_gym/coaching/curriculum_config.py
Andreas Köpf c69bc5d4e6 Basic curriculum (#198)
* feat: Add optional curriculum support to dataset registration and creation
* docs: Add docstrings to create_curriculum() and register_dataset()
* feat: Add curriculum configuration classes for CurriculumExperiment
* feat: Add weight parameter to CurriculumAttributeConfig and use in DatasetSpec
* refactor: Simplify CurriculumAttributeConfig with "*" attribute level support
* test: Add unit tests for CurriculumExperiment class
* feat: Add from_yaml() method to CurriculumExperimentConfig with unit test
2025-03-07 11:22:12 +01:00

92 lines
2.9 KiB
Python

from dataclasses import dataclass
from typing import Dict, Optional
import yaml
@dataclass
class CurriculumAttributeConfig:
"""Configuration for curriculum attribute levels"""
# Dictionary mapping attribute names to levels
# Special key "*" means apply that level to all attributes
attribute_levels: Dict[str, int]
# Weight for sampling this dataset
weight: float = 1.0
def validate(self):
"""Validate the configuration"""
if not self.attribute_levels:
raise ValueError("Must specify at least one attribute level")
@dataclass
class CurriculumExperimentConfig:
"""Configuration for curriculum experiments"""
# Dictionary mapping dataset names to their curriculum configurations
curricula: Dict[str, CurriculumAttributeConfig]
def validate(self):
"""Validate the configuration"""
if not self.curricula:
raise ValueError("Must specify at least one curriculum")
for dataset_name, attr_config in self.curricula.items():
if not isinstance(attr_config, CurriculumAttributeConfig):
raise ValueError(f"Invalid attribute config for dataset {dataset_name}")
attr_config.validate()
@classmethod
def from_yaml_stream(cls, stream) -> "CurriculumExperimentConfig":
"""Load configuration from a YAML stream
Args:
stream: A file-like object containing YAML data
Returns:
CurriculumExperimentConfig instance
Raises:
ValueError: If YAML data has invalid format
"""
data = yaml.safe_load(stream)
if not isinstance(data, dict):
raise ValueError("YAML data must contain a dictionary")
if "curricula" not in data:
raise ValueError("YAML data must contain a 'curricula' key")
# Convert curriculum configs
curricula = {}
for dataset_name, config in data["curricula"].items():
if not isinstance(config, dict):
raise ValueError(f"Curriculum config for {dataset_name} must be a dictionary")
if "attribute_levels" not in config:
raise ValueError(f"Curriculum config for {dataset_name} must contain 'attribute_levels'")
weight = config.get("weight", 1.0)
curricula[dataset_name] = CurriculumAttributeConfig(
attribute_levels=config["attribute_levels"], weight=weight
)
return cls(curricula=curricula)
@classmethod
def from_yaml(cls, yaml_path: str) -> "CurriculumExperimentConfig":
"""Load configuration from YAML file
Args:
yaml_path: Path to YAML configuration file
Returns:
CurriculumExperimentConfig instance
Raises:
ValueError: If YAML file has invalid format
"""
with open(yaml_path, "r") as f:
return cls.from_yaml_stream(f)