Files
reasoning-gym/tests/test_prime_factorization.py
2025-06-02 08:58:54 +02:00

156 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for prime factorization task generation"""
import pytest
from reasoning_gym.arithmetic.prime_factorization import (
PrimeFactorizationConfig,
PrimeFactorizationCurriculum,
PrimeFactorizationDataset,
)
def test_prime_factorization_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = PrimeFactorizationConfig(min_value=1) # Too small
config.validate()
with pytest.raises(AssertionError):
config = PrimeFactorizationConfig(min_value=100, max_value=50) # max < min
config.validate()
def test_prime_factorization_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = PrimeFactorizationConfig(seed=42, size=10)
dataset1 = PrimeFactorizationDataset(config)
dataset2 = PrimeFactorizationDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_prime_factorization_dataset_items():
"""Test basic properties of generated items"""
config = PrimeFactorizationConfig(min_value=2, max_value=100, size=10, seed=42)
dataset = PrimeFactorizationDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "number" in item["metadata"]
assert "factors" in item["metadata"]
# Verify value range
number = item["metadata"]["number"]
assert config.min_value <= number <= config.max_value
# Verify factorization is correct
factors = item["metadata"]["factors"]
product = 1
for factor in factors:
product *= factor
assert product == number
# Verify factors are prime
for factor in factors:
assert is_prime(factor), f"{factor} is not prime"
# Verify answer format
assert item["answer"] == " × ".join(map(str, factors))
def test_prime_factorization_dataset_iteration():
"""Test that iteration respects dataset size"""
config = PrimeFactorizationConfig(size=5, seed=42)
dataset = PrimeFactorizationDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_prime_factorization_known_values():
"""Test factorization of known values"""
config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
dataset = PrimeFactorizationDataset(config)
item = dataset[0]
assert item["metadata"]["number"] == 12
assert item["metadata"]["factors"] == [2, 2, 3]
assert item["answer"] == "2 × 2 × 3"
def test_prime_factorization_score_answer():
"""Test scoring of answers"""
config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
dataset = PrimeFactorizationDataset(config)
item = dataset[0]
# Perfectly ordered answer
answer = "2 × 2 × 3"
assert dataset.score_answer(answer, item) == 1.0
# No white spaces answer (still correct)
answer = "2×2×3"
assert dataset.score_answer(answer, item) == 1.0
# Shuffled factors (still correct)
answer = "2 × 3 × 2"
assert dataset.score_answer(answer, item) == 1.0
# Partially correct answer (not all numbers are fully factorized)
answer = "2 × 6"
assert dataset.score_answer(answer, item) == 0.5
# Incorrect answer
answer = "2 × 5"
assert dataset.score_answer(answer, item) == 0.01
# Answer is none
answer = None
assert dataset.score_answer(answer, item) == 0.0
# Answer is empty string
answer = ""
assert dataset.score_answer(answer, item) == 0.01
def is_prime(n: int) -> bool:
"""Helper function to check if a number is prime"""
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
def test_prime_factorization_curriculum():
curriculum = PrimeFactorizationCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: PrimeFactorizationConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_value == 10 and base_cfg.max_value == 1_000
# test incrementing attribute levels
curriculum.increment_attr_level("value")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_value == 10 and increased_cfg.max_value == 5_000
# test decrementing attribute level for value again
curriculum.decrement_attr_level("value")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_value == 10 and partially_decreased_cfg.max_value == 1_000