mirror of
https://github.com/microsoft/LLMLingua.git
synced 2024-01-23 02:05:46 +03:00
586 lines
24 KiB
Plaintext
586 lines
24 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1972a352-a0e3-41b7-81dc-dd4ae2b890c3",
|
||
"metadata": {},
|
||
"source": [
|
||
"## In-Context Learning, Chain-of-Thought, Reasoning"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "05d999bc-83a3-454f-a8a4-44cbff1fcedc",
|
||
"metadata": {},
|
||
"source": [
|
||
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/microsoft/LLMLingua/blob/main/examples/CoT.ipynb\">\r\n",
|
||
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\r\n",
|
||
"</a>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fe3ed1ce-d38d-4048-9db6-9707b55dc642",
|
||
"metadata": {},
|
||
"source": [
|
||
"**In-Context Learning (ICL)** is a unique capability of large models, allowing Language Models (LLMs) to quickly learn relevant tasks from a few examples. Generally, ICL is used in combination with the Chain-of-Thought (CoT) approach, which involves describing the reasoning process in detail within the examples to enhance the LLMs' reasoning abilities. For instance, Yao et al.'s Complexity-Based Prompting improved GSM8K performance from 74.9 to 78.85 in GPT-3.5-Turbo-0301. However, this can also lead to increasingly lengthy prompts, such as the GSM8K prompt with a token count of **2,366**."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ae003ead-2f07-44a4-b641-2e33be920dd9",
|
||
"metadata": {},
|
||
"source": [
|
||
"<center><img width=\"800\" src=\"../images/LLMLingua_framework.png\"></center>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0b39b33f-5860-4825-8f00-d60aed0dce86",
|
||
"metadata": {},
|
||
"source": [
|
||
"To address this, we propose [**LLMLingua**](https://arxiv.org/abs/2310.05736), that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to **20x** compression with minimal performance loss."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "18422597-687a-43aa-a6ed-ce6244d0eb55",
|
||
"metadata": {},
|
||
"source": [
|
||
"### GSM8K"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "51a7accd-5ec2-4ed2-9582-1afdb441a998",
|
||
"metadata": {},
|
||
"source": [
|
||
"Next, we will demonstrate the use of LLMLingua on the GSM8K dataset, which effectively alleviates the \"lost in the middle\" issue. The original dataset can be found at https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/lib_prompt/prompt_hardest.txt, which has 2,366 tokens and is an 8-shot setup."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "a970a901-11bd-43af-a8bc-7fb2fc6a1a07",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Defaulting to user installation because normal site-packages is not writeable\n",
|
||
"Requirement already satisfied: llmlingua in /home/hjiang/Code/github/LLMLingua (0.1.2)\n",
|
||
"Requirement already satisfied: datasets in /home/hjiang/.local/lib/python3.9/site-packages (2.14.4)\n",
|
||
"Requirement already satisfied: nltk in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (3.8.1)\n",
|
||
"Requirement already satisfied: numpy in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (1.23.5)\n",
|
||
"Requirement already satisfied: tiktoken in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (0.4.0)\n",
|
||
"Requirement already satisfied: torch in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (1.13.1+cu116)\n",
|
||
"Requirement already satisfied: transformers>=4.26.0 in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (4.34.1)\n",
|
||
"Requirement already satisfied: pyarrow>=8.0.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (11.0.0)\n",
|
||
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.3.7)\n",
|
||
"Requirement already satisfied: pandas in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2.0.3)\n",
|
||
"Requirement already satisfied: requests>=2.19.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2.29.0)\n",
|
||
"Requirement already satisfied: tqdm>=4.62.1 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (4.65.0)\n",
|
||
"Requirement already satisfied: xxhash in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (3.3.0)\n",
|
||
"Requirement already satisfied: multiprocess in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.70.15)\n",
|
||
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2023.6.0)\n",
|
||
"Requirement already satisfied: aiohttp in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (3.8.5)\n",
|
||
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.16.4)\n",
|
||
"Requirement already satisfied: packaging in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (23.0)\n",
|
||
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from datasets) (5.3.1)\n",
|
||
"Requirement already satisfied: attrs>=17.3.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
||
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (3.2.0)\n",
|
||
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
||
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (4.0.2)\n",
|
||
"Requirement already satisfied: yarl<2.0,>=1.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.9.2)\n",
|
||
"Requirement already satisfied: frozenlist>=1.1.1 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.4.0)\n",
|
||
"Requirement already satisfied: aiosignal>=1.1.2 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
||
"Requirement already satisfied: filelock in /home/hjiang/.local/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.2)\n",
|
||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/hjiang/.local/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.7.1)\n",
|
||
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2.8)\n",
|
||
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/hjiang/.local/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (1.26.16)\n",
|
||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2019.11.28)\n",
|
||
"Requirement already satisfied: regex!=2019.12.17 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (2023.6.3)\n",
|
||
"Requirement already satisfied: tokenizers<0.15,>=0.14 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (0.14.1)\n",
|
||
"Requirement already satisfied: safetensors>=0.3.1 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (0.3.1)\n",
|
||
"Requirement already satisfied: click in /home/hjiang/.local/lib/python3.9/site-packages (from nltk->llmlingua) (8.1.6)\n",
|
||
"Requirement already satisfied: joblib in /home/hjiang/.local/lib/python3.9/site-packages (from nltk->llmlingua) (1.3.1)\n",
|
||
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2.8.2)\n",
|
||
"Requirement already satisfied: pytz>=2020.1 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2023.3)\n",
|
||
"Requirement already satisfied: tzdata>=2022.1 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2023.3)\n",
|
||
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.14.0)\n",
|
||
"\n",
|
||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.1\u001b[0m\n",
|
||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.9 -m pip install --upgrade pip\u001b[0m\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Install dependency.\n",
|
||
"!pip install llmlingua datasets"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "641235b6-71a5-4f2a-8eec-272c73931bef",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"--2023-10-30 09:15:31-- https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
|
||
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.110.133, ...\n",
|
||
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
|
||
"HTTP request sent, awaiting response... 200 OK\n",
|
||
"Length: 8464 (8.3K) [text/plain]\n",
|
||
"Saving to: ‘prompt_hardest.txt’\n",
|
||
"\n",
|
||
"prompt_hardest.txt 100%[===================>] 8.27K --.-KB/s in 0s \n",
|
||
"\n",
|
||
"2023-10-30 09:15:31 (78.8 MB/s) - ‘prompt_hardest.txt’ saved [8464/8464]\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Download the original prompt and dataset\n",
|
||
"from datasets import load_dataset\n",
|
||
"!wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
|
||
"prompt_complex = open(\"./prompt_hardest.txt\").read()\n",
|
||
"gsm8k = load_dataset(\"gsm8k\", \"main\")\n",
|
||
"gsm8k_test = gsm8k[\"test\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "cbbbf3de-a9d6-46cf-afab-dcb72a6154ec",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Using the OAI\n",
|
||
"import openai\n",
|
||
"openai.api_key = \"<insert_openai_key>\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "46506810-8565-43da-984b-d862c56b49c2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# or Using the AOAI\n",
|
||
"import openai\n",
|
||
"openai.api_key = \"<insert_openai_key>\"\n",
|
||
"openai.api_base = \"https://xxxx.openai.azure.com/\"\n",
|
||
"openai.api_type = 'azure'\n",
|
||
"openai.api_version = '2023-05-15'"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f8676ffa-5117-44dc-9742-bb9ab1d56e0c",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Setup Data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "cc17bbc5-86cb-4d15-a730-955af85a10b2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# select an example from GSM8K\n",
|
||
"question, answer = [gsm8k_test[2][key] for key in [\"question\", \"answer\"]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "58718a19-cc4e-4002-a92a-58ea3de9c9d0",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Question: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\n",
|
||
"Answer: The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\n",
|
||
"He increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\n",
|
||
"So the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\n",
|
||
"So he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\n",
|
||
"#### 70000\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Ground-truth Answer\n",
|
||
"print(\"Question:\", question)\n",
|
||
"print(\"Answer:\", answer)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ba1c6d52-dc87-434c-a41c-0bbc8a286504",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### The response of Original prompt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"id": "3d441f10-c5c7-4d45-b09a-717e536b36bf",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{\n",
|
||
" \"id\": \"cmpl-8FZvcX70FH7ck9c9MegWmnUocH0A0\",\n",
|
||
" \"object\": \"text_completion\",\n",
|
||
" \"created\": 1698723720,\n",
|
||
" \"model\": \"gpt-35-turbo\",\n",
|
||
" \"choices\": [\n",
|
||
" {\n",
|
||
" \"text\": \" \\nLet's think step by step\\nThe repairs increased the value of the house by 150% so that means it increased by 80,000*1.5=$<<80000*1.5=120000>>120,000\\nSo the total value of the house is 80,000+120,000=$<<80000+120000=200000>>200,000\\nHe spent 80,000+50,000=$<<80000+50000=130000>>130,000\\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\\nThe answer is 70,000\",\n",
|
||
" \"index\": 0,\n",
|
||
" \"finish_reason\": \"stop\",\n",
|
||
" \"logprobs\": null\n",
|
||
" }\n",
|
||
" ],\n",
|
||
" \"usage\": {\n",
|
||
" \"prompt_tokens\": 2428,\n",
|
||
" \"completion_tokens\": 142,\n",
|
||
" \"total_tokens\": 2570\n",
|
||
" }\n",
|
||
"}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# The response from original prompt\n",
|
||
"import json\n",
|
||
"instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
|
||
"prompt = instruction + prompt_complex + \"\\n\\nQuestion: \" + question\n",
|
||
"\n",
|
||
"request_data = {\n",
|
||
" \"prompt\": prompt,\n",
|
||
" \"max_tokens\": 400,\n",
|
||
" \"temperature\": 0,\n",
|
||
" \"top_p\": 1,\n",
|
||
" \"n\": 1,\n",
|
||
" \"stream\": False,\n",
|
||
" \"stop\": \"\\n\\n\",\n",
|
||
"}\n",
|
||
"response = openai.Completion.create(\n",
|
||
" \"gpt-3.5-turbo-0301\",\n",
|
||
" **request_data,\n",
|
||
")\n",
|
||
"print(json.dumps(response, indent=4))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9aa90492-8ad1-4a89-85c5-26b8472f1ff0",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### The response of Compressed Prompt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "fa638dec-c9ec-4dce-9dac-d768145de714",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "8ec90053e7274da59973427652f879a1",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/home/hjiang/.local/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
|
||
" warnings.warn(\n",
|
||
"/home/hjiang/.local/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
|
||
" warnings.warn(\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Setup LLMLingua\n",
|
||
"from llmlingua import PromptCompressor\n",
|
||
"llm_lingua = PromptCompressor()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"id": "5f61a186-6641-4118-ad04-5245a53b6d79",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{\n",
|
||
" \"compressed_prompt\": \"Question: Sam bought a dozen boxes, each with 30 highlighter pens inside, for $10 each. He reanged five of boxes into packages of sixlters each and sold them $3 per. He sold the rest theters separately at the of three pens $2. How much did make in total, dollars?\\nLets think step step\\nSam bought 1 boxes x00 oflters.\\nHe bought 12 00ters in total\\nSam then took5 boxes 6ters0ters\\nHe sold these boxes for 5 *5\\nAfterelling these boxes there were 30330ters remaining\\nese form 330 /30 of three\\n sold each for2 each, so made * =0 from\\n total, he0 $15\\nSince his original1 he earned $120 = $115 in profit.\\nThe answer is 115\",\n",
|
||
" \"origin_tokens\": 2365,\n",
|
||
" \"compressed_tokens\": 174,\n",
|
||
" \"ratio\": \"13.6x\",\n",
|
||
" \"saving\": \", Saving $0.1 in GPT-4.\"\n",
|
||
"}\n",
|
||
"Response: {\n",
|
||
" \"id\": \"cmpl-8FZwYp1QIwiQs6pEhy2cRK6wnLnAO\",\n",
|
||
" \"object\": \"text_completion\",\n",
|
||
" \"created\": 1698723778,\n",
|
||
" \"model\": \"gpt-35-turbo\",\n",
|
||
" \"choices\": [\n",
|
||
" {\n",
|
||
" \"text\": \" \\n\\nThe repairs increased the value of the house by 150% so that means it increased by 80000*1.5=$<<80000*1.5=120000>>120,000\\nSo the total value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\\nThat means he made a profit of 200,000-80,000-50,000=$<<200000-80000-50000=70000>>70,000. Answer: \\\\boxed{70,000}.<|im_end|>\",\n",
|
||
" \"index\": 0,\n",
|
||
" \"finish_reason\": \"stop\",\n",
|
||
" \"logprobs\": null\n",
|
||
" }\n",
|
||
" ],\n",
|
||
" \"usage\": {\n",
|
||
" \"prompt_tokens\": 237,\n",
|
||
" \"completion_tokens\": 120,\n",
|
||
" \"total_tokens\": 357\n",
|
||
" }\n",
|
||
"}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 174 tokens Compression, 13.6x\n",
|
||
"compressed_prompt = llm_lingua.compress_prompt(\n",
|
||
" prompt_complex.split(\"\\n\\n\"),\n",
|
||
" instruction=\"\",\n",
|
||
" question=\"\",\n",
|
||
" target_token=200,\n",
|
||
" context_budget=\"*1.5\",\n",
|
||
" iterative_size=100,\n",
|
||
")\n",
|
||
"\n",
|
||
"instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
|
||
"prompt = instruction + compressed_prompt[\"compressed_prompt\"] + \"\\n\\nQuestion: \" + question\n",
|
||
"\n",
|
||
"request_data = {\n",
|
||
" \"prompt\": prompt,\n",
|
||
" \"max_tokens\": 400,\n",
|
||
" \"temperature\": 0,\n",
|
||
" \"top_p\": 1,\n",
|
||
" \"n\": 1,\n",
|
||
" \"stream\": False,\n",
|
||
" \"stop\": \"\\r\\n\",\n",
|
||
"}\n",
|
||
"response = openai.Completion.create(\n",
|
||
" \"gpt-3.5-turbo-0301\",\n",
|
||
" **request_data,\n",
|
||
")\n",
|
||
"print(json.dumps(compressed_prompt, indent=4))\n",
|
||
"print(\"Response:\", response)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1f89bb0f-7959-4a14-95be-dc80d88ce576",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Test in GSM8K test set"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"id": "c1ac9bf5-23a9-446c-9394-8bb19aa1d89d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import re\n",
|
||
"\n",
|
||
"def extract_ans(ans_model):\n",
|
||
" ans_model = ans_model.split(\"\\n\")\n",
|
||
" ans = []\n",
|
||
" residual = []\n",
|
||
" for li, al in enumerate(ans_model):\n",
|
||
" ans.append(al)\n",
|
||
" if \"answer is\" in al:\n",
|
||
" break\n",
|
||
" residual = list(ans_model[li + 1 :])\n",
|
||
" ans = \"\\n\".join(ans)\n",
|
||
" residual = \"\\n\".join(residual)\n",
|
||
" return ans, residual\n",
|
||
"\n",
|
||
"def parse_pred_ans(filename):\n",
|
||
" with open(filename) as fd:\n",
|
||
" lines = fd.readlines()\n",
|
||
" am, a = None, None\n",
|
||
" num_q, acc = 0, 0\n",
|
||
" current_mode = \"none\"\n",
|
||
" questions = []\n",
|
||
" ans_pred = []\n",
|
||
" ans_gold = []\n",
|
||
" for l in lines:\n",
|
||
" l = l.replace(\",\", \"\")\n",
|
||
" if l.startswith(\"Q: \"):\n",
|
||
" if am is not None and a is not None:\n",
|
||
" questions.append(q)\n",
|
||
" ans_pred.append(am)\n",
|
||
" ans_gold.append(a)\n",
|
||
" if test_answer(am, a):\n",
|
||
" acc += 1\n",
|
||
" current_mode = \"q\"\n",
|
||
" q = l\n",
|
||
" num_q += 1\n",
|
||
" elif l.startswith(\"A_model:\"):\n",
|
||
" current_mode = \"am\"\n",
|
||
" am = l\n",
|
||
" elif l.startswith(\"A:\"):\n",
|
||
" current_mode = \"a\"\n",
|
||
" a = l\n",
|
||
" else:\n",
|
||
" if current_mode == \"q\":\n",
|
||
" q += l\n",
|
||
" elif current_mode == \"am\":\n",
|
||
" am += l\n",
|
||
" elif current_mode == \"a\":\n",
|
||
" a += l\n",
|
||
" else:\n",
|
||
" raise ValueError(current_mode)\n",
|
||
"\n",
|
||
" questions.append(q)\n",
|
||
" ans_pred.append(am)\n",
|
||
" ans_gold.append(a)\n",
|
||
" if test_answer(am, a):\n",
|
||
" acc += 1\n",
|
||
" print(\"num_q %d correct %d ratio %.4f\" % (num_q, acc, float(acc / num_q)))\n",
|
||
" return questions, ans_pred, ans_gold\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_result(text: str):\n",
|
||
" pattern = \"\\d*\\.?\\d+\"\n",
|
||
" res = re.findall(pattern, text)\n",
|
||
" return res[-1] if res else \"\"\n",
|
||
"\n",
|
||
"\n",
|
||
"def test_answer(pred_str, ans_str):\n",
|
||
" pred, gold = get_result(pred_str), get_result(ans_str)\n",
|
||
" return pred == gold"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 66,
|
||
"id": "cb209d5a-f822-4734-afc5-dafc07cc1bbc",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"100%|██████████| 1319/1319 [47:55<00:00, 2.18s/it] \n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Test in GSM8K test set\n",
|
||
"from tqdm import tqdm\n",
|
||
"import os\n",
|
||
"os.makedirs(\"outputs\", exist_ok=True)\n",
|
||
"i = 0\n",
|
||
"\n",
|
||
"compressed_prompt = llm_lingua.compress_prompt(\n",
|
||
" prompt_complex.split(\"\\n\\n\"),\n",
|
||
" instruction=\"\",\n",
|
||
" question=\"\",\n",
|
||
" target_token=200,\n",
|
||
" context_budget=\"*1.5\",\n",
|
||
" iterative_size=100,\n",
|
||
")\n",
|
||
"\n",
|
||
"for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
|
||
" total=len(gsm8k_test['question'])):\n",
|
||
" instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
|
||
" prompt = instruction + compressed_prompt[\"compressed_prompt\"] + \"\\n\\nQuestion: \" + q + \"\\n\"\n",
|
||
" \n",
|
||
" request_data = {\n",
|
||
" \"prompt\": prompt,\n",
|
||
" \"max_tokens\": 400,\n",
|
||
" \"temperature\": 0,\n",
|
||
" \"top_p\": 1,\n",
|
||
" \"n\": 1,\n",
|
||
" \"stream\": False,\n",
|
||
" }\n",
|
||
" response = openai.Completion.create(\n",
|
||
" \"gpt-3.5-turbo-0301\",\n",
|
||
" **request_data,\n",
|
||
" )\n",
|
||
" ans_model = response[\"choices\"][0][\"text\"]\n",
|
||
" ans_, residual = extract_ans(ans_model)\n",
|
||
" with open('outputs/test_gpt_3.5_turbo_LLMLingua_174.txt', 'a') as fd:\n",
|
||
" fd.write(\"Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n\" % (q, ans_.replace(\"Q:\", \"\").replace(\"A:\", \"\"), a))\n",
|
||
" i += 1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 67,
|
||
"id": "3a35d298-8596-4b92-8dda-8da4250c873c",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"num_q 1319 correct 1032 ratio 0.7824\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"_ = parse_pred_ans(\"outputs/test_gpt_3.5_turbo_LLMLingua_174.txt\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.18"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|