mirror of
https://github.com/anthropics/claude-cookbooks.git
synced 2025-10-06 01:00:28 +03:00
Merge pull request #152 from anthropics/speculative-prompt-caching
Speculative prompt caching
This commit is contained in:
485
misc/speculative_prompt_caching.ipynb
Normal file
485
misc/speculative_prompt_caching.ipynb
Normal file
@@ -0,0 +1,485 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5c4d70b6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Speculative Prompt Caching\n",
|
||||
"\n",
|
||||
"This cookbook demonstrates \"Speculative Prompt Caching\" - a pattern that reduces time-to-first-token (TTFT) by warming up the cache while users are still formulating their queries.\n",
|
||||
"\n",
|
||||
"**Without Speculative Caching:**\n",
|
||||
"1. User types their question (3 seconds)\n",
|
||||
"2. User submits question\n",
|
||||
"3. API loads context into cache AND generates response\n",
|
||||
"\n",
|
||||
"**With Speculative Caching:**\n",
|
||||
"1. User starts typing (cache warming begins immediately)\n",
|
||||
"2. User continues typing (cache warming completes in background)\n",
|
||||
"3. User submits question\n",
|
||||
"4. API uses warm cache to generate response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66720f9a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"First, let's install the required packages:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 82,
|
||||
"id": "9a0adb31",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install anthropic httpx --quiet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 83,
|
||||
"id": "c4a98035",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import copy\n",
|
||||
"import datetime\n",
|
||||
"import time\n",
|
||||
"import asyncio\n",
|
||||
"import httpx\n",
|
||||
"from anthropic import AsyncAnthropic\n",
|
||||
"\n",
|
||||
"# Configuration constants\n",
|
||||
"MODEL = \"claude-3-5-sonnet-20241022\"\n",
|
||||
"SQLITE_SOURCES = {\n",
|
||||
" \"btree.h\": \"https://sqlite.org/src/raw/18e5e7b2124c23426a283523e5f31a4bff029131b795bb82391f9d2f3136fc50?at=btree.h\",\n",
|
||||
" \"btree.c\": \"https://sqlite.org/src/raw/63ca6b647342e8cef643863cd0962a542f133e1069460725ba4461dcda92b03c?at=btree.c\",\n",
|
||||
"}\n",
|
||||
"DEFAULT_CLIENT_ARGS = {\n",
|
||||
" \"system\": \"You are an expert systems programmer helping analyze database internals.\",\n",
|
||||
" \"max_tokens\": 4096,\n",
|
||||
" \"temperature\": 0,\n",
|
||||
" \"extra_headers\": {\"anthropic-beta\": \"prompt-caching-2024-07-31\"},\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "49d56e2f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Helper Functions\n",
|
||||
"\n",
|
||||
"Let's set up the functions to download our large context and prepare messages:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 84,
|
||||
"id": "08b7a5a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def get_sqlite_sources() -> dict[str, str]:\n",
|
||||
" print(\"Downloading SQLite source files...\")\n",
|
||||
"\n",
|
||||
" source_files = {}\n",
|
||||
" start_time = time.time()\n",
|
||||
"\n",
|
||||
" async with httpx.AsyncClient(timeout=30.0) as client:\n",
|
||||
" tasks = []\n",
|
||||
"\n",
|
||||
" async def download_file(filename: str, url: str) -> tuple[str, str]:\n",
|
||||
" response = await client.get(url, follow_redirects=True)\n",
|
||||
" response.raise_for_status()\n",
|
||||
" print(f\"Successfully downloaded {filename}\")\n",
|
||||
" return filename, response.text\n",
|
||||
"\n",
|
||||
" for filename, url in SQLITE_SOURCES.items():\n",
|
||||
" tasks.append(download_file(filename, url))\n",
|
||||
"\n",
|
||||
" results = await asyncio.gather(*tasks)\n",
|
||||
" source_files = dict(results)\n",
|
||||
"\n",
|
||||
" duration = time.time() - start_time\n",
|
||||
" print(f\"Downloaded {len(source_files)} files in {duration:.2f} seconds\")\n",
|
||||
" return source_files\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def create_initial_message():\n",
|
||||
" sources = await get_sqlite_sources()\n",
|
||||
" # Prepare the initial message with the source code as context.\n",
|
||||
" # A Timestamp is included to prevent cache sharing across different runs.\n",
|
||||
" initial_message = {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": f\"\"\"\n",
|
||||
"Current time: {datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")}\n",
|
||||
"\n",
|
||||
"Source to Analyze:\n",
|
||||
"\n",
|
||||
"btree.h:\n",
|
||||
"```c\n",
|
||||
"{sources[\"btree.h\"]}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"btree.c:\n",
|
||||
"```c\n",
|
||||
"{sources[\"btree.c\"]}\n",
|
||||
"```\"\"\",\n",
|
||||
" \"cache_control\": {\"type\": \"ephemeral\"},\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
" return initial_message\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def sample_one_token(client: AsyncAnthropic, messages: list):\n",
|
||||
" \"\"\"Send a single-token request to warm up the cache\"\"\"\n",
|
||||
" args = copy.deepcopy(DEFAULT_CLIENT_ARGS)\n",
|
||||
" args[\"max_tokens\"] = 1\n",
|
||||
" await client.messages.create(\n",
|
||||
" messages=messages,\n",
|
||||
" model=MODEL,\n",
|
||||
" **args,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_query_statistics(response, query_type: str) -> None:\n",
|
||||
" print(f\"\\n{query_type} query statistics:\")\n",
|
||||
" print(f\"\\tInput tokens: {response.usage.input_tokens}\")\n",
|
||||
" print(f\"\\tOutput tokens: {response.usage.output_tokens}\")\n",
|
||||
" print(\n",
|
||||
" f\"\\tCache read input tokens: {getattr(response.usage, 'cache_read_input_tokens', '---')}\"\n",
|
||||
" )\n",
|
||||
" print(\n",
|
||||
" f\"\\tCache creation input tokens: {getattr(response.usage, 'cache_creation_input_tokens', '---')}\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7b183621",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example 1: Standard Prompt Caching (Without Speculative Caching)\n",
|
||||
"\n",
|
||||
"First, let's see how standard prompt caching works. The user types their question, then we send the entire context + question to the API:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 85,
|
||||
"id": "b16e9048",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def standard_prompt_caching_demo():\n",
|
||||
" client = AsyncAnthropic()\n",
|
||||
" \n",
|
||||
" # Prepare the large context\n",
|
||||
" initial_message = await create_initial_message()\n",
|
||||
" \n",
|
||||
" # Simulate user typing time (in real app, this would be actual user input)\n",
|
||||
" print(\"User is typing their question...\")\n",
|
||||
" await asyncio.sleep(3) # Simulate 3 seconds of typing\n",
|
||||
" user_question = \"What is the purpose of the BtShared structure?\"\n",
|
||||
" print(f\"User submitted: {user_question}\")\n",
|
||||
" \n",
|
||||
" # Now send the full request (context + question)\n",
|
||||
" full_message = copy.deepcopy(initial_message)\n",
|
||||
" full_message[\"content\"].append(\n",
|
||||
" {\"type\": \"text\", \"text\": f\"Answer the user's question: {user_question}\"}\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" print(\"\\nSending request to API...\")\n",
|
||||
" start_time = time.time()\n",
|
||||
" \n",
|
||||
" # Measure time to first token\n",
|
||||
" first_token_time = None\n",
|
||||
" async with client.messages.stream(\n",
|
||||
" messages=[full_message],\n",
|
||||
" model=MODEL,\n",
|
||||
" **DEFAULT_CLIENT_ARGS,\n",
|
||||
" ) as stream:\n",
|
||||
" async for text in stream.text_stream:\n",
|
||||
" if first_token_time is None and text.strip():\n",
|
||||
" first_token_time = time.time() - start_time\n",
|
||||
" print(f\"\\n🕐 Time to first token: {first_token_time:.2f} seconds\")\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" # Get the full response\n",
|
||||
" response = await stream.get_final_message()\n",
|
||||
" \n",
|
||||
" total_time = time.time() - start_time\n",
|
||||
" print(f\"Total response time: {total_time:.2f} seconds\")\n",
|
||||
" print_query_statistics(response, \"Standard Caching\")\n",
|
||||
" \n",
|
||||
" return first_token_time, total_time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 86,
|
||||
"id": "6dfa2ad1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading SQLite source files...\n",
|
||||
"Successfully downloaded btree.h\n",
|
||||
"Successfully downloaded btree.c\n",
|
||||
"Downloaded 2 files in 0.30 seconds\n",
|
||||
"User is typing their question...\n",
|
||||
"User submitted: What is the purpose of the BtShared structure?\n",
|
||||
"\n",
|
||||
"Sending request to API...\n",
|
||||
"\n",
|
||||
"🕐 Time to first token: 20.87 seconds\n",
|
||||
"Total response time: 28.32 seconds\n",
|
||||
"\n",
|
||||
"Standard Caching query statistics:\n",
|
||||
"\tInput tokens: 22\n",
|
||||
"\tOutput tokens: 362\n",
|
||||
"\tCache read input tokens: 0\n",
|
||||
"\tCache creation input tokens: 151629\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Run the standard demo\n",
|
||||
"standard_ttft, standard_total = await standard_prompt_caching_demo()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "57d0c669",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example 2: Speculative Prompt Caching\n",
|
||||
"\n",
|
||||
"Now let's see how speculative prompt caching improves TTFT by warming the cache while the user is typing:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 87,
|
||||
"id": "c4ca7484",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def speculative_prompt_caching_demo():\n",
|
||||
" client = AsyncAnthropic()\n",
|
||||
" \n",
|
||||
" # The user has a large amount of context they want to interact with,\n",
|
||||
" # in this case it's the sqlite b-tree implementation (~150k tokens).\n",
|
||||
" initial_message = await create_initial_message()\n",
|
||||
" \n",
|
||||
" # Start speculative caching while user is typing\n",
|
||||
" print(\"User is typing their question...\")\n",
|
||||
" print(\"🔥 Starting cache warming in background...\")\n",
|
||||
" \n",
|
||||
" # While the user is typing out their question, we sample a single token\n",
|
||||
" # from the context the user is going to be interacting with with explicit\n",
|
||||
" # prompt caching turned on to warm up the cache.\n",
|
||||
" cache_task = asyncio.create_task(sample_one_token(client, [initial_message]))\n",
|
||||
" \n",
|
||||
" # Simulate user typing time\n",
|
||||
" await asyncio.sleep(3) # Simulate 3 seconds of typing\n",
|
||||
" user_question = \"What is the purpose of the BtShared structure?\"\n",
|
||||
" print(f\"User submitted: {user_question}\")\n",
|
||||
" \n",
|
||||
" # Ensure cache warming is complete\n",
|
||||
" await cache_task\n",
|
||||
" print(\"✅ Cache warming completed!\")\n",
|
||||
" \n",
|
||||
" # Prepare messages for cached query. We make sure we\n",
|
||||
" # reuse the same initial message as was cached to ensure we have a cache hit.\n",
|
||||
" cached_message = copy.deepcopy(initial_message)\n",
|
||||
" cached_message[\"content\"].append(\n",
|
||||
" {\"type\": \"text\", \"text\": f\"Answer the user's question: {user_question}\"}\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" print(\"\\nSending request to API (with warm cache)...\")\n",
|
||||
" start_time = time.time()\n",
|
||||
" \n",
|
||||
" # Measure time to first token\n",
|
||||
" first_token_time = None\n",
|
||||
" async with client.messages.stream(\n",
|
||||
" messages=[cached_message],\n",
|
||||
" model=MODEL,\n",
|
||||
" **DEFAULT_CLIENT_ARGS,\n",
|
||||
" ) as stream:\n",
|
||||
" async for text in stream.text_stream:\n",
|
||||
" if first_token_time is None and text.strip():\n",
|
||||
" first_token_time = time.time() - start_time\n",
|
||||
" print(f\"\\n🚀 Time to first token: {first_token_time:.2f} seconds\")\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" # Get the full response\n",
|
||||
" response = await stream.get_final_message()\n",
|
||||
" \n",
|
||||
" total_time = time.time() - start_time\n",
|
||||
" print(f\"Total response time: {total_time:.2f} seconds\")\n",
|
||||
" print_query_statistics(response, \"Speculative Caching\")\n",
|
||||
" \n",
|
||||
" return first_token_time, total_time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 88,
|
||||
"id": "6c960975",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading SQLite source files...\n",
|
||||
"Successfully downloaded btree.h\n",
|
||||
"Successfully downloaded btree.c\n",
|
||||
"Downloaded 2 files in 0.36 seconds\n",
|
||||
"User is typing their question...\n",
|
||||
"🔥 Starting cache warming in background...\n",
|
||||
"User submitted: What is the purpose of the BtShared structure?\n",
|
||||
"✅ Cache warming completed!\n",
|
||||
"\n",
|
||||
"Sending request to API (with warm cache)...\n",
|
||||
"\n",
|
||||
"🚀 Time to first token: 1.94 seconds\n",
|
||||
"Total response time: 8.40 seconds\n",
|
||||
"\n",
|
||||
"Speculative Caching query statistics:\n",
|
||||
"\tInput tokens: 22\n",
|
||||
"\tOutput tokens: 330\n",
|
||||
"\tCache read input tokens: 151629\n",
|
||||
"\tCache creation input tokens: 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Run the speculative caching demo \n",
|
||||
"speculative_ttft, speculative_total = await speculative_prompt_caching_demo()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "deed1898",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Performance Comparison\n",
|
||||
"\n",
|
||||
"Let's compare the results to see the benefit of speculative caching:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 89,
|
||||
"id": "66b3009e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"============================================================\n",
|
||||
"PERFORMANCE COMPARISON\n",
|
||||
"============================================================\n",
|
||||
"\n",
|
||||
"Standard Prompt Caching:\n",
|
||||
" Time to First Token: 20.87 seconds\n",
|
||||
" Total Response Time: 28.32 seconds\n",
|
||||
"\n",
|
||||
"Speculative Prompt Caching:\n",
|
||||
" Time to First Token: 1.94 seconds\n",
|
||||
" Total Response Time: 8.40 seconds\n",
|
||||
"\n",
|
||||
"🎯 IMPROVEMENTS:\n",
|
||||
" TTFT Improvement: 90.7% (18.93s faster)\n",
|
||||
" Total Time Improvement: 70.4% (19.92s faster)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(\"PERFORMANCE COMPARISON\")\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
"print(f\"\\nStandard Prompt Caching:\")\n",
|
||||
"print(f\" Time to First Token: {standard_ttft:.2f} seconds\")\n",
|
||||
"print(f\" Total Response Time: {standard_total:.2f} seconds\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nSpeculative Prompt Caching:\")\n",
|
||||
"print(f\" Time to First Token: {speculative_ttft:.2f} seconds\")\n",
|
||||
"print(f\" Total Response Time: {speculative_total:.2f} seconds\")\n",
|
||||
"\n",
|
||||
"ttft_improvement = (standard_ttft - speculative_ttft) / standard_ttft * 100\n",
|
||||
"total_improvement = (standard_total - speculative_total) / standard_total * 100\n",
|
||||
"\n",
|
||||
"print(f\"\\n🎯 IMPROVEMENTS:\")\n",
|
||||
"print(f\" TTFT Improvement: {ttft_improvement:.1f}% ({standard_ttft - speculative_ttft:.2f}s faster)\")\n",
|
||||
"print(f\" Total Time Improvement: {total_improvement:.1f}% ({standard_total - speculative_total:.2f}s faster)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d93c9bdd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Key Takeaways\n",
|
||||
"\n",
|
||||
"1. **Speculative caching dramatically reduces TTFT** by warming the cache while users are typing\n",
|
||||
"2. **The pattern is most effective** with large contexts (>1000 tokens) that are reused across queries\n",
|
||||
"3. **Implementation is simple** - just send a 1-token request while the user is typing\n",
|
||||
"4. **Cache warming happens in parallel** with user input, effectively \"hiding\" the cache creation time\n",
|
||||
"\n",
|
||||
"## Best Practices\n",
|
||||
"\n",
|
||||
"- Start cache warming as early as possible (e.g., when a user focuses an input field)\n",
|
||||
"- Use exactly the same context for warming and actual requests to ensure cache hits\n",
|
||||
"- Monitor `cache_read_input_tokens` to verify cache hits\n",
|
||||
"- Add timestamps to prevent unwanted cache sharing across sessions"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user