mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
* feat: Add optional curriculum support to dataset registration and creation * docs: Add docstrings to create_curriculum() and register_dataset() * feat: Add curriculum configuration classes for CurriculumExperiment * feat: Add weight parameter to CurriculumAttributeConfig and use in DatasetSpec * refactor: Simplify CurriculumAttributeConfig with "*" attribute level support * test: Add unit tests for CurriculumExperiment class * feat: Add from_yaml() method to CurriculumExperimentConfig with unit test
150 lines
4.3 KiB
Python
150 lines
4.3 KiB
Python
"""Leg counting task generator"""
|
|
|
|
from dataclasses import dataclass
|
|
from random import Random
|
|
from typing import Optional
|
|
|
|
from reasoning_gym.coaching.attributes import AttributeType, RangeAttributeDefinition
|
|
from reasoning_gym.coaching.base_curriculum import BaseCurriculum
|
|
|
|
from ..factory import ProceduralDataset, register_dataset
|
|
|
|
ANIMALS = {
|
|
# Animals with 0 legs
|
|
"snake": 0,
|
|
"sea slug": 0,
|
|
"jellyfish": 0,
|
|
"flatworm": 0,
|
|
"leech": 0,
|
|
# Animals with 2 legs
|
|
"chicken": 2,
|
|
"bird": 2,
|
|
"human": 2,
|
|
"duck": 2,
|
|
# Animals with 4 legs
|
|
"dog": 4,
|
|
"cat": 4,
|
|
"cow": 4,
|
|
"horse": 4,
|
|
"lion": 4,
|
|
"elephant": 4,
|
|
"giraffe": 4,
|
|
"tiger": 4,
|
|
"deer": 4,
|
|
"sheep": 4,
|
|
# Animals with 5 legs
|
|
"starfish": 5,
|
|
# Animals with 6 legs
|
|
"insect": 6,
|
|
"ant": 6,
|
|
"butterfly": 6,
|
|
"beetle": 6,
|
|
"bee": 6,
|
|
"wasp": 6,
|
|
"grasshopper": 6,
|
|
"cricket": 6,
|
|
"cockroach": 6,
|
|
"praying mantis": 6,
|
|
"firefly": 6,
|
|
# Animals with 8 legs
|
|
"spider": 8,
|
|
"scorpion": 8,
|
|
# Animals with 10 legs
|
|
"crab": 10,
|
|
"lobster": 10,
|
|
"shrimp": 10,
|
|
# Animals with 14 legs
|
|
"woodlouse": 14,
|
|
}
|
|
|
|
QUESTION_TEMPLATE = """Your task is to count how many legs there are in total when given a list of animals.
|
|
|
|
Now, how many legs are there in total if you have {animals}?
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class LegCountingConfig:
|
|
"""Configuration for leg counting task generation"""
|
|
|
|
min_animals: int = 3 # Minimum number of animals in problem
|
|
max_animals: int = 10 # Maximum number of animals
|
|
max_instances: int = 15 # Maximum instances of each animal
|
|
seed: Optional[int] = None
|
|
size: int = 500 # Virtual dataset size
|
|
|
|
def validate(self) -> None:
|
|
"""Validate configuration parameters"""
|
|
assert self.min_animals > 0, "min_animals must be positive"
|
|
assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals"
|
|
assert self.max_instances > 0, "max_instances must be positive"
|
|
|
|
|
|
class LegCountingDataset(ProceduralDataset):
|
|
"""Generates leg counting arithmetic tasks"""
|
|
|
|
def __init__(self, config: LegCountingConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
|
|
def _generate_animals(self, rng: Random) -> dict[str, int]:
|
|
"""Generate a random set of animals and their counts"""
|
|
num_types = rng.randint(self.config.min_animals, self.config.max_animals)
|
|
animals = {}
|
|
|
|
# Select random animals
|
|
selected_animals = rng.sample(list(ANIMALS.keys()), num_types)
|
|
for animal in selected_animals:
|
|
count = rng.randint(1, self.config.max_instances)
|
|
animals[animal] = count
|
|
|
|
return animals
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""Generate a single leg counting task"""
|
|
rng = Random(self.seed + idx)
|
|
|
|
# Generate random animals and their counts
|
|
animals = self._generate_animals(rng)
|
|
|
|
# Calculate total legs
|
|
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
|
|
|
|
# Format animal counts for question
|
|
animal_list = []
|
|
for animal, count in animals.items():
|
|
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
|
|
|
return {
|
|
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
|
|
"answer": str(total_legs),
|
|
"metadata": {
|
|
"difficulty": {
|
|
"num_animals": len(animals),
|
|
},
|
|
"animals": animals,
|
|
"total_legs": total_legs,
|
|
},
|
|
}
|
|
|
|
|
|
class LegCountingCurriculum(BaseCurriculum):
|
|
def __init__(self):
|
|
super().__init__(LegCountingCurriculum.__name__, LegCountingConfig)
|
|
|
|
# Define attributes
|
|
self._define_attributes(
|
|
RangeAttributeDefinition(
|
|
name="num_animals",
|
|
levels=list(range(1, 20)),
|
|
default_level=0, # Start with 2 terms
|
|
description="Number of animals in question",
|
|
attr_type=AttributeType.APPEND,
|
|
min_value=1, # Ensure at least 1 animal
|
|
lower_field_name="min_animals",
|
|
upper_field_name="max_animals",
|
|
),
|
|
)
|
|
|
|
|
|
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig, LegCountingCurriculum)
|