Files
reasoning-gym/reasoning_gym/algorithmic/word_ladder.py
Zafir Stojanovski dced3bfc45 fix(curriculum): Make boundaries in curriculum more sensible (#407)
* init

* fix tests

* unify codeio

* filtered for libraries not present in reasoning-gym

* fix more bounds

* puzzle24

* knight swap curriculum

* fix number sorting

* fix attributes

* add validation of config in creation of dataset

* dry run for instantiating and validating the datasets

* remove unused imports

* fix curriculum tests to reference newly updated attribute names
2025-04-04 20:24:14 +02:00

294 lines
11 KiB
Python

"""Word ladder task generator"""
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time.
Provide your answer as a comma-separated sequence of uppercase letters without spaces.
Each step must be a valid English word."""
DATASET_NAME = "word_ladder"
@dataclass
class WordLadderConfig:
"""Configuration for word ladder task generation"""
min_word_length: int = 4 # Minimum word length
max_word_length: int = 4 # Maximum word length
min_chain_length: int = -1 # Set to -1 for shortest path or a minimum of 3
max_chain_length: int = -1 # Set to -1 for shortest path or a max
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_word_length >= 3, "min_word_length must be >= 3"
assert self.max_word_length >= self.min_word_length, "max_word_length must be >= min_word_length"
assert self.max_word_length <= 5, "max_word_length must be <= 5"
# Add size validation
if self.size > 20000: # Add reasonable upper limit
raise ValueError("Dataset size too large for this algorithm and constraints")
# Modified validation logic
if self.min_chain_length == -1:
if self.max_chain_length != -1:
assert (
self.max_chain_length >= 3
), "When min_chain_length=-1 (shortest path), max_chain_length must be -1 or >=3"
elif self.max_chain_length == -1:
raise AssertionError("max_chain_length cannot be -1 unless min_chain_length is also -1")
else:
assert self.min_chain_length >= 3, "min_chain_length must be 3 or -1"
assert self.max_chain_length >= self.min_chain_length, "max_chain_length must be >= min_chain_length"
def is_valid_path_length(self, length: int) -> bool:
"""Check if a path length meets the configuration requirements"""
# When min_chain_length is -1, we accept any path of length >= 3
if self.min_chain_length == -1:
if self.max_chain_length == -1:
return length >= 3
return 3 <= length <= self.max_chain_length
# Otherwise check against both min and max
return (
self.min_chain_length <= length <= (self.max_chain_length if self.max_chain_length != -1 else float("inf"))
)
class WordLadderDataset(ProceduralDataset):
"""Generates word ladder transformation tasks"""
def __init__(self, config: WordLadderConfig):
self.config = config
self.word_sets = {}
self.word_graphs = {}
self._vocabulary = None # A large list of dictionary words to validate words against
# Load words from CSV
self.word_sets = self._load_words_from_csv(
min_length=self.config.min_word_length, max_length=self.config.max_word_length
)
# Precompute word graphs for all lengths
for length in range(self.config.min_word_length, self.config.max_word_length + 1):
self.word_graphs[length] = self._build_word_graph(length)
config.validate()
super().__init__(config=config, seed=config.seed, size=config.size)
@classmethod
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> dict[int, set[str]]:
"""Load words from CSV file organized by length"""
# Validate length range before processing
assert 3 <= min_length <= max_length <= 5, "Word length must be between 3 and 5 inclusive"
import csv
word_sets = {}
try:
# Get CSV content as string
with get_data_file_path("words.csv").open("r", encoding="utf-8") as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
# Process each word length column using config range
for length in range(min_length, max_length + 1):
col_name = f"{length}_letter"
word = row.get(col_name, "")
if not word: # Skip empty entries
continue
word_sets.setdefault(length, set()).add(word.upper())
except Exception as e:
raise RuntimeError(f"Error processing words.csv content: {e}") from e
# Validate we have words for each length
for length in range(min_length, max_length + 1):
if length not in word_sets or not word_sets[length]:
raise ValueError(f"No valid words found for length {length}")
return word_sets
def _get_neighbors(self, word: str, word_set: set[str]) -> set[str]:
"""Get neighbors from either precomputed graph or by computing on demand"""
# Try precomputed graph first
if len(word) in self.word_graphs and word in self.word_graphs[len(word)]:
return self.word_graphs[len(word)].get(word, set())
# Fall back to computing neighbors directly for custom word sets
neighbors = set()
for i in range(len(word)):
for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
neighbor = word[:i] + c + word[i + 1 :]
if neighbor != word and neighbor in word_set:
neighbors.add(neighbor)
return neighbors
def _build_word_graph(self, word_length: int) -> dict[str, set[str]]:
"""Build graph of word connections for given length, using caching"""
# Return cached graph if it exists
if word_length in self.word_graphs:
return self.word_graphs[word_length]
# Build new graph
word_set = self.word_sets[word_length]
graph = {}
# Build connections
for word in word_set:
neighbors = set()
for i in range(word_length):
for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
neighbor = word[:i] + c + word[i + 1 :]
if neighbor != word and neighbor in word_set:
neighbors.add(neighbor)
graph[word] = neighbors
# Cache and return
self.word_graphs[word_length] = graph
return self.word_graphs[word_length]
def _find_path(self, start: str, end: str, word_set: set[str]) -> Optional[list[str]]:
"""Simplified path finding using BFS for shortest paths"""
# Early exit if words are direct neighbors
if end in self._get_neighbors(start, word_set):
return [start, end]
# Use basic BFS for shortest path
queue = deque([(start, [start])])
visited = {start}
while queue:
current, path = queue.popleft()
if current == end:
if self.config.is_valid_path_length(len(path)):
return path
return None
for neighbor in self._get_neighbors(current, word_set):
if neighbor not in visited:
visited.add(neighbor)
new_path = path + [neighbor]
queue.append((neighbor, new_path))
return None
def _generate_word_pair(self, rng: Random, length: int) -> tuple[str, str, list[str]]:
"""Simplified word pair generation"""
word_set = self.word_sets[length]
words_list = sorted(word_set)
max_attempts = 100
for _ in range(max_attempts):
start = rng.choice(words_list)
end = rng.choice(words_list)
if start == end:
continue
path = self._find_path(start, end, word_set)
if path:
return start, end, path
raise RuntimeError(f"Failed to find valid pair for length {length}")
def __getitem__(self, idx: int) -> dict:
"""Generate a single word ladder task"""
if idx >= self.size:
raise IndexError(f"Dataset index {idx} out of range for size {self.size}")
try:
rng = Random(self.seed + idx)
length = rng.randint(self.config.min_word_length, self.config.max_word_length)
start, end, path = self._generate_word_pair(rng, length)
except RuntimeError as e:
# If we run out of valid paths, adjust the virtual size
self.size = idx
raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}")
return {
"question": QUESTION_TEMPLATE.format(start=start, end=end),
"answer": ",".join(path),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"start_word": start,
"end_word": end,
"word_length": length,
"chain_length": len(path),
"difficulty": {
"word_length": (self.config.min_word_length, self.config.max_word_length),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not isinstance(answer, str):
return 0.0
answer_words = tuple(s.strip() for s in answer.upper().split(","))
metadata = entry["metadata"]
start_word = metadata["start_word"]
end_word = metadata["end_word"]
word_length = len(end_word)
known_words = self.word_sets[word_length]
# Check conditions:
# 1. start and end word match question
# 2. all words have the correct length
# 3. every changed word is a single letter change from the previous word
# 4. all words are in our vocabulary
if len(answer_words) < 2:
return 0.0
if answer_words[0] != start_word or answer_words[-1] != end_word:
return 0.0
if not all(len(w) == word_length for w in answer_words):
return 0.0
for i in range(1, len(answer_words)):
if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1:
return 0.0
reward = 1.0
for word in answer_words:
if not word in known_words:
reward *= 0.5
return reward
class WordLadderCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(WordLadderCurriculum.__name__, WordLadderConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="word_length",
levels=[3, 4, 5],
description="Length of words in the puzzle",
lower_field_name="min_word_length",
upper_field_name="max_word_length",
ensure_interval=True,
)
)
register_dataset(DATASET_NAME, WordLadderDataset, WordLadderConfig, WordLadderCurriculum)