mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
add kakurasu env (#460)
* add kakurasu env * add kakurasu curriculum * add kakurasu tests
This commit is contained in:
@@ -10,6 +10,7 @@ from .boxnet import BoxnetConfig, BoxnetCurriculum, BoxnetDataset
|
||||
from .countdown import CountdownConfig, CountdownCurriculum, CountdownDataset
|
||||
from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryCurriculum, EmojiMysteryDataset
|
||||
from .futoshiki import FutoshikiConfig, FutoshikiCurriculum, FutoshikiDataset
|
||||
from .kakurasu import KakurasuConfig, KakurasuCurriculum, KakurasuDataset
|
||||
from .knight_swap import KnightSwapConfig, KnightSwapCurriculum, KnightSwapDataset
|
||||
from .mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset
|
||||
from .maze import MazeConfig, MazeCurriculum, MazeDataset
|
||||
@@ -35,6 +36,9 @@ __all__ = [
|
||||
"FutoshikiConfig",
|
||||
"FutoshikiCurriculum",
|
||||
"FutoshikiDataset",
|
||||
"KakurasuConfig",
|
||||
"KakurasuCurriculum",
|
||||
"KakurasuDataset",
|
||||
"MiniSudokuConfig",
|
||||
"MiniSudokuDataset",
|
||||
"MiniSudokuCurriculum",
|
||||
|
||||
253
reasoning_gym/games/kakurasu.py
Normal file
253
reasoning_gym/games/kakurasu.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Kakurasu puzzle dataset, adapted for Reasoning Gym from the SynLogic repository: https://github.com/MiniMax-AI/SynLogic/tree/main/games/tasks/kukurasu
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "kakurasu"
|
||||
|
||||
PROMPT_TEMPLATES = [
|
||||
"You are given a {n_rows} x {n_cols} grid representing a Kukurasu puzzle. In this puzzle, you need to place 1s in the grid so that the weighted sum of each row and column matches the given constraints. The row sums are {row_sums} and the column sums are {col_sums}.\n1. Rules:\n 1. Each cell can contain either a 1 or an 0.\n 2. The weight of a 1 in a row is its column position (1 to {n_cols}).\n 3. The weight of a 1 in a column is its row position (1 to {n_rows}).\n 4. The weighted sum of each row must match the corresponding row constraint.\n 5. The weighted sum of each column must match the corresponding column constraint.\n2. Input:\n{puzzle}",
|
||||
"This is a {n_rows} x {n_cols} Kukurasu puzzle grid. Your task is to fill in the grid with 1s and 0s such that the weighted sums match the given constraints. The row sums are {row_sums} and the column sums are {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. In each row, a 1 in position j contributes j points to that row's sum (positions are 1-indexed).\n 3. In each column, a 1 in position i contributes i points to that column's sum (positions are 1-indexed).\n 4. The weighted sum of each row must equal its constraint value.\n 5. The weighted sum of each column must equal its constraint value.\n2. Input:\n{puzzle}",
|
||||
"You're presented with a {n_rows} x {n_cols} Kukurasu puzzle grid. The goal is to place 1s in the grid so that the weighted sums of rows and columns match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. A 1 in column j of any row contributes j points to that row's sum (j ranges from 1 to {n_cols}).\n 3. A 1 in row i of any column contributes i points to that column's sum (i ranges from 1 to {n_rows}).\n 4. Each row's weighted sum must match its constraint value.\n 5. Each column's weighted sum must match its constraint value.\n2. Input:\n{puzzle}",
|
||||
"Below is a {n_rows} x {n_cols} Kukurasu puzzle grid. Your objective is to place 1s in the grid such that the weighted sums of rows and columns match the given constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. The weight of a 1 in a row equals its column number (1 to {n_cols}).\n 3. The weight of a 1 in a column equals its row number (1 to {n_rows}).\n 4. The sum of weighted 1s in each row must equal the row constraint.\n 5. The sum of weighted 1s in each column must equal the column constraint.\n2. Input:\n{puzzle}",
|
||||
"Here's a {n_rows} x {n_cols} Kukurasu logic puzzle. You need to place 1s in the grid so that the weighted sums match the constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell can be filled with either a 1 or an 0.\n 2. A 1 in the jth position of a row contributes j points to that row's sum.\n 3. A 1 in the ith position of a column contributes i points to that column's sum.\n 4. The weighted sum of each row must equal its constraint value.\n 5. The weighted sum of each column must equal its constraint value.\n2. Input:\n{puzzle}",
|
||||
"I'm presenting you with a {n_rows} x {n_cols} Kukurasu puzzle. Your task is to place 1s in the grid so that the weighted sums match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. In each row, a 1 in position j has a weight of j (where j ranges from 1 to {n_cols}).\n 3. In each column, a 1 in position i has a weight of i (where i ranges from 1 to {n_rows}).\n 4. The weighted sum of each row must match its constraint.\n 5. The weighted sum of each column must match its constraint.\n2. Input:\n{puzzle}",
|
||||
"Consider this {n_rows} x {n_cols} Kukurasu puzzle grid. You need to place 1s in the grid such that the weighted sums match the constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. A 1 in column position j contributes j points to its row's sum.\n 3. A 1 in row position i contributes i points to its column's sum.\n 4. Each row's weighted sum must equal its constraint value.\n 5. Each column's weighted sum must equal its constraint value.\n2. Input:\n{puzzle}",
|
||||
"You have a {n_rows} x {n_cols} Kukurasu puzzle grid. Your goal is to place 1s in the grid so that the weighted sums match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. The weight of a 1 in a row is its column position (1 to {n_cols}).\n 3. The weight of a 1 in a column is its row position (1 to {n_rows}).\n 4. The weighted sum of each row must match its constraint.\n 5. The weighted sum of each column must match its constraint.\n2. Input:\n{puzzle}",
|
||||
"This {n_rows} x {n_cols} grid represents a Kukurasu puzzle. Your task is to place 1s in the grid so that the weighted sums match the constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. A 1 in the jth position of a row contributes j points to that row's sum.\n 3. A 1 in the ith position of a column contributes i points to that column's sum.\n 4. The weighted sum of each row must equal its constraint value.\n 5. The weighted sum of each column must equal its constraint value.\n2. Input:\n{puzzle}",
|
||||
"Examine this {n_rows} x {n_cols} Kukurasu puzzle grid. Your objective is to place 1s in the grid such that the weighted sums match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. The weight of a 1 in a row equals its column number (1 to {n_cols}).\n 3. The weight of a 1 in a column equals its row number (1 to {n_rows}).\n 4. The weighted sum of each row must match its constraint.\n 5. The weighted sum of each column must match its constraint.\n2. Input:\n{puzzle}",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class KakurasuConfig:
|
||||
min_rows: int = 4
|
||||
max_rows: int = 5
|
||||
min_cols: int = 4
|
||||
max_cols: int = 5
|
||||
p_ones: float = 0.3
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
max_retries: int = 1000 # Max retries to find a unique puzzle. If exceeded, a non-unique puzzle may be returned
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 3 <= self.min_rows <= 9, "n_rows must be between 3 and 9"
|
||||
assert 3 <= self.max_rows <= 9, "n_cols must be between 3 and 9"
|
||||
assert 3 <= self.min_rows <= 9, "n_rows must be between 3 and 9"
|
||||
assert 3 <= self.max_cols <= 9, "n_cols must be between 3 and 9"
|
||||
assert self.min_rows <= self.max_rows, "min_rows must be less than or equal to max_rows"
|
||||
assert self.min_cols <= self.max_cols, "min_cols must be less than or equal to max_cols"
|
||||
assert 0 <= self.p_ones <= 1, "p_ones must be between 0 and 1"
|
||||
|
||||
|
||||
class KakurasuDataset(ProceduralDataset):
|
||||
"""Generates Kakurasu puzzles with configurable size."""
|
||||
|
||||
def __init__(self, config: KakurasuConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.config.size
|
||||
|
||||
def __iter__(self):
|
||||
self._current_idx = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._current_idx >= self.config.size:
|
||||
raise StopIteration
|
||||
item = self[self._current_idx]
|
||||
self._current_idx += 1
|
||||
return item
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate Kakurasu puzzles that have at least one solution."""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
n_rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
||||
n_cols = rng.randint(self.config.min_cols, self.config.max_cols)
|
||||
|
||||
for retry in range(self.config.max_retries):
|
||||
solution_grid = self._generate_random_grid(rng, n_rows, n_cols)
|
||||
self._repair_grid(rng, solution_grid)
|
||||
|
||||
row_sums, col_sums = self._calculate_row_col_sums(solution_grid)
|
||||
empty_grid = [[0 for _ in range(n_cols)] for _ in range(n_rows)]
|
||||
|
||||
if retry < self.config.max_retries - 1:
|
||||
if 0 in row_sums or 0 in col_sums or sum(row_sums) != sum(col_sums):
|
||||
continue
|
||||
if self._count_solutions(n_rows, n_cols, row_sums, col_sums) != 1:
|
||||
continue
|
||||
|
||||
prompt = rng.choice(PROMPT_TEMPLATES).format(
|
||||
n_rows=n_rows,
|
||||
n_cols=n_cols,
|
||||
row_sums=row_sums,
|
||||
col_sums=col_sums,
|
||||
puzzle="\n".join(
|
||||
[" ".join(str(cell) for cell in row) for row in empty_grid],
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"question": prompt,
|
||||
"answer": "\n".join(
|
||||
[" ".join(str(cell) for cell in row) for row in solution_grid],
|
||||
),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_idx": idx,
|
||||
"n_rows": n_rows,
|
||||
"n_cols": n_cols,
|
||||
"p_ones": self.config.p_ones,
|
||||
"puzzle": empty_grid,
|
||||
"row_sums": row_sums,
|
||||
"col_sums": col_sums,
|
||||
"solution": solution_grid,
|
||||
"difficulty": {
|
||||
"rows": (self.config.min_rows, self.config.max_rows),
|
||||
"cols": (self.config.min_cols, self.config.max_cols),
|
||||
"p_ones": self.config.p_ones,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_random_grid(self, rng: Random, n_rows: int, n_cols: int) -> list[list[int]]:
|
||||
"""Generate a random valid solution grid."""
|
||||
return [[1 if rng.random() < self.config.p_ones else 0 for _ in range(n_cols)] for _ in range(n_rows)]
|
||||
|
||||
def _calculate_row_col_sums(self, grid) -> tuple[list[int], list[int]]:
|
||||
"""Calculate row and column sums based on the solution grid"""
|
||||
n_rows = len(grid)
|
||||
n_cols = len(grid[0]) if n_rows > 0 else 0
|
||||
row_sums = [sum((j + 1) for j, cell in enumerate(row) if cell == 1) for row in grid]
|
||||
col_sums = [sum((i + 1) for i in range(n_rows) if grid[i][j] == 1) for j in range(n_cols)]
|
||||
return row_sums, col_sums
|
||||
|
||||
def _repair_grid(self, rng: Random, grid: list[list[int]]):
|
||||
"""Ensure every row/col has at least one '1'."""
|
||||
n_rows = len(grid)
|
||||
n_cols = len(grid[0]) if n_rows > 0 else 0
|
||||
|
||||
for i, row in enumerate(grid):
|
||||
if 1 not in row:
|
||||
grid[i][rng.randrange(n_cols)] = 1
|
||||
|
||||
for j in range(n_cols):
|
||||
if all(grid[i][j] == 0 for i in range(n_rows)):
|
||||
grid[rng.randrange(n_rows)][j] = 1
|
||||
|
||||
def _count_solutions(self, n: int, m: int, row_sums: list[int], col_sums: list[int], limit: int = 2) -> int:
|
||||
"""Return number of solutions, stopping at `limit`."""
|
||||
row_patterns: list[list[list[int]]] = []
|
||||
for target in row_sums:
|
||||
patterns = []
|
||||
for mask in range(1 << m):
|
||||
if sum(((j + 1) if (mask >> j) & 1 else 0) for j in range(m)) == target:
|
||||
patterns.append([(mask >> j) & 1 for j in range(m)])
|
||||
if not patterns:
|
||||
return 0
|
||||
row_patterns.append(patterns)
|
||||
|
||||
col_remaining = col_sums[:]
|
||||
solutions = 0
|
||||
|
||||
def dfs(r: int):
|
||||
nonlocal solutions
|
||||
if solutions >= limit:
|
||||
return
|
||||
if r == n:
|
||||
if all(c == 0 for c in col_remaining):
|
||||
solutions += 1
|
||||
return
|
||||
for pat in row_patterns[r]:
|
||||
ok = True
|
||||
for j, bit in enumerate(pat):
|
||||
col_remaining[j] -= (r + 1) * bit
|
||||
if col_remaining[j] < 0:
|
||||
ok = False
|
||||
if ok:
|
||||
dfs(r + 1)
|
||||
for j, bit in enumerate(pat):
|
||||
col_remaining[j] += (r + 1) * bit
|
||||
if solutions >= limit:
|
||||
break
|
||||
|
||||
dfs(0)
|
||||
return solutions
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
metadata = entry["metadata"]
|
||||
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
|
||||
n_rows, n_cols = metadata["n_rows"], metadata["n_cols"]
|
||||
|
||||
try:
|
||||
grid = self._parse_grid(answer)
|
||||
|
||||
if len(grid) != n_rows or any(len(row) != n_cols for row in grid):
|
||||
return 0.0
|
||||
|
||||
if any(cell not in [1, 0] for row in grid for cell in row):
|
||||
return 0.0
|
||||
|
||||
ans_row_sums = [sum((j + 1) for j, cell in enumerate(row) if cell == 1) for row in grid]
|
||||
|
||||
if ans_row_sums != row_sums:
|
||||
return 0.0
|
||||
|
||||
ans_col_sums = [sum((i + 1) for i in range(n_rows) if grid[i][j] == 1) for j in range(n_cols)]
|
||||
|
||||
if ans_col_sums != col_sums:
|
||||
return 0.0
|
||||
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _parse_grid(self, answer: str) -> list[list[str]]:
|
||||
grid = []
|
||||
for line in answer.strip().split("\n"):
|
||||
grid.append([int(c) for c in line.strip() if c in "01"])
|
||||
return grid
|
||||
|
||||
|
||||
class KakurasuCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(KakurasuCurriculum.__name__, KakurasuConfig)
|
||||
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="rows",
|
||||
levels=[4, 6, 7, 9],
|
||||
description="Row count",
|
||||
lower_field_name="min_rows",
|
||||
upper_field_name="max_rows",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="cols",
|
||||
levels=[4, 6, 7, 9],
|
||||
description="Column count",
|
||||
lower_field_name="min_cols",
|
||||
upper_field_name="max_cols",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="p_ones",
|
||||
levels=[0.50, 0.40, 0.30, 0.20],
|
||||
description="Probability of a cell being filled",
|
||||
field_name="p_ones",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, KakurasuDataset, KakurasuConfig, KakurasuCurriculum)
|
||||
181
tests/test_kakurasu.py
Normal file
181
tests/test_kakurasu.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.coaching.base_curriculum import DefaultCurriculumContext, RangeAttributeMode
|
||||
from reasoning_gym.games import KakurasuConfig, KakurasuCurriculum, KakurasuDataset
|
||||
|
||||
|
||||
def test_kakurasu_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = KakurasuConfig(min_rows=5, max_rows=4) # max_rows < min_rows
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = KakurasuConfig(p_ones=-0.1) # negative probability
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_kakurasu_deterministic():
|
||||
"""Test that dataset generates same puzzles with same seed"""
|
||||
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.2)
|
||||
dataset1 = KakurasuDataset(config)
|
||||
dataset2 = KakurasuDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_kakurasu_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
|
||||
dataset = KakurasuDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Verify key metadata contents
|
||||
metadata = item["metadata"]
|
||||
assert "puzzle" in metadata
|
||||
assert "solution" in metadata
|
||||
assert "row_sums" in metadata
|
||||
assert "col_sums" in metadata
|
||||
|
||||
# Verify board dimensions for both puzzle and solution
|
||||
puzzle, solution = metadata["puzzle"], metadata["solution"]
|
||||
assert len(puzzle) >= config.min_rows
|
||||
assert len(puzzle) <= config.max_rows
|
||||
assert len(solution) >= config.min_rows
|
||||
assert len(solution) <= config.max_rows
|
||||
for row in puzzle:
|
||||
assert len(row) >= config.min_cols
|
||||
assert len(row) <= config.max_cols
|
||||
for row in solution:
|
||||
assert len(row) >= config.min_cols
|
||||
assert len(row) <= config.max_cols
|
||||
|
||||
# Verify row and column sums
|
||||
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
|
||||
assert len(row_sums) == len(puzzle)
|
||||
assert len(col_sums) == len(puzzle[0]) if puzzle else 0
|
||||
|
||||
|
||||
def test_kakurasu_solution_validity():
|
||||
"""Test that solutions are valid according to Kakurasu rules"""
|
||||
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
|
||||
dataset = KakurasuDataset(config)
|
||||
|
||||
def is_valid_solution(solution, n_rows, n_cols, row_sums, col_sums):
|
||||
"""Check if the solution is valid according to Kakurasu rules"""
|
||||
if len(solution) != n_rows or any(len(row) != n_cols for row in solution):
|
||||
return False
|
||||
|
||||
# Check row sums
|
||||
for i, row in enumerate(solution):
|
||||
if sum((j + 1) for j, val in enumerate(row) if val == 1) != row_sums[i]:
|
||||
return False
|
||||
|
||||
# Check column sums
|
||||
for j in range(n_cols):
|
||||
if sum((i + 1) for i, row in enumerate(solution) if row[j] == 1) != col_sums[j]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
metadata = item["metadata"]
|
||||
solution = metadata["solution"]
|
||||
n_rows, n_cols = metadata["n_rows"], metadata["n_cols"]
|
||||
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
|
||||
assert is_valid_solution(solution, n_rows, n_cols, row_sums, col_sums)
|
||||
|
||||
|
||||
def test_kakurasu_puzzle_solvability():
|
||||
"""Test that generated puzzles are solvable and have unique solutions"""
|
||||
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
|
||||
dataset = KakurasuDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
metadata = item["metadata"]
|
||||
n_rows, n_cols = metadata["n_rows"], metadata["n_cols"]
|
||||
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
|
||||
|
||||
# Verify puzzle has exactly one solution
|
||||
assert dataset._count_solutions(n_rows, n_cols, row_sums, col_sums) == 1
|
||||
|
||||
|
||||
def test_kakurasu_answer_scoring():
|
||||
"""Test the answer scoring mechanism"""
|
||||
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
|
||||
dataset = KakurasuDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
# Correct answer should score 1.0
|
||||
assert dataset.score_answer(item["answer"], item) == 1.0
|
||||
|
||||
# Wrong answer should score lower
|
||||
wrong_answer = item["answer"].replace("1", "2")
|
||||
assert dataset.score_answer(wrong_answer, item) < 1.0
|
||||
|
||||
# None or empty answer should score 0.0
|
||||
assert dataset.score_answer(None, item) == 0.0
|
||||
assert dataset.score_answer("", item) == 0.0
|
||||
|
||||
|
||||
def test_futoshiki_curriculum():
|
||||
"""Test the KakurasuCurriculum works as expected"""
|
||||
curriculum = KakurasuCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
|
||||
base_cfg: KakurasuConfig = curriculum.generate_configuration(base_value, context=context)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.min_rows == 4 and base_cfg.max_rows == 4
|
||||
assert base_cfg.min_cols == 4 and base_cfg.max_cols == 4
|
||||
assert base_cfg.p_ones == 0.50
|
||||
|
||||
# Test incrementing attribute levels
|
||||
curriculum.increment_attr_level("rows")
|
||||
curriculum.increment_attr_level("p_ones")
|
||||
increased_cfg = curriculum.generate_configuration(base_value, context=context)
|
||||
assert increased_cfg.min_rows == 6 and increased_cfg.max_rows == 6
|
||||
assert increased_cfg.p_ones == 0.4
|
||||
|
||||
# Test incrementing again
|
||||
curriculum.increment_attr_level("cols")
|
||||
curriculum.increment_attr_level("p_ones")
|
||||
increased_cfg2 = curriculum.generate_configuration(base_value, context=context)
|
||||
assert increased_cfg2.min_cols == 6 and increased_cfg2.max_cols == 6
|
||||
assert increased_cfg2.p_ones == 0.3
|
||||
|
||||
# Test incrementing to max level
|
||||
curriculum.increment_attr_level("p_ones")
|
||||
max_cfg = curriculum.generate_configuration(base_value, context=context)
|
||||
assert max_cfg.p_ones == 0.2
|
||||
|
||||
# Test that we can't go beyond max level
|
||||
curriculum.increment_attr_level("p_ones")
|
||||
still_max_cfg = curriculum.generate_configuration(base_value, context=context)
|
||||
assert still_max_cfg.p_ones == 0.2
|
||||
|
||||
# Test decrementing attribute levels
|
||||
curriculum.decrement_attr_level("p_ones")
|
||||
decreased_cfg = curriculum.generate_configuration(base_value, context=context)
|
||||
assert decreased_cfg.p_ones == 0.3
|
||||
|
||||
# Test global level setting
|
||||
curriculum.set_global_level(0)
|
||||
global_lvl0_cfg = curriculum.generate_configuration(base_value, context=context)
|
||||
assert global_lvl0_cfg.min_rows == 4 and global_lvl0_cfg.max_rows == 4
|
||||
|
||||
# Test global level increment
|
||||
curriculum.increment_global_level()
|
||||
global_lvl1_cfg = curriculum.generate_configuration(base_value, context=context)
|
||||
assert global_lvl1_cfg.min_rows == 6 and global_lvl1_cfg.max_rows == 6
|
||||
Reference in New Issue
Block a user