Files
reasoning-gym/reasoning_gym/logic/aiw.py
Oliver Stanley 1a727ecf4e support python 3.10 (#450)
* support python 3.10

* add 3.10 to tests

* new StrEnum
2025-06-04 10:34:01 +01:00

246 lines
9.6 KiB
Python

from dataclasses import dataclass, field
from random import Random
from string import Template
from typing import Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
from ..utils import StrEnum
DATASET_NAME = "aiw"
class TaskType(StrEnum):
"""Defines the type of task for the Alice in Wonderland dataset."""
SIBLINGS = "siblings"
FRIENDS = "friends"
COLLEAGUES = "colleagues" # Added colleagues task
@dataclass
class AliceInWonderlandConfig:
"""Configuration options for the Alice in Wonderland dataset.
Attributes:
male_names (list[str]): List of male names to use in questions.
female_names (list[str]): List of female names to use in questions. Must include 'Alice'.
task_types (list[TaskType]): List of task types to include in dataset.
seed (Optional[int]): Seed for random number generation.
size (int): Number of samples in the dataset.
max_entities (int): Max number of siblings/friends/colleagues in questions.
"""
male_names: list[str] = field(
default_factory=lambda: [
"James",
"John",
"Robert",
"Michael",
"William",
"David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Bob",
]
)
female_names: list[str] = field(
default_factory=lambda: [
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Margaret",
"Alice",
]
)
task_types: list[TaskType] = field(
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
)
task_type_weights: list[float] = field(default_factory=lambda: [1 / 3, 1 / 3, 1 / 3])
seed: Optional[int] = None
size: int = 10
max_entities: int = 6 # Added max_entities
def validate(self) -> None:
"""Validates the configuration parameters."""
assert len(self.male_names) > 0, "must provide male names"
assert len(self.female_names) > 0, "must provide female names"
assert "Alice" in self.female_names, "'Alice' must be in female names"
assert len(self.task_types) > 0, "must provide at least one task type"
assert self.max_entities > 0, "max_entities must be positive"
class AliceInWonderlandDataset(ProceduralDataset):
"""
A procedural dataset inspired by the "Alice in Wonderland" paper.
The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}
"""
def __init__(self, config: AliceInWonderlandConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.templates = {
TaskType.SIBLINGS: [
Template(
"$female_name has $num_brothers brothers and she also has "
"$num_sisters sisters. How many sisters does "
"$female_name's brother have?"
),
Template(
"$female_name has $num_sisters sisters and she also has "
"$num_brothers brothers. How many sisters does "
"$male_name's brother have?"
),
],
TaskType.FRIENDS: [
Template(
"$female_name has $num_male male friends and she also has "
"$num_female female friends. They all are friends with each "
"other and have no other friends aside. How many female "
"friends does $male_name, a male friend of $female_name, "
"have?"
)
],
TaskType.COLLEAGUES: [ # New colleagues templates
Template(
"$female_name has $num_male_colleagues_alice_circle male colleagues and she also has "
"$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. "
"All these mentioned persons around $female_name are colleagues of each other. "
"$male_name has $num_male_colleagues_bob_circle male colleagues "
"and $num_female_colleagues_bob_circle female colleagues in total. "
"All these mentioned persons around $male_name are colleagues of each other. "
"The people in the circle around $male_name do not have "
"other colleagues aside - with the only exception of Matilda. "
"She is colleague of $male_name and she is also colleague of $female_name, "
"being part of $female_name's circle. How many female colleagues does Matilda have?"
),
],
}
def _get_aiw(self, rng: Random, idx: int) -> dict:
"""Generates a single Alice in Wonderland question.
Args:
rng (Random): Random number generator.
Returns:
dict: A dictionary containing the generated question, the right answer
and a description of the example.
"""
task_type = rng.choices(self.config.task_types, weights=self.config.task_type_weights, k=1)[0]
female_name = rng.choice(self.config.female_names)
male_name = rng.choice(self.config.male_names)
if task_type == TaskType.SIBLINGS:
num_brothers = rng.randint(1, self.config.max_entities)
num_sisters = rng.randint(1, self.config.max_entities)
answer = num_sisters + 1
template = rng.choice(self.templates[TaskType.SIBLINGS])
question = template.substitute(
female_name=female_name,
male_name=male_name,
num_brothers=num_brothers,
num_sisters=num_sisters,
)
elif task_type == TaskType.FRIENDS:
num_male = rng.randint(1, self.config.max_entities)
num_female = rng.randint(1, self.config.max_entities)
answer = num_female + 1
template = rng.choice(self.templates[TaskType.FRIENDS])
question = template.substitute(
female_name=female_name,
male_name=male_name,
num_male=num_male,
num_female=num_female,
)
elif task_type == TaskType.COLLEAGUES:
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
answer = num_female_colleagues_alice_circle + 1
template = rng.choice(self.templates[TaskType.COLLEAGUES])
question = template.substitute(
female_name=female_name,
male_name=male_name,
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
)
return {
"question": question,
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"task_type": task_type.value,
"difficulty": {
"task_type_weights": self.config.task_type_weights,
"num_entities": self.config.max_entities,
},
},
}
def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)
return self._get_aiw(rng, idx)
class AliceInWonderlandCurriculum(BaseCurriculum):
"""Curriculum for the Alice in Wonderland dataset."""
def __init__(self):
super().__init__(AliceInWonderlandCurriculum.__name__, AliceInWonderlandConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="task_type_weights",
field_name="task_type_weights",
description="The weight of the task type",
levels=[
[1.0, 0.0, 0.0],
[0.9, 0.05, 0.05],
[0.7, 0.15, 0.15],
[0.6, 0.2, 0.2],
[0.5, 0.25, 0.25],
[0.4, 0.3, 0.3],
[0.3, 0.35, 0.35],
[0.2, 0.4, 0.4],
[0.1, 0.45, 0.45],
],
),
ScalarAttributeDefinition(
name="num_entities",
field_name="max_entities",
description="The number of entities in the question",
levels=list(range(4, 18, 2)),
),
)
register_dataset(DATASET_NAME, AliceInWonderlandDataset, AliceInWonderlandConfig, AliceInWonderlandCurriculum)