Revert "figlet font curriculum"

This reverts commit 29bf78293f.
This commit is contained in:
Andreas Koepf
2025-03-18 23:36:06 +01:00
parent 6c95811278
commit 37170afb50
3 changed files with 117 additions and 186 deletions

View File

@@ -3,7 +3,7 @@ Cognition tasks for training reasoning capabilities.
""" """
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationCurriculum, ColorCubeRotationDataset from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationCurriculum, ColorCubeRotationDataset
from .figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset from .figlet_fonts import FigletFontConfig, FigletFontDataset
from .modulo_grid import ModuloGridConfig, ModuloGridDataset from .modulo_grid import ModuloGridConfig, ModuloGridDataset
from .needle_haystack import NeedleHaystackConfig, NeedleHaystackCurriculum, NeedleHaystackDataset from .needle_haystack import NeedleHaystackConfig, NeedleHaystackCurriculum, NeedleHaystackDataset
from .number_sequences import NumberSequenceConfig, NumberSequenceCurriculum, NumberSequenceDataset from .number_sequences import NumberSequenceConfig, NumberSequenceCurriculum, NumberSequenceDataset
@@ -16,7 +16,6 @@ __all__ = [
"ColorCubeRotationCurriculum", "ColorCubeRotationCurriculum",
"FigletFontConfig", "FigletFontConfig",
"FigletFontDataset", "FigletFontDataset",
"FigletFontCurriculum",
"NumberSequenceConfig", "NumberSequenceConfig",
"NumberSequenceDataset", "NumberSequenceDataset",
"NumberSequenceCurriculum", "NumberSequenceCurriculum",

View File

@@ -1,125 +1,11 @@
import json
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
import pyfiglet import pyfiglet
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
# These ones are funky and probably aren't good for train/testing
BAD_FONTS = [
"pyramid",
"runyc",
"assalt_m",
"term",
"tengwar",
"heart_right",
"faces_of",
"heroboti",
"hieroglyphs",
"rainbow_",
"notie_ca",
"ghost",
"rampage_",
"atc_____",
"pacos_pe",
"mad_nurs",
"icl-1900",
"joust___",
"dcs_bfmo",
"letter_w",
"flyn_sh",
"fun_face",
"morse2",
"tecrvs__",
"ntgreek",
"tsalagi",
"etcrvs__",
"faces_of",
"future_8",
"efti_robot",
"danc4",
"p_s_h_m_",
"smkeyboard",
"konto",
"odel_lak",
"courb",
"jerusalem",
"nfi1____",
"keyboard",
"konto_slant" "rot13",
"mirror",
"katakana",
"cards",
"eftichess",
"heart_left",
"trashman",
"morse",
"eftipiti",
"smtengwar",
"e__fist_",
"mike",
"bear",
"hills___",
"rotated",
"wow",
"eftipiti",
"relief2",
"mshebrew210",
"kik_star",
"puzzle",
"p_skateb",
"hypa_bal",
"tomahawk",
"timesofl",
"moscow",
"cola",
"baz__bil",
"stencil1",
"battlesh",
"tsn_base",
"kgames_i",
"binary",
"greek",
"mnemonic",
"panther_",
"b1ff",
"c_consen",
"horizontal_right",
"dwhistled",
"hex",
"flipped",
"high_noo",
"patorjk-hex",
"amc_3_liv1",
"gauntlet",
"cybersmall",
"octal",
"js_cursive",
"battle_s",
"deep_str",
"rally_s2",
"convoy__",
"atc_gran",
"grand_pr",
"ivrit",
"rammstein",
"horizontal_left",
"eftiwall",
"decimal",
"goofy",
"rot13",
"konto_slant",
"subteran",
"rally_sp",
"charset_",
]
ALL_FONTS = pyfiglet.FigletFont.getFonts()
OK_FONTS = list(filter(lambda x: x not in BAD_FONTS, ALL_FONTS))
@dataclass @dataclass
class FigletFontConfig: class FigletFontConfig:
@@ -127,35 +13,18 @@ class FigletFontConfig:
static_word: Optional[str] = None static_word: Optional[str] = None
static_font: Optional[str] = None static_font: Optional[str] = None
min_word_len: int = 3
max_word_len: int = 7
space_letters: bool = True space_letters: bool = True
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
def validate(self):
assert self.min_word_len > 0, "min_word_len must be greater than 0"
assert self.min_word_len <= self.max_word_len, "min_word_len must be less than or equal to max_word_len"
if self.static_word:
assert len(self.static_word) > 0, "static_word must have at least one character"
if self.static_font:
assert len(self.static_font) > 0, "static_font must have at least one character"
assert self.static_font in OK_FONTS, f"static_font must be one of {OK_FONTS}"
class FigletFontDataset(ProceduralDataset): class FigletFontDataset(ProceduralDataset):
"""Generates FigletFont tasks""" """Generates FigletFont tasks"""
def __init__(self, config: FigletFontConfig): def __init__(self, config: FigletFontConfig):
with get_data_file_path("anagrams.jsonl").open() as f: from ..data.wordle_words import wordle_words
self.words = [
word
for line in f
for word in json.loads(line)["words"]
if config.min_word_len <= len(word) <= config.max_word_len
]
assert len(self.words) > 0, "No words found in the dataset with the specified length range"
self.wordle_words = wordle_words
self._prompt_templates = [ self._prompt_templates = [
"What word does this say?\n\n{figlet_render}", "What word does this say?\n\n{figlet_render}",
"Please read the following figlet font:\n\n{figlet_render}", "Please read the following figlet font:\n\n{figlet_render}",
@@ -173,13 +42,123 @@ class FigletFontDataset(ProceduralDataset):
""" """
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
word = self.config.static_word if self.config.static_word is not None else rng.choice(self.words).upper() word = self.config.static_word if self.config.static_word is not None else rng.choice(self.wordle_words).upper()
if self.config.space_letters: if self.config.space_letters:
render_word = " ".join(word) render_word = " ".join(word)
else: else:
render_word = word render_word = word
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(OK_FONTS) # These ones are funky and probably aren't good for train/testing
bad_fonts = [
"pyramid",
"runyc",
"assalt_m",
"term",
"tengwar",
"heart_right",
"faces_of",
"heroboti",
"hieroglyphs",
"rainbow_",
"notie_ca",
"ghost",
"rampage_",
"atc_____",
"pacos_pe",
"mad_nurs",
"icl-1900",
"joust___",
"dcs_bfmo",
"letter_w",
"flyn_sh",
"fun_face",
"morse2",
"tecrvs__",
"ntgreek",
"tsalagi",
"etcrvs__",
"faces_of",
"future_8",
"efti_robot",
"danc4",
"p_s_h_m_",
"smkeyboard",
"konto",
"odel_lak",
"courb",
"jerusalem",
"nfi1____",
"keyboard",
"konto_slant" "rot13",
"mirror",
"katakana",
"cards",
"eftichess",
"heart_left",
"trashman",
"morse",
"eftipiti",
"smtengwar",
"e__fist_",
"mike",
"bear",
"hills___",
"rotated",
"wow",
"eftipiti",
"relief2",
"mshebrew210",
"kik_star",
"puzzle",
"p_skateb",
"hypa_bal",
"tomahawk",
"timesofl",
"moscow",
"cola",
"baz__bil",
"stencil1",
"battlesh",
"tsn_base",
"kgames_i",
"binary",
"greek",
"mnemonic",
"panther_",
"b1ff",
"c_consen",
"horizontal_right",
"dwhistled",
"hex",
"flipped",
"high_noo",
"patorjk-hex",
"amc_3_liv1",
"gauntlet",
"cybersmall",
"octal",
"js_cursive",
"battle_s",
"deep_str",
"rally_s2",
"convoy__",
"atc_gran",
"grand_pr",
"ivrit",
"rammstein",
"horizontal_left",
"eftiwall",
"decimal",
"goofy",
"rot13",
"konto_slant",
"subteran",
"rally_sp",
"charset_",
]
all_fonts = pyfiglet.FigletFont.getFonts()
ok_fonts = list(filter(lambda x: x not in bad_fonts, all_fonts))
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(ok_fonts)
figlet_render = pyfiglet.figlet_format(render_word, font=chosen_font) figlet_render = pyfiglet.figlet_format(render_word, font=chosen_font)
return { return {
@@ -223,24 +202,5 @@ class FigletFontDataset(ProceduralDataset):
return score return score
class FigletFontCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(FigletFontCurriculum.__name__, FigletFontConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="word_len",
levels=[3, 5, 10, 15, 20, 30],
default_level=1,
description="The length of the word to be displayed",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_word_len",
upper_field_name="max_word_len",
),
)
# Register the dataset # Register the dataset
register_dataset("figlet_font", FigletFontDataset, FigletFontConfig, FigletFontCurriculum) register_dataset("figlet_font", FigletFontDataset, FigletFontConfig)

View File

@@ -1,13 +1,6 @@
import pytest import pytest
from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontDataset
def test_figlet_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = FigletFontConfig(min_word_len=2, max_word_len=1) # max_word_len < min_word_len
config.validate()
def test_figlet_deterministic(): def test_figlet_deterministic():
@@ -48,24 +41,3 @@ def test_static_figlet():
assert dataset.score_answer(answer="TESTY", entry=item) == 1.0 assert dataset.score_answer(answer="TESTY", entry=item) == 1.0
assert dataset.score_answer(answer="WESTY", entry=item) == 0.4 assert dataset.score_answer(answer="WESTY", entry=item) == 0.4
assert dataset.score_answer(answer=None, entry=item) == 0 assert dataset.score_answer(answer=None, entry=item) == 0
def test_figlet_curriculum():
curriculum = FigletFontCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: FigletFontConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_word_len == 3 and base_cfg.max_word_len == 5
# test incrementing attribute levels
curriculum.increment_attr_level("word_len")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_word_len == 3 and increased_cfg.max_word_len == 10
# test decrementing attribute level
curriculum.decrement_attr_level("word_len")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_word_len == 3 and partially_decreased_cfg.max_word_len == 5