Files
reasoning-gym/reasoning_gym/composite.py
Andreas Köpf c69bc5d4e6 Basic curriculum (#198)
* feat: Add optional curriculum support to dataset registration and creation
* docs: Add docstrings to create_curriculum() and register_dataset()
* feat: Add curriculum configuration classes for CurriculumExperiment
* feat: Add weight parameter to CurriculumAttributeConfig and use in DatasetSpec
* refactor: Simplify CurriculumAttributeConfig with "*" attribute level support
* test: Add unit tests for CurriculumExperiment class
* feat: Add from_yaml() method to CurriculumExperimentConfig with unit test
2025-03-07 11:22:12 +01:00

311 lines
11 KiB
Python

from dataclasses import dataclass, replace
from random import Random
from typing import Any, Optional
import yaml
from .dataset import ProceduralDataset
from .factory import create_dataset, register_dataset
from .version_manager import DatasetVersionManager
@dataclass
class DatasetSpec:
"""Specification for a single dataset within the composite"""
name: str
weight: float
config: dict
def validate(self):
"""Validate dataset specification"""
assert self.name, "Dataset name cannot be empty"
assert self.weight > 0, "Weight must be positive"
assert isinstance(self.config, dict), "Config must be a dictionary"
@dataclass
class CompositeConfig:
"""Configuration for CompositeDataset"""
size: int = 500
seed: Optional[int] = None
datasets: list[DatasetSpec] = None
def validate(self):
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.datasets, "Must specify at least one dataset"
assert len(self.datasets) > 0, "Must specify at least one dataset"
# Check for duplicate dataset names
dataset_names = [ds.name for ds in self.datasets]
if len(dataset_names) != len(set(dataset_names)):
raise ValueError("Duplicate dataset names are not allowed in CompositeDataset")
# Validate each dataset spec
for ds in self.datasets:
ds.validate()
@classmethod
def from_yaml_stream(cls, stream) -> "CompositeConfig":
"""Load configuration from a YAML stream
Args:
stream: A file-like object containing YAML data
Returns:
CompositeConfig instance
"""
data = yaml.safe_load(stream)
# Convert dataset specs to DatasetSpec objects
if "datasets" in data:
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]]
return cls(**data)
@classmethod
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
"""Load configuration from YAML file
Args:
yaml_path: Path to YAML configuration file
Returns:
CompositeConfig instance
"""
with open(yaml_path, "r") as f:
return cls.from_yaml_stream(f)
class CompositeDataset(ProceduralDataset):
"""A dataset that combines multiple datasets with weighted sampling"""
def __init__(self, config: CompositeConfig, version_manager: Optional[DatasetVersionManager] = None):
super().__init__(config=config, seed=config.seed, size=config.size)
self.version_manager = version_manager
self.dataset_versions = {} # dataset_name -> version_id
# Initialize sub-datasets with incremented seeds
self.datasets = {}
self.weights = []
for i, ds_spec in enumerate(config.datasets):
# Create dataset with derived seed
ds_config = ds_spec.config.copy()
if "seed" not in ds_config:
ds_config["seed"] = self.seed + i + 1
if "size" not in ds_config:
ds_config["size"] = self.size
if ds_spec.weight < 0:
raise ValueError(f"Dataset '{ds_spec.name}' has invalid weight {ds_spec.weight}, must be non-negative")
dataset = create_dataset(ds_spec.name, **ds_config)
self.datasets[ds_spec.name] = dataset
# Register version if tracking enabled
if version_manager is not None:
version_id = version_manager.register_dataset(ds_spec.name, dataset)
self.dataset_versions[ds_spec.name] = version_id
self.weights.append(ds_spec.weight) # Store unnormalized weights directly
self.dataset_names = [ds.name for ds in config.datasets]
def __getitem__(self, idx: int) -> dict:
"""Generate a single dataset item by sampling from sub-datasets"""
# Create deterministic RNG for this index
rng = Random(self.seed + idx)
# Sample dataset according to weights
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
dataset_name = self.dataset_names[dataset_idx]
dataset = self.datasets[dataset_name]
# Get item from selected dataset
item = dataset[idx]
# Add source dataset info to metadata
item["metadata"]["source_dataset"] = dataset_name
item["metadata"]["source_index"] = idx
# Add version info if tracking enabled
if self.version_manager is not None:
version_id = self.dataset_versions[dataset_name]
item["metadata"]["version_id"] = version_id
# Add entry_id combining version and index
item["metadata"]["entry_id"] = f"{version_id}.{idx}"
return item
def update_dataset_config(self, dataset_name: str, config_updates: dict[str, Any]) -> None:
"""Update configuration of a specific dataset
Args:
dataset_name: Name of the dataset to update
config_updates: Dictionary of configuration parameters to update
Raises:
KeyError: If dataset_name is not found
AttributeError: If config parameter doesn't exist
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
dataset = self.datasets[dataset_name]
# Update the current config
new_config = replace(dataset.config, **config_updates)
# Validate new config
new_config.validate()
# Create new dataset instance with updated config
dataset_cls = dataset.__class__
new_dataset = dataset_cls(new_config)
self.datasets[dataset_name] = new_dataset
# Register new version if tracking enabled
if self.version_manager is not None:
version_id = self.version_manager.register_dataset(dataset_name, new_dataset)
self.dataset_versions[dataset_name] = version_id
def update_dataset_weight(self, dataset_name: str, weight: float) -> None:
"""Update weight for a specific dataset in the configuration
Args:
dataset_name: Name of the dataset to update
weight: New weight value
Raises:
KeyError: If dataset_name not found
ValueError: If weight is negative
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
if weight < 0:
raise ValueError(f"Weight must be non-negative, got {weight}")
# Update weight in both config and weights list
for i, ds_spec in enumerate(self.config.datasets):
if ds_spec.name == dataset_name:
ds_spec.weight = weight
self.weights[i] = weight
break
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Forward scoring to appropriate dataset"""
dataset_name = entry["metadata"]["source_dataset"]
return self.datasets[dataset_name].score_answer(answer, entry)
def add_dataset(self, dataset_spec: DatasetSpec) -> None:
"""Add a new dataset to the composite
Args:
dataset_spec: Specification for the dataset to add
Raises:
ValueError: If dataset name already exists
"""
# Validate spec
dataset_spec.validate()
# Check for duplicate name
if dataset_spec.name in self.datasets:
raise ValueError(f"Dataset '{dataset_spec.name}' already exists in composite")
# Create dataset with derived seed
ds_config = dataset_spec.config.copy()
if "seed" not in ds_config:
ds_config["seed"] = self.seed + len(self.datasets) + 1
if "size" not in ds_config:
ds_config["size"] = self.size
# Create and add dataset
dataset = create_dataset(dataset_spec.name, **ds_config)
self.datasets[dataset_spec.name] = dataset
# Register version if tracking enabled
if self.version_manager is not None:
version_id = self.version_manager.register_dataset(dataset_spec.name, dataset)
self.dataset_versions[dataset_spec.name] = version_id
# Add to config and update internal state
self.config.datasets.append(dataset_spec)
self.dataset_names.append(dataset_spec.name)
self.weights.append(dataset_spec.weight) # Use weight directly from spec
def remove_dataset(self, dataset_name: str) -> None:
"""Remove a dataset from the composite
Args:
dataset_name: Name of the dataset to remove
Raises:
KeyError: If dataset not found
ValueError: If trying to remove last dataset
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
if len(self.datasets) <= 1:
raise ValueError("Cannot remove last dataset from composite")
# Remove from all internal structures
del self.datasets[dataset_name]
if self.version_manager is not None:
del self.dataset_versions[dataset_name]
# Remove from config
self.config.datasets = [ds for ds in self.config.datasets if ds.name != dataset_name]
# Update internal state
idx = self.dataset_names.index(dataset_name)
self.dataset_names.pop(idx)
self.weights.pop(idx)
def resolve_entry_id(self, entry_id: str) -> tuple[ProceduralDataset, int, str]:
if self.version_manager is None:
raise RuntimeError("Version manager required for scoring with entry_id")
try:
version_id, index = map(int, entry_id.split("."))
except ValueError:
raise ValueError(f"Invalid entry_id format: {entry_id}, expected 'version_id.index'")
# Get dataset from version manager
dataset_info = self.version_manager.get_dataset(version_id)
if dataset_info is None:
raise KeyError(f"Version {version_id} not found in version manager")
dataset_name, dataset = dataset_info
return dataset, index, dataset_name
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
"""Score an answer using an entry_id to lookup the original entry
Args:
answer: The answer to score
entry_id: String in format "version_id.index"
Returns:
Score between 0 and 1
Raises:
ValueError: If entry_id format is invalid
KeyError: If version not found in version manager
"""
dataset, index, _ = self.resolve_entry_id(entry_id)
# Get entry from dataset
entry = dataset[index]
# Score answer using dataset's scoring function
return dataset.score_answer(answer, entry)
# Register the dataset
register_dataset("composite", CompositeDataset, CompositeConfig)