figlet font curriculum

This commit is contained in:
Zafir Stojanovski
2025-03-14 22:21:19 +01:00
committed by Rich Jones
parent 9234aa77bf
commit 29bf78293f
3 changed files with 186 additions and 117 deletions

View File

@@ -3,7 +3,7 @@ Cognition tasks for training reasoning capabilities.
"""
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationCurriculum, ColorCubeRotationDataset
from .figlet_fonts import FigletFontConfig, FigletFontDataset
from .figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset
from .modulo_grid import ModuloGridConfig, ModuloGridDataset
from .needle_haystack import NeedleHaystackConfig, NeedleHaystackCurriculum, NeedleHaystackDataset
from .number_sequences import NumberSequenceConfig, NumberSequenceCurriculum, NumberSequenceDataset
@@ -16,6 +16,7 @@ __all__ = [
"ColorCubeRotationCurriculum",
"FigletFontConfig",
"FigletFontDataset",
"FigletFontCurriculum",
"NumberSequenceConfig",
"NumberSequenceDataset",
"NumberSequenceCurriculum",

View File

@@ -1,11 +1,125 @@
import json
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
import pyfiglet
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset
# These ones are funky and probably aren't good for train/testing
BAD_FONTS = [
"pyramid",
"runyc",
"assalt_m",
"term",
"tengwar",
"heart_right",
"faces_of",
"heroboti",
"hieroglyphs",
"rainbow_",
"notie_ca",
"ghost",
"rampage_",
"atc_____",
"pacos_pe",
"mad_nurs",
"icl-1900",
"joust___",
"dcs_bfmo",
"letter_w",
"flyn_sh",
"fun_face",
"morse2",
"tecrvs__",
"ntgreek",
"tsalagi",
"etcrvs__",
"faces_of",
"future_8",
"efti_robot",
"danc4",
"p_s_h_m_",
"smkeyboard",
"konto",
"odel_lak",
"courb",
"jerusalem",
"nfi1____",
"keyboard",
"konto_slant" "rot13",
"mirror",
"katakana",
"cards",
"eftichess",
"heart_left",
"trashman",
"morse",
"eftipiti",
"smtengwar",
"e__fist_",
"mike",
"bear",
"hills___",
"rotated",
"wow",
"eftipiti",
"relief2",
"mshebrew210",
"kik_star",
"puzzle",
"p_skateb",
"hypa_bal",
"tomahawk",
"timesofl",
"moscow",
"cola",
"baz__bil",
"stencil1",
"battlesh",
"tsn_base",
"kgames_i",
"binary",
"greek",
"mnemonic",
"panther_",
"b1ff",
"c_consen",
"horizontal_right",
"dwhistled",
"hex",
"flipped",
"high_noo",
"patorjk-hex",
"amc_3_liv1",
"gauntlet",
"cybersmall",
"octal",
"js_cursive",
"battle_s",
"deep_str",
"rally_s2",
"convoy__",
"atc_gran",
"grand_pr",
"ivrit",
"rammstein",
"horizontal_left",
"eftiwall",
"decimal",
"goofy",
"rot13",
"konto_slant",
"subteran",
"rally_sp",
"charset_",
]
ALL_FONTS = pyfiglet.FigletFont.getFonts()
OK_FONTS = list(filter(lambda x: x not in BAD_FONTS, ALL_FONTS))
@dataclass
class FigletFontConfig:
@@ -13,18 +127,35 @@ class FigletFontConfig:
static_word: Optional[str] = None
static_font: Optional[str] = None
min_word_len: int = 3
max_word_len: int = 7
space_letters: bool = True
seed: Optional[int] = None
size: int = 500
def validate(self):
assert self.min_word_len > 0, "min_word_len must be greater than 0"
assert self.min_word_len <= self.max_word_len, "min_word_len must be less than or equal to max_word_len"
if self.static_word:
assert len(self.static_word) > 0, "static_word must have at least one character"
if self.static_font:
assert len(self.static_font) > 0, "static_font must have at least one character"
assert self.static_font in OK_FONTS, f"static_font must be one of {OK_FONTS}"
class FigletFontDataset(ProceduralDataset):
"""Generates FigletFont tasks"""
def __init__(self, config: FigletFontConfig):
from ..data.wordle_words import wordle_words
with get_data_file_path("anagrams.jsonl").open() as f:
self.words = [
word
for line in f
for word in json.loads(line)["words"]
if config.min_word_len <= len(word) <= config.max_word_len
]
assert len(self.words) > 0, "No words found in the dataset with the specified length range"
self.wordle_words = wordle_words
self._prompt_templates = [
"What word does this say?\n\n{figlet_render}",
"Please read the following figlet font:\n\n{figlet_render}",
@@ -42,123 +173,13 @@ class FigletFontDataset(ProceduralDataset):
"""
rng = Random(self.seed + idx)
word = self.config.static_word if self.config.static_word is not None else rng.choice(self.wordle_words).upper()
word = self.config.static_word if self.config.static_word is not None else rng.choice(self.words).upper()
if self.config.space_letters:
render_word = " ".join(word)
else:
render_word = word
# These ones are funky and probably aren't good for train/testing
bad_fonts = [
"pyramid",
"runyc",
"assalt_m",
"term",
"tengwar",
"heart_right",
"faces_of",
"heroboti",
"hieroglyphs",
"rainbow_",
"notie_ca",
"ghost",
"rampage_",
"atc_____",
"pacos_pe",
"mad_nurs",
"icl-1900",
"joust___",
"dcs_bfmo",
"letter_w",
"flyn_sh",
"fun_face",
"morse2",
"tecrvs__",
"ntgreek",
"tsalagi",
"etcrvs__",
"faces_of",
"future_8",
"efti_robot",
"danc4",
"p_s_h_m_",
"smkeyboard",
"konto",
"odel_lak",
"courb",
"jerusalem",
"nfi1____",
"keyboard",
"konto_slant" "rot13",
"mirror",
"katakana",
"cards",
"eftichess",
"heart_left",
"trashman",
"morse",
"eftipiti",
"smtengwar",
"e__fist_",
"mike",
"bear",
"hills___",
"rotated",
"wow",
"eftipiti",
"relief2",
"mshebrew210",
"kik_star",
"puzzle",
"p_skateb",
"hypa_bal",
"tomahawk",
"timesofl",
"moscow",
"cola",
"baz__bil",
"stencil1",
"battlesh",
"tsn_base",
"kgames_i",
"binary",
"greek",
"mnemonic",
"panther_",
"b1ff",
"c_consen",
"horizontal_right",
"dwhistled",
"hex",
"flipped",
"high_noo",
"patorjk-hex",
"amc_3_liv1",
"gauntlet",
"cybersmall",
"octal",
"js_cursive",
"battle_s",
"deep_str",
"rally_s2",
"convoy__",
"atc_gran",
"grand_pr",
"ivrit",
"rammstein",
"horizontal_left",
"eftiwall",
"decimal",
"goofy",
"rot13",
"konto_slant",
"subteran",
"rally_sp",
"charset_",
]
all_fonts = pyfiglet.FigletFont.getFonts()
ok_fonts = list(filter(lambda x: x not in bad_fonts, all_fonts))
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(ok_fonts)
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(OK_FONTS)
figlet_render = pyfiglet.figlet_format(render_word, font=chosen_font)
return {
@@ -202,5 +223,24 @@ class FigletFontDataset(ProceduralDataset):
return score
class FigletFontCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(FigletFontCurriculum.__name__, FigletFontConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="word_len",
levels=[3, 5, 10, 15, 20, 30],
default_level=1,
description="The length of the word to be displayed",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_word_len",
upper_field_name="max_word_len",
),
)
# Register the dataset
register_dataset("figlet_font", FigletFontDataset, FigletFontConfig)
register_dataset("figlet_font", FigletFontDataset, FigletFontConfig, FigletFontCurriculum)

View File

@@ -1,6 +1,13 @@
import pytest
from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontDataset
from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset
def test_figlet_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = FigletFontConfig(min_word_len=2, max_word_len=1) # max_word_len < min_word_len
config.validate()
def test_figlet_deterministic():
@@ -41,3 +48,24 @@ def test_static_figlet():
assert dataset.score_answer(answer="TESTY", entry=item) == 1.0
assert dataset.score_answer(answer="WESTY", entry=item) == 0.4
assert dataset.score_answer(answer=None, entry=item) == 0
def test_figlet_curriculum():
curriculum = FigletFontCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: FigletFontConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_word_len == 3 and base_cfg.max_word_len == 5
# test incrementing attribute levels
curriculum.increment_attr_level("word_len")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_word_len == 3 and increased_cfg.max_word_len == 10
# test decrementing attribute level
curriculum.decrement_attr_level("word_len")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_word_len == 3 and partially_decreased_cfg.max_word_len == 5