Files
reasoning-gym/reasoning_gym/arithmetic/decimal_arithmetic.py
Oliver Stanley 3df26e0fb2 fix decimal arithmetic curriculum to respect constraints (#430)
* fix decimal arithmetic curriculum to respect constraints

* update test accordingly
2025-05-06 19:23:19 +01:00

264 lines
9.8 KiB
Python

import ast
from dataclasses import dataclass
from decimal import ROUND_HALF_UP, Decimal, getcontext
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "decimal_arithmetic"
@dataclass
class DecimalArithmeticConfig:
"""Configuration for decimal arithmetic dataset generation"""
min_num_decimal_places: int = 3
max_num_decimal_places: int = 3
min_terms: int = 2
max_terms: int = 6
precision: int = 12
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert (
self.precision > self.max_num_decimal_places + 1
), "precision must be 2 or more higher than max_num_decimal_places"
def build_grouped_expression(operands: list[str], operators: list[str], rng: Random) -> str:
"""
Recursively build an arithmetic expression string from operands and operators,
inserting parentheses at random.
The expression is built by choosing a random split among the operands;
the operator at that split becomes the "root" of the subexpression.
With 50% chance, the resulting combination is wrapped in parentheses.
"""
if len(operands) == 1:
return operands[0]
# Randomly choose a split point (1 <= split < len(operands)).
split: int = rng.randint(1, len(operands) - 1)
left_expr: str = build_grouped_expression(operands[:split], operators[: split - 1], rng)
right_expr: str = build_grouped_expression(operands[split:], operators[split:], rng)
# The operator at position (split - 1) is the one combining the two groups.
expr: str = left_expr + operators[split - 1] + right_expr
# Randomly decide to add parentheses around this subexpression.
if rng.choice([True, False]):
expr = "(" + expr + ")"
return expr
def generate_arithmetic_problem(
rng: Random,
min_num_decimal_places: int,
max_num_decimal_places: int,
terms: int = 2,
operations: Optional[list[str]] = None,
) -> str:
"""
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
to a specific number of decimal places, with random parenthesis grouping.
Parameters:
rng: Random number generator.
min_num_decimal_places (int): Minimum number of decimal places.
max_num_decimal_places (int): Maximum number of decimal places.
terms (int): Number of numbers in the arithmetic expression.
operations (list): List of operations to use (default: ['+', '-', '*', '/']).
Returns:
str: A formatted arithmetic expression ending with " = ?"
"""
if operations is None:
operations = ["+", "-", "*", "/"]
operands: list[str] = []
operators: list[str] = []
max_ndp = 1
for i in range(terms):
# Choose a random number of decimal places for this term.
ndp: int = rng.randint(min_num_decimal_places, max_num_decimal_places)
if ndp > max_ndp:
max_ndp = ndp
max_integer_part: int = 10 # Maximum whole number before the decimal
max_value: int = max_integer_part * (10**ndp)
raw_int: int = rng.randint(1, max_value)
# Create the Decimal number and quantize it to exactly ndp decimal places.
num: Decimal = Decimal(raw_int) / (Decimal(10) ** ndp)
quantize_str: str = "1." + "0" * ndp
num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
# Format the number as a string with exactly ndp decimals.
num_str: str = f"{num:.{ndp}f}"
operands.append(num_str)
if i < terms - 1:
op: str = rng.choice(operations)
operators.append(op)
expr: str = build_grouped_expression(operands, operators, rng)
problem_str: str = expr + " = ?"
return problem_str, max_ndp
def evaluate_expression(expr: str) -> Decimal:
"""
Safely evaluates a simple arithmetic expression using AST parsing, performing
all arithmetic in the Decimal context.
Args:
expr: A string containing the arithmetic expression.
Returns:
Decimal: The computed result.
"""
tree: ast.Expression = ast.parse(expr, mode="eval")
return _eval_ast(tree.body)
def _eval_ast(node: ast.AST) -> Decimal:
"""Recursively evaluate an AST node using Decimal arithmetic."""
if isinstance(node, ast.BinOp):
left: Decimal = _eval_ast(node.left)
right: Decimal = _eval_ast(node.right)
if isinstance(node.op, ast.Add):
return left + right
elif isinstance(node.op, ast.Sub):
return left - right
elif isinstance(node.op, ast.Mult):
return left * right
elif isinstance(node.op, ast.Div):
return left / right
else:
raise ValueError(f"Unsupported operator: {node.op}")
elif isinstance(node, ast.UnaryOp):
operand: Decimal = _eval_ast(node.operand)
if isinstance(node.op, ast.UAdd):
return operand
elif isinstance(node.op, ast.USub):
return -operand
else:
raise ValueError(f"Unsupported unary operator: {node.op}")
elif isinstance(node, ast.Constant): # For Python 3.8+
return Decimal(str(node.value))
elif isinstance(node, ast.Num): # For older Python versions
return Decimal(str(node.n))
else:
raise ValueError(f"Unsupported expression component: {node}")
class DecimalArithmeticDataset(ProceduralDataset):
"""Dataset that generates basic arithmetic tasks using Decimal arithmetic and proper operator precedence."""
def __init__(self, config: DecimalArithmeticConfig) -> None:
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""
Generate a single arithmetic task.
Returns:
dict: Contains:
- 'question': The formatted arithmetic expression as a string.
- 'answer': The computed Decimal result.
- 'metadata': Additional metadata (currently empty).
"""
# Create a deterministic RNG from base seed and index.
rng: Random = Random(self.seed + idx if self.seed is not None else None)
getcontext().prec = self.config.precision
terms = rng.randint(self.config.min_terms, self.config.max_terms)
problem_str, decimal_places = generate_arithmetic_problem(
rng,
self.config.min_num_decimal_places,
self.config.max_num_decimal_places,
terms=terms,
)
# Remove the trailing " = ?" to obtain the pure arithmetic expression.
expr: str = problem_str.replace(" = ?", "").strip()
answer: Decimal = evaluate_expression(expr)
problem_str = (
f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n"
+ problem_str
)
return {
"question": problem_str,
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"decimal_places": decimal_places,
"num_terms": terms,
"difficulty": {
"decimal_places": (self.config.min_num_decimal_places, self.config.max_num_decimal_places),
"num_terms": (self.config.min_terms, self.config.max_terms),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""
Compares the user's answer (converted to Decimal) with the correct answer.
Instead of requiring exact equality, we allow an error up to one unit in the
least significant digit as determined by the level of precision (max_num_decimal_places).
Returns:
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
"""
if not isinstance(answer, str):
return 0.0
try:
user_ans: Decimal = Decimal(answer)
correct_ans: Decimal = Decimal(entry["answer"])
# Determine tolerance based on the desired precision.
precision: int = self.config.max_num_decimal_places
tol: Decimal = Decimal(10) ** (-precision)
if abs(user_ans - correct_ans) <= tol:
return 1.0
except Exception:
pass
return 0.0
class DecimalArithmeticCurriculum(BaseCurriculum):
"""Curriculum for Decimal Arithmetic"""
def __init__(self):
super().__init__(DecimalArithmeticCurriculum.__name__, DecimalArithmeticConfig)
self._define_attributes(
RangeAttributeDefinition(
name="decimal_places",
levels=[2, 4, 6, 8],
description="Number of decimal places of the numbers in problem",
lower_field_name="min_num_decimal_places",
upper_field_name="max_num_decimal_places",
ensure_interval=True,
),
ScalarAttributeDefinition(
name="precision",
field_name="precision",
description="Precision of the Decimal arithmetic operations",
levels=[6, 8, 10, 12],
),
RangeAttributeDefinition(
name="num_terms",
levels=[2, 5, 8, 10],
description="Number of terms in the arithmetic expression",
lower_field_name="min_terms",
upper_field_name="max_terms",
),
)
# Register the dataset with the factory.
register_dataset(DATASET_NAME, DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum)