Fix bug in normalize_answer method (#444)

This commit is contained in:
Adefioye
2025-06-02 01:58:54 -05:00
committed by GitHub
parent c0e98f93b4
commit 9053009dbe
2 changed files with 8 additions and 1 deletions

View File

@@ -49,7 +49,10 @@ class PrimeFactorizationDataset(ProceduralDataset):
def _normalize_answer(self, answer: str) -> list[int]:
"""Parse and sort factors from a string"""
return sorted([int(factor.strip()) for factor in answer.split("×")])
if not answer or answer.strip() == "":
return []
return sorted([int(factor.strip()) for factor in answer.split("×") if factor.strip() != ""])
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"]

View File

@@ -119,6 +119,10 @@ def test_prime_factorization_score_answer():
answer = None
assert dataset.score_answer(answer, item) == 0.0
# Answer is empty string
answer = ""
assert dataset.score_answer(answer, item) == 0.01
def is_prime(n: int) -> bool:
"""Helper function to check if a number is prime"""