corrected countdown issue (#479)

This commit is contained in:
joesharratt1229
2025-06-25 21:37:04 +01:00
committed by GitHub
parent c2ac6fae32
commit 876e0aa440
2 changed files with 32 additions and 5 deletions

View File

@@ -3,6 +3,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
import numpy as np
import sympy
from sympy import Symbol, symbols
from sympy.parsing.sympy_parser import parse_expr
@@ -23,6 +24,16 @@ Final answer format instructions:
DATASET_NAME = "countdown"
_num_re = re.compile(r"\b\d+\b") # pre-compile once, reuse
def _extract_ints(expr_str: str) -> list[int]:
"""
Fast path: grab the literal integers that appear in the source text.
Handles duplicates correctly (e.g. “1 + 1 + 81” ⇒ [1, 1, 81]).
"""
return [int(m) for m in _num_re.findall(expr_str)]
@dataclass
class CountdownConfig:
@@ -192,12 +203,14 @@ class CountdownDataset(ProceduralDataset):
return reward
try:
answer = answer.strip()
user_answer = int(parse_expr(answer))
used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)]
target_numbers = set(entry["metadata"]["numbers"])
user_answer = float(parse_expr(answer))
used_numbers = _extract_ints(answer)
target_numbers = entry["metadata"]["numbers"]
if (user_answer == entry["metadata"]["target"]) and (set(used_numbers) == target_numbers):
if sorted(used_numbers) != sorted(target_numbers):
return 0.05
if np.isclose(user_answer, entry["metadata"]["target"], atol=1e-6):
return 1.0
return 0.05 if answer else 0.01

View File

@@ -100,6 +100,20 @@ def test_answer_without_all_numbers():
assert dataset.score_answer(answer=answer, entry=item) == 0.05
def test_edge_cases_1():
dataset = CountdownDataset(CountdownConfig(size=10, seed=42))
answer = "1*81"
item = {"metadata": {"numbers": [1, 1, 1, 81], "target": 81}}
assert dataset.score_answer(answer=answer, entry=item) != 1.0
def test_edge_cases_2():
dataset = CountdownDataset(CountdownConfig(size=10, seed=42))
answer = "6*34/11-1"
item = {"metadata": {"numbers": [6, 34, 1, 11], "target": 17}}
assert dataset.score_answer(answer=answer, entry=item) != 1.0
def test_countdown_game_randomization():
"""Test number randomization configuration"""
config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing