mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
First version of CodeI/O reasoning data (#264)
* notebook for prepping first set of raw code files * updated codeio processing notebook for repo-level processing * fix for edge case in codeio scoring * Add reformat notebook * filtering pass * add non-determinism filtering * Tweak CodeIODataset & include first real data * add basic codeio test, metadata
This commit is contained in:
1
notebooks/codeio/.gitignore
vendored
1
notebooks/codeio/.gitignore
vendored
@@ -1 +1,2 @@
|
||||
raw_files/
|
||||
output/
|
||||
|
||||
File diff suppressed because one or more lines are too long
694
notebooks/codeio/ReformatAndFilter.ipynb
Normal file
694
notebooks/codeio/ReformatAndFilter.ipynb
Normal file
@@ -0,0 +1,694 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Reformat the output JSON & code from the preprocessing step in `notebooks/codeio/PreprocessCode.ipynb`.\n",
|
||||
"\n",
|
||||
"The output format will align with the data we extract from existing CodeI/O dataset, in `notebooks/codeio.ipynb`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"with open(Path(\"output/processed_code.jsonl\"), \"r\") as f:\n",
|
||||
" samples = [json.loads(line) for line in f]\n",
|
||||
"\n",
|
||||
"for sample in samples:\n",
|
||||
" main_code = sample[\"reference_code\"]\n",
|
||||
" del sample[\"reference_code\"]\n",
|
||||
" if \"def main(\" in main_code:\n",
|
||||
" main_code = main_code.replace(\"def main(\", \"def main_solution(\")\n",
|
||||
" sample[\"code_sample\"] = main_code\n",
|
||||
"\n",
|
||||
" input_generator = sample[\"input_generator\"]\n",
|
||||
" if \"def input_generator()\" in input_generator:\n",
|
||||
" input_generator = input_generator.replace(\"def input_generator()\", \"def generate_inputs(random: Random)\")\n",
|
||||
" if \"import random\" in input_generator:\n",
|
||||
" input_generator = input_generator.replace(\"import random\\n \", \"\").replace(\"import random\\n\", \"\")\n",
|
||||
" sample[\"input_generator\"] = input_generator\n",
|
||||
"\n",
|
||||
" sample[\"input_output_spec\"] = sample[\"parameters\"]\n",
|
||||
" del sample[\"parameters\"]\n",
|
||||
"\n",
|
||||
" sample[\"task_description\"] = sample[\"query\"]\n",
|
||||
" del sample[\"query\"]\n",
|
||||
"\n",
|
||||
"with open(Path(\"output/formatted_code.jsonl\"), \"w\") as f:\n",
|
||||
" for sample in samples:\n",
|
||||
" f.write(json.dumps(sample) + \"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we need to filter out unsuitable samples from the data. First we prioritise samples which are inherently random, reliant on external services (e.g. network requests), or whose input generators do not match the correct random usage requirements, as this could cause irreproducibility in RL training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Removing sample 6 due to bad input generator\n",
|
||||
"Removing sample 8 due to bad input generator\n",
|
||||
"Removing sample 28 due to bad input generator\n",
|
||||
"Removing sample 30 due to bad input generator\n",
|
||||
"Removing sample 39 due to bad main solution\n",
|
||||
"Removing sample 43 due to bad main solution\n",
|
||||
"Removing sample 47 due to bad main solution\n",
|
||||
"Removing sample 53 due to bad input generator\n",
|
||||
"Removing sample 59 due to bad input generator\n",
|
||||
"Removing sample 64 due to bad main solution\n",
|
||||
"Removing sample 87 due to bad main solution\n",
|
||||
"Removing sample 112 due to bad main solution\n",
|
||||
"Removing sample 116 due to bad main solution\n",
|
||||
"Removing sample 121 due to bad input generator\n",
|
||||
"Removing sample 141 due to bad main solution\n",
|
||||
"Removing sample 144 due to bad main solution\n",
|
||||
"Removing sample 150 due to bad main solution\n",
|
||||
"Removing sample 155 due to bad main solution\n",
|
||||
"Removing sample 159 due to bad main solution\n",
|
||||
"Removing sample 162 due to bad input generator\n",
|
||||
"Removing sample 168 due to bad input generator\n",
|
||||
"Removing sample 170 due to bad main solution\n",
|
||||
"Removing sample 189 due to bad input generator\n",
|
||||
"Removing sample 206 due to bad input generator\n",
|
||||
"Removing sample 236 due to bad main solution\n",
|
||||
"Removing sample 245 due to bad main solution\n",
|
||||
"Removing sample 253 due to bad main solution\n",
|
||||
"Removing sample 255 due to bad main solution\n",
|
||||
"Removing sample 279 due to bad main solution\n",
|
||||
"Removing sample 320 due to bad input generator\n",
|
||||
"Removing sample 324 due to bad main solution\n",
|
||||
"Removing sample 339 due to bad main solution\n",
|
||||
"Removing sample 346 due to bad main solution\n",
|
||||
"Removing sample 371 due to bad input generator\n",
|
||||
"Removing sample 372 due to bad input generator\n",
|
||||
"Removing sample 375 due to bad main solution\n",
|
||||
"Removing sample 390 due to bad input generator\n",
|
||||
"Removing sample 415 due to bad input generator\n",
|
||||
"Removing sample 422 due to bad input generator\n",
|
||||
"Removing sample 429 due to bad input generator\n",
|
||||
"Removing sample 434 due to bad main solution\n",
|
||||
"Removing sample 453 due to bad input generator\n",
|
||||
"Removing sample 461 due to bad main solution\n",
|
||||
"Removing sample 463 due to bad main solution\n",
|
||||
"Removing sample 465 due to bad main solution\n",
|
||||
"Removing sample 471 due to bad input generator\n",
|
||||
"Removing sample 475 due to bad input generator\n",
|
||||
"Removing sample 482 due to bad main solution\n",
|
||||
"Removing sample 500 due to bad main solution\n",
|
||||
"Removing sample 507 due to bad input generator\n",
|
||||
"Removing sample 508 due to bad input generator\n",
|
||||
"Removing sample 510 due to bad input generator\n",
|
||||
"Removing sample 516 due to bad main solution\n",
|
||||
"Removing sample 517 due to bad main solution\n",
|
||||
"Removing sample 529 due to bad input generator\n",
|
||||
"Removing sample 558 due to bad main solution\n",
|
||||
"Removing sample 570 due to bad main solution\n",
|
||||
"Removing sample 595 due to bad main solution\n",
|
||||
"Removing sample 596 due to bad input generator\n",
|
||||
"Removing sample 605 due to bad main solution\n",
|
||||
"Removing sample 622 due to bad main solution\n",
|
||||
"Removing sample 635 due to bad main solution\n",
|
||||
"Removing sample 639 due to bad main solution\n",
|
||||
"Removing sample 653 due to bad main solution\n",
|
||||
"Removing sample 662 due to bad input generator\n",
|
||||
"Removing sample 663 due to bad main solution\n",
|
||||
"Removing sample 678 due to bad input generator\n",
|
||||
"Removing sample 686 due to bad input generator\n",
|
||||
"Removing sample 687 due to bad main solution\n",
|
||||
"Removing sample 704 due to bad main solution\n",
|
||||
"Removing sample 737 due to bad main solution\n",
|
||||
"Removing sample 773 due to bad main solution\n",
|
||||
"Removing sample 778 due to bad input generator\n",
|
||||
"Removing sample 793 due to bad input generator\n",
|
||||
"Removing sample 798 due to bad main solution\n",
|
||||
"Removing sample 819 due to bad main solution\n",
|
||||
"Removing sample 823 due to bad input generator\n",
|
||||
"Removing sample 834 due to bad main solution\n",
|
||||
"Removing sample 840 due to bad main solution\n",
|
||||
"Removing sample 844 due to bad input generator\n",
|
||||
"Removing sample 861 due to bad input generator\n",
|
||||
"Removed 81 samples\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def verify_input_generator(input_generator_code):\n",
|
||||
" if \"def generate_inputs(random: Random)\" not in input_generator_code and \"def generate_inputs(rng: Random)\" not in input_generator_code:\n",
|
||||
" return False\n",
|
||||
" if \"import numpy\" in input_generator_code or \"np.random\" in input_generator_code:\n",
|
||||
" return False\n",
|
||||
" if \"import random\" in input_generator_code:\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
"def verify_main_solution(main_solution_code):\n",
|
||||
" if \"def main_solution(\" not in main_solution_code:\n",
|
||||
" return False\n",
|
||||
" if \"import random\" in main_solution_code:\n",
|
||||
" return False\n",
|
||||
" if \"from random import\" in main_solution_code:\n",
|
||||
" return False\n",
|
||||
" if \"np.random\" in main_solution_code:\n",
|
||||
" return False\n",
|
||||
" if \"import requests\" in main_solution_code or \" requests.\" in main_solution_code or \"from requests import\" in main_solution_code:\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
"remove = set()\n",
|
||||
"for i, sample in enumerate(samples):\n",
|
||||
" if not verify_input_generator(sample[\"input_generator\"]):\n",
|
||||
" remove.add(i)\n",
|
||||
" print(f\"Removing sample {i} due to bad input generator\")\n",
|
||||
" elif not verify_main_solution(sample[\"code_sample\"]):\n",
|
||||
" remove.add(i)\n",
|
||||
" print(f\"Removing sample {i} due to bad main solution\")\n",
|
||||
"\n",
|
||||
"removed_samples = [sample for i, sample in enumerate(samples) if i in remove]\n",
|
||||
"samples = [sample for i, sample in enumerate(samples) if i not in remove]\n",
|
||||
"print(f\"Removed {len(remove)} samples\")\n",
|
||||
"\n",
|
||||
"with open(Path(\"output/filtered_code.jsonl\"), \"w\") as f:\n",
|
||||
" for sample in samples:\n",
|
||||
" f.write(json.dumps(sample) + \"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'def generate_inputs(random: Random):\\n import numpy as np\\n \\n height = random.randint(10, 20)\\n width = random.randint(10, 20)\\n image0 = np.random.rand(height, width)\\n image1 = np.random.rand(height, width)\\n num_iter = random.randint(10, 100)\\n alpha = random.uniform(0.01, 1.0) if random.choice([True, False]) else None\\n\\n return {\"image0\": image0, \"image1\": image1, \"num_iter\": num_iter, \"alpha\": alpha}'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"removed_samples[0][\"input_generator\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'def main_solution(search_terms):\\n import requests\\n from bs4 import BeautifulSoup\\n from fake_useragent import UserAgent\\n import webbrowser\\n\\n url = \"https://www.google.com/search?q=\" + \" \".join(search_terms)\\n res = requests.get(url, headers={\"UserAgent\": UserAgent().random}, timeout=10)\\n soup = BeautifulSoup(res.text, \"html.parser\")\\n links = list(soup.select(\".eZt8xd\"))[:5]\\n\\n opened_links = []\\n for link in links:\\n if link.text == \"Maps\":\\n opened_links.append(link.get(\"href\"))\\n webbrowser.open(link.get(\"href\"))\\n else:\\n opened_links.append(f\"https://google.com{link.get(\\'href\\')}\")\\n webbrowser.open(f\"https://google.com{link.get(\\'href\\')}\")\\n\\n return {\"opened_links\": opened_links}'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"removed_samples[43][\"code_sample\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dotenv import load_dotenv\n",
|
||||
"load_dotenv()\n",
|
||||
"import asyncio\n",
|
||||
"import os\n",
|
||||
"from openai import AsyncOpenAI\n",
|
||||
"from openai.types.chat import ChatCompletion, ChatCompletionMessageParam\n",
|
||||
"from typing import Any, Iterable\n",
|
||||
"\n",
|
||||
"VERIFY_PROMPT = \"\"\"\n",
|
||||
"Given the following code snippet, you must verify whether it is deterministic.\n",
|
||||
"\n",
|
||||
"It is not deterministic if it utilises potentially non-deterministic functions such as random number generators, network requests, or time functions. It also qualifies as non-deterministic if it calls another function or library which in turn produces non-deterministic outputs.\n",
|
||||
"\n",
|
||||
"Code snippet:\n",
|
||||
"\n",
|
||||
"{0}\n",
|
||||
"\n",
|
||||
"If the function is deterministic, return True. Otherwise, return False. Respond only with this one work, no other content or explanation.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# Cap concurrent requests. I had to set this to 1 for the DeepSeek API to work, YMMV\n",
|
||||
"semaphore = asyncio.Semaphore(1)\n",
|
||||
"\n",
|
||||
"async def llm_generate(\n",
|
||||
" client: AsyncOpenAI,\n",
|
||||
" messages: Iterable[ChatCompletionMessageParam],\n",
|
||||
" sampling_params: dict[str, Any],\n",
|
||||
" retry_empty_response: bool = True,\n",
|
||||
" max_retries: int = 3,\n",
|
||||
") -> ChatCompletion:\n",
|
||||
" for trial in range(max_retries):\n",
|
||||
" async with semaphore:\n",
|
||||
" try:\n",
|
||||
" completion = await client.chat.completions.create(\n",
|
||||
" messages=messages, **sampling_params\n",
|
||||
" )\n",
|
||||
" if completion.choices[0].message.content or not retry_empty_response:\n",
|
||||
" return completion\n",
|
||||
" await asyncio.sleep(5)\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Failure response (trial {trial}):\", e)\n",
|
||||
" await asyncio.sleep(3 * (trial + 1))\n",
|
||||
" if trial == max_retries - 1:\n",
|
||||
" raise\n",
|
||||
"\n",
|
||||
"client = AsyncOpenAI(\n",
|
||||
" base_url=os.getenv(\"API_BASE_URL\"),\n",
|
||||
" api_key=os.getenv(\"API_KEY\"),\n",
|
||||
" timeout=120.0,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"sampling_params = {\n",
|
||||
" \"model\": \"deepseek-chat\", # For DeepSeek API\n",
|
||||
" #\"model\": \"deepseek/deepseek-chat:free\", # For OpenRouter\n",
|
||||
" \"max_tokens\": 8192,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"33it [04:49, 8.14s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 32 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"58it [08:49, 9.66s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 57 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"147it [23:40, 12.39s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 146 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"152it [24:19, 8.55s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 151 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"158it [25:30, 10.53s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 157 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"172it [27:33, 7.87s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 171 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"173it [27:47, 9.64s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 172 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"231it [37:31, 9.87s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 230 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"285it [48:06, 10.91s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 284 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"343it [58:49, 15.48s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 342 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"363it [1:02:19, 11.92s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 362 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"374it [1:04:16, 11.96s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 373 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"394it [1:07:47, 11.56s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 393 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"429it [1:14:50, 11.54s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 428 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"451it [1:19:16, 12.64s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 450 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"555it [1:40:31, 9.80s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 554 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"603it [1:48:46, 9.54s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 602 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"634it [1:53:27, 10.77s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 633 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"638it [1:53:59, 8.85s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 637 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"685it [2:01:43, 10.44s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 684 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"689it [2:02:21, 9.03s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 688 is non-deterministic\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"782it [2:19:05, 10.67s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Removed 81 samples\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"remove_nondeterministic = set()\n",
|
||||
"for i, sample in tqdm(enumerate(samples)):\n",
|
||||
" messages = [\n",
|
||||
" {\"role\": \"user\", \"content\": VERIFY_PROMPT.format(sample[\"code_sample\"])},\n",
|
||||
" ]\n",
|
||||
" completion = await llm_generate(client, messages, sampling_params)\n",
|
||||
" content = completion.choices[0].message.content\n",
|
||||
" if not content or content.strip() not in [\"True\", \"False\"]:\n",
|
||||
" print(f\"Sample {i} failed to verify\")\n",
|
||||
" print(content)\n",
|
||||
" elif content.strip() == \"False\":\n",
|
||||
" print(f\"Sample {i} is non-deterministic\")\n",
|
||||
" remove_nondeterministic.add(i)\n",
|
||||
"\n",
|
||||
"removed_samples = [sample for i, sample in enumerate(samples) if i in remove]\n",
|
||||
"samples = [sample for i, sample in enumerate(samples) if i not in remove]\n",
|
||||
"print(f\"Removed {len(remove)} samples\")\n",
|
||||
"\n",
|
||||
"with open(Path(\"output/filtered_code_2.jsonl\"), \"w\") as f:\n",
|
||||
" for sample in samples:\n",
|
||||
" f.write(json.dumps(sample) + \"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'def main_solution(message, word_percentage=20, letter_percentage=85):\\n ENGLISH_WORDS = {}\\n with open(\"dictionary.txt\") as dictionary_file:\\n for word in dictionary_file.read().split(\"\\\\n\"):\\n ENGLISH_WORDS[word] = None\\n\\n def remove_non_letters(message):\\n return \"\".join(symbol for symbol in message if symbol in ascii_letters + \" \\\\t\\\\n\")\\n\\n def get_english_count(message):\\n message = message.upper()\\n message = remove_non_letters(message)\\n possible_words = message.split()\\n matches = len([word for word in possible_words if word in ENGLISH_WORDS])\\n return float(matches) / len(possible_words)\\n\\n words_match = get_english_count(message) * 100 >= word_percentage\\n num_letters = len(remove_non_letters(message))\\n message_letters_percentage = (float(num_letters) / len(message)) * 100\\n letters_match = message_letters_percentage >= letter_percentage\\n is_english = words_match and letters_match\\n\\n return {\"is_english\": is_english}'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"removed_samples[0][\"code_sample\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note: following the above steps, two further filtering steps were taken:\n",
|
||||
"\n",
|
||||
"- manually review every code snippet for security issues, dependencies on libraries, or non-determinism missed by the LLM classification\n",
|
||||
"- run every code snippet and input generator 100 times, dropping any which caused an error"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import gzip
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -79,16 +78,19 @@ class CodeIODataset(ProceduralDataset):
|
||||
with gzip.open(self._data_path, "rt", encoding="utf-8") as f:
|
||||
CodeIODataset._jsonl_data = [json.loads(line) for line in f]
|
||||
|
||||
def _generate_io_pairs(self, main_code: str, input_generator_code: str, num_pairs: int = 1):
|
||||
def _generate_io_pair(self, main_code: str, input_generator_code: str, rng: Random, max_retries: int = 3):
|
||||
local_vars = {}
|
||||
exec(main_code, {}, local_vars)
|
||||
exec(input_generator_code, {}, local_vars)
|
||||
io_pairs = []
|
||||
for _ in range(num_pairs):
|
||||
inputs = local_vars["input_generator"]()
|
||||
outputs = local_vars["main"](**inputs)
|
||||
io_pairs.append((inputs, outputs))
|
||||
return io_pairs
|
||||
exec(main_code, {"Random": Random}, local_vars)
|
||||
exec(input_generator_code, {"Random": Random}, local_vars)
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
inputs = local_vars["generate_inputs"](rng)
|
||||
outputs = local_vars["main_solution"](**inputs)
|
||||
except Exception:
|
||||
# Retry
|
||||
continue
|
||||
return inputs, outputs
|
||||
return {}, {}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single CodeI/O reasoning task"""
|
||||
@@ -96,12 +98,12 @@ class CodeIODataset(ProceduralDataset):
|
||||
|
||||
json_data = rng.choice(CodeIODataset._jsonl_data)
|
||||
|
||||
query = json_data["query"]
|
||||
parameters = json_data["parameters"]
|
||||
reference_code = json_data["reference_code"]
|
||||
query = json_data["task_description"]
|
||||
parameters = json_data["input_output_spec"]
|
||||
reference_code = json_data["code_sample"]
|
||||
input_generator_code = json_data["input_generator"]
|
||||
|
||||
input_data, output_data = self._generate_io_pairs(reference_code, input_generator_code, num_pairs=1)[0]
|
||||
input_data, output_data = self._generate_io_pair(reference_code, input_generator_code, rng)
|
||||
|
||||
if rng.random() < self.config.input_prediction_probability:
|
||||
question = OUTPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, input_data, reference_code)
|
||||
@@ -113,7 +115,7 @@ class CodeIODataset(ProceduralDataset):
|
||||
return {
|
||||
"question": question,
|
||||
"answer": solution,
|
||||
"metadata": {},
|
||||
"metadata": {"input_data": input_data, "output_data": output_data},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
@@ -142,15 +144,17 @@ class CodeIODataset(ProceduralDataset):
|
||||
reward = 0.1
|
||||
else:
|
||||
# At least we got a JSON object, I guess?
|
||||
reward = 0.05
|
||||
reward = 0.01
|
||||
except json.JSONDecodeError:
|
||||
if oracle_answer in answer:
|
||||
reward = len(oracle_answer) / len(answer)
|
||||
else:
|
||||
reward = 0.00
|
||||
elif oracle_answer in answer:
|
||||
# max() to avoid penalising too heavily, since correct answers are short here
|
||||
reward = max(len(oracle_answer) / len(answer), 0.2)
|
||||
else:
|
||||
reward = 0.01
|
||||
reward = 0.00
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
Binary file not shown.
42
tests/test_codeio.py
Normal file
42
tests/test_codeio.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.code.codeio import CodeIOConfig, CodeIODataset
|
||||
|
||||
|
||||
def test_codeio_dataset():
|
||||
# Create a small CodeI/O reasoning dataset
|
||||
config = CodeIOConfig(size=10, seed=42)
|
||||
dataset = CodeIODataset(config)
|
||||
|
||||
for i in range(10):
|
||||
item = dataset[i]
|
||||
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
assert "input_data" in item["metadata"]
|
||||
assert "output_data" in item["metadata"]
|
||||
|
||||
# Score some correct and incorrect answers
|
||||
score = dataset.score_answer(answer=item["answer"], entry=item)
|
||||
assert score == 1.0
|
||||
# Incorrect answer (None)
|
||||
score = dataset.score_answer(answer=None, entry=item)
|
||||
assert score == 0.00
|
||||
# Incorrect answer (empty dict)
|
||||
score = dataset.score_answer(answer="{}", entry=item)
|
||||
assert score == 0.01
|
||||
|
||||
|
||||
def test_codeio_config():
|
||||
# Test constraints on input probability
|
||||
with pytest.raises(AssertionError):
|
||||
CodeIOConfig(size=10, seed=42, input_prediction_probability=1.1).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
CodeIOConfig(size=10, seed=42, input_prediction_probability=-0.1).validate()
|
||||
|
||||
CodeIOConfig(size=10, seed=42, input_prediction_probability=0.1).validate()
|
||||
CodeIOConfig(size=10, seed=42, input_prediction_probability=0.9).validate()
|
||||
Reference in New Issue
Block a user