mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
corrected countdown issue (#479)
This commit is contained in:
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import sympy
|
||||
from sympy import Symbol, symbols
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
@@ -23,6 +24,16 @@ Final answer format instructions:
|
||||
|
||||
DATASET_NAME = "countdown"
|
||||
|
||||
_num_re = re.compile(r"\b\d+\b") # pre-compile once, reuse
|
||||
|
||||
|
||||
def _extract_ints(expr_str: str) -> list[int]:
|
||||
"""
|
||||
Fast path: grab the literal integers that appear in the source text.
|
||||
Handles duplicates correctly (e.g. “1 + 1 + 81” ⇒ [1, 1, 81]).
|
||||
"""
|
||||
return [int(m) for m in _num_re.findall(expr_str)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountdownConfig:
|
||||
@@ -192,12 +203,14 @@ class CountdownDataset(ProceduralDataset):
|
||||
return reward
|
||||
|
||||
try:
|
||||
answer = answer.strip()
|
||||
user_answer = int(parse_expr(answer))
|
||||
used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)]
|
||||
target_numbers = set(entry["metadata"]["numbers"])
|
||||
user_answer = float(parse_expr(answer))
|
||||
used_numbers = _extract_ints(answer)
|
||||
target_numbers = entry["metadata"]["numbers"]
|
||||
|
||||
if (user_answer == entry["metadata"]["target"]) and (set(used_numbers) == target_numbers):
|
||||
if sorted(used_numbers) != sorted(target_numbers):
|
||||
return 0.05
|
||||
|
||||
if np.isclose(user_answer, entry["metadata"]["target"], atol=1e-6):
|
||||
return 1.0
|
||||
|
||||
return 0.05 if answer else 0.01
|
||||
|
||||
@@ -100,6 +100,20 @@ def test_answer_without_all_numbers():
|
||||
assert dataset.score_answer(answer=answer, entry=item) == 0.05
|
||||
|
||||
|
||||
def test_edge_cases_1():
|
||||
dataset = CountdownDataset(CountdownConfig(size=10, seed=42))
|
||||
answer = "1*81"
|
||||
item = {"metadata": {"numbers": [1, 1, 1, 81], "target": 81}}
|
||||
assert dataset.score_answer(answer=answer, entry=item) != 1.0
|
||||
|
||||
|
||||
def test_edge_cases_2():
|
||||
dataset = CountdownDataset(CountdownConfig(size=10, seed=42))
|
||||
answer = "6*34/11-1"
|
||||
item = {"metadata": {"numbers": [6, 34, 1, 11], "target": 17}}
|
||||
assert dataset.score_answer(answer=answer, entry=item) != 1.0
|
||||
|
||||
|
||||
def test_countdown_game_randomization():
|
||||
"""Test number randomization configuration"""
|
||||
config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing
|
||||
|
||||
Reference in New Issue
Block a user