diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 157c993d..61380122 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: pull-requests: write strategy: matrix: - python-version: ["3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index 3129330b..d7b53aca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "A library of procedural dataset generators for training reasoning models" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.10" dependencies = [ "bfi==1.0.4", "cellpylib==2.4.0", diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py index 6e4d1fbc..b83c8807 100644 --- a/reasoning_gym/algorithmic/word_sorting.py +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -2,13 +2,13 @@ import re from dataclasses import dataclass -from enum import StrEnum from random import Random from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum class TextTransformation(StrEnum): diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py index 40a223bd..023432d5 100644 --- a/reasoning_gym/arithmetic/calendar_arithmetic.py +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -3,11 +3,12 @@ import math import random from dataclasses import dataclass from datetime import date, timedelta -from enum import Enum, StrEnum, auto +from enum import Enum, auto from typing import Any, Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum DATASET_NAME = "calendar_arithmetic" diff --git a/reasoning_gym/coaching/base_curriculum.py b/reasoning_gym/coaching/base_curriculum.py index f64f1234..2a6553c5 100644 --- a/reasoning_gym/coaching/base_curriculum.py +++ b/reasoning_gym/coaching/base_curriculum.py @@ -1,8 +1,8 @@ import abc from collections.abc import Iterable -from enum import StrEnum from typing import Any, Optional, TypeVar +from ..utils import StrEnum from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition ConfigT = TypeVar("ConfigT") diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index 335c4b7e..44100bfb 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -1,10 +1,10 @@ import random from dataclasses import dataclass -from enum import StrEnum from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum class Color(StrEnum): diff --git a/reasoning_gym/cognition/number_sequences.py b/reasoning_gym/cognition/number_sequences.py index 170be7a9..6fd6d503 100644 --- a/reasoning_gym/cognition/number_sequences.py +++ b/reasoning_gym/cognition/number_sequences.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from enum import StrEnum from random import Random from typing import Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum DATASET_NAME = "number_sequence" diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 5d52cb80..1de0c1c8 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -1,11 +1,11 @@ import random from dataclasses import dataclass, field -from enum import StrEnum from itertools import count from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum DATASET_NAME = "family_relationships" diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 2cdb299e..e6c1109f 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -1,11 +1,11 @@ from dataclasses import dataclass, field -from enum import StrEnum from random import Random from string import Template from typing import Optional -from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition +from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum DATASET_NAME = "aiw" diff --git a/reasoning_gym/logic/propositional_logic.py b/reasoning_gym/logic/propositional_logic.py index d3c6479c..961f94f3 100644 --- a/reasoning_gym/logic/propositional_logic.py +++ b/reasoning_gym/logic/propositional_logic.py @@ -2,12 +2,12 @@ import re from dataclasses import dataclass -from enum import StrEnum from random import Random from typing import Any, Optional from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum DATASET_NAME = "propositional_logic" diff --git a/reasoning_gym/logic/syllogisms.py b/reasoning_gym/logic/syllogisms.py index 600f09bc..6a56594f 100644 --- a/reasoning_gym/logic/syllogisms.py +++ b/reasoning_gym/logic/syllogisms.py @@ -1,12 +1,12 @@ """Syllogism reasoning task generator""" from dataclasses import dataclass -from enum import StrEnum from random import Random from typing import Optional from ..coaching import BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset +from ..utils import StrEnum DATASET_NAME = "syllogism" diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index ad0a5472..83ab6d46 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -1,6 +1,7 @@ import math import re from decimal import Decimal, InvalidOperation +from enum import Enum from fractions import Fraction from typing import Any, Optional, Union @@ -117,3 +118,38 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm reward = len(oracle_answer) / len(answer) return reward + + +class StrEnum(str, Enum): + """ + Taken from Python 3.11 StrEnum implementation, moved here to support Python 3.10. + Enum where members are also (and must be) strings + """ + + def __new__(cls, *values): + "values must already be of type `str`" + if len(values) > 3: + raise TypeError("too many arguments for str(): %r" % (values,)) + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError("%r is not a string" % (values[0],)) + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError("encoding must be a string, not %r" % (values[1],)) + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError("errors must be a string, not %r" % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member + + @staticmethod + def _generate_next_value_(name, start, count, last_values): + """ + Return the lower-cased version of the member name. + """ + return name.lower()