Speculative prompt caching

This commit introduces a new cookbook demonstrating how to use the
speculative prompt caching pattern to reduce time-to-first-token (TTFT).
This commit is contained in:
Eric Harmeling
2025-05-15 17:10:33 -04:00
parent eaacd9cddf
commit 25232bb014

View File

@@ -0,0 +1,485 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5c4d70b6",
"metadata": {},
"source": [
"# Speculative Prompt Caching\n",
"\n",
"This cookbook demonstrates \"Speculative Prompt Caching\" - a pattern that reduces time-to-first-token (TTFT) by warming up the cache while users are still formulating their queries.\n",
"\n",
"**Without Speculative Caching:**\n",
"1. User types their question (3 seconds)\n",
"2. User submits question\n",
"3. API loads context into cache AND generates response\n",
"\n",
"**With Speculative Caching:**\n",
"1. User starts typing (cache warming begins immediately)\n",
"2. User continues typing (cache warming completes in background)\n",
"3. User submits question\n",
"4. API uses warm cache to generate response"
]
},
{
"cell_type": "markdown",
"id": "66720f9a",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, let's install the required packages:"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "9a0adb31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install anthropic httpx --quiet"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "c4a98035",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import datetime\n",
"import time\n",
"import asyncio\n",
"import httpx\n",
"from anthropic import AsyncAnthropic\n",
"\n",
"# Configuration constants\n",
"MODEL = \"claude-3-5-sonnet-20241022\"\n",
"SQLITE_SOURCES = {\n",
" \"btree.h\": \"https://sqlite.org/src/raw/18e5e7b2124c23426a283523e5f31a4bff029131b795bb82391f9d2f3136fc50?at=btree.h\",\n",
" \"btree.c\": \"https://sqlite.org/src/raw/63ca6b647342e8cef643863cd0962a542f133e1069460725ba4461dcda92b03c?at=btree.c\",\n",
"}\n",
"DEFAULT_CLIENT_ARGS = {\n",
" \"system\": \"You are an expert systems programmer helping analyze database internals.\",\n",
" \"max_tokens\": 4096,\n",
" \"temperature\": 0,\n",
" \"extra_headers\": {\"anthropic-beta\": \"prompt-caching-2024-07-31\"},\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "49d56e2f",
"metadata": {},
"source": [
"## Helper Functions\n",
"\n",
"Let's set up the functions to download our large context and prepare messages:"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "08b7a5a2",
"metadata": {},
"outputs": [],
"source": [
"async def get_sqlite_sources() -> dict[str, str]:\n",
" print(\"Downloading SQLite source files...\")\n",
"\n",
" source_files = {}\n",
" start_time = time.time()\n",
"\n",
" async with httpx.AsyncClient(timeout=30.0) as client:\n",
" tasks = []\n",
"\n",
" async def download_file(filename: str, url: str) -> tuple[str, str]:\n",
" response = await client.get(url, follow_redirects=True)\n",
" response.raise_for_status()\n",
" print(f\"Successfully downloaded {filename}\")\n",
" return filename, response.text\n",
"\n",
" for filename, url in SQLITE_SOURCES.items():\n",
" tasks.append(download_file(filename, url))\n",
"\n",
" results = await asyncio.gather(*tasks)\n",
" source_files = dict(results)\n",
"\n",
" duration = time.time() - start_time\n",
" print(f\"Downloaded {len(source_files)} files in {duration:.2f} seconds\")\n",
" return source_files\n",
"\n",
"\n",
"async def create_initial_message():\n",
" sources = await get_sqlite_sources()\n",
" # Prepare the initial message with the source code as context.\n",
" # A Timestamp is included to prevent cache sharing across different runs.\n",
" initial_message = {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": f\"\"\"\n",
"Current time: {datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")}\n",
"\n",
"Source to Analyze:\n",
"\n",
"btree.h:\n",
"```c\n",
"{sources[\"btree.h\"]}\n",
"```\n",
"\n",
"btree.c:\n",
"```c\n",
"{sources[\"btree.c\"]}\n",
"```\"\"\",\n",
" \"cache_control\": {\"type\": \"ephemeral\"},\n",
" }\n",
" ],\n",
" }\n",
" return initial_message\n",
"\n",
"\n",
"async def sample_one_token(client: AsyncAnthropic, messages: list):\n",
" \"\"\"Send a single-token request to warm up the cache\"\"\"\n",
" args = copy.deepcopy(DEFAULT_CLIENT_ARGS)\n",
" args[\"max_tokens\"] = 1\n",
" await client.messages.create(\n",
" messages=messages,\n",
" model=MODEL,\n",
" **args,\n",
" )\n",
"\n",
"\n",
"def print_query_statistics(response, query_type: str) -> None:\n",
" print(f\"\\n{query_type} query statistics:\")\n",
" print(f\"\\tInput tokens: {response.usage.input_tokens}\")\n",
" print(f\"\\tOutput tokens: {response.usage.output_tokens}\")\n",
" print(\n",
" f\"\\tCache read input tokens: {getattr(response.usage, 'cache_read_input_tokens', '---')}\"\n",
" )\n",
" print(\n",
" f\"\\tCache creation input tokens: {getattr(response.usage, 'cache_creation_input_tokens', '---')}\"\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "7b183621",
"metadata": {},
"source": [
"## Example 1: Standard Prompt Caching (Without Speculative Caching)\n",
"\n",
"First, let's see how standard prompt caching works. The user types their question, then we send the entire context + question to the API:"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "b16e9048",
"metadata": {},
"outputs": [],
"source": [
"async def standard_prompt_caching_demo():\n",
" client = AsyncAnthropic()\n",
" \n",
" # Prepare the large context\n",
" initial_message = await create_initial_message()\n",
" \n",
" # Simulate user typing time (in real app, this would be actual user input)\n",
" print(\"User is typing their question...\")\n",
" await asyncio.sleep(3) # Simulate 3 seconds of typing\n",
" user_question = \"What is the purpose of the BtShared structure?\"\n",
" print(f\"User submitted: {user_question}\")\n",
" \n",
" # Now send the full request (context + question)\n",
" full_message = copy.deepcopy(initial_message)\n",
" full_message[\"content\"].append(\n",
" {\"type\": \"text\", \"text\": f\"Answer the user's question: {user_question}\"}\n",
" )\n",
" \n",
" print(\"\\nSending request to API...\")\n",
" start_time = time.time()\n",
" \n",
" # Measure time to first token\n",
" first_token_time = None\n",
" async with client.messages.stream(\n",
" messages=[full_message],\n",
" model=MODEL,\n",
" **DEFAULT_CLIENT_ARGS,\n",
" ) as stream:\n",
" async for text in stream.text_stream:\n",
" if first_token_time is None and text.strip():\n",
" first_token_time = time.time() - start_time\n",
" print(f\"\\n🕐 Time to first token: {first_token_time:.2f} seconds\")\n",
" break\n",
" \n",
" # Get the full response\n",
" response = await stream.get_final_message()\n",
" \n",
" total_time = time.time() - start_time\n",
" print(f\"Total response time: {total_time:.2f} seconds\")\n",
" print_query_statistics(response, \"Standard Caching\")\n",
" \n",
" return first_token_time, total_time"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "6dfa2ad1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading SQLite source files...\n",
"Successfully downloaded btree.h\n",
"Successfully downloaded btree.c\n",
"Downloaded 2 files in 0.30 seconds\n",
"User is typing their question...\n",
"User submitted: What is the purpose of the BtShared structure?\n",
"\n",
"Sending request to API...\n",
"\n",
"🕐 Time to first token: 20.87 seconds\n",
"Total response time: 28.32 seconds\n",
"\n",
"Standard Caching query statistics:\n",
"\tInput tokens: 22\n",
"\tOutput tokens: 362\n",
"\tCache read input tokens: 0\n",
"\tCache creation input tokens: 151629\n"
]
}
],
"source": [
"# Run the standard demo\n",
"standard_ttft, standard_total = await standard_prompt_caching_demo()"
]
},
{
"cell_type": "markdown",
"id": "57d0c669",
"metadata": {},
"source": [
"## Example 2: Speculative Prompt Caching\n",
"\n",
"Now let's see how speculative prompt caching improves TTFT by warming the cache while the user is typing:"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "c4ca7484",
"metadata": {},
"outputs": [],
"source": [
"async def speculative_prompt_caching_demo():\n",
" client = AsyncAnthropic()\n",
" \n",
" # The user has a large amount of context they want to interact with,\n",
" # in this case it's the sqlite b-tree implementation (~150k tokens).\n",
" initial_message = await create_initial_message()\n",
" \n",
" # Start speculative caching while user is typing\n",
" print(\"User is typing their question...\")\n",
" print(\"🔥 Starting cache warming in background...\")\n",
" \n",
" # While the user is typing out their question, we sample a single token\n",
" # from the context the user is going to be interacting with with explicit\n",
" # prompt caching turned on to warm up the cache.\n",
" cache_task = asyncio.create_task(sample_one_token(client, [initial_message]))\n",
" \n",
" # Simulate user typing time\n",
" await asyncio.sleep(3) # Simulate 3 seconds of typing\n",
" user_question = \"What is the purpose of the BtShared structure?\"\n",
" print(f\"User submitted: {user_question}\")\n",
" \n",
" # Ensure cache warming is complete\n",
" await cache_task\n",
" print(\"✅ Cache warming completed!\")\n",
" \n",
" # Prepare messages for cached query. We make sure we\n",
" # reuse the same initial message as was cached to ensure we have a cache hit.\n",
" cached_message = copy.deepcopy(initial_message)\n",
" cached_message[\"content\"].append(\n",
" {\"type\": \"text\", \"text\": f\"Answer the user's question: {user_question}\"}\n",
" )\n",
" \n",
" print(\"\\nSending request to API (with warm cache)...\")\n",
" start_time = time.time()\n",
" \n",
" # Measure time to first token\n",
" first_token_time = None\n",
" async with client.messages.stream(\n",
" messages=[cached_message],\n",
" model=MODEL,\n",
" **DEFAULT_CLIENT_ARGS,\n",
" ) as stream:\n",
" async for text in stream.text_stream:\n",
" if first_token_time is None and text.strip():\n",
" first_token_time = time.time() - start_time\n",
" print(f\"\\n🚀 Time to first token: {first_token_time:.2f} seconds\")\n",
" break\n",
" \n",
" # Get the full response\n",
" response = await stream.get_final_message()\n",
" \n",
" total_time = time.time() - start_time\n",
" print(f\"Total response time: {total_time:.2f} seconds\")\n",
" print_query_statistics(response, \"Speculative Caching\")\n",
" \n",
" return first_token_time, total_time"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "6c960975",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading SQLite source files...\n",
"Successfully downloaded btree.h\n",
"Successfully downloaded btree.c\n",
"Downloaded 2 files in 0.36 seconds\n",
"User is typing their question...\n",
"🔥 Starting cache warming in background...\n",
"User submitted: What is the purpose of the BtShared structure?\n",
"✅ Cache warming completed!\n",
"\n",
"Sending request to API (with warm cache)...\n",
"\n",
"🚀 Time to first token: 1.94 seconds\n",
"Total response time: 8.40 seconds\n",
"\n",
"Speculative Caching query statistics:\n",
"\tInput tokens: 22\n",
"\tOutput tokens: 330\n",
"\tCache read input tokens: 151629\n",
"\tCache creation input tokens: 0\n"
]
}
],
"source": [
"# Run the speculative caching demo \n",
"speculative_ttft, speculative_total = await speculative_prompt_caching_demo()"
]
},
{
"cell_type": "markdown",
"id": "deed1898",
"metadata": {},
"source": [
"## Performance Comparison\n",
"\n",
"Let's compare the results to see the benefit of speculative caching:"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "66b3009e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================================================\n",
"PERFORMANCE COMPARISON\n",
"============================================================\n",
"\n",
"Standard Prompt Caching:\n",
" Time to First Token: 20.87 seconds\n",
" Total Response Time: 28.32 seconds\n",
"\n",
"Speculative Prompt Caching:\n",
" Time to First Token: 1.94 seconds\n",
" Total Response Time: 8.40 seconds\n",
"\n",
"🎯 IMPROVEMENTS:\n",
" TTFT Improvement: 90.7% (18.93s faster)\n",
" Total Time Improvement: 70.4% (19.92s faster)\n"
]
}
],
"source": [
"print(\"=\" * 60)\n",
"print(\"PERFORMANCE COMPARISON\")\n",
"print(\"=\" * 60)\n",
"\n",
"print(f\"\\nStandard Prompt Caching:\")\n",
"print(f\" Time to First Token: {standard_ttft:.2f} seconds\")\n",
"print(f\" Total Response Time: {standard_total:.2f} seconds\")\n",
"\n",
"print(f\"\\nSpeculative Prompt Caching:\")\n",
"print(f\" Time to First Token: {speculative_ttft:.2f} seconds\")\n",
"print(f\" Total Response Time: {speculative_total:.2f} seconds\")\n",
"\n",
"ttft_improvement = (standard_ttft - speculative_ttft) / standard_ttft * 100\n",
"total_improvement = (standard_total - speculative_total) / standard_total * 100\n",
"\n",
"print(f\"\\n🎯 IMPROVEMENTS:\")\n",
"print(f\" TTFT Improvement: {ttft_improvement:.1f}% ({standard_ttft - speculative_ttft:.2f}s faster)\")\n",
"print(f\" Total Time Improvement: {total_improvement:.1f}% ({standard_total - speculative_total:.2f}s faster)\")"
]
},
{
"cell_type": "markdown",
"id": "d93c9bdd",
"metadata": {},
"source": [
"## Key Takeaways\n",
"\n",
"1. **Speculative caching dramatically reduces TTFT** by warming the cache while users are typing\n",
"2. **The pattern is most effective** with large contexts (>1000 tokens) that are reused across queries\n",
"3. **Implementation is simple** - just send a 1-token request while the user is typing\n",
"4. **Cache warming happens in parallel** with user input, effectively \"hiding\" the cache creation time\n",
"\n",
"## Best Practices\n",
"\n",
"- Start cache warming as early as possible (e.g., when a user focuses an input field)\n",
"- Use exactly the same context for warming and actual requests to ensure cache hits\n",
"- Monitor `cache_read_input_tokens` to verify cache hits\n",
"- Add timestamps to prevent unwanted cache sharing across sessions"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}