diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index 890ac81f..4cdb9164 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -3,7 +3,7 @@ Cognition tasks for training reasoning capabilities. """ from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationCurriculum, ColorCubeRotationDataset -from .figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset +from .figlet_fonts import FigletFontConfig, FigletFontDataset from .modulo_grid import ModuloGridConfig, ModuloGridDataset from .needle_haystack import NeedleHaystackConfig, NeedleHaystackCurriculum, NeedleHaystackDataset from .number_sequences import NumberSequenceConfig, NumberSequenceCurriculum, NumberSequenceDataset @@ -16,7 +16,6 @@ __all__ = [ "ColorCubeRotationCurriculum", "FigletFontConfig", "FigletFontDataset", - "FigletFontCurriculum", "NumberSequenceConfig", "NumberSequenceDataset", "NumberSequenceCurriculum", diff --git a/reasoning_gym/cognition/figlet_fonts.py b/reasoning_gym/cognition/figlet_fonts.py index cb08919d..f4150fd5 100644 --- a/reasoning_gym/cognition/figlet_fonts.py +++ b/reasoning_gym/cognition/figlet_fonts.py @@ -1,125 +1,11 @@ -import json from dataclasses import dataclass from random import Random from typing import Any, Optional import pyfiglet -from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition -from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset -# These ones are funky and probably aren't good for train/testing -BAD_FONTS = [ - "pyramid", - "runyc", - "assalt_m", - "term", - "tengwar", - "heart_right", - "faces_of", - "heroboti", - "hieroglyphs", - "rainbow_", - "notie_ca", - "ghost", - "rampage_", - "atc_____", - "pacos_pe", - "mad_nurs", - "icl-1900", - "joust___", - "dcs_bfmo", - "letter_w", - "flyn_sh", - "fun_face", - "morse2", - "tecrvs__", - "ntgreek", - "tsalagi", - "etcrvs__", - "faces_of", - "future_8", - "efti_robot", - "danc4", - "p_s_h_m_", - "smkeyboard", - "konto", - "odel_lak", - "courb", - "jerusalem", - "nfi1____", - "keyboard", - "konto_slant" "rot13", - "mirror", - "katakana", - "cards", - "eftichess", - "heart_left", - "trashman", - "morse", - "eftipiti", - "smtengwar", - "e__fist_", - "mike", - "bear", - "hills___", - "rotated", - "wow", - "eftipiti", - "relief2", - "mshebrew210", - "kik_star", - "puzzle", - "p_skateb", - "hypa_bal", - "tomahawk", - "timesofl", - "moscow", - "cola", - "baz__bil", - "stencil1", - "battlesh", - "tsn_base", - "kgames_i", - "binary", - "greek", - "mnemonic", - "panther_", - "b1ff", - "c_consen", - "horizontal_right", - "dwhistled", - "hex", - "flipped", - "high_noo", - "patorjk-hex", - "amc_3_liv1", - "gauntlet", - "cybersmall", - "octal", - "js_cursive", - "battle_s", - "deep_str", - "rally_s2", - "convoy__", - "atc_gran", - "grand_pr", - "ivrit", - "rammstein", - "horizontal_left", - "eftiwall", - "decimal", - "goofy", - "rot13", - "konto_slant", - "subteran", - "rally_sp", - "charset_", -] -ALL_FONTS = pyfiglet.FigletFont.getFonts() -OK_FONTS = list(filter(lambda x: x not in BAD_FONTS, ALL_FONTS)) - @dataclass class FigletFontConfig: @@ -127,35 +13,18 @@ class FigletFontConfig: static_word: Optional[str] = None static_font: Optional[str] = None - min_word_len: int = 3 - max_word_len: int = 7 space_letters: bool = True seed: Optional[int] = None size: int = 500 - def validate(self): - assert self.min_word_len > 0, "min_word_len must be greater than 0" - assert self.min_word_len <= self.max_word_len, "min_word_len must be less than or equal to max_word_len" - if self.static_word: - assert len(self.static_word) > 0, "static_word must have at least one character" - if self.static_font: - assert len(self.static_font) > 0, "static_font must have at least one character" - assert self.static_font in OK_FONTS, f"static_font must be one of {OK_FONTS}" - class FigletFontDataset(ProceduralDataset): """Generates FigletFont tasks""" def __init__(self, config: FigletFontConfig): - with get_data_file_path("anagrams.jsonl").open() as f: - self.words = [ - word - for line in f - for word in json.loads(line)["words"] - if config.min_word_len <= len(word) <= config.max_word_len - ] - assert len(self.words) > 0, "No words found in the dataset with the specified length range" + from ..data.wordle_words import wordle_words + self.wordle_words = wordle_words self._prompt_templates = [ "What word does this say?\n\n{figlet_render}", "Please read the following figlet font:\n\n{figlet_render}", @@ -173,13 +42,123 @@ class FigletFontDataset(ProceduralDataset): """ rng = Random(self.seed + idx) - word = self.config.static_word if self.config.static_word is not None else rng.choice(self.words).upper() + word = self.config.static_word if self.config.static_word is not None else rng.choice(self.wordle_words).upper() if self.config.space_letters: render_word = " ".join(word) else: render_word = word - chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(OK_FONTS) + # These ones are funky and probably aren't good for train/testing + bad_fonts = [ + "pyramid", + "runyc", + "assalt_m", + "term", + "tengwar", + "heart_right", + "faces_of", + "heroboti", + "hieroglyphs", + "rainbow_", + "notie_ca", + "ghost", + "rampage_", + "atc_____", + "pacos_pe", + "mad_nurs", + "icl-1900", + "joust___", + "dcs_bfmo", + "letter_w", + "flyn_sh", + "fun_face", + "morse2", + "tecrvs__", + "ntgreek", + "tsalagi", + "etcrvs__", + "faces_of", + "future_8", + "efti_robot", + "danc4", + "p_s_h_m_", + "smkeyboard", + "konto", + "odel_lak", + "courb", + "jerusalem", + "nfi1____", + "keyboard", + "konto_slant" "rot13", + "mirror", + "katakana", + "cards", + "eftichess", + "heart_left", + "trashman", + "morse", + "eftipiti", + "smtengwar", + "e__fist_", + "mike", + "bear", + "hills___", + "rotated", + "wow", + "eftipiti", + "relief2", + "mshebrew210", + "kik_star", + "puzzle", + "p_skateb", + "hypa_bal", + "tomahawk", + "timesofl", + "moscow", + "cola", + "baz__bil", + "stencil1", + "battlesh", + "tsn_base", + "kgames_i", + "binary", + "greek", + "mnemonic", + "panther_", + "b1ff", + "c_consen", + "horizontal_right", + "dwhistled", + "hex", + "flipped", + "high_noo", + "patorjk-hex", + "amc_3_liv1", + "gauntlet", + "cybersmall", + "octal", + "js_cursive", + "battle_s", + "deep_str", + "rally_s2", + "convoy__", + "atc_gran", + "grand_pr", + "ivrit", + "rammstein", + "horizontal_left", + "eftiwall", + "decimal", + "goofy", + "rot13", + "konto_slant", + "subteran", + "rally_sp", + "charset_", + ] + all_fonts = pyfiglet.FigletFont.getFonts() + ok_fonts = list(filter(lambda x: x not in bad_fonts, all_fonts)) + chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(ok_fonts) figlet_render = pyfiglet.figlet_format(render_word, font=chosen_font) return { @@ -223,24 +202,5 @@ class FigletFontDataset(ProceduralDataset): return score -class FigletFontCurriculum(BaseCurriculum): - def __init__(self): - super().__init__(FigletFontCurriculum.__name__, FigletFontConfig) - - # Define attributes - self._define_attributes( - RangeAttributeDefinition( - name="word_len", - levels=[3, 5, 10, 15, 20, 30], - default_level=1, - description="The length of the word to be displayed", - attr_type=AttributeType.APPEND, - min_value=1, - lower_field_name="min_word_len", - upper_field_name="max_word_len", - ), - ) - - # Register the dataset -register_dataset("figlet_font", FigletFontDataset, FigletFontConfig, FigletFontCurriculum) +register_dataset("figlet_font", FigletFontDataset, FigletFontConfig) diff --git a/tests/test_figlet_fonts.py b/tests/test_figlet_fonts.py index 915d1412..6f0aff40 100644 --- a/tests/test_figlet_fonts.py +++ b/tests/test_figlet_fonts.py @@ -1,13 +1,6 @@ import pytest -from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontCurriculum, FigletFontDataset - - -def test_figlet_config_validation(): - """Test that invalid configs raise appropriate errors""" - with pytest.raises(AssertionError): - config = FigletFontConfig(min_word_len=2, max_word_len=1) # max_word_len < min_word_len - config.validate() +from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontDataset def test_figlet_deterministic(): @@ -48,24 +41,3 @@ def test_static_figlet(): assert dataset.score_answer(answer="TESTY", entry=item) == 1.0 assert dataset.score_answer(answer="WESTY", entry=item) == 0.4 assert dataset.score_answer(answer=None, entry=item) == 0 - - -def test_figlet_curriculum(): - curriculum = FigletFontCurriculum() - - base_value = {"size": 150, "seed": 1} - - base_cfg: FigletFontConfig = curriculum.generate_configuration(base_value) - assert base_cfg.seed == 1 - assert base_cfg.size == 150 - assert base_cfg.min_word_len == 3 and base_cfg.max_word_len == 5 - - # test incrementing attribute levels - curriculum.increment_attr_level("word_len") - increased_cfg = curriculum.generate_configuration(base_value) - assert increased_cfg.min_word_len == 3 and increased_cfg.max_word_len == 10 - - # test decrementing attribute level - curriculum.decrement_attr_level("word_len") - partially_decreased_cfg = curriculum.generate_configuration(base_value) - assert partially_decreased_cfg.min_word_len == 3 and partially_decreased_cfg.max_word_len == 5