Files
reasoning-gym/reasoning_gym/arithmetic/power_function.py
Zafir Stojanovski dced3bfc45 fix(curriculum): Make boundaries in curriculum more sensible (#407)
* init

* fix tests

* unify codeio

* filtered for libraries not present in reasoning-gym

* fix more bounds

* puzzle24

* knight swap curriculum

* fix number sorting

* fix attributes

* add validation of config in creation of dataset

* dry run for instantiating and validating the datasets

* remove unused imports

* fix curriculum tests to reference newly updated attribute names
2025-04-04 20:24:14 +02:00

105 lines
3.4 KiB
Python

"""Computhe the power of a number."""
from dataclasses import dataclass
from decimal import Decimal
from math import pow
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
Compute {base}^{exponent}. Return your final answer correct to 3 significant figures.
Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4).
"""
DATASET_NAME = "power_function"
@dataclass
class PowerFunctionConfig:
"""Configuration for Power Function dataset generation"""
min_base: float = -1e3 # Minimum base value
max_base: float = 1e3 # Maximum base value
min_exponent: int = 0 # Minimum exponent value
max_exponent: int = 8 # Maximum exponent value
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
class PowerFunctionDataset(ProceduralDataset):
"""Generates Power Function exercises with configurable difficulty"""
def __init__(self, config: PowerFunctionConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the answer by checking if it matches the expected answer to 3 significant figures."""
oracle_answer = entry["answer"]
if answer is not None:
try:
user_answer = Decimal(answer)
oracle_value = Decimal(oracle_answer)
if oracle_value == 0:
return 1.0 if user_answer == 0 else 0.01
user_sig_figs = f"{user_answer:.3g}"
oracle_sig_figs = f"{oracle_value:.3g}"
# Check if they match to 3 significant figures
if user_sig_figs == oracle_sig_figs:
return 1.0
else:
return 0.01
except Exception as e:
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict:
"""Generate a single Power Function question"""
rng = Random(self.seed + idx)
base = round(rng.uniform(self.config.min_base, self.config.max_base), 4)
exponent = rng.randint(self.config.min_exponent, self.config.max_exponent)
if rng.random() < 0.5:
exponent = -exponent
answer = pow(base, exponent)
return {
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"base": base,
"exponent": exponent,
"solution": answer,
"difficulty": {
"exponent": (self.config.min_exponent, self.config.max_exponent),
},
},
}
class PowerFunctionCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(PowerFunctionCurriculum.__name__, PowerFunctionConfig)
self._define_attributes(
RangeAttributeDefinition(
name="exponent",
levels=[2, 4, 6, 8, 10],
lower_field_name="min_exponent",
upper_field_name="max_exponent",
),
)
register_dataset(DATASET_NAME, PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum)