Files
reasoning-gym/reasoning_gym/utils.py
Oliver Stanley 1a727ecf4e support python 3.10 (#450)
* support python 3.10

* add 3.10 to tests

* new StrEnum
2025-06-04 10:34:01 +01:00

156 lines
6.3 KiB
Python

import math
import re
from decimal import Decimal, InvalidOperation
from enum import Enum
from fractions import Fraction
from typing import Any, Optional, Union
SYSTEM_PROMPTS = {
"DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer>answer here</answer>
Do not explain your reasoning inside the answer tags, provide only the final answer. When an example is provided, you should strictly follow the format of the output/answer in that example.
""",
"default": """Given a problem, your task is to answer the question by thinking step-by-step in a clear and specific manner.
Once you have thought about the reasoning process, provide the answer in the following format:
<answer>answer here</answer>
Do not explain your reasoning inside the answer tags, provide only the final answer. When an example is provided, you should strictly follow the format of the output/answer in that example.
""",
"simple": "You are a helpful assistant that answers questions accurately and concisely. When asked to solve a problem, show your work step by step. Provide your final answer between <answer> and </answer> tags.",
"direct": "Answer the question directly. Provide your answer between <answer> and </answer> tags. Do not return any preamble, explanation, or reasoning.",
"chain_of_draft": "Think step by step, but only keep a minimum draft for each thinking step, with 5 words at most. Return the answer at the end of the response enclosed in <answer> </answer> tags.",
}
def extract_answer(completion: str, tag_name: str = "answer", strip: bool = True) -> Optional[str]:
regex = f"<{tag_name}>\\s?(.*?)\\s?</{tag_name}>"
matches = list(
re.finditer(
regex,
completion,
flags=re.DOTALL,
)
)
if not matches:
return None
answer = matches[-1].group(1)
if strip:
answer = answer.strip()
return answer
def format_number(num: Union[int, float], max_decimals: int = 2, round_if_needed: bool = False) -> str:
"""Convert a number to string representation with controlled decimal places.
Args:
num: Number to format
max_decimals: Maximum allowed decimal places
round_if_needed: If True, round the number to max_decimals instead of raising an error
Returns:
String representation of the number
Raises:
ValueError: If number requires more decimal places than allowed and round_if_needed is False
"""
if isinstance(num, int) or num.is_integer():
return str(int(num))
# Convert to Decimal for exact decimal arithmetic
d = Decimal(str(num))
# Find required decimals by removing trailing zeros
str_val = f"{d:f}"
str_val = str_val.rstrip("0").rstrip(".")
if "." in str_val:
required_decimals = len(str_val.split(".")[1])
if required_decimals > max_decimals and not round_if_needed:
raise ValueError(f"Number {num} requires {required_decimals} decimals but only {max_decimals} allowed")
# Format with required decimals (will round if needed)
result = f"{num:.{max_decimals}f}".rstrip("0").rstrip(".")
# Verify result parses back to original value (skip verification if rounding was applied)
if not round_if_needed:
try:
parsed = float(result)
if not math.isclose(parsed, num, rel_tol=1e-9):
raise ValueError(f"String representation {result} does not match original value {num}")
except (ValueError, InvalidOperation) as e:
raise ValueError(f"Failed to verify string representation: {e}")
return result
def is_integer(obj: Any) -> bool:
if isinstance(obj, (int, float)):
return isinstance(obj, int) or obj.is_integer()
elif isinstance(obj, Fraction):
return obj.denominator == 1
return False
def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_commas: bool = True) -> float:
"""Compute the reward for a given answer compared to the oracle answer.
Tolerate sign, leading zeros and trailing decimals, optionally strip commas ("+01,000.00" == "1000")
Args:
answer: Answer provided by model
oracle_answer: Correct answer to the question
strip_commas: Whether to remove commas from answers e.g "1,000" = "1000"
Returns:
Reward value between 0.0 and 1.0
"""
reward = 0.0
if answer is not None and len(answer) > 0:
try:
if strip_commas:
answer = answer.replace(",", "")
oracle_answer = oracle_answer.replace(",", "")
if Decimal(answer) == Decimal(oracle_answer):
return 1.0
except:
pass
if oracle_answer in answer:
reward = len(oracle_answer) / len(answer)
return reward
class StrEnum(str, Enum):
"""
Taken from Python 3.11 StrEnum implementation, moved here to support Python 3.10.
Enum where members are also (and must be) strings
"""
def __new__(cls, *values):
"values must already be of type `str`"
if len(values) > 3:
raise TypeError("too many arguments for str(): %r" % (values,))
if len(values) == 1:
# it must be a string
if not isinstance(values[0], str):
raise TypeError("%r is not a string" % (values[0],))
if len(values) >= 2:
# check that encoding argument is a string
if not isinstance(values[1], str):
raise TypeError("encoding must be a string, not %r" % (values[1],))
if len(values) == 3:
# check that errors argument is a string
if not isinstance(values[2], str):
raise TypeError("errors must be a string, not %r" % (values[2]))
value = str(*values)
member = str.__new__(cls, value)
member._value_ = value
return member
@staticmethod
def _generate_next_value_(name, start, count, last_values):
"""
Return the lower-cased version of the member name.
"""
return name.lower()