Add probability dataset (initial: Coin Flip dataset + curriculum) (#505)

This commit is contained in:
Kumar Anant
2025-09-06 20:29:23 +05:30
committed by GitHub
parent b399c658ca
commit b0815043a2
3 changed files with 282 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
"""
Probability reasoning tasks.
"""
from .coin_flip import CoinFlipConfig, CoinFlipCurriculum, CoinFlipDataset
__all__ = ["CoinFlipDataset", "CoinFlipConfig", "CoinFlipCurriculum"]

View File

@@ -0,0 +1,169 @@
import math
import random
from dataclasses import dataclass
from fractions import Fraction
from typing import Optional
from reasoning_gym.dataset import ProceduralDataset
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import register_dataset
DATASET_NAME = "coin_flip"
@dataclass
class CoinFlipConfig:
"""Configuration for coin flip probability task generation."""
min_trials: int = 3
max_trials: int = 15
allow_exact: bool = True # whether to allow "exactly k heads" problems
allow_at_least: bool = True # whether to allow "at least k heads" problems
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
assert self.size > 0, "size must be positive"
assert self.min_trials > 0, "min_trials must be positive"
assert self.max_trials >= self.min_trials, "max_trials must be >= min_trials"
assert self.allow_exact or self.allow_at_least, "At least one of allow_exact or allow_at_least must be True"
class CoinFlipDataset(ProceduralDataset):
"""Generates coin-flip probability problems (exact k heads / at-least k heads)."""
def __init__(self, config: CoinFlipConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""
Generate a single N coin flip probability problem.
Args:
idx: Index of the item to generate
Returns:
dict with keys:
- question: str, the formatted arithmetic expression
- answer: str, the ground truth result
- metadata: dict with generation parameters
"""
# Create deterministic RNG from base seed and idx
rng = random.Random(self.seed + idx)
# Pick number of trials
n = rng.randint(self.config.min_trials, self.config.max_trials)
available_types = []
if self.config.allow_exact:
available_types.append("exact")
if self.config.allow_at_least:
available_types.append("at_least")
problem_type = rng.choice(available_types)
if problem_type == "exact":
k = rng.randint(0, n)
question = f"What is the probability of getting exactly {k} heads in {n} fair coin flips?"
prob = self._prob_exact_heads(n, k) # compute actual answer as float
else:
k = rng.randint(0, n)
question = f"What is the probability of getting at least {k} heads in {n} fair coin flips?"
prob = self._prob_at_least_heads(n, k) # compute actual answer as float
answer_str = format(prob, ".10g")
return {
"question": question,
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"num_trials": n,
"k_heads": k,
"problem_type": problem_type,
"rational": {
"numerator": self._rational_numerator(n, k, problem_type),
"denominator": 2**n,
},
"difficulty": {
"num_trials": (self.config.min_trials, self.config.max_trials),
},
},
}
def _prob_exact_heads(self, n: int, k: int) -> float:
"""Return probability of exactly k heads in n fair coin tosses."""
comb = math.comb(n, k)
return comb * (0.5**n)
def _prob_at_least_heads(self, n: int, k: int) -> float:
"""Return probability of at least k heads in n fair coin tosses."""
total = sum(math.comb(n, i) for i in range(k, n + 1))
return total * (0.5**n)
def _rational_numerator(self, n: int, k: int, problem_type: str) -> int:
"""Return the numerator of the probability as a rational number."""
if problem_type == "exact":
return math.comb(n, k)
else:
return sum(math.comb(n, i) for i in range(k, n + 1))
def score_answer(self, answer: Optional[str], entry: dict, tol: float = 1e-4) -> float:
"""
Compute reward for LLM answer against oracle probability.
Handles decimals, fractions, small numeric errors, and extra text.
"""
reward = 0.0
oracle_answer = entry["answer"]
if answer is None or len(answer.strip()) == 0:
return reward
answer = answer.replace(",", "")
oracle_answer = oracle_answer.replace(",", "")
try:
answer_float = float(Fraction(answer))
oracle_answer_float = float(Fraction(oracle_answer))
except (ValueError, ZeroDivisionError):
return reward
if abs(answer_float - oracle_answer_float) <= tol:
return 1.0
answer_str = f"{answer_float:.10g}"
oracle_answer_str = f"{oracle_answer_float:.10g}"
# Partial Reward for matching prefix
match_len = 0
for a_char, o_char in zip(answer_str, oracle_answer_str):
if a_char == o_char:
match_len += 1
else:
break
reward = match_len / min(len(oracle_answer_str), len(answer_str))
return reward
class CoinFlipCurriculum(BaseCurriculum):
"""Curriculum that allows scaling the number of tosses."""
def __init__(self):
super().__init__(CoinFlipCurriculum.__name__, CoinFlipConfig)
self._define_attributes(
RangeAttributeDefinition(
name="num_trials",
levels=list(range(3, 16)), # starting from 3 upto 15 tosses
default_level=0,
description="Number of coin tosses (difficulty)",
lower_field_name="min_trials",
upper_field_name="max_trials",
),
)
register_dataset(DATASET_NAME, CoinFlipDataset, CoinFlipConfig, CoinFlipCurriculum)

106
tests/test_coin_flip.py Normal file
View File

@@ -0,0 +1,106 @@
from fractions import Fraction
import pytest
from reasoning_gym.probability import CoinFlipConfig, CoinFlipCurriculum, CoinFlipDataset
def test_coin_flip_config_validation():
"""Test that invalid configs raise errors"""
with pytest.raises(AssertionError):
config = CoinFlipConfig(size=0)
config.validate()
with pytest.raises(AssertionError):
config = CoinFlipConfig(min_trials=0)
config.validate()
with pytest.raises(AssertionError):
config = CoinFlipConfig(min_trials=5, max_trials=3)
config.validate()
with pytest.raises(AssertionError):
config = CoinFlipConfig(allow_exact=False, allow_at_least=False)
config.validate()
def test_coin_flip_deterministic():
"""Dataset generates same items with same seed"""
config = CoinFlipConfig(size=10, seed=42)
dataset1 = CoinFlipDataset(config)
dataset2 = CoinFlipDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_coin_flip_items():
"""Test basic properties of generated items"""
config = CoinFlipConfig(min_trials=3, max_trials=6, size=7, seed=42)
dataset = CoinFlipDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert 0.0 <= float(item["answer"]) <= 1.0
assert "metadata" in item
metadata = item["metadata"]
assert "num_trials" in metadata
assert "k_heads" in metadata
assert "problem_type" in metadata
assert metadata["problem_type"] in ["exact", "at_least"]
rational = metadata["rational"]
assert rational["denominator"] == 2 ** metadata["num_trials"]
assert rational["numerator"] > 0
def test_coin_flip_score_answer():
"""Test full and partial reward behavior"""
config = CoinFlipConfig(size=200, seed=42)
dataset = CoinFlipDataset(config)
for i in range(len(dataset)):
entry = dataset[i]
answer = entry["answer"]
# Exact answer -> full reward
reward = dataset.score_answer(answer, entry)
assert reward == 1.0
# Slightly wrong answer -> partial reward
if float(answer) + 0.01 <= 1.0:
slightly_wrong = str(float(answer) + 0.01)
else:
slightly_wrong = str(float(answer) - 0.01)
reward_partial = dataset.score_answer(slightly_wrong, entry)
assert 0.0 <= reward_partial <= 1.0
def test_coin_flip_curriculum():
"""Test curriculum generates valid configurations and increments attributes"""
curriculum = CoinFlipCurriculum()
base_value = {"size": 100, "seed": 32}
cfg = curriculum.generate_configuration(base_value)
assert isinstance(cfg, CoinFlipConfig)
assert cfg.size == 100
assert cfg.seed == 32
assert cfg.min_trials == 3
assert cfg.max_trials == 3
# Increment attribute level for num_trials
curriculum.increment_attr_level("num_trials")
cfg_inc = curriculum.generate_configuration(base_value)
assert cfg_inc.min_trials == 3
assert cfg_inc.max_trials == 4
# Decrement attribute level
curriculum.decrement_attr_level("num_trials")
cfg_dec = curriculum.generate_configuration(base_value)
assert cfg_dec.min_trials == 3
assert cfg_dec.max_trials == 3