mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
120 lines
3.1 KiB
Python
120 lines
3.1 KiB
Python
from dataclasses import is_dataclass
|
|
from typing import Callable, Optional, Type, TypeVar
|
|
|
|
from reasoning_gym.coaching.base_curriculum import BaseCurriculum, ConfigT
|
|
|
|
from .dataset import ProceduralDataset
|
|
|
|
# Type variables for generic type hints
|
|
|
|
DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
|
|
CurriculumT = TypeVar("CurriculumT", bound=BaseCurriculum)
|
|
|
|
# Global registry of datasets
|
|
DATASETS: dict[str, tuple[Type[ProceduralDataset], Type]] = {}
|
|
CURRICULA: dict[str, BaseCurriculum] = {}
|
|
|
|
|
|
def register_dataset(
|
|
name: str,
|
|
dataset_cls: Type[DatasetT],
|
|
config_cls: Type[ConfigT],
|
|
curriculum_cls: Optional[CurriculumT] = None,
|
|
) -> None:
|
|
"""
|
|
Register a dataset class with its configuration class and optional curriculum.
|
|
|
|
Args:
|
|
name: Unique identifier for the dataset
|
|
dataset_cls: Class derived from ProceduralDataset
|
|
config_cls: Configuration dataclass for the dataset
|
|
curriculum_cls: Optional curriculum class for progressive difficulty
|
|
|
|
Raises:
|
|
ValueError: If name is already registered or invalid types provided
|
|
"""
|
|
if name in DATASETS:
|
|
raise ValueError(f"Dataset '{name}' is already registered")
|
|
|
|
if not issubclass(dataset_cls, ProceduralDataset):
|
|
raise ValueError(f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}")
|
|
|
|
if not is_dataclass(config_cls):
|
|
raise ValueError(f"Config class must be a dataclass, got {config_cls}")
|
|
|
|
DATASETS[name] = (dataset_cls, config_cls)
|
|
|
|
if curriculum_cls:
|
|
CURRICULA[name] = curriculum_cls
|
|
|
|
|
|
def create_dataset(name: str, **kwargs) -> ProceduralDataset:
|
|
"""
|
|
Create a dataset instance by name with the given configuration.
|
|
|
|
Args:
|
|
name: Registered dataset name
|
|
|
|
Returns:
|
|
Configured dataset instance
|
|
|
|
Raises:
|
|
ValueError: If dataset not found or config type mismatch
|
|
"""
|
|
if name not in DATASETS:
|
|
raise ValueError(f"Dataset '{name}' not registered")
|
|
|
|
dataset_cls, config_cls = DATASETS[name]
|
|
|
|
config = config_cls(**kwargs)
|
|
if hasattr(config, "validate"):
|
|
config.validate()
|
|
|
|
return dataset_cls(config=config)
|
|
|
|
|
|
def create_curriculum(name: str) -> BaseCurriculum:
|
|
"""
|
|
Create a curriculum instance for the named dataset.
|
|
|
|
Args:
|
|
name: Registered dataset name
|
|
|
|
Returns:
|
|
Configured curriculum instance
|
|
|
|
Raises:
|
|
ValueError: If dataset not found or has no curriculum registered
|
|
"""
|
|
if name not in CURRICULA:
|
|
raise ValueError(f"No curriculum registered for dataset '{name}'")
|
|
|
|
curriculum_cls = CURRICULA[name]
|
|
|
|
return curriculum_cls()
|
|
|
|
|
|
def has_curriculum(name: str) -> bool:
|
|
return name in CURRICULA
|
|
|
|
|
|
def get_score_answer_fn(name: str) -> Callable[[], float]:
|
|
"""
|
|
Get the score answer function for the named dataset.
|
|
|
|
Args:
|
|
name: Registered dataset name
|
|
|
|
Returns:
|
|
Score function for the dataset
|
|
|
|
Raises:
|
|
ValueError: If dataset not found
|
|
"""
|
|
if name not in DATASETS:
|
|
raise ValueError(f"Dataset '{name}' not registered")
|
|
|
|
dataset_cls, config_cls = DATASETS[name]
|
|
|
|
return dataset_cls(config=config_cls()).score_answer
|