add kakurasu env (#460)

* add kakurasu env

* add kakurasu curriculum

* add kakurasu tests
This commit is contained in:
Oliver Stanley
2025-06-08 09:20:53 +01:00
committed by GitHub
parent be2babea9c
commit c2fdb11980
3 changed files with 438 additions and 0 deletions

View File

@@ -10,6 +10,7 @@ from .boxnet import BoxnetConfig, BoxnetCurriculum, BoxnetDataset
from .countdown import CountdownConfig, CountdownCurriculum, CountdownDataset
from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryCurriculum, EmojiMysteryDataset
from .futoshiki import FutoshikiConfig, FutoshikiCurriculum, FutoshikiDataset
from .kakurasu import KakurasuConfig, KakurasuCurriculum, KakurasuDataset
from .knight_swap import KnightSwapConfig, KnightSwapCurriculum, KnightSwapDataset
from .mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset
from .maze import MazeConfig, MazeCurriculum, MazeDataset
@@ -35,6 +36,9 @@ __all__ = [
"FutoshikiConfig",
"FutoshikiCurriculum",
"FutoshikiDataset",
"KakurasuConfig",
"KakurasuCurriculum",
"KakurasuDataset",
"MiniSudokuConfig",
"MiniSudokuDataset",
"MiniSudokuCurriculum",

View File

@@ -0,0 +1,253 @@
"""
Kakurasu puzzle dataset, adapted for Reasoning Gym from the SynLogic repository: https://github.com/MiniMax-AI/SynLogic/tree/main/games/tasks/kukurasu
"""
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "kakurasu"
PROMPT_TEMPLATES = [
"You are given a {n_rows} x {n_cols} grid representing a Kukurasu puzzle. In this puzzle, you need to place 1s in the grid so that the weighted sum of each row and column matches the given constraints. The row sums are {row_sums} and the column sums are {col_sums}.\n1. Rules:\n 1. Each cell can contain either a 1 or an 0.\n 2. The weight of a 1 in a row is its column position (1 to {n_cols}).\n 3. The weight of a 1 in a column is its row position (1 to {n_rows}).\n 4. The weighted sum of each row must match the corresponding row constraint.\n 5. The weighted sum of each column must match the corresponding column constraint.\n2. Input:\n{puzzle}",
"This is a {n_rows} x {n_cols} Kukurasu puzzle grid. Your task is to fill in the grid with 1s and 0s such that the weighted sums match the given constraints. The row sums are {row_sums} and the column sums are {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. In each row, a 1 in position j contributes j points to that row's sum (positions are 1-indexed).\n 3. In each column, a 1 in position i contributes i points to that column's sum (positions are 1-indexed).\n 4. The weighted sum of each row must equal its constraint value.\n 5. The weighted sum of each column must equal its constraint value.\n2. Input:\n{puzzle}",
"You're presented with a {n_rows} x {n_cols} Kukurasu puzzle grid. The goal is to place 1s in the grid so that the weighted sums of rows and columns match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. A 1 in column j of any row contributes j points to that row's sum (j ranges from 1 to {n_cols}).\n 3. A 1 in row i of any column contributes i points to that column's sum (i ranges from 1 to {n_rows}).\n 4. Each row's weighted sum must match its constraint value.\n 5. Each column's weighted sum must match its constraint value.\n2. Input:\n{puzzle}",
"Below is a {n_rows} x {n_cols} Kukurasu puzzle grid. Your objective is to place 1s in the grid such that the weighted sums of rows and columns match the given constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. The weight of a 1 in a row equals its column number (1 to {n_cols}).\n 3. The weight of a 1 in a column equals its row number (1 to {n_rows}).\n 4. The sum of weighted 1s in each row must equal the row constraint.\n 5. The sum of weighted 1s in each column must equal the column constraint.\n2. Input:\n{puzzle}",
"Here's a {n_rows} x {n_cols} Kukurasu logic puzzle. You need to place 1s in the grid so that the weighted sums match the constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell can be filled with either a 1 or an 0.\n 2. A 1 in the jth position of a row contributes j points to that row's sum.\n 3. A 1 in the ith position of a column contributes i points to that column's sum.\n 4. The weighted sum of each row must equal its constraint value.\n 5. The weighted sum of each column must equal its constraint value.\n2. Input:\n{puzzle}",
"I'm presenting you with a {n_rows} x {n_cols} Kukurasu puzzle. Your task is to place 1s in the grid so that the weighted sums match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. In each row, a 1 in position j has a weight of j (where j ranges from 1 to {n_cols}).\n 3. In each column, a 1 in position i has a weight of i (where i ranges from 1 to {n_rows}).\n 4. The weighted sum of each row must match its constraint.\n 5. The weighted sum of each column must match its constraint.\n2. Input:\n{puzzle}",
"Consider this {n_rows} x {n_cols} Kukurasu puzzle grid. You need to place 1s in the grid such that the weighted sums match the constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. A 1 in column position j contributes j points to its row's sum.\n 3. A 1 in row position i contributes i points to its column's sum.\n 4. Each row's weighted sum must equal its constraint value.\n 5. Each column's weighted sum must equal its constraint value.\n2. Input:\n{puzzle}",
"You have a {n_rows} x {n_cols} Kukurasu puzzle grid. Your goal is to place 1s in the grid so that the weighted sums match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. The weight of a 1 in a row is its column position (1 to {n_cols}).\n 3. The weight of a 1 in a column is its row position (1 to {n_rows}).\n 4. The weighted sum of each row must match its constraint.\n 5. The weighted sum of each column must match its constraint.\n2. Input:\n{puzzle}",
"This {n_rows} x {n_cols} grid represents a Kukurasu puzzle. Your task is to place 1s in the grid so that the weighted sums match the constraints. Row sums: {row_sums}. Column sums: {col_sums}.\n1. Rules:\n 1. Each cell must contain either a 1 or an 0.\n 2. A 1 in the jth position of a row contributes j points to that row's sum.\n 3. A 1 in the ith position of a column contributes i points to that column's sum.\n 4. The weighted sum of each row must equal its constraint value.\n 5. The weighted sum of each column must equal its constraint value.\n2. Input:\n{puzzle}",
"Examine this {n_rows} x {n_cols} Kukurasu puzzle grid. Your objective is to place 1s in the grid such that the weighted sums match the given constraints: row sums {row_sums} and column sums {col_sums}.\n1. Rules:\n 1. Each cell must be filled with either a 1 or an 0.\n 2. The weight of a 1 in a row equals its column number (1 to {n_cols}).\n 3. The weight of a 1 in a column equals its row number (1 to {n_rows}).\n 4. The weighted sum of each row must match its constraint.\n 5. The weighted sum of each column must match its constraint.\n2. Input:\n{puzzle}",
]
@dataclass
class KakurasuConfig:
min_rows: int = 4
max_rows: int = 5
min_cols: int = 4
max_cols: int = 5
p_ones: float = 0.3
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
max_retries: int = 1000 # Max retries to find a unique puzzle. If exceeded, a non-unique puzzle may be returned
def validate(self):
"""Validate configuration parameters"""
assert 3 <= self.min_rows <= 9, "n_rows must be between 3 and 9"
assert 3 <= self.max_rows <= 9, "n_cols must be between 3 and 9"
assert 3 <= self.min_rows <= 9, "n_rows must be between 3 and 9"
assert 3 <= self.max_cols <= 9, "n_cols must be between 3 and 9"
assert self.min_rows <= self.max_rows, "min_rows must be less than or equal to max_rows"
assert self.min_cols <= self.max_cols, "min_cols must be less than or equal to max_cols"
assert 0 <= self.p_ones <= 1, "p_ones must be between 0 and 1"
class KakurasuDataset(ProceduralDataset):
"""Generates Kakurasu puzzles with configurable size."""
def __init__(self, config: KakurasuConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self):
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def __getitem__(self, idx: int) -> dict:
"""Generate Kakurasu puzzles that have at least one solution."""
rng = Random(self.seed + idx)
n_rows = rng.randint(self.config.min_rows, self.config.max_rows)
n_cols = rng.randint(self.config.min_cols, self.config.max_cols)
for retry in range(self.config.max_retries):
solution_grid = self._generate_random_grid(rng, n_rows, n_cols)
self._repair_grid(rng, solution_grid)
row_sums, col_sums = self._calculate_row_col_sums(solution_grid)
empty_grid = [[0 for _ in range(n_cols)] for _ in range(n_rows)]
if retry < self.config.max_retries - 1:
if 0 in row_sums or 0 in col_sums or sum(row_sums) != sum(col_sums):
continue
if self._count_solutions(n_rows, n_cols, row_sums, col_sums) != 1:
continue
prompt = rng.choice(PROMPT_TEMPLATES).format(
n_rows=n_rows,
n_cols=n_cols,
row_sums=row_sums,
col_sums=col_sums,
puzzle="\n".join(
[" ".join(str(cell) for cell in row) for row in empty_grid],
),
)
return {
"question": prompt,
"answer": "\n".join(
[" ".join(str(cell) for cell in row) for row in solution_grid],
),
"metadata": {
"source_dataset": DATASET_NAME,
"source_idx": idx,
"n_rows": n_rows,
"n_cols": n_cols,
"p_ones": self.config.p_ones,
"puzzle": empty_grid,
"row_sums": row_sums,
"col_sums": col_sums,
"solution": solution_grid,
"difficulty": {
"rows": (self.config.min_rows, self.config.max_rows),
"cols": (self.config.min_cols, self.config.max_cols),
"p_ones": self.config.p_ones,
},
},
}
def _generate_random_grid(self, rng: Random, n_rows: int, n_cols: int) -> list[list[int]]:
"""Generate a random valid solution grid."""
return [[1 if rng.random() < self.config.p_ones else 0 for _ in range(n_cols)] for _ in range(n_rows)]
def _calculate_row_col_sums(self, grid) -> tuple[list[int], list[int]]:
"""Calculate row and column sums based on the solution grid"""
n_rows = len(grid)
n_cols = len(grid[0]) if n_rows > 0 else 0
row_sums = [sum((j + 1) for j, cell in enumerate(row) if cell == 1) for row in grid]
col_sums = [sum((i + 1) for i in range(n_rows) if grid[i][j] == 1) for j in range(n_cols)]
return row_sums, col_sums
def _repair_grid(self, rng: Random, grid: list[list[int]]):
"""Ensure every row/col has at least one '1'."""
n_rows = len(grid)
n_cols = len(grid[0]) if n_rows > 0 else 0
for i, row in enumerate(grid):
if 1 not in row:
grid[i][rng.randrange(n_cols)] = 1
for j in range(n_cols):
if all(grid[i][j] == 0 for i in range(n_rows)):
grid[rng.randrange(n_rows)][j] = 1
def _count_solutions(self, n: int, m: int, row_sums: list[int], col_sums: list[int], limit: int = 2) -> int:
"""Return number of solutions, stopping at `limit`."""
row_patterns: list[list[list[int]]] = []
for target in row_sums:
patterns = []
for mask in range(1 << m):
if sum(((j + 1) if (mask >> j) & 1 else 0) for j in range(m)) == target:
patterns.append([(mask >> j) & 1 for j in range(m)])
if not patterns:
return 0
row_patterns.append(patterns)
col_remaining = col_sums[:]
solutions = 0
def dfs(r: int):
nonlocal solutions
if solutions >= limit:
return
if r == n:
if all(c == 0 for c in col_remaining):
solutions += 1
return
for pat in row_patterns[r]:
ok = True
for j, bit in enumerate(pat):
col_remaining[j] -= (r + 1) * bit
if col_remaining[j] < 0:
ok = False
if ok:
dfs(r + 1)
for j, bit in enumerate(pat):
col_remaining[j] += (r + 1) * bit
if solutions >= limit:
break
dfs(0)
return solutions
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not isinstance(answer, str):
return 0.0
metadata = entry["metadata"]
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
n_rows, n_cols = metadata["n_rows"], metadata["n_cols"]
try:
grid = self._parse_grid(answer)
if len(grid) != n_rows or any(len(row) != n_cols for row in grid):
return 0.0
if any(cell not in [1, 0] for row in grid for cell in row):
return 0.0
ans_row_sums = [sum((j + 1) for j, cell in enumerate(row) if cell == 1) for row in grid]
if ans_row_sums != row_sums:
return 0.0
ans_col_sums = [sum((i + 1) for i in range(n_rows) if grid[i][j] == 1) for j in range(n_cols)]
if ans_col_sums != col_sums:
return 0.0
return 1.0
except Exception:
return 0.0
def _parse_grid(self, answer: str) -> list[list[str]]:
grid = []
for line in answer.strip().split("\n"):
grid.append([int(c) for c in line.strip() if c in "01"])
return grid
class KakurasuCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(KakurasuCurriculum.__name__, KakurasuConfig)
self._define_attributes(
RangeAttributeDefinition(
name="rows",
levels=[4, 6, 7, 9],
description="Row count",
lower_field_name="min_rows",
upper_field_name="max_rows",
),
RangeAttributeDefinition(
name="cols",
levels=[4, 6, 7, 9],
description="Column count",
lower_field_name="min_cols",
upper_field_name="max_cols",
),
ScalarAttributeDefinition(
name="p_ones",
levels=[0.50, 0.40, 0.30, 0.20],
description="Probability of a cell being filled",
field_name="p_ones",
),
)
register_dataset(DATASET_NAME, KakurasuDataset, KakurasuConfig, KakurasuCurriculum)

181
tests/test_kakurasu.py Normal file
View File

@@ -0,0 +1,181 @@
import pytest
from reasoning_gym.coaching.base_curriculum import DefaultCurriculumContext, RangeAttributeMode
from reasoning_gym.games import KakurasuConfig, KakurasuCurriculum, KakurasuDataset
def test_kakurasu_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = KakurasuConfig(min_rows=5, max_rows=4) # max_rows < min_rows
config.validate()
with pytest.raises(AssertionError):
config = KakurasuConfig(p_ones=-0.1) # negative probability
config.validate()
def test_kakurasu_deterministic():
"""Test that dataset generates same puzzles with same seed"""
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.2)
dataset1 = KakurasuDataset(config)
dataset2 = KakurasuDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_kakurasu_items():
"""Test basic properties of generated items"""
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
dataset = KakurasuDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify key metadata contents
metadata = item["metadata"]
assert "puzzle" in metadata
assert "solution" in metadata
assert "row_sums" in metadata
assert "col_sums" in metadata
# Verify board dimensions for both puzzle and solution
puzzle, solution = metadata["puzzle"], metadata["solution"]
assert len(puzzle) >= config.min_rows
assert len(puzzle) <= config.max_rows
assert len(solution) >= config.min_rows
assert len(solution) <= config.max_rows
for row in puzzle:
assert len(row) >= config.min_cols
assert len(row) <= config.max_cols
for row in solution:
assert len(row) >= config.min_cols
assert len(row) <= config.max_cols
# Verify row and column sums
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
assert len(row_sums) == len(puzzle)
assert len(col_sums) == len(puzzle[0]) if puzzle else 0
def test_kakurasu_solution_validity():
"""Test that solutions are valid according to Kakurasu rules"""
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
dataset = KakurasuDataset(config)
def is_valid_solution(solution, n_rows, n_cols, row_sums, col_sums):
"""Check if the solution is valid according to Kakurasu rules"""
if len(solution) != n_rows or any(len(row) != n_cols for row in solution):
return False
# Check row sums
for i, row in enumerate(solution):
if sum((j + 1) for j, val in enumerate(row) if val == 1) != row_sums[i]:
return False
# Check column sums
for j in range(n_cols):
if sum((i + 1) for i, row in enumerate(solution) if row[j] == 1) != col_sums[j]:
return False
return True
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
solution = metadata["solution"]
n_rows, n_cols = metadata["n_rows"], metadata["n_cols"]
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
assert is_valid_solution(solution, n_rows, n_cols, row_sums, col_sums)
def test_kakurasu_puzzle_solvability():
"""Test that generated puzzles are solvable and have unique solutions"""
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
dataset = KakurasuDataset(config)
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
n_rows, n_cols = metadata["n_rows"], metadata["n_cols"]
row_sums, col_sums = metadata["row_sums"], metadata["col_sums"]
# Verify puzzle has exactly one solution
assert dataset._count_solutions(n_rows, n_cols, row_sums, col_sums) == 1
def test_kakurasu_answer_scoring():
"""Test the answer scoring mechanism"""
config = KakurasuConfig(seed=42, size=10, min_rows=3, max_rows=9, min_cols=3, max_cols=9, p_ones=0.3)
dataset = KakurasuDataset(config)
for item in dataset:
# Correct answer should score 1.0
assert dataset.score_answer(item["answer"], item) == 1.0
# Wrong answer should score lower
wrong_answer = item["answer"].replace("1", "2")
assert dataset.score_answer(wrong_answer, item) < 1.0
# None or empty answer should score 0.0
assert dataset.score_answer(None, item) == 0.0
assert dataset.score_answer("", item) == 0.0
def test_futoshiki_curriculum():
"""Test the KakurasuCurriculum works as expected"""
curriculum = KakurasuCurriculum()
base_value = {"size": 150, "seed": 1}
context = DefaultCurriculumContext(mode=RangeAttributeMode.UPPER_BOUND)
base_cfg: KakurasuConfig = curriculum.generate_configuration(base_value, context=context)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_rows == 4 and base_cfg.max_rows == 4
assert base_cfg.min_cols == 4 and base_cfg.max_cols == 4
assert base_cfg.p_ones == 0.50
# Test incrementing attribute levels
curriculum.increment_attr_level("rows")
curriculum.increment_attr_level("p_ones")
increased_cfg = curriculum.generate_configuration(base_value, context=context)
assert increased_cfg.min_rows == 6 and increased_cfg.max_rows == 6
assert increased_cfg.p_ones == 0.4
# Test incrementing again
curriculum.increment_attr_level("cols")
curriculum.increment_attr_level("p_ones")
increased_cfg2 = curriculum.generate_configuration(base_value, context=context)
assert increased_cfg2.min_cols == 6 and increased_cfg2.max_cols == 6
assert increased_cfg2.p_ones == 0.3
# Test incrementing to max level
curriculum.increment_attr_level("p_ones")
max_cfg = curriculum.generate_configuration(base_value, context=context)
assert max_cfg.p_ones == 0.2
# Test that we can't go beyond max level
curriculum.increment_attr_level("p_ones")
still_max_cfg = curriculum.generate_configuration(base_value, context=context)
assert still_max_cfg.p_ones == 0.2
# Test decrementing attribute levels
curriculum.decrement_attr_level("p_ones")
decreased_cfg = curriculum.generate_configuration(base_value, context=context)
assert decreased_cfg.p_ones == 0.3
# Test global level setting
curriculum.set_global_level(0)
global_lvl0_cfg = curriculum.generate_configuration(base_value, context=context)
assert global_lvl0_cfg.min_rows == 4 and global_lvl0_cfg.max_rows == 4
# Test global level increment
curriculum.increment_global_level()
global_lvl1_cfg = curriculum.generate_configuration(base_value, context=context)
assert global_lvl1_cfg.min_rows == 6 and global_lvl1_cfg.max_rows == 6