Revert "figlet font curriculum"

This reverts commit 29bf78293f.
This commit is contained in:
Andreas Koepf
2025-03-18 23:36:06 +01:00
parent 6c95811278
commit 37170afb50
3 changed files with 117 additions and 186 deletions

View File

@@ -3,7 +3,7 @@ Cognition tasks for training reasoning capabilities.
"""
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationCurriculum, ColorCubeRotationDataset
from .figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset
from .figlet_fonts import FigletFontConfig, FigletFontDataset
from .modulo_grid import ModuloGridConfig, ModuloGridDataset
from .needle_haystack import NeedleHaystackConfig, NeedleHaystackCurriculum, NeedleHaystackDataset
from .number_sequences import NumberSequenceConfig, NumberSequenceCurriculum, NumberSequenceDataset
@@ -16,7 +16,6 @@ __all__ = [
"ColorCubeRotationCurriculum",
"FigletFontConfig",
"FigletFontDataset",
"FigletFontCurriculum",
"NumberSequenceConfig",
"NumberSequenceDataset",
"NumberSequenceCurriculum",

View File

@@ -1,16 +1,55 @@
import json
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
import pyfiglet
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset
@dataclass
class FigletFontConfig:
"""Configuration for FigletFont task generation"""
static_word: Optional[str] = None
static_font: Optional[str] = None
space_letters: bool = True
seed: Optional[int] = None
size: int = 500
class FigletFontDataset(ProceduralDataset):
"""Generates FigletFont tasks"""
def __init__(self, config: FigletFontConfig):
from ..data.wordle_words import wordle_words
self.wordle_words = wordle_words
self._prompt_templates = [
"What word does this say?\n\n{figlet_render}",
"Please read the following figlet font:\n\n{figlet_render}",
]
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single FigletFont task
Returns:
dict with keys:
- question: str, the task description with figlet string
- answer: str, the figlet encoded word
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
word = self.config.static_word if self.config.static_word is not None else rng.choice(self.wordle_words).upper()
if self.config.space_letters:
render_word = " ".join(word)
else:
render_word = word
# These ones are funky and probably aren't good for train/testing
BAD_FONTS = [
bad_fonts = [
"pyramid",
"runyc",
"assalt_m",
@@ -117,69 +156,9 @@ BAD_FONTS = [
"rally_sp",
"charset_",
]
ALL_FONTS = pyfiglet.FigletFont.getFonts()
OK_FONTS = list(filter(lambda x: x not in BAD_FONTS, ALL_FONTS))
@dataclass
class FigletFontConfig:
"""Configuration for FigletFont task generation"""
static_word: Optional[str] = None
static_font: Optional[str] = None
min_word_len: int = 3
max_word_len: int = 7
space_letters: bool = True
seed: Optional[int] = None
size: int = 500
def validate(self):
assert self.min_word_len > 0, "min_word_len must be greater than 0"
assert self.min_word_len <= self.max_word_len, "min_word_len must be less than or equal to max_word_len"
if self.static_word:
assert len(self.static_word) > 0, "static_word must have at least one character"
if self.static_font:
assert len(self.static_font) > 0, "static_font must have at least one character"
assert self.static_font in OK_FONTS, f"static_font must be one of {OK_FONTS}"
class FigletFontDataset(ProceduralDataset):
"""Generates FigletFont tasks"""
def __init__(self, config: FigletFontConfig):
with get_data_file_path("anagrams.jsonl").open() as f:
self.words = [
word
for line in f
for word in json.loads(line)["words"]
if config.min_word_len <= len(word) <= config.max_word_len
]
assert len(self.words) > 0, "No words found in the dataset with the specified length range"
self._prompt_templates = [
"What word does this say?\n\n{figlet_render}",
"Please read the following figlet font:\n\n{figlet_render}",
]
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single FigletFont task
Returns:
dict with keys:
- question: str, the task description with figlet string
- answer: str, the figlet encoded word
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
word = self.config.static_word if self.config.static_word is not None else rng.choice(self.words).upper()
if self.config.space_letters:
render_word = " ".join(word)
else:
render_word = word
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(OK_FONTS)
all_fonts = pyfiglet.FigletFont.getFonts()
ok_fonts = list(filter(lambda x: x not in bad_fonts, all_fonts))
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(ok_fonts)
figlet_render = pyfiglet.figlet_format(render_word, font=chosen_font)
return {
@@ -223,24 +202,5 @@ class FigletFontDataset(ProceduralDataset):
return score
class FigletFontCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(FigletFontCurriculum.__name__, FigletFontConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="word_len",
levels=[3, 5, 10, 15, 20, 30],
default_level=1,
description="The length of the word to be displayed",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_word_len",
upper_field_name="max_word_len",
),
)
# Register the dataset
register_dataset("figlet_font", FigletFontDataset, FigletFontConfig, FigletFontCurriculum)
register_dataset("figlet_font", FigletFontDataset, FigletFontConfig)

View File

@@ -1,13 +1,6 @@
import pytest
from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset
def test_figlet_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = FigletFontConfig(min_word_len=2, max_word_len=1) # max_word_len < min_word_len
config.validate()
from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontDataset
def test_figlet_deterministic():
@@ -48,24 +41,3 @@ def test_static_figlet():
assert dataset.score_answer(answer="TESTY", entry=item) == 1.0
assert dataset.score_answer(answer="WESTY", entry=item) == 0.4
assert dataset.score_answer(answer=None, entry=item) == 0
def test_figlet_curriculum():
curriculum = FigletFontCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: FigletFontConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_word_len == 3 and base_cfg.max_word_len == 5
# test incrementing attribute levels
curriculum.increment_attr_level("word_len")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_word_len == 3 and increased_cfg.max_word_len == 10
# test decrementing attribute level
curriculum.decrement_attr_level("word_len")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_word_len == 3 and partially_decreased_cfg.max_word_len == 5