mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
* feat: Add optional curriculum support to dataset registration and creation * docs: Add docstrings to create_curriculum() and register_dataset() * feat: Add curriculum configuration classes for CurriculumExperiment * feat: Add weight parameter to CurriculumAttributeConfig and use in DatasetSpec * refactor: Simplify CurriculumAttributeConfig with "*" attribute level support * test: Add unit tests for CurriculumExperiment class * feat: Add from_yaml() method to CurriculumExperimentConfig with unit test
92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
from dataclasses import dataclass
|
|
from typing import Dict, Optional
|
|
|
|
import yaml
|
|
|
|
|
|
@dataclass
|
|
class CurriculumAttributeConfig:
|
|
"""Configuration for curriculum attribute levels"""
|
|
|
|
# Dictionary mapping attribute names to levels
|
|
# Special key "*" means apply that level to all attributes
|
|
attribute_levels: Dict[str, int]
|
|
# Weight for sampling this dataset
|
|
weight: float = 1.0
|
|
|
|
def validate(self):
|
|
"""Validate the configuration"""
|
|
if not self.attribute_levels:
|
|
raise ValueError("Must specify at least one attribute level")
|
|
|
|
|
|
@dataclass
|
|
class CurriculumExperimentConfig:
|
|
"""Configuration for curriculum experiments"""
|
|
|
|
# Dictionary mapping dataset names to their curriculum configurations
|
|
curricula: Dict[str, CurriculumAttributeConfig]
|
|
|
|
def validate(self):
|
|
"""Validate the configuration"""
|
|
if not self.curricula:
|
|
raise ValueError("Must specify at least one curriculum")
|
|
|
|
for dataset_name, attr_config in self.curricula.items():
|
|
if not isinstance(attr_config, CurriculumAttributeConfig):
|
|
raise ValueError(f"Invalid attribute config for dataset {dataset_name}")
|
|
attr_config.validate()
|
|
|
|
@classmethod
|
|
def from_yaml_stream(cls, stream) -> "CurriculumExperimentConfig":
|
|
"""Load configuration from a YAML stream
|
|
|
|
Args:
|
|
stream: A file-like object containing YAML data
|
|
|
|
Returns:
|
|
CurriculumExperimentConfig instance
|
|
|
|
Raises:
|
|
ValueError: If YAML data has invalid format
|
|
"""
|
|
data = yaml.safe_load(stream)
|
|
|
|
if not isinstance(data, dict):
|
|
raise ValueError("YAML data must contain a dictionary")
|
|
|
|
if "curricula" not in data:
|
|
raise ValueError("YAML data must contain a 'curricula' key")
|
|
|
|
# Convert curriculum configs
|
|
curricula = {}
|
|
for dataset_name, config in data["curricula"].items():
|
|
if not isinstance(config, dict):
|
|
raise ValueError(f"Curriculum config for {dataset_name} must be a dictionary")
|
|
|
|
if "attribute_levels" not in config:
|
|
raise ValueError(f"Curriculum config for {dataset_name} must contain 'attribute_levels'")
|
|
|
|
weight = config.get("weight", 1.0)
|
|
curricula[dataset_name] = CurriculumAttributeConfig(
|
|
attribute_levels=config["attribute_levels"], weight=weight
|
|
)
|
|
|
|
return cls(curricula=curricula)
|
|
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: str) -> "CurriculumExperimentConfig":
|
|
"""Load configuration from YAML file
|
|
|
|
Args:
|
|
yaml_path: Path to YAML configuration file
|
|
|
|
Returns:
|
|
CurriculumExperimentConfig instance
|
|
|
|
Raises:
|
|
ValueError: If YAML file has invalid format
|
|
"""
|
|
with open(yaml_path, "r") as f:
|
|
return cls.from_yaml_stream(f)
|