mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
Add probability dataset (initial: Coin Flip dataset + curriculum) (#505)
This commit is contained in:
7
reasoning_gym/probability/__init__.py
Normal file
7
reasoning_gym/probability/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Probability reasoning tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .coin_flip import CoinFlipConfig, CoinFlipCurriculum, CoinFlipDataset
|
||||||
|
|
||||||
|
__all__ = ["CoinFlipDataset", "CoinFlipConfig", "CoinFlipCurriculum"]
|
||||||
169
reasoning_gym/probability/coin_flip.py
Normal file
169
reasoning_gym/probability/coin_flip.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from reasoning_gym.dataset import ProceduralDataset
|
||||||
|
|
||||||
|
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||||
|
from ..factory import register_dataset
|
||||||
|
|
||||||
|
DATASET_NAME = "coin_flip"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CoinFlipConfig:
|
||||||
|
"""Configuration for coin flip probability task generation."""
|
||||||
|
|
||||||
|
min_trials: int = 3
|
||||||
|
max_trials: int = 15
|
||||||
|
allow_exact: bool = True # whether to allow "exactly k heads" problems
|
||||||
|
allow_at_least: bool = True # whether to allow "at least k heads" problems
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
assert self.size > 0, "size must be positive"
|
||||||
|
assert self.min_trials > 0, "min_trials must be positive"
|
||||||
|
assert self.max_trials >= self.min_trials, "max_trials must be >= min_trials"
|
||||||
|
assert self.allow_exact or self.allow_at_least, "At least one of allow_exact or allow_at_least must be True"
|
||||||
|
|
||||||
|
|
||||||
|
class CoinFlipDataset(ProceduralDataset):
|
||||||
|
"""Generates coin-flip probability problems (exact k heads / at-least k heads)."""
|
||||||
|
|
||||||
|
def __init__(self, config: CoinFlipConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""
|
||||||
|
Generate a single N coin flip probability problem.
|
||||||
|
Args:
|
||||||
|
idx: Index of the item to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- question: str, the formatted arithmetic expression
|
||||||
|
- answer: str, the ground truth result
|
||||||
|
- metadata: dict with generation parameters
|
||||||
|
"""
|
||||||
|
# Create deterministic RNG from base seed and idx
|
||||||
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
|
# Pick number of trials
|
||||||
|
n = rng.randint(self.config.min_trials, self.config.max_trials)
|
||||||
|
|
||||||
|
available_types = []
|
||||||
|
if self.config.allow_exact:
|
||||||
|
available_types.append("exact")
|
||||||
|
if self.config.allow_at_least:
|
||||||
|
available_types.append("at_least")
|
||||||
|
|
||||||
|
problem_type = rng.choice(available_types)
|
||||||
|
|
||||||
|
if problem_type == "exact":
|
||||||
|
k = rng.randint(0, n)
|
||||||
|
question = f"What is the probability of getting exactly {k} heads in {n} fair coin flips?"
|
||||||
|
prob = self._prob_exact_heads(n, k) # compute actual answer as float
|
||||||
|
|
||||||
|
else:
|
||||||
|
k = rng.randint(0, n)
|
||||||
|
question = f"What is the probability of getting at least {k} heads in {n} fair coin flips?"
|
||||||
|
prob = self._prob_at_least_heads(n, k) # compute actual answer as float
|
||||||
|
|
||||||
|
answer_str = format(prob, ".10g")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": question,
|
||||||
|
"answer": answer_str,
|
||||||
|
"metadata": {
|
||||||
|
"source_dataset": DATASET_NAME,
|
||||||
|
"source_index": idx,
|
||||||
|
"num_trials": n,
|
||||||
|
"k_heads": k,
|
||||||
|
"problem_type": problem_type,
|
||||||
|
"rational": {
|
||||||
|
"numerator": self._rational_numerator(n, k, problem_type),
|
||||||
|
"denominator": 2**n,
|
||||||
|
},
|
||||||
|
"difficulty": {
|
||||||
|
"num_trials": (self.config.min_trials, self.config.max_trials),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _prob_exact_heads(self, n: int, k: int) -> float:
|
||||||
|
"""Return probability of exactly k heads in n fair coin tosses."""
|
||||||
|
comb = math.comb(n, k)
|
||||||
|
return comb * (0.5**n)
|
||||||
|
|
||||||
|
def _prob_at_least_heads(self, n: int, k: int) -> float:
|
||||||
|
"""Return probability of at least k heads in n fair coin tosses."""
|
||||||
|
total = sum(math.comb(n, i) for i in range(k, n + 1))
|
||||||
|
return total * (0.5**n)
|
||||||
|
|
||||||
|
def _rational_numerator(self, n: int, k: int, problem_type: str) -> int:
|
||||||
|
"""Return the numerator of the probability as a rational number."""
|
||||||
|
if problem_type == "exact":
|
||||||
|
return math.comb(n, k)
|
||||||
|
else:
|
||||||
|
return sum(math.comb(n, i) for i in range(k, n + 1))
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: dict, tol: float = 1e-4) -> float:
|
||||||
|
"""
|
||||||
|
Compute reward for LLM answer against oracle probability.
|
||||||
|
Handles decimals, fractions, small numeric errors, and extra text.
|
||||||
|
"""
|
||||||
|
reward = 0.0
|
||||||
|
oracle_answer = entry["answer"]
|
||||||
|
|
||||||
|
if answer is None or len(answer.strip()) == 0:
|
||||||
|
return reward
|
||||||
|
|
||||||
|
answer = answer.replace(",", "")
|
||||||
|
oracle_answer = oracle_answer.replace(",", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer_float = float(Fraction(answer))
|
||||||
|
oracle_answer_float = float(Fraction(oracle_answer))
|
||||||
|
except (ValueError, ZeroDivisionError):
|
||||||
|
return reward
|
||||||
|
|
||||||
|
if abs(answer_float - oracle_answer_float) <= tol:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
answer_str = f"{answer_float:.10g}"
|
||||||
|
oracle_answer_str = f"{oracle_answer_float:.10g}"
|
||||||
|
|
||||||
|
# Partial Reward for matching prefix
|
||||||
|
match_len = 0
|
||||||
|
for a_char, o_char in zip(answer_str, oracle_answer_str):
|
||||||
|
if a_char == o_char:
|
||||||
|
match_len += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
reward = match_len / min(len(oracle_answer_str), len(answer_str))
|
||||||
|
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
class CoinFlipCurriculum(BaseCurriculum):
|
||||||
|
"""Curriculum that allows scaling the number of tosses."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(CoinFlipCurriculum.__name__, CoinFlipConfig)
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_trials",
|
||||||
|
levels=list(range(3, 16)), # starting from 3 upto 15 tosses
|
||||||
|
default_level=0,
|
||||||
|
description="Number of coin tosses (difficulty)",
|
||||||
|
lower_field_name="min_trials",
|
||||||
|
upper_field_name="max_trials",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset(DATASET_NAME, CoinFlipDataset, CoinFlipConfig, CoinFlipCurriculum)
|
||||||
106
tests/test_coin_flip.py
Normal file
106
tests/test_coin_flip.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.probability import CoinFlipConfig, CoinFlipCurriculum, CoinFlipDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_coin_flip_config_validation():
|
||||||
|
"""Test that invalid configs raise errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CoinFlipConfig(size=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CoinFlipConfig(min_trials=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CoinFlipConfig(min_trials=5, max_trials=3)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CoinFlipConfig(allow_exact=False, allow_at_least=False)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_coin_flip_deterministic():
|
||||||
|
"""Dataset generates same items with same seed"""
|
||||||
|
config = CoinFlipConfig(size=10, seed=42)
|
||||||
|
dataset1 = CoinFlipDataset(config)
|
||||||
|
dataset2 = CoinFlipDataset(config)
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_coin_flip_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = CoinFlipConfig(min_trials=3, max_trials=6, size=7, seed=42)
|
||||||
|
dataset = CoinFlipDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert 0.0 <= float(item["answer"]) <= 1.0
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
metadata = item["metadata"]
|
||||||
|
assert "num_trials" in metadata
|
||||||
|
assert "k_heads" in metadata
|
||||||
|
assert "problem_type" in metadata
|
||||||
|
assert metadata["problem_type"] in ["exact", "at_least"]
|
||||||
|
|
||||||
|
rational = metadata["rational"]
|
||||||
|
assert rational["denominator"] == 2 ** metadata["num_trials"]
|
||||||
|
assert rational["numerator"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_coin_flip_score_answer():
|
||||||
|
"""Test full and partial reward behavior"""
|
||||||
|
config = CoinFlipConfig(size=200, seed=42)
|
||||||
|
dataset = CoinFlipDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
entry = dataset[i]
|
||||||
|
answer = entry["answer"]
|
||||||
|
|
||||||
|
# Exact answer -> full reward
|
||||||
|
reward = dataset.score_answer(answer, entry)
|
||||||
|
assert reward == 1.0
|
||||||
|
|
||||||
|
# Slightly wrong answer -> partial reward
|
||||||
|
if float(answer) + 0.01 <= 1.0:
|
||||||
|
slightly_wrong = str(float(answer) + 0.01)
|
||||||
|
else:
|
||||||
|
slightly_wrong = str(float(answer) - 0.01)
|
||||||
|
reward_partial = dataset.score_answer(slightly_wrong, entry)
|
||||||
|
assert 0.0 <= reward_partial <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_coin_flip_curriculum():
|
||||||
|
"""Test curriculum generates valid configurations and increments attributes"""
|
||||||
|
|
||||||
|
curriculum = CoinFlipCurriculum()
|
||||||
|
base_value = {"size": 100, "seed": 32}
|
||||||
|
|
||||||
|
cfg = curriculum.generate_configuration(base_value)
|
||||||
|
|
||||||
|
assert isinstance(cfg, CoinFlipConfig)
|
||||||
|
assert cfg.size == 100
|
||||||
|
assert cfg.seed == 32
|
||||||
|
assert cfg.min_trials == 3
|
||||||
|
assert cfg.max_trials == 3
|
||||||
|
|
||||||
|
# Increment attribute level for num_trials
|
||||||
|
curriculum.increment_attr_level("num_trials")
|
||||||
|
cfg_inc = curriculum.generate_configuration(base_value)
|
||||||
|
assert cfg_inc.min_trials == 3
|
||||||
|
assert cfg_inc.max_trials == 4
|
||||||
|
|
||||||
|
# Decrement attribute level
|
||||||
|
curriculum.decrement_attr_level("num_trials")
|
||||||
|
cfg_dec = curriculum.generate_configuration(base_value)
|
||||||
|
assert cfg_dec.min_trials == 3
|
||||||
|
assert cfg_dec.max_trials == 3
|
||||||
Reference in New Issue
Block a user