Files
LLMLingua/examples/CoT.ipynb
2023-11-01 02:56:51 +00:00

586 lines
24 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "1972a352-a0e3-41b7-81dc-dd4ae2b890c3",
"metadata": {},
"source": [
"## In-Context Learning, Chain-of-Thought, Reasoning"
]
},
{
"cell_type": "markdown",
"id": "05d999bc-83a3-454f-a8a4-44cbff1fcedc",
"metadata": {},
"source": [
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/microsoft/LLMLingua/blob/main/examples/CoT.ipynb\">\r\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\r\n",
"</a>"
]
},
{
"cell_type": "markdown",
"id": "fe3ed1ce-d38d-4048-9db6-9707b55dc642",
"metadata": {},
"source": [
"**In-Context Learning (ICL)** is a unique capability of large models, allowing Language Models (LLMs) to quickly learn relevant tasks from a few examples. Generally, ICL is used in combination with the Chain-of-Thought (CoT) approach, which involves describing the reasoning process in detail within the examples to enhance the LLMs' reasoning abilities. For instance, Yao et al.'s Complexity-Based Prompting improved GSM8K performance from 74.9 to 78.85 in GPT-3.5-Turbo-0301. However, this can also lead to increasingly lengthy prompts, such as the GSM8K prompt with a token count of **2,366**."
]
},
{
"cell_type": "markdown",
"id": "ae003ead-2f07-44a4-b641-2e33be920dd9",
"metadata": {},
"source": [
"<center><img width=\"800\" src=\"../images/LLMLingua_framework.png\"></center>"
]
},
{
"cell_type": "markdown",
"id": "0b39b33f-5860-4825-8f00-d60aed0dce86",
"metadata": {},
"source": [
"To address this, we propose [**LLMLingua**](https://arxiv.org/abs/2310.05736), that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to **20x** compression with minimal performance loss."
]
},
{
"cell_type": "markdown",
"id": "18422597-687a-43aa-a6ed-ce6244d0eb55",
"metadata": {},
"source": [
"### GSM8K"
]
},
{
"cell_type": "markdown",
"id": "51a7accd-5ec2-4ed2-9582-1afdb441a998",
"metadata": {},
"source": [
"Next, we will demonstrate the use of LLMLingua on the GSM8K dataset, which effectively alleviates the \"lost in the middle\" issue. The original dataset can be found at https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/lib_prompt/prompt_hardest.txt, which has 2,366 tokens and is an 8-shot setup."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a970a901-11bd-43af-a8bc-7fb2fc6a1a07",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: llmlingua in /home/hjiang/Code/github/LLMLingua (0.1.2)\n",
"Requirement already satisfied: datasets in /home/hjiang/.local/lib/python3.9/site-packages (2.14.4)\n",
"Requirement already satisfied: nltk in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (3.8.1)\n",
"Requirement already satisfied: numpy in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (1.23.5)\n",
"Requirement already satisfied: tiktoken in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (0.4.0)\n",
"Requirement already satisfied: torch in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (1.13.1+cu116)\n",
"Requirement already satisfied: transformers>=4.26.0 in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (4.34.1)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (11.0.0)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: pandas in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2.29.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (4.65.0)\n",
"Requirement already satisfied: xxhash in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.16.4)\n",
"Requirement already satisfied: packaging in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (23.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from datasets) (5.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (3.2.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (4.0.2)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: filelock in /home/hjiang/.local/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.2)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/hjiang/.local/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.7.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2.8)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/hjiang/.local/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2019.11.28)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (2023.6.3)\n",
"Requirement already satisfied: tokenizers<0.15,>=0.14 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (0.14.1)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (0.3.1)\n",
"Requirement already satisfied: click in /home/hjiang/.local/lib/python3.9/site-packages (from nltk->llmlingua) (8.1.6)\n",
"Requirement already satisfied: joblib in /home/hjiang/.local/lib/python3.9/site-packages (from nltk->llmlingua) (1.3.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2023.3)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.14.0)\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.9 -m pip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"# Install dependency.\n",
"!pip install llmlingua datasets"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "641235b6-71a5-4f2a-8eec-272c73931bef",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2023-10-30 09:15:31-- https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 8464 (8.3K) [text/plain]\n",
"Saving to: prompt_hardest.txt\n",
"\n",
"prompt_hardest.txt 100%[===================>] 8.27K --.-KB/s in 0s \n",
"\n",
"2023-10-30 09:15:31 (78.8 MB/s) - prompt_hardest.txt saved [8464/8464]\n",
"\n"
]
}
],
"source": [
"# Download the original prompt and dataset\n",
"from datasets import load_dataset\n",
"!wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
"prompt_complex = open(\"./prompt_hardest.txt\").read()\n",
"gsm8k = load_dataset(\"gsm8k\", \"main\")\n",
"gsm8k_test = gsm8k[\"test\"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cbbbf3de-a9d6-46cf-afab-dcb72a6154ec",
"metadata": {},
"outputs": [],
"source": [
"# Using the OAI\n",
"import openai\n",
"openai.api_key = \"<insert_openai_key>\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "46506810-8565-43da-984b-d862c56b49c2",
"metadata": {},
"outputs": [],
"source": [
"# or Using the AOAI\n",
"import openai\n",
"openai.api_key = \"<insert_openai_key>\"\n",
"openai.api_base = \"https://xxxx.openai.azure.com/\"\n",
"openai.api_type = 'azure'\n",
"openai.api_version = '2023-05-15'"
]
},
{
"cell_type": "markdown",
"id": "f8676ffa-5117-44dc-9742-bb9ab1d56e0c",
"metadata": {},
"source": [
"### Setup Data"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cc17bbc5-86cb-4d15-a730-955af85a10b2",
"metadata": {},
"outputs": [],
"source": [
"# select an example from GSM8K\n",
"question, answer = [gsm8k_test[2][key] for key in [\"question\", \"answer\"]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "58718a19-cc4e-4002-a92a-58ea3de9c9d0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\n",
"Answer: The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\n",
"He increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\n",
"So the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\n",
"So he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\n",
"#### 70000\n"
]
}
],
"source": [
"# Ground-truth Answer\n",
"print(\"Question:\", question)\n",
"print(\"Answer:\", answer)"
]
},
{
"cell_type": "markdown",
"id": "ba1c6d52-dc87-434c-a41c-0bbc8a286504",
"metadata": {},
"source": [
"#### The response of Original prompt"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "3d441f10-c5c7-4d45-b09a-717e536b36bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"id\": \"cmpl-8FZvcX70FH7ck9c9MegWmnUocH0A0\",\n",
" \"object\": \"text_completion\",\n",
" \"created\": 1698723720,\n",
" \"model\": \"gpt-35-turbo\",\n",
" \"choices\": [\n",
" {\n",
" \"text\": \" \\nLet's think step by step\\nThe repairs increased the value of the house by 150% so that means it increased by 80,000*1.5=$<<80000*1.5=120000>>120,000\\nSo the total value of the house is 80,000+120,000=$<<80000+120000=200000>>200,000\\nHe spent 80,000+50,000=$<<80000+50000=130000>>130,000\\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\\nThe answer is 70,000\",\n",
" \"index\": 0,\n",
" \"finish_reason\": \"stop\",\n",
" \"logprobs\": null\n",
" }\n",
" ],\n",
" \"usage\": {\n",
" \"prompt_tokens\": 2428,\n",
" \"completion_tokens\": 142,\n",
" \"total_tokens\": 2570\n",
" }\n",
"}\n"
]
}
],
"source": [
"# The response from original prompt\n",
"import json\n",
"instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
"prompt = instruction + prompt_complex + \"\\n\\nQuestion: \" + question\n",
"\n",
"request_data = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": 400,\n",
" \"temperature\": 0,\n",
" \"top_p\": 1,\n",
" \"n\": 1,\n",
" \"stream\": False,\n",
" \"stop\": \"\\n\\n\",\n",
"}\n",
"response = openai.Completion.create(\n",
" \"gpt-3.5-turbo-0301\",\n",
" **request_data,\n",
")\n",
"print(json.dumps(response, indent=4))"
]
},
{
"cell_type": "markdown",
"id": "9aa90492-8ad1-4a89-85c5-26b8472f1ff0",
"metadata": {},
"source": [
"#### The response of Compressed Prompt"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fa638dec-c9ec-4dce-9dac-d768145de714",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ec90053e7274da59973427652f879a1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hjiang/.local/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
" warnings.warn(\n",
"/home/hjiang/.local/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
" warnings.warn(\n"
]
}
],
"source": [
"# Setup LLMLingua\n",
"from llmlingua import PromptCompressor\n",
"llm_lingua = PromptCompressor()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "5f61a186-6641-4118-ad04-5245a53b6d79",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"compressed_prompt\": \"Question: Sam bought a dozen boxes, each with 30 highlighter pens inside, for $10 each. He reanged five of boxes into packages of sixlters each and sold them $3 per. He sold the rest theters separately at the of three pens $2. How much did make in total, dollars?\\nLets think step step\\nSam bought 1 boxes x00 oflters.\\nHe bought 12 00ters in total\\nSam then took5 boxes 6ters0ters\\nHe sold these boxes for 5 *5\\nAfterelling these boxes there were 30330ters remaining\\nese form 330 /30 of three\\n sold each for2 each, so made * =0 from\\n total, he0 $15\\nSince his original1 he earned $120 = $115 in profit.\\nThe answer is 115\",\n",
" \"origin_tokens\": 2365,\n",
" \"compressed_tokens\": 174,\n",
" \"ratio\": \"13.6x\",\n",
" \"saving\": \", Saving $0.1 in GPT-4.\"\n",
"}\n",
"Response: {\n",
" \"id\": \"cmpl-8FZwYp1QIwiQs6pEhy2cRK6wnLnAO\",\n",
" \"object\": \"text_completion\",\n",
" \"created\": 1698723778,\n",
" \"model\": \"gpt-35-turbo\",\n",
" \"choices\": [\n",
" {\n",
" \"text\": \" \\n\\nThe repairs increased the value of the house by 150% so that means it increased by 80000*1.5=$<<80000*1.5=120000>>120,000\\nSo the total value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\\nThat means he made a profit of 200,000-80,000-50,000=$<<200000-80000-50000=70000>>70,000. Answer: \\\\boxed{70,000}.<|im_end|>\",\n",
" \"index\": 0,\n",
" \"finish_reason\": \"stop\",\n",
" \"logprobs\": null\n",
" }\n",
" ],\n",
" \"usage\": {\n",
" \"prompt_tokens\": 237,\n",
" \"completion_tokens\": 120,\n",
" \"total_tokens\": 357\n",
" }\n",
"}\n"
]
}
],
"source": [
"# 174 tokens Compression, 13.6x\n",
"compressed_prompt = llm_lingua.compress_prompt(\n",
" prompt_complex.split(\"\\n\\n\"),\n",
" instruction=\"\",\n",
" question=\"\",\n",
" target_token=200,\n",
" context_budget=\"*1.5\",\n",
" iterative_size=100,\n",
")\n",
"\n",
"instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
"prompt = instruction + compressed_prompt[\"compressed_prompt\"] + \"\\n\\nQuestion: \" + question\n",
"\n",
"request_data = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": 400,\n",
" \"temperature\": 0,\n",
" \"top_p\": 1,\n",
" \"n\": 1,\n",
" \"stream\": False,\n",
" \"stop\": \"\\r\\n\",\n",
"}\n",
"response = openai.Completion.create(\n",
" \"gpt-3.5-turbo-0301\",\n",
" **request_data,\n",
")\n",
"print(json.dumps(compressed_prompt, indent=4))\n",
"print(\"Response:\", response)"
]
},
{
"cell_type": "markdown",
"id": "1f89bb0f-7959-4a14-95be-dc80d88ce576",
"metadata": {},
"source": [
"### Test in GSM8K test set"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "c1ac9bf5-23a9-446c-9394-8bb19aa1d89d",
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"\n",
"def extract_ans(ans_model):\n",
" ans_model = ans_model.split(\"\\n\")\n",
" ans = []\n",
" residual = []\n",
" for li, al in enumerate(ans_model):\n",
" ans.append(al)\n",
" if \"answer is\" in al:\n",
" break\n",
" residual = list(ans_model[li + 1 :])\n",
" ans = \"\\n\".join(ans)\n",
" residual = \"\\n\".join(residual)\n",
" return ans, residual\n",
"\n",
"def parse_pred_ans(filename):\n",
" with open(filename) as fd:\n",
" lines = fd.readlines()\n",
" am, a = None, None\n",
" num_q, acc = 0, 0\n",
" current_mode = \"none\"\n",
" questions = []\n",
" ans_pred = []\n",
" ans_gold = []\n",
" for l in lines:\n",
" l = l.replace(\",\", \"\")\n",
" if l.startswith(\"Q: \"):\n",
" if am is not None and a is not None:\n",
" questions.append(q)\n",
" ans_pred.append(am)\n",
" ans_gold.append(a)\n",
" if test_answer(am, a):\n",
" acc += 1\n",
" current_mode = \"q\"\n",
" q = l\n",
" num_q += 1\n",
" elif l.startswith(\"A_model:\"):\n",
" current_mode = \"am\"\n",
" am = l\n",
" elif l.startswith(\"A:\"):\n",
" current_mode = \"a\"\n",
" a = l\n",
" else:\n",
" if current_mode == \"q\":\n",
" q += l\n",
" elif current_mode == \"am\":\n",
" am += l\n",
" elif current_mode == \"a\":\n",
" a += l\n",
" else:\n",
" raise ValueError(current_mode)\n",
"\n",
" questions.append(q)\n",
" ans_pred.append(am)\n",
" ans_gold.append(a)\n",
" if test_answer(am, a):\n",
" acc += 1\n",
" print(\"num_q %d correct %d ratio %.4f\" % (num_q, acc, float(acc / num_q)))\n",
" return questions, ans_pred, ans_gold\n",
"\n",
"\n",
"def get_result(text: str):\n",
" pattern = \"\\d*\\.?\\d+\"\n",
" res = re.findall(pattern, text)\n",
" return res[-1] if res else \"\"\n",
"\n",
"\n",
"def test_answer(pred_str, ans_str):\n",
" pred, gold = get_result(pred_str), get_result(ans_str)\n",
" return pred == gold"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "cb209d5a-f822-4734-afc5-dafc07cc1bbc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1319/1319 [47:55<00:00, 2.18s/it] \n"
]
}
],
"source": [
"# Test in GSM8K test set\n",
"from tqdm import tqdm\n",
"import os\n",
"os.makedirs(\"outputs\", exist_ok=True)\n",
"i = 0\n",
"\n",
"compressed_prompt = llm_lingua.compress_prompt(\n",
" prompt_complex.split(\"\\n\\n\"),\n",
" instruction=\"\",\n",
" question=\"\",\n",
" target_token=200,\n",
" context_budget=\"*1.5\",\n",
" iterative_size=100,\n",
")\n",
"\n",
"for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
" total=len(gsm8k_test['question'])):\n",
" instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
" prompt = instruction + compressed_prompt[\"compressed_prompt\"] + \"\\n\\nQuestion: \" + q + \"\\n\"\n",
" \n",
" request_data = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": 400,\n",
" \"temperature\": 0,\n",
" \"top_p\": 1,\n",
" \"n\": 1,\n",
" \"stream\": False,\n",
" }\n",
" response = openai.Completion.create(\n",
" \"gpt-3.5-turbo-0301\",\n",
" **request_data,\n",
" )\n",
" ans_model = response[\"choices\"][0][\"text\"]\n",
" ans_, residual = extract_ans(ans_model)\n",
" with open('outputs/test_gpt_3.5_turbo_LLMLingua_174.txt', 'a') as fd:\n",
" fd.write(\"Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n\" % (q, ans_.replace(\"Q:\", \"\").replace(\"A:\", \"\"), a))\n",
" i += 1"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "3a35d298-8596-4b92-8dda-8da4250c873c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"num_q 1319 correct 1032 ratio 0.7824\n"
]
}
],
"source": [
"_ = parse_pred_ans(\"outputs/test_gpt_3.5_turbo_LLMLingua_174.txt\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}