mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
246 lines
9.6 KiB
Python
246 lines
9.6 KiB
Python
from dataclasses import dataclass, field
|
|
from random import Random
|
|
from string import Template
|
|
from typing import Optional
|
|
|
|
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
|
from ..factory import ProceduralDataset, register_dataset
|
|
from ..utils import StrEnum
|
|
|
|
DATASET_NAME = "aiw"
|
|
|
|
|
|
class TaskType(StrEnum):
|
|
"""Defines the type of task for the Alice in Wonderland dataset."""
|
|
|
|
SIBLINGS = "siblings"
|
|
FRIENDS = "friends"
|
|
COLLEAGUES = "colleagues" # Added colleagues task
|
|
|
|
|
|
@dataclass
|
|
class AliceInWonderlandConfig:
|
|
"""Configuration options for the Alice in Wonderland dataset.
|
|
|
|
Attributes:
|
|
male_names (list[str]): List of male names to use in questions.
|
|
female_names (list[str]): List of female names to use in questions. Must include 'Alice'.
|
|
task_types (list[TaskType]): List of task types to include in dataset.
|
|
seed (Optional[int]): Seed for random number generation.
|
|
size (int): Number of samples in the dataset.
|
|
max_entities (int): Max number of siblings/friends/colleagues in questions.
|
|
"""
|
|
|
|
male_names: list[str] = field(
|
|
default_factory=lambda: [
|
|
"James",
|
|
"John",
|
|
"Robert",
|
|
"Michael",
|
|
"William",
|
|
"David",
|
|
"Richard",
|
|
"Joseph",
|
|
"Thomas",
|
|
"Charles",
|
|
"Bob",
|
|
]
|
|
)
|
|
female_names: list[str] = field(
|
|
default_factory=lambda: [
|
|
"Mary",
|
|
"Patricia",
|
|
"Jennifer",
|
|
"Linda",
|
|
"Elizabeth",
|
|
"Barbara",
|
|
"Susan",
|
|
"Jessica",
|
|
"Sarah",
|
|
"Margaret",
|
|
"Alice",
|
|
]
|
|
)
|
|
task_types: list[TaskType] = field(
|
|
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
|
|
)
|
|
task_type_weights: list[float] = field(default_factory=lambda: [1 / 3, 1 / 3, 1 / 3])
|
|
seed: Optional[int] = None
|
|
size: int = 10
|
|
max_entities: int = 6 # Added max_entities
|
|
|
|
def validate(self) -> None:
|
|
"""Validates the configuration parameters."""
|
|
assert len(self.male_names) > 0, "must provide male names"
|
|
assert len(self.female_names) > 0, "must provide female names"
|
|
assert "Alice" in self.female_names, "'Alice' must be in female names"
|
|
assert len(self.task_types) > 0, "must provide at least one task type"
|
|
assert self.max_entities > 0, "max_entities must be positive"
|
|
|
|
|
|
class AliceInWonderlandDataset(ProceduralDataset):
|
|
"""
|
|
A procedural dataset inspired by the "Alice in Wonderland" paper.
|
|
|
|
The dataset is inspired by the following paper:
|
|
@inproceedings{nezhurina2024alice,
|
|
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
|
|
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
|
|
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
|
|
Jenia Jitsev},
|
|
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
|
|
Deep Learning},
|
|
year={2024},
|
|
url={https://openreview.net/forum?id=Mkl7dzjYiW}
|
|
}
|
|
|
|
"""
|
|
|
|
def __init__(self, config: AliceInWonderlandConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
self.templates = {
|
|
TaskType.SIBLINGS: [
|
|
Template(
|
|
"$female_name has $num_brothers brothers and she also has "
|
|
"$num_sisters sisters. How many sisters does "
|
|
"$female_name's brother have?"
|
|
),
|
|
Template(
|
|
"$female_name has $num_sisters sisters and she also has "
|
|
"$num_brothers brothers. How many sisters does "
|
|
"$male_name's brother have?"
|
|
),
|
|
],
|
|
TaskType.FRIENDS: [
|
|
Template(
|
|
"$female_name has $num_male male friends and she also has "
|
|
"$num_female female friends. They all are friends with each "
|
|
"other and have no other friends aside. How many female "
|
|
"friends does $male_name, a male friend of $female_name, "
|
|
"have?"
|
|
)
|
|
],
|
|
TaskType.COLLEAGUES: [ # New colleagues templates
|
|
Template(
|
|
"$female_name has $num_male_colleagues_alice_circle male colleagues and she also has "
|
|
"$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. "
|
|
"All these mentioned persons around $female_name are colleagues of each other. "
|
|
"$male_name has $num_male_colleagues_bob_circle male colleagues "
|
|
"and $num_female_colleagues_bob_circle female colleagues in total. "
|
|
"All these mentioned persons around $male_name are colleagues of each other. "
|
|
"The people in the circle around $male_name do not have "
|
|
"other colleagues aside - with the only exception of Matilda. "
|
|
"She is colleague of $male_name and she is also colleague of $female_name, "
|
|
"being part of $female_name's circle. How many female colleagues does Matilda have?"
|
|
),
|
|
],
|
|
}
|
|
|
|
def _get_aiw(self, rng: Random, idx: int) -> dict:
|
|
"""Generates a single Alice in Wonderland question.
|
|
|
|
Args:
|
|
rng (Random): Random number generator.
|
|
|
|
Returns:
|
|
dict: A dictionary containing the generated question, the right answer
|
|
and a description of the example.
|
|
"""
|
|
|
|
task_type = rng.choices(self.config.task_types, weights=self.config.task_type_weights, k=1)[0]
|
|
female_name = rng.choice(self.config.female_names)
|
|
male_name = rng.choice(self.config.male_names)
|
|
|
|
if task_type == TaskType.SIBLINGS:
|
|
num_brothers = rng.randint(1, self.config.max_entities)
|
|
num_sisters = rng.randint(1, self.config.max_entities)
|
|
|
|
answer = num_sisters + 1
|
|
template = rng.choice(self.templates[TaskType.SIBLINGS])
|
|
question = template.substitute(
|
|
female_name=female_name,
|
|
male_name=male_name,
|
|
num_brothers=num_brothers,
|
|
num_sisters=num_sisters,
|
|
)
|
|
elif task_type == TaskType.FRIENDS:
|
|
num_male = rng.randint(1, self.config.max_entities)
|
|
num_female = rng.randint(1, self.config.max_entities)
|
|
|
|
answer = num_female + 1
|
|
template = rng.choice(self.templates[TaskType.FRIENDS])
|
|
question = template.substitute(
|
|
female_name=female_name,
|
|
male_name=male_name,
|
|
num_male=num_male,
|
|
num_female=num_female,
|
|
)
|
|
elif task_type == TaskType.COLLEAGUES:
|
|
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
|
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
|
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
|
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
|
|
|
answer = num_female_colleagues_alice_circle + 1
|
|
template = rng.choice(self.templates[TaskType.COLLEAGUES])
|
|
question = template.substitute(
|
|
female_name=female_name,
|
|
male_name=male_name,
|
|
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
|
|
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
|
|
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
|
|
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
|
|
)
|
|
|
|
return {
|
|
"question": question,
|
|
"answer": str(answer),
|
|
"metadata": {
|
|
"source_dataset": DATASET_NAME,
|
|
"source_index": idx,
|
|
"task_type": task_type.value,
|
|
"difficulty": {
|
|
"task_type_weights": self.config.task_type_weights,
|
|
"num_entities": self.config.max_entities,
|
|
},
|
|
},
|
|
}
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
rng = Random(self.seed + idx)
|
|
return self._get_aiw(rng, idx)
|
|
|
|
|
|
class AliceInWonderlandCurriculum(BaseCurriculum):
|
|
"""Curriculum for the Alice in Wonderland dataset."""
|
|
|
|
def __init__(self):
|
|
super().__init__(AliceInWonderlandCurriculum.__name__, AliceInWonderlandConfig)
|
|
self._define_attributes(
|
|
ScalarAttributeDefinition(
|
|
name="task_type_weights",
|
|
field_name="task_type_weights",
|
|
description="The weight of the task type",
|
|
levels=[
|
|
[1.0, 0.0, 0.0],
|
|
[0.9, 0.05, 0.05],
|
|
[0.7, 0.15, 0.15],
|
|
[0.6, 0.2, 0.2],
|
|
[0.5, 0.25, 0.25],
|
|
[0.4, 0.3, 0.3],
|
|
[0.3, 0.35, 0.35],
|
|
[0.2, 0.4, 0.4],
|
|
[0.1, 0.45, 0.45],
|
|
],
|
|
),
|
|
ScalarAttributeDefinition(
|
|
name="num_entities",
|
|
field_name="max_entities",
|
|
description="The number of entities in the question",
|
|
levels=list(range(4, 18, 2)),
|
|
),
|
|
)
|
|
|
|
|
|
register_dataset(DATASET_NAME, AliceInWonderlandDataset, AliceInWonderlandConfig, AliceInWonderlandCurriculum)
|