Files
reasoning-gym/reasoning_gym/factory.py
rasdani 72e45e9401 Feat: expose score_answer function without needing to instantiate a dataset (#422)
* feat: get `score_answer` for a given dataset

* fix: `self` error

* add test
2025-04-18 10:36:44 +02:00

120 lines
3.1 KiB
Python

from dataclasses import is_dataclass
from typing import Callable, Optional, Type, TypeVar
from reasoning_gym.coaching.base_curriculum import BaseCurriculum, ConfigT
from .dataset import ProceduralDataset
# Type variables for generic type hints
DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
CurriculumT = TypeVar("CurriculumT", bound=BaseCurriculum)
# Global registry of datasets
DATASETS: dict[str, tuple[Type[ProceduralDataset], Type]] = {}
CURRICULA: dict[str, BaseCurriculum] = {}
def register_dataset(
name: str,
dataset_cls: Type[DatasetT],
config_cls: Type[ConfigT],
curriculum_cls: Optional[CurriculumT] = None,
) -> None:
"""
Register a dataset class with its configuration class and optional curriculum.
Args:
name: Unique identifier for the dataset
dataset_cls: Class derived from ProceduralDataset
config_cls: Configuration dataclass for the dataset
curriculum_cls: Optional curriculum class for progressive difficulty
Raises:
ValueError: If name is already registered or invalid types provided
"""
if name in DATASETS:
raise ValueError(f"Dataset '{name}' is already registered")
if not issubclass(dataset_cls, ProceduralDataset):
raise ValueError(f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}")
if not is_dataclass(config_cls):
raise ValueError(f"Config class must be a dataclass, got {config_cls}")
DATASETS[name] = (dataset_cls, config_cls)
if curriculum_cls:
CURRICULA[name] = curriculum_cls
def create_dataset(name: str, **kwargs) -> ProceduralDataset:
"""
Create a dataset instance by name with the given configuration.
Args:
name: Registered dataset name
Returns:
Configured dataset instance
Raises:
ValueError: If dataset not found or config type mismatch
"""
if name not in DATASETS:
raise ValueError(f"Dataset '{name}' not registered")
dataset_cls, config_cls = DATASETS[name]
config = config_cls(**kwargs)
if hasattr(config, "validate"):
config.validate()
return dataset_cls(config=config)
def create_curriculum(name: str) -> BaseCurriculum:
"""
Create a curriculum instance for the named dataset.
Args:
name: Registered dataset name
Returns:
Configured curriculum instance
Raises:
ValueError: If dataset not found or has no curriculum registered
"""
if name not in CURRICULA:
raise ValueError(f"No curriculum registered for dataset '{name}'")
curriculum_cls = CURRICULA[name]
return curriculum_cls()
def has_curriculum(name: str) -> bool:
return name in CURRICULA
def get_score_answer_fn(name: str) -> Callable[[], float]:
"""
Get the score answer function for the named dataset.
Args:
name: Registered dataset name
Returns:
Score function for the dataset
Raises:
ValueError: If dataset not found
"""
if name not in DATASETS:
raise ValueError(f"Dataset '{name}' not registered")
dataset_cls, config_cls = DATASETS[name]
return dataset_cls(config=config_cls()).score_answer