{ "cells": [ { "cell_type": "markdown", "id": "1972a352-a0e3-41b7-81dc-dd4ae2b890c3", "metadata": {}, "source": [ "## Retrieval-Augmented Generation (RAG)" ] }, { "cell_type": "markdown", "id": "05d999bc-83a3-454f-a8a4-44cbff1fcedc", "metadata": {}, "source": [ "\r\n", " \"Open\r\n", "" ] }, { "cell_type": "markdown", "id": "fe3ed1ce-d38d-4048-9db6-9707b55dc642", "metadata": {}, "source": [ "Retrieval-Augmented Generation (RAG) is a powerful and popular technique that applies specialized knowledge to large language models (LLMs). However, traditional RAG methods tend to have increasingly long prompts, sometimes exceeding **40k**, which can result in high financial and latency costs. Moreover, the decreased information density within the prompts can lead to performance degradation in LLMs, such as the \"lost in the middle\" issue." ] }, { "cell_type": "markdown", "id": "ae003ead-2f07-44a4-b641-2e33be920dd9", "metadata": {}, "source": [ "
" ] }, { "cell_type": "markdown", "id": "0b39b33f-5860-4825-8f00-d60aed0dce86", "metadata": {}, "source": [ "To address this, we propose [**LongLLMLingua**](https://arxiv.org/abs/2310.06839), which specifically tackles the low information density problem in long context scenarios via prompt compression, making it particularly suitable for RAG tasks. The main ideas involve a two-stage compression process, as shown by the **red line**, which significantly improves the original curve:\n", "\n", "- Coarse-grained compression through document-level perplexity;\n", "- Fine-grained compression of the remaining text using token perplexity;" ] }, { "cell_type": "markdown", "id": "c748f877-4bbf-443c-b72b-332be1df6f1a", "metadata": {}, "source": [ "Instead of fighting against positional effects, we aim to utilize them to our advantage through document reordering, as illustrated by the **green line**. In this approach, the most critical passages are placed at the beginning and the end of the context. Furthermore, the entire process becomes more **cost-effective and faster** since it only requires handling **1/4** of the original context." ] }, { "cell_type": "markdown", "id": "18422597-687a-43aa-a6ed-ce6244d0eb55", "metadata": {}, "source": [ "### NaturalQuestions Multi-document QA" ] }, { "cell_type": "markdown", "id": "51a7accd-5ec2-4ed2-9582-1afdb441a998", "metadata": {}, "source": [ "Next, we will demonstrate the use of LongLLMLingua on the NaturalQuestions dataset, which effectively alleviates the \"lost in the middle\" issue. This dataset closely resembles real-world RAG scenarios, as it first employs the Contriever retrieval system to recall 20 relevant documents (including 1 ground truth and 19 related documents), and then answers the respective questions based on the prompts composed of these 20 documents.\n", "\n", "The original dataset can be found in https://github.com/nelson-liu/lost-in-the-middle/tree/main/qa_data." ] }, { "cell_type": "code", "execution_count": 6, "id": "a970a901-11bd-43af-a8bc-7fb2fc6a1a07", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'lost-in-the-middle'...\n", "remote: Enumerating objects: 101, done.\u001b[K\n", "remote: Counting objects: 100% (78/78), done.\u001b[K\n", "remote: Compressing objects: 100% (52/52), done.\u001b[K\n", "remote: Total 101 (delta 33), reused 61 (delta 17), pack-reused 23\u001b[K\n", "Receiving objects: 100% (101/101), 254.44 MiB | 48.70 MiB/s, done.\n", "Resolving deltas: 100% (33/33), done.\n", "Defaulting to user installation because normal site-packages is not writeable\n", "Obtaining file:///home/hjiang/Code/github/LLMLingua/examples/lost-in-the-middle\n", " Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n", "\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: xopen in /home/hjiang/.local/lib/python3.9/site-packages (from lost-in-the-middle==0.0.0) (1.7.0)\n", "Requirement already satisfied: isal>=1.0.0 in /home/hjiang/.local/lib/python3.9/site-packages (from xopen->lost-in-the-middle==0.0.0) (1.2.0)\n", "Building wheels for collected packages: lost-in-the-middle\n", " Building editable for lost-in-the-middle (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for lost-in-the-middle: filename=lost_in_the_middle-0.0.0-0.editable-py3-none-any.whl size=4611 sha256=2c670631c3bce6e2ca5b87fdc43e73402f33cc2b96aceaa3c89b4ae22f3de741\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-y7iw2jwb/wheels/1e/ff/75/6c31681b19235602b007f32c4ec397e7e2eeacc2c76fcefcde\n", "Successfully built lost-in-the-middle\n", "Installing collected packages: lost-in-the-middle\n", " Attempting uninstall: lost-in-the-middle\n", " Found existing installation: lost-in-the-middle 0.0.0\n", " Uninstalling lost-in-the-middle-0.0.0:\n", " Successfully uninstalled lost-in-the-middle-0.0.0\n", "Successfully installed lost-in-the-middle-0.0.0\n", "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.9 -m pip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "# Install dependency.\n", "## Lost in the middle\n", "!git clone https://github.com/nelson-liu/lost-in-the-middle\n", "!cd lost-in-the-middle && echo \"xopen\" > requirements.txt && pip install -e .\n", "## LLMLingu\n", "!pip install llmlingua" ] }, { "cell_type": "code", "execution_count": 8, "id": "cbbbf3de-a9d6-46cf-afab-dcb72a6154ec", "metadata": {}, "outputs": [], "source": [ "# Using the OAI\n", "import openai\n", "openai.api_key = \"\"" ] }, { "cell_type": "code", "execution_count": 42, "id": "46506810-8565-43da-984b-d862c56b49c2", "metadata": {}, "outputs": [], "source": [ "# or Using the AOAI\n", "import openai\n", "openai.api_key = \"\"\n", "openai.api_base = \"https://xxxx.openai.azure.com/\"\n", "openai.api_type = 'azure'\n", "openai.api_version = '2023-05-15'" ] }, { "cell_type": "markdown", "id": "f8676ffa-5117-44dc-9742-bb9ab1d56e0c", "metadata": {}, "source": [ "### Setup Data" ] }, { "cell_type": "code", "execution_count": 12, "id": "bb349566-83d8-44ac-a683-b67ed9ddf7a6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 2655/2655 [00:01<00:00, 1550.38it/s]\n" ] } ], "source": [ "import json\n", "from xopen import xopen\n", "from copy import deepcopy\n", "from tqdm import tqdm\n", "from lost_in_the_middle.prompting import (\n", " Document,\n", " get_closedbook_qa_prompt,\n", " get_qa_prompt,\n", ")\n", "\n", "datasets = []\n", "path = \"./lost-in-the-middle/qa_data/20_total_documents/nq-open-20_total_documents_gold_at_9.jsonl.gz\"\n", "with xopen(path) as f:\n", " for ii, jj in tqdm(enumerate(f), total=2655):\n", " input_example = json.loads(jj)\n", " question = input_example[\"question\"]\n", " documents = []\n", " for ctx in deepcopy(input_example[\"ctxs\"]):\n", " documents.append(Document.from_dict(ctx))\n", "\n", " prompt = get_qa_prompt(\n", " question,\n", " documents,\n", " mention_random_ordering=False,\n", " query_aware_contextualization=False,\n", " )\n", "\n", " c = prompt.split(\"\\n\\n\")\n", " instruction, question = c[0], c[-1]\n", " demonstration = \"\\n\".join(c[1:-1])\n", " datasets.append({\"id\": ii, \"instruction\": instruction, \"demonstration\": demonstration, \"question\": question, \"answer\": input_example[\"answers\"]})" ] }, { "cell_type": "code", "execution_count": 20, "id": "cc17bbc5-86cb-4d15-a730-955af85a10b2", "metadata": {}, "outputs": [], "source": [ "# select an example from NaturalQuestions\n", "instruction, demonstration_str, question, answer = [datasets[23][key] for key in [\"instruction\", \"demonstration\", \"question\", \"answer\"]]" ] }, { "cell_type": "code", "execution_count": 23, "id": "58718a19-cc4e-4002-a92a-58ea3de9c9d0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['14']" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Ground-truth Answer\n", "answer" ] }, { "cell_type": "markdown", "id": "ba1c6d52-dc87-434c-a41c-0bbc8a286504", "metadata": {}, "source": [ "### The response of Original prompt (Error)" ] }, { "cell_type": "code", "execution_count": 25, "id": "3d441f10-c5c7-4d45-b09a-717e536b36bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"id\": \"chatcmpl-8FFZIQCjv9Dv5Q9WQcDmNBT1OJIP8\",\n", " \"object\": \"chat.completion\",\n", " \"created\": 1698645456,\n", " \"model\": \"gpt-35-turbo\",\n", " \"choices\": [\n", " {\n", " \"index\": 0,\n", " \"finish_reason\": \"stop\",\n", " \"message\": {\n", " \"role\": \"assistant\",\n", " \"content\": \"As of the provided search results, OPEC has 15 member countries.\"\n", " }\n", " }\n", " ],\n", " \"usage\": {\n", " \"prompt_tokens\": 2897,\n", " \"completion_tokens\": 15,\n", " \"total_tokens\": 2912\n", " }\n", "}\n" ] } ], "source": [ "# The response from original prompt, error\n", "prompt = \"\\n\\n\".join([instruction, demonstration_str, question])\n", "\n", "message = [\n", " {\"role\": \"user\", \"content\": prompt},\n", "]\n", "\n", "request_data = {\n", " \"messages\": message,\n", " \"max_tokens\": 100,\n", " \"temperature\": 0,\n", " \"top_p\": 1,\n", " \"n\": 1,\n", " \"stream\": False,\n", "}\n", "response = openai.ChatCompletion.create(\n", " \"gpt-3.5-turbo\",\n", " **request_data,\n", ")\n", "print(json.dumps(response, indent=4))" ] }, { "cell_type": "markdown", "id": "9aa90492-8ad1-4a89-85c5-26b8472f1ff0", "metadata": {}, "source": [ "### The response of Compressed Prompt (Correct in 10x Compression)" ] }, { "cell_type": "code", "execution_count": 29, "id": "fa638dec-c9ec-4dce-9dac-d768145de714", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0cbd44bf14024a3291cce2187b1ec363", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00