Files
reasoning-gym/notebooks/check-collisions-in-reasoning-gym-dataset.ipynb
Adefioye 5b653b346c Data collisions notebooks and data (#406)
* Add collisions data

* Fix logic issues in basic_arithmetic and gsm_symbolic data
2025-04-02 09:36:09 +02:00

419 lines
73 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "e81636e0-fa3c-462d-ae2e-2bb35cafd544",
"metadata": {},
"source": [
"## Investigating collisions in reasoning-gym datasets\n",
"\n",
"This notebook helps to investigate collisions in training and validation datasets generated with different seeds as intended to be used for RL training."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "42323371-404e-4e86-b8b8-a420b4c79303",
"metadata": {},
"outputs": [],
"source": [
"import reasoning_gym"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f06e7932-6c77-4609-8a33-7c4d815841d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of data: 15\n"
]
}
],
"source": [
"with open(\"data.txt\") as f:\n",
" data_names = f.readlines()\n",
" data_names = [name.strip() for name in data_names]\n",
" print(\"Total number of data: \", len(data_names))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d7a5a5bf-7428-46f5-a7f5-a46238df2543",
"metadata": {},
"outputs": [],
"source": [
"TOTAL = 10000\n",
"collisions = []"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7138aced-d61a-4e2a-9935-a9b251e6d554",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"for name in data_names:\n",
" data_1 = reasoning_gym.create_dataset(name, size=TOTAL, seed=1)\n",
" data_2 = reasoning_gym.create_dataset(name, size=TOTAL, seed=2)\n",
" count = 0\n",
" for item_1, item_2 in zip(data_1, data_2):\n",
" if item_1[\"question\"] == item_2[\"question\"]:\n",
" count += 1\n",
"\n",
" # Add name, count to collisions.txt\n",
" with open('collisions_1.txt', 'a') as file:\n",
" file.write(f\"{name}, {count}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73297817-0eda-4bb4-9850-1375b2fcfa4d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "86d9535a-73fd-4930-a2df-393efd73cb75",
"metadata": {},
"source": [
"# Report on collisions data generated"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d41efd7a-3d23-4433-98d5-617b5aa66f07",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "e732e125-e232-4e1a-bc66-1ee37405d6ff",
"metadata": {},
"outputs": [],
"source": [
"with open('collisions.txt', 'r') as file:\n",
" data = [line.strip().split(\",\") for line in file.readlines()]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "77a90340-f760-41d4-ab2a-0b14b6dbfc08",
"metadata": {},
"outputs": [],
"source": [
"# Clean data\n",
"data = [(name, collision.strip()) for name, collision in data]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "36054684-470c-4ea4-8a60-25efb8218926",
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame([[name, collision] for name, collision in data], columns=[\"name\", \"collisions\"])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "0d2e6752-8d04-46c9-8057-3e71a39819f9",
"metadata": {},
"outputs": [],
"source": [
"# Change collision to int\n",
"df[\"collisions\"] = df[\"collisions\"].astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a2b05dc7-4319-4709-ae57-b3a6637bc66a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>collisions</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>complex_arithmetic</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>intermediate_integration</td>\n",
" <td>12</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>polynomial_equations</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>polynomial_multiplication</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>simple_equations</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name collisions\n",
"0 complex_arithmetic 0\n",
"1 intermediate_integration 12\n",
"2 polynomial_equations 0\n",
"3 polynomial_multiplication 0\n",
"4 simple_equations 0"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f817273d-3875-4b7f-9f7e-a8ea2cb0b455",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of data: 86\n"
]
}
],
"source": [
"print(\"Total number of data: \", len(df))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "5c056fd2-2f05-443e-8ccc-17811d810852",
"metadata": {},
"outputs": [],
"source": [
"# Filter out non-zero collision entries\n",
"df_zero_collision = df[df[\"collisions\"] == 0]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "6c8df99f-9d20-4d06-852f-98e4576e1a1d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The number of data with zero collision: 66\n",
"Percentage of data with zero collision: 0.7674418604651163\n"
]
}
],
"source": [
"print(\"The number of data with zero collision: \", len(df_zero_collision))\n",
"print(\"Percentage of data with zero collision: \", len(df_zero_collision) / len(df))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "f09f6190-0436-4a76-9acf-0f775d0b8c2a",
"metadata": {},
"outputs": [],
"source": [
"# Filter out zero collision entries\n",
"df_non_zero_collision = df[df[\"collisions\"] > 0]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "9a260c44-e492-43e5-8e55-77b78457e142",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The number of data with non-zero collision: 20\n",
"Percentage of data with non-zero collision: 0.23255813953488372\n"
]
}
],
"source": [
"print(\"The number of data with non-zero collision: \", len(df_non_zero_collision))\n",
"print(\"Percentage of data with non-zero collision: \", len(df_non_zero_collision) / len(df))"
]
},
{
"cell_type": "markdown",
"id": "5a59e38b-5914-4dbd-b6fc-4be425f28d4b",
"metadata": {},
"source": [
"## Visualize datasets with collisions data"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "1906a0a9-264c-4b45-8def-e94ed69a7c44",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Sort values for better visualization\n",
"df_non_zero_collision = df_non_zero_collision.sort_values(by=\"collisions\", ascending=False)\n",
"\n",
"# Plot\n",
"plt.figure(figsize=(12, 6))\n",
"ax = sns.barplot(\n",
" y=\"name\", \n",
" x=\"collisions\", \n",
" hue=\"name\", # Assign hue to the y-variable\n",
" data=df_non_zero_collision, \n",
" palette=\"viridis\", \n",
" legend=False # Hide legend since hue is just for color mapping\n",
")\n",
"\n",
"# Annotate bars with their values\n",
"for index, value in enumerate(df_non_zero_collision[\"collisions\"]):\n",
" ax.text(value + 0.5, index, str(value), va='center', fontsize=10, color='black')\n",
"\n",
"# Labels and title\n",
"plt.xlabel(\"Collisions\")\n",
"plt.ylabel(\"Dataset Name\")\n",
"plt.title(\"Number of Collisions Per Reasoning-gym dataset\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f3d1b1f6-bcf6-4f70-9168-33b3179ae171",
"metadata": {},
"outputs": [],
"source": [
"# Dataset not yet done, \n",
"# futoshiki\n",
"# knight_swap\n",
"# mahjong_puzzle\n",
"# maze\n",
"# mini_sudoku\n",
"# n_queens\n",
"# puzzle24\n",
"# rush_hour\n",
"# sokoban\n",
"# sudoku\n",
"# tower_of_hanoi\n",
"# tsumego\n",
"# zebra_puzzles"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3c11508-0a4a-4578-82f5-3908c5a45604",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "51263e94-0867-4b2a-b205-ebafb316f811",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}