Add memory & context management cookbook

Interactive notebook demonstrating Claude's memory tool and context editing capabilities with code review examples.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Alex Notov
2025-09-26 20:11:13 -06:00
parent 4b36a1e1f6
commit 4d3ed1f75b
14 changed files with 2152 additions and 999 deletions

6
.gitignore vendored
View File

@@ -149,4 +149,8 @@ lychee-report.md
# Notebook validation
.notebook_validation_state.json
.notebook_validation_checkpoint.json
validation_report_*.md
validation_report_*.md
# Memory tool demo artifacts
tool_use/demo_memory/
tool_use/memory_storage/
tool_use/.env

14
tool_use/.env.example Normal file
View File

@@ -0,0 +1,14 @@
# Anthropic API Configuration
# Copy this file to .env and fill in your actual values
# Your Anthropic API key from https://console.anthropic.com/
ANTHROPIC_API_KEY=your_api_key_here
# Model name - Use a model that supports memory_20250818 tool
# Supported models (as of launch):
# - claude-sonnet-4-20250514
# - claude-opus-4-20250514
# - claude-opus-4-1-20250805
# - claude-sonnet-4-5-20250929
ANTHROPIC_MODEL=claude-sonnet-4-20250514

File diff suppressed because it is too large Load Diff

5
tool_use/memory_demo/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
# Ignore demo-generated directories and files
demo_memory/
memory_storage/
__pycache__/
*.pyc

View File

@@ -0,0 +1,340 @@
"""
Code Review Assistant Demo - Three-session demonstration.
This demo showcases:
1. Session 1: Claude learns debugging patterns
2. Session 2: Claude applies learned patterns (faster!)
3. Session 3: Long session with context editing
Requires:
- .env file with ANTHROPIC_API_KEY and ANTHROPIC_MODEL
- memory_tool.py in the same directory
"""
import os
from typing import Any, Dict, List, Optional
from anthropic import Anthropic
from dotenv import load_dotenv
import sys
from pathlib import Path
# Add parent directory to path to import memory_tool
sys.path.insert(0, str(Path(__file__).parent.parent))
from memory_tool import MemoryToolHandler
# Load environment variables
load_dotenv()
API_KEY = os.getenv("ANTHROPIC_API_KEY")
MODEL = os.getenv("ANTHROPIC_MODEL")
if not API_KEY:
raise ValueError(
"ANTHROPIC_API_KEY not found. Copy .env.example to .env and add your API key."
)
if not MODEL:
raise ValueError(
"ANTHROPIC_MODEL not found. Copy .env.example to .env and set the model."
)
# Context management configuration
CONTEXT_MANAGEMENT = {
"edits": [
{
"type": "clear_tool_uses_20250919",
"trigger": {"type": "input_tokens", "value": 30000},
"keep": {"type": "tool_uses", "value": 3},
"clear_at_least": {"type": "input_tokens", "value": 5000},
}
]
}
class CodeReviewAssistant:
"""
Code review assistant with memory and context editing capabilities.
This assistant:
- Checks memory for debugging patterns before reviewing code
- Stores learned patterns for future sessions
- Automatically clears old tool results when context grows large
"""
def __init__(self, memory_storage_path: str = "./memory_storage"):
"""
Initialize the code review assistant.
Args:
memory_storage_path: Path for memory storage
"""
self.client = Anthropic(api_key=API_KEY)
self.memory_handler = MemoryToolHandler(base_path=memory_storage_path)
self.messages: List[Dict[str, Any]] = []
def _create_system_prompt(self) -> str:
"""Create system prompt with memory instructions."""
return """You are an expert code reviewer focused on finding bugs and suggesting improvements.
MEMORY PROTOCOL:
1. ALWAYS check your /memories directory FIRST using the memory tool
2. Look for relevant debugging patterns or insights from previous reviews
3. When you find a bug or pattern, update your memory with what you learned
4. Keep your memory organized - use descriptive file names and clear content
When reviewing code:
- Identify bugs, security issues, and code quality problems
- Explain the issue clearly
- Provide a corrected version
- Store important patterns in memory for future reference
Remember: Your memory persists across conversations. Use it wisely."""
def _execute_tool_use(self, tool_use: Any) -> str:
"""Execute a tool use and return the result."""
if tool_use.name == "memory":
result = self.memory_handler.execute(**tool_use.input)
return result.get("success") or result.get("error", "Unknown error")
return f"Unknown tool: {tool_use.name}"
def review_code(
self, code: str, filename: str, description: str = ""
) -> Dict[str, Any]:
"""
Review code with memory-enhanced analysis.
Args:
code: The code to review
filename: Name of the file being reviewed
description: Optional description of what to look for
Returns:
Dict with review results and metadata
"""
# Construct user message
user_message = f"Please review this code from {filename}"
if description:
user_message += f"\n\nContext: {description}"
user_message += f"\n\n```python\n{code}\n```"
self.messages.append({"role": "user", "content": user_message})
# Track token usage and context management
total_input_tokens = 0
context_edits_applied = []
# Conversation loop
turn = 1
while True:
print(f" 🔄 Turn {turn}: Calling Claude API...", end="", flush=True)
response = self.client.beta.messages.create(
model=MODEL,
max_tokens=4096,
system=self._create_system_prompt(),
messages=self.messages,
tools=[{"type": "memory_20250818", "name": "memory"}],
betas=["context-management-2025-06-27"],
extra_body={"context_management": CONTEXT_MANAGEMENT},
)
print("")
# Track usage
total_input_tokens = response.usage.input_tokens
# Check for context management
if hasattr(response, "context_management") and response.context_management:
applied = response.context_management.get("applied_edits", [])
if applied:
context_edits_applied.extend(applied)
# Process response content
assistant_content = []
tool_results = []
final_text = []
for content in response.content:
if content.type == "text":
assistant_content.append({"type": "text", "text": content.text})
final_text.append(content.text)
elif content.type == "tool_use":
cmd = content.input.get('command', 'unknown')
path = content.input.get('path', '')
print(f" 🔧 Memory: {cmd} {path}")
# Execute tool
result = self._execute_tool_use(content)
assistant_content.append(
{
"type": "tool_use",
"id": content.id,
"name": content.name,
"input": content.input,
}
)
tool_results.append(
{
"type": "tool_result",
"tool_use_id": content.id,
"content": result,
}
)
# Add assistant message
self.messages.append({"role": "assistant", "content": assistant_content})
# If there are tool results, add them and continue
if tool_results:
self.messages.append({"role": "user", "content": tool_results})
turn += 1
else:
# No more tool uses, we're done
print()
break
return {
"review": "\n".join(final_text),
"input_tokens": total_input_tokens,
"context_edits": context_edits_applied,
}
def start_new_session(self) -> None:
"""Start a new conversation session (memory persists)."""
self.messages = []
def run_session_1() -> None:
"""Session 1: Learn debugging patterns."""
print("=" * 80)
print("SESSION 1: Learning from First Code Review")
print("=" * 80)
assistant = CodeReviewAssistant()
# Read sample code
with open("memory_demo/sample_code/web_scraper_v1.py", "r") as f:
code = f.read()
print("\n📋 Reviewing web_scraper_v1.py...")
print("\nMulti-threaded web scraper that sometimes loses results.\n")
result = assistant.review_code(
code=code,
filename="web_scraper_v1.py",
description="This scraper sometimes returns fewer results than expected. "
"The count is inconsistent across runs. Can you find the issue?",
)
print("\n🤖 Claude's Review:\n")
print(result["review"])
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
if result["context_edits"]:
print(f"\n🧹 Context edits applied: {result['context_edits']}")
print("\n✅ Session 1 complete - Claude learned debugging patterns!\n")
def run_session_2() -> None:
"""Session 2: Apply learned patterns."""
print("=" * 80)
print("SESSION 2: Applying Learned Patterns (New Conversation)")
print("=" * 80)
# New assistant instance (new conversation, but memory persists)
assistant = CodeReviewAssistant()
# Read different sample code with similar bug
with open("memory_demo/sample_code/api_client_v1.py", "r") as f:
code = f.read()
print("\n📋 Reviewing api_client_v1.py...")
print("\nAsync API client with concurrent requests.\n")
result = assistant.review_code(
code=code,
filename="api_client_v1.py",
description="Review this async API client. "
"It fetches multiple endpoints concurrently. Are there any issues?",
)
print("\n🤖 Claude's Review:\n")
print(result["review"])
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
print("\n✅ Session 2 complete - Claude applied learned patterns faster!\n")
def run_session_3() -> None:
"""Session 3: Long session with context editing."""
print("=" * 80)
print("SESSION 3: Long Session with Context Editing")
print("=" * 80)
assistant = CodeReviewAssistant()
# Read data processor code (has multiple issues)
with open("memory_demo/sample_code/data_processor_v1.py", "r") as f:
code = f.read()
print("\n📋 Reviewing data_processor_v1.py...")
print("\nLarge file with multiple concurrent processing classes.\n")
result = assistant.review_code(
code=code,
filename="data_processor_v1.py",
description="This data processor handles files concurrently. "
"There's also a SharedCache class. Review all components for issues.",
)
print("\n🤖 Claude's Review:\n")
print(result["review"])
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
if result["context_edits"]:
print("\n🧹 Context Management Applied:")
for edit in result["context_edits"]:
print(f" - Type: {edit.get('type')}")
print(f" - Cleared tool uses: {edit.get('cleared_tool_uses', 0)}")
print(f" - Tokens saved: {edit.get('cleared_input_tokens', 0):,}")
print("\n✅ Session 3 complete - Context editing kept conversation manageable!\n")
def main() -> None:
"""Run all three demo sessions."""
print("\n🚀 Code Review Assistant Demo\n")
print("This demo shows:")
print("1. Session 1: Claude learns debugging patterns")
print("2. Session 2: Claude applies learned patterns (new conversation)")
print("3. Session 3: Long session with context editing\n")
input("Press Enter to start Session 1...")
run_session_1()
input("Press Enter to start Session 2...")
run_session_2()
input("Press Enter to start Session 3...")
run_session_3()
print("=" * 80)
print("🎉 Demo Complete!")
print("=" * 80)
print("\nKey Takeaways:")
print("- Memory tool enabled cross-conversation learning")
print("- Claude got faster at recognizing similar bugs")
print("- Context editing handled long sessions gracefully")
print("\n💡 For production GitHub PR reviews, check out:")
print(" https://github.com/anthropics/claude-code-action\n")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,194 @@
"""
Helper functions for memory cookbook demos.
This module provides reusable functions for running conversation loops
with Claude, handling tool execution, and managing context.
"""
from typing import Any
from anthropic import Anthropic
from memory_tool import MemoryToolHandler
def execute_tool(tool_use: Any, memory_handler: MemoryToolHandler) -> str:
"""
Execute a tool use and return the result.
Args:
tool_use: The tool use object from Claude's response
memory_handler: The memory tool handler instance
Returns:
str: The result of the tool execution
"""
if tool_use.name == "memory":
result = memory_handler.execute(**tool_use.input)
return result.get("success") or result.get("error", "Unknown error")
return f"Unknown tool: {tool_use.name}"
def run_conversation_turn(
client: Anthropic,
model: str,
messages: list[dict[str, Any]],
memory_handler: MemoryToolHandler,
system: str,
context_management: dict[str, Any] | None = None,
max_tokens: int = 1024,
verbose: bool = False
) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]:
"""
Run a single conversation turn, handling tool uses.
Args:
client: Anthropic client instance
model: Model to use
messages: Current conversation messages
memory_handler: Memory tool handler instance
system: System prompt
context_management: Optional context management config
max_tokens: Max tokens for response
verbose: Whether to print tool operations
Returns:
Tuple of (response, assistant_content, tool_results)
"""
memory_tool: dict[str, Any] = {"type": "memory_20250818", "name": "memory"}
request_params: dict[str, Any] = {
"model": model,
"max_tokens": max_tokens,
"system": system,
"messages": messages,
"tools": [memory_tool],
"betas": ["context-management-2025-06-27"]
}
if context_management:
request_params["extra_body"] = {"context_management": context_management}
response = client.beta.messages.create(**request_params)
assistant_content = []
tool_results = []
for content in response.content:
if content.type == "text":
if verbose:
print(f"💬 Claude: {content.text}\n")
assistant_content.append({"type": "text", "text": content.text})
elif content.type == "tool_use":
if verbose:
cmd = content.input.get('command')
path = content.input.get('path', '')
print(f" 🔧 Memory tool: {cmd} {path}")
result = execute_tool(content, memory_handler)
if verbose:
result_preview = result[:80] + "..." if len(result) > 80 else result
print(f" ✓ Result: {result_preview}")
assistant_content.append({
"type": "tool_use",
"id": content.id,
"name": content.name,
"input": content.input
})
tool_results.append({
"type": "tool_result",
"tool_use_id": content.id,
"content": result
})
return response, assistant_content, tool_results
def run_conversation_loop(
client: Anthropic,
model: str,
messages: list[dict[str, Any]],
memory_handler: MemoryToolHandler,
system: str,
context_management: dict[str, Any] | None = None,
max_tokens: int = 1024,
max_turns: int = 5,
verbose: bool = False
) -> Any:
"""
Run a complete conversation loop until Claude stops using tools.
Args:
client: Anthropic client instance
model: Model to use
messages: Current conversation messages (will be modified in-place)
memory_handler: Memory tool handler instance
system: System prompt
context_management: Optional context management config
max_tokens: Max tokens for response
max_turns: Maximum number of turns to prevent infinite loops
verbose: Whether to print progress
Returns:
The final API response
"""
turn = 1
response = None
while turn <= max_turns:
if verbose:
print(f"\n🔄 Turn {turn}:")
response, assistant_content, tool_results = run_conversation_turn(
client=client,
model=model,
messages=messages,
memory_handler=memory_handler,
system=system,
context_management=context_management,
max_tokens=max_tokens,
verbose=verbose
)
messages.append({"role": "assistant", "content": assistant_content})
if tool_results:
messages.append({"role": "user", "content": tool_results})
turn += 1
else:
# No more tool uses, conversation complete
break
return response
def print_context_management_info(response: Any) -> tuple[bool, int]:
"""
Print context management information from response.
Args:
response: API response to analyze
Returns:
Tuple of (context_cleared, saved_tokens)
"""
context_cleared = False
saved_tokens = 0
if hasattr(response, "context_management") and response.context_management:
edits = response.context_management.get("applied_edits", [])
if edits:
context_cleared = True
cleared_uses = edits[0].get('cleared_tool_uses', 0)
saved_tokens = edits[0].get('cleared_input_tokens', 0)
print(f" ✂️ Context editing triggered!")
print(f" • Cleared {cleared_uses} tool uses")
print(f" • Saved {saved_tokens:,} tokens")
print(f" • After clearing: {response.usage.input_tokens:,} tokens")
else:
print(f" Context below threshold - no clearing triggered")
else:
print(f" No context management applied")
return context_cleared, saved_tokens

View File

@@ -0,0 +1,99 @@
"""
Async API client with similar concurrency issues.
This demonstrates Claude applying thread-safety patterns to async code.
"""
import asyncio
from typing import List, Dict, Optional
import aiohttp
class AsyncAPIClient:
"""Async API client for fetching data from multiple endpoints."""
def __init__(self, base_url: str):
self.base_url = base_url
self.responses = [] # BUG: Shared state accessed from multiple coroutines!
self.error_count = 0 # BUG: Race condition on counter increment!
async def fetch_endpoint(
self, session: aiohttp.ClientSession, endpoint: str
) -> Dict[str, any]:
"""Fetch a single endpoint."""
url = f"{self.base_url}/{endpoint}"
try:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=5)
) as response:
data = await response.json()
return {
"endpoint": endpoint,
"status": response.status,
"data": data,
}
except Exception as e:
return {
"endpoint": endpoint,
"error": str(e),
}
async def fetch_all(self, endpoints: List[str]) -> List[Dict[str, any]]:
"""
Fetch multiple endpoints concurrently.
BUG: Similar to the threading issue, multiple coroutines
modify self.responses and self.error_count without coordination!
While Python's GIL prevents some race conditions in threads,
async code can still have interleaving issues.
"""
async with aiohttp.ClientSession() as session:
tasks = [self.fetch_endpoint(session, endpoint) for endpoint in endpoints]
for coro in asyncio.as_completed(tasks):
result = await coro
# RACE CONDITION: Multiple coroutines modify shared state
if "error" in result:
self.error_count += 1 # Not atomic!
else:
self.responses.append(result) # Not thread-safe in async context!
return self.responses
def get_summary(self) -> Dict[str, any]:
"""Get summary statistics."""
return {
"total_responses": len(self.responses),
"errors": self.error_count,
"success_rate": (
len(self.responses) / (len(self.responses) + self.error_count)
if (len(self.responses) + self.error_count) > 0
else 0
),
}
async def main():
"""Test the async API client."""
client = AsyncAPIClient("https://jsonplaceholder.typicode.com")
endpoints = [
"posts/1",
"posts/2",
"posts/3",
"users/1",
"users/2",
"invalid/endpoint", # Will error
] * 20 # 120 requests total
results = await client.fetch_all(endpoints)
print(f"Expected: ~100 successful responses")
print(f"Got: {len(results)} responses")
print(f"Summary: {client.get_summary()}")
print("\nNote: Counts may be incorrect due to race conditions!")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,115 @@
"""
Cache manager with mutable default argument bug.
This is one of Python's most common gotchas.
"""
from datetime import datetime
from typing import Dict, List, Optional
class CacheManager:
"""Manage cached data with TTL support."""
def __init__(self):
self.cache = {}
def add_items(
self, key: str, items: List[str] = [] # BUG: Mutable default argument!
) -> None:
"""
Add items to cache.
BUG: Using [] as default creates a SHARED list across all calls!
This is one of Python's classic gotchas.
"""
# The items list is shared across ALL calls that don't provide items
items.append(f"Added at {datetime.now()}")
self.cache[key] = items
def add_items_fixed(self, key: str, items: Optional[List[str]] = None) -> None:
"""Add items with proper default handling."""
if items is None:
items = []
items = items.copy() # Also make a copy to avoid mutation
items.append(f"Added at {datetime.now()}")
self.cache[key] = items
def merge_configs(
self, name: str, overrides: Dict[str, any] = {} # BUG: Mutable default!
) -> Dict[str, any]:
"""
Merge configuration with overrides.
BUG: The default dict is shared across all calls!
"""
defaults = {"timeout": 30, "retries": 3, "cache_enabled": True}
# This modifies the SHARED overrides dict
overrides.update(defaults)
return overrides
def merge_configs_fixed(
self, name: str, overrides: Optional[Dict[str, any]] = None
) -> Dict[str, any]:
"""Merge configs properly."""
if overrides is None:
overrides = {}
defaults = {"timeout": 30, "retries": 3, "cache_enabled": True}
# Create new dict to avoid mutation
result = {**defaults, **overrides}
return result
class DataProcessor:
"""Another example of the mutable default bug."""
def process_batch(
self, data: List[int], filters: List[str] = [] # BUG: Mutable default!
) -> List[int]:
"""
Process data with optional filters.
BUG: filters list is shared across calls!
"""
filters.append("default_filter") # Modifies shared list!
result = []
for item in data:
if "positive" in filters and item < 0:
continue
result.append(item * 2)
return result
if __name__ == "__main__":
cache = CacheManager()
# Demonstrate the bug
print("=== Demonstrating Mutable Default Argument Bug ===\n")
# First call with no items
cache.add_items("key1")
print(f"key1: {cache.cache['key1']}")
# Second call with no items - SURPRISE! Gets the same list
cache.add_items("key2")
print(f"key2: {cache.cache['key2']}") # Will have TWO timestamps!
# Third call - even worse!
cache.add_items("key3")
print(f"key3: {cache.cache['key3']}") # Will have THREE timestamps!
print("\nAll keys share the same list object!")
print(f"key1 is key2: {cache.cache['key1'] is cache.cache['key2']}")
print("\n=== Using Fixed Version ===\n")
cache2 = CacheManager()
cache2.add_items_fixed("key1")
cache2.add_items_fixed("key2")
cache2.add_items_fixed("key3")
print(f"key1: {cache2.cache['key1']}")
print(f"key2: {cache2.cache['key2']}")
print(f"key3: {cache2.cache['key3']}")
print(f"\nkey1 is key2: {cache2.cache['key1'] is cache2.cache['key2']}")

View File

@@ -0,0 +1,145 @@
"""
Data processor with multiple concurrency and thread-safety issues.
Used for Session 3 to demonstrate context editing with multiple bugs.
"""
import json
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List, Dict, Any
class DataProcessor:
"""Process data files concurrently with various thread-safety issues."""
def __init__(self, max_workers: int = 5):
self.max_workers = max_workers
self.processed_count = 0 # BUG: Race condition on counter
self.results = [] # BUG: Shared list without locking
self.errors = {} # BUG: Shared dict without locking
self.lock = threading.Lock() # Available but not used!
def process_file(self, file_path: str) -> Dict[str, Any]:
"""Process a single file."""
try:
with open(file_path, "r") as f:
data = json.load(f)
# Simulate some processing
processed = {
"file": file_path,
"record_count": len(data) if isinstance(data, list) else 1,
"size_bytes": Path(file_path).stat().st_size,
}
return processed
except Exception as e:
return {"file": file_path, "error": str(e)}
def process_batch(self, file_paths: List[str]) -> List[Dict[str, Any]]:
"""
Process multiple files concurrently.
MULTIPLE BUGS:
1. self.processed_count is incremented without locking
2. self.results is appended to from multiple threads
3. self.errors is modified from multiple threads
4. We have a lock but don't use it!
"""
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self.process_file, fp) for fp in file_paths]
for future in futures:
result = future.result()
# RACE CONDITION: Increment counter without lock
self.processed_count += 1 # BUG!
if "error" in result:
# RACE CONDITION: Modify dict without lock
self.errors[result["file"]] = result["error"] # BUG!
else:
# RACE CONDITION: Append to list without lock
self.results.append(result) # BUG!
return self.results
def get_statistics(self) -> Dict[str, Any]:
"""
Get processing statistics.
BUG: Accessing shared state without ensuring thread-safety.
If called while processing, could get inconsistent values.
"""
total_files = self.processed_count
successful = len(self.results)
failed = len(self.errors)
# BUG: These counts might not add up correctly due to race conditions
return {
"total_processed": total_files,
"successful": successful,
"failed": failed,
"success_rate": successful / total_files if total_files > 0 else 0,
}
def reset(self):
"""
Reset processor state.
BUG: No locking - if called during processing, causes corruption.
"""
self.processed_count = 0 # RACE CONDITION
self.results = [] # RACE CONDITION
self.errors = {} # RACE CONDITION
class SharedCache:
"""
A shared cache with thread-safety issues.
BUG: Classic read-modify-write race condition pattern.
"""
def __init__(self):
self.cache = {} # BUG: Shared dict without locking
self.hit_count = 0 # BUG: Race condition
self.miss_count = 0 # BUG: Race condition
def get(self, key: str) -> Any:
"""Get from cache - RACE CONDITION on hit/miss counts."""
if key in self.cache:
self.hit_count += 1 # BUG: Not atomic!
return self.cache[key]
else:
self.miss_count += 1 # BUG: Not atomic!
return None
def set(self, key: str, value: Any):
"""Set in cache - RACE CONDITION on dict modification."""
self.cache[key] = value # BUG: Dict access not synchronized!
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics - may be inconsistent."""
total = self.hit_count + self.miss_count
return {
"hits": self.hit_count,
"misses": self.miss_count,
"hit_rate": self.hit_count / total if total > 0 else 0,
}
if __name__ == "__main__":
# Create some test files (not included)
processor = DataProcessor(max_workers=10)
# Simulate processing many files
file_paths = [f"data/file_{i}.json" for i in range(100)]
print("Processing files concurrently...")
results = processor.process_batch(file_paths)
print(f"\nStatistics: {processor.get_statistics()}")
print("\nNote: Counts may be inconsistent due to race conditions!")

View File

@@ -0,0 +1,105 @@
"""
SQL query builder with SQL injection vulnerability.
Demonstrates dangerous string formatting in SQL queries.
"""
from typing import List, Optional
class UserDatabase:
"""Simple database interface (mock)."""
def execute(self, query: str) -> List[dict]:
"""Mock execute - just returns the query for inspection."""
print(f"Executing: {query}")
return []
class QueryBuilder:
"""Build SQL queries for user operations."""
def __init__(self, db: UserDatabase):
self.db = db
def get_user_by_name(self, username: str) -> Optional[dict]:
"""
Get user by username.
BUG: SQL INJECTION VULNERABILITY!
Using string formatting with user input allows SQL injection.
"""
# DANGEROUS: Never use f-strings or % formatting with user input!
query = f"SELECT * FROM users WHERE username = '{username}'"
results = self.db.execute(query)
return results[0] if results else None
def get_user_by_name_safe(self, username: str) -> Optional[dict]:
"""Safe version using parameterized queries."""
# Use parameterized queries (this is pseudo-code for the concept)
query = "SELECT * FROM users WHERE username = ?"
# In real code: self.db.execute(query, (username,))
print(f"Safe query with parameter: {query}, params: ({username},)")
return None
def search_users(self, search_term: str, limit: int = 10) -> List[dict]:
"""
Search users by term.
BUG: SQL INJECTION through LIKE clause!
"""
# DANGEROUS: User input directly in LIKE clause
query = f"SELECT * FROM users WHERE name LIKE '%{search_term}%' LIMIT {limit}"
return self.db.execute(query)
def delete_user(self, user_id: str) -> bool:
"""
Delete a user.
BUG: SQL INJECTION in DELETE statement!
This is especially dangerous as it can lead to data loss.
"""
# DANGEROUS: Unvalidated user input in DELETE
query = f"DELETE FROM users WHERE id = {user_id}"
self.db.execute(query)
return True
def get_users_by_role(self, role: str, order_by: str = "name") -> List[dict]:
"""
Get users by role.
BUG: SQL INJECTION in ORDER BY clause!
Even the ORDER BY clause can be exploited.
"""
# DANGEROUS: User-controlled ORDER BY
query = f"SELECT * FROM users WHERE role = '{role}' ORDER BY {order_by}"
return self.db.execute(query)
if __name__ == "__main__":
db = UserDatabase()
qb = QueryBuilder(db)
print("=== Demonstrating SQL Injection Vulnerabilities ===\n")
# Example 1: Basic injection
print("1. Basic username injection:")
qb.get_user_by_name("admin' OR '1'='1")
# Executes: SELECT * FROM users WHERE username = 'admin' OR '1'='1'
# Returns ALL users!
print("\n2. Search term injection:")
qb.search_users("test%' OR 1=1--")
# Can bypass the LIKE and return everything
print("\n3. DELETE injection:")
qb.delete_user("1 OR 1=1")
# Executes: DELETE FROM users WHERE id = 1 OR 1=1
# DELETES ALL USERS!
print("\n4. ORDER BY injection:")
qb.get_users_by_role("admin", "name; DROP TABLE users--")
# Can execute arbitrary SQL commands!
print("\n=== Safe Version ===")
qb.get_user_by_name_safe("admin' OR '1'='1")
# Parameters are properly escaped

View File

@@ -0,0 +1,84 @@
"""
Concurrent web scraper with a race condition bug.
Multiple threads modify shared state without synchronization.
"""
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict
import requests
class WebScraper:
"""Web scraper that fetches multiple URLs concurrently."""
def __init__(self, max_workers: int = 10):
self.max_workers = max_workers
self.results = [] # BUG: Shared mutable state accessed by multiple threads!
self.failed_urls = [] # BUG: Another race condition!
def fetch_url(self, url: str) -> Dict[str, any]:
"""Fetch a single URL and return the result."""
try:
response = requests.get(url, timeout=5)
response.raise_for_status()
return {
"url": url,
"status": response.status_code,
"content_length": len(response.content),
}
except requests.exceptions.RequestException as e:
return {"url": url, "error": str(e)}
def scrape_urls(self, urls: List[str]) -> List[Dict[str, any]]:
"""
Scrape multiple URLs concurrently.
BUG: self.results is accessed from multiple threads without locking!
This causes race conditions where results can be lost or corrupted.
"""
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self.fetch_url, url) for url in urls]
for future in as_completed(futures):
result = future.result()
# RACE CONDITION: Multiple threads append to self.results simultaneously
if "error" in result:
self.failed_urls.append(result["url"]) # RACE CONDITION
else:
self.results.append(result) # RACE CONDITION
return self.results
def get_stats(self) -> Dict[str, int]:
"""Get scraping statistics."""
return {
"total_results": len(self.results),
"failed_urls": len(self.failed_urls),
"success_rate": (
len(self.results) / (len(self.results) + len(self.failed_urls))
if (len(self.results) + len(self.failed_urls)) > 0
else 0
),
}
if __name__ == "__main__":
# Test with multiple URLs
urls = [
"https://httpbin.org/delay/1",
"https://httpbin.org/status/200",
"https://httpbin.org/status/404",
"https://httpbin.org/delay/2",
"https://httpbin.org/status/500",
] * 10 # 50 URLs total to increase race condition probability
scraper = WebScraper(max_workers=10)
results = scraper.scrape_urls(urls)
print(f"Expected: 50 results")
print(f"Got: {len(results)} results")
print(f"Stats: {scraper.get_stats()}")
print("\nNote: Results count may be less than expected due to race condition!")

351
tool_use/memory_tool.py Normal file
View File

@@ -0,0 +1,351 @@
"""
Production-ready memory tool handler for Claude's memory_20250818 tool.
This implementation provides secure, client-side execution of memory operations
with path validation, error handling, and comprehensive security measures.
"""
import shutil
from pathlib import Path
from typing import Any
class MemoryToolHandler:
"""
Handles execution of Claude's memory tool commands.
The memory tool enables Claude to read, write, and manage files in a memory
system through a standardized tool interface. This handler provides client-side
implementation with security controls.
Attributes:
base_path: Root directory for memory storage
memory_root: The /memories directory within base_path
"""
def __init__(self, base_path: str = "./memory_storage"):
"""
Initialize the memory tool handler.
Args:
base_path: Root directory for all memory operations
"""
self.base_path = Path(base_path).resolve()
self.memory_root = self.base_path / "memories"
self.memory_root.mkdir(parents=True, exist_ok=True)
def _validate_path(self, path: str) -> Path:
"""
Validate and resolve memory paths to prevent directory traversal attacks.
Args:
path: The path to validate (must start with /memories)
Returns:
Resolved absolute Path object within memory_root
Raises:
ValueError: If path is invalid or attempts to escape memory directory
"""
if not path.startswith("/memories"):
raise ValueError(
f"Path must start with /memories, got: {path}. "
"All memory operations must be confined to the /memories directory."
)
# Remove /memories prefix and any leading slashes
relative_path = path[len("/memories") :].lstrip("/")
# Resolve to absolute path within memory_root
if relative_path:
full_path = (self.memory_root / relative_path).resolve()
else:
full_path = self.memory_root.resolve()
# Verify the resolved path is still within memory_root
try:
full_path.relative_to(self.memory_root.resolve())
except ValueError as e:
raise ValueError(
f"Path '{path}' would escape /memories directory. "
"Directory traversal attempts are not allowed."
) from e
return full_path
def execute(self, **params: Any) -> dict[str, str]:
"""
Execute a memory tool command.
Args:
**params: Command parameters from Claude's tool use
Returns:
Dict with either 'success' or 'error' key
Supported commands:
- view: Show directory contents or file contents
- create: Create or overwrite a file
- str_replace: Replace text in a file
- insert: Insert text at a specific line
- delete: Delete a file or directory
- rename: Rename or move a file/directory
"""
command = params.get("command")
try:
if command == "view":
return self._view(params)
elif command == "create":
return self._create(params)
elif command == "str_replace":
return self._str_replace(params)
elif command == "insert":
return self._insert(params)
elif command == "delete":
return self._delete(params)
elif command == "rename":
return self._rename(params)
else:
return {
"error": f"Unknown command: '{command}'. "
"Valid commands are: view, create, str_replace, insert, delete, rename"
}
except ValueError as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"Unexpected error executing {command}: {e}"}
def _view(self, params: dict[str, Any]) -> dict[str, str]:
"""View directory contents or file contents."""
path = params.get("path")
view_range = params.get("view_range")
if not path:
return {"error": "Missing required parameter: path"}
full_path = self._validate_path(path)
# Handle directory listing
if full_path.is_dir():
try:
items = []
for item in sorted(full_path.iterdir()):
if item.name.startswith("."):
continue
items.append(f"{item.name}/" if item.is_dir() else item.name)
if not items:
return {"success": f"Directory: {path}\n(empty)"}
return {
"success": f"Directory: {path}\n"
+ "\n".join([f"- {item}" for item in items])
}
except Exception as e:
return {"error": f"Cannot read directory {path}: {e}"}
# Handle file reading
elif full_path.is_file():
try:
content = full_path.read_text(encoding="utf-8")
lines = content.splitlines()
# Apply view range if specified
if view_range:
start_line = max(1, view_range[0]) - 1 # Convert to 0-indexed
end_line = len(lines) if view_range[1] == -1 else view_range[1]
lines = lines[start_line:end_line]
start_num = start_line + 1
else:
start_num = 1
# Format with line numbers
numbered_lines = [
f"{i + start_num:4d}: {line}" for i, line in enumerate(lines)
]
return {"success": "\n".join(numbered_lines)}
except UnicodeDecodeError:
return {"error": f"Cannot read {path}: File is not valid UTF-8 text"}
except Exception as e:
return {"error": f"Cannot read file {path}: {e}"}
else:
return {"error": f"Path not found: {path}"}
def _create(self, params: dict[str, Any]) -> dict[str, str]:
"""Create or overwrite a file."""
path = params.get("path")
file_text = params.get("file_text", "")
if not path:
return {"error": "Missing required parameter: path"}
full_path = self._validate_path(path)
# Don't allow creating directories directly
if not path.endswith((".txt", ".md", ".json", ".py", ".yaml", ".yml")):
return {
"error": f"Cannot create {path}: Only text files are supported. "
"Use file extensions: .txt, .md, .json, .py, .yaml, .yml"
}
try:
# Create parent directories if needed
full_path.parent.mkdir(parents=True, exist_ok=True)
# Write the file
full_path.write_text(file_text, encoding="utf-8")
return {"success": f"File created successfully at {path}"}
except Exception as e:
return {"error": f"Cannot create file {path}: {e}"}
def _str_replace(self, params: dict[str, Any]) -> dict[str, str]:
"""Replace text in a file."""
path = params.get("path")
old_str = params.get("old_str")
new_str = params.get("new_str", "")
if not path or old_str is None:
return {"error": "Missing required parameters: path, old_str"}
full_path = self._validate_path(path)
if not full_path.is_file():
return {"error": f"File not found: {path}"}
try:
content = full_path.read_text(encoding="utf-8")
# Check if old_str exists
count = content.count(old_str)
if count == 0:
return {
"error": f"String not found in {path}. "
"The exact text must exist in the file."
}
elif count > 1:
return {
"error": f"String appears {count} times in {path}. "
"The string must be unique. Use more specific context."
}
# Perform replacement
new_content = content.replace(old_str, new_str, 1)
full_path.write_text(new_content, encoding="utf-8")
return {"success": f"File {path} has been edited successfully"}
except Exception as e:
return {"error": f"Cannot edit file {path}: {e}"}
def _insert(self, params: dict[str, Any]) -> dict[str, str]:
"""Insert text at a specific line."""
path = params.get("path")
insert_line = params.get("insert_line")
insert_text = params.get("insert_text", "")
if not path or insert_line is None:
return {"error": "Missing required parameters: path, insert_line"}
full_path = self._validate_path(path)
if not full_path.is_file():
return {"error": f"File not found: {path}"}
try:
lines = full_path.read_text(encoding="utf-8").splitlines()
# Validate insert_line
if insert_line < 0 or insert_line > len(lines):
return {
"error": f"Invalid insert_line {insert_line}. "
f"Must be between 0 and {len(lines)}"
}
# Insert the text
lines.insert(insert_line, insert_text.rstrip("\n"))
# Write back
full_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
return {"success": f"Text inserted at line {insert_line} in {path}"}
except Exception as e:
return {"error": f"Cannot insert into {path}: {e}"}
def _delete(self, params: dict[str, Any]) -> dict[str, str]:
"""Delete a file or directory."""
path = params.get("path")
if not path:
return {"error": "Missing required parameter: path"}
# Prevent deletion of root memories directory
if path == "/memories":
return {"error": "Cannot delete the /memories directory itself"}
full_path = self._validate_path(path)
if not full_path.exists():
return {"error": f"Path not found: {path}"}
try:
if full_path.is_file():
full_path.unlink()
return {"success": f"File deleted: {path}"}
elif full_path.is_dir():
shutil.rmtree(full_path)
return {"success": f"Directory deleted: {path}"}
except Exception as e:
return {"error": f"Cannot delete {path}: {e}"}
def _rename(self, params: dict[str, Any]) -> dict[str, str]:
"""Rename or move a file/directory."""
old_path = params.get("old_path")
new_path = params.get("new_path")
if not old_path or not new_path:
return {"error": "Missing required parameters: old_path, new_path"}
old_full_path = self._validate_path(old_path)
new_full_path = self._validate_path(new_path)
if not old_full_path.exists():
return {"error": f"Source path not found: {old_path}"}
if new_full_path.exists():
return {
"error": f"Destination already exists: {new_path}. "
"Cannot overwrite existing files/directories."
}
try:
# Create parent directories if needed
new_full_path.parent.mkdir(parents=True, exist_ok=True)
# Perform rename/move
old_full_path.rename(new_full_path)
return {"success": f"Renamed {old_path} to {new_path}"}
except Exception as e:
return {"error": f"Cannot rename {old_path} to {new_path}: {e}"}
def clear_all_memory(self) -> dict[str, str]:
"""
Clear all memory files (useful for testing or starting fresh).
Returns:
Dict with success message
"""
try:
if self.memory_root.exists():
shutil.rmtree(self.memory_root)
self.memory_root.mkdir(parents=True, exist_ok=True)
return {"success": "All memory cleared successfully"}
except Exception as e:
return {"error": f"Cannot clear memory: {e}"}

View File

@@ -0,0 +1,3 @@
anthropic>=0.18.0
python-dotenv>=1.0.0
ipykernel>=6.29.0 # For Jupyter in VSCode

View File

@@ -0,0 +1,426 @@
"""
Unit tests for the memory tool handler.
Tests security validation, command execution, and error handling.
"""
import shutil
import tempfile
import unittest
from pathlib import Path
from memory_tool import MemoryToolHandler
class TestMemoryToolHandler(unittest.TestCase):
"""Test suite for MemoryToolHandler."""
def setUp(self):
"""Create temporary directory for each test."""
self.test_dir = tempfile.mkdtemp()
self.handler = MemoryToolHandler(base_path=self.test_dir)
def tearDown(self):
"""Clean up temporary directory after each test."""
shutil.rmtree(self.test_dir)
# Security Tests
def test_path_validation_requires_memories_prefix(self):
"""Test that paths must start with /memories."""
result = self.handler.execute(command="view", path="/etc/passwd")
self.assertIn("error", result)
self.assertIn("must start with /memories", result["error"])
def test_path_validation_prevents_traversal_dotdot(self):
"""Test that .. traversal is blocked."""
result = self.handler.execute(
command="view", path="/memories/../../../etc/passwd"
)
self.assertIn("error", result)
self.assertIn("escape", result["error"].lower())
def test_path_validation_prevents_traversal_encoded(self):
"""Test that URL-encoded traversal is blocked."""
result = self.handler.execute(
command="view", path="/memories/%2e%2e/%2e%2e/etc/passwd"
)
# The path will be processed and should fail validation
self.assertIn("error", result)
def test_path_validation_allows_valid_paths(self):
"""Test that valid memory paths are accepted."""
result = self.handler.execute(
command="create", path="/memories/test.txt", file_text="test"
)
self.assertIn("success", result)
# View Command Tests
def test_view_empty_directory(self):
"""Test viewing an empty /memories directory."""
result = self.handler.execute(command="view", path="/memories")
self.assertIn("success", result)
self.assertIn("empty", result["success"].lower())
def test_view_directory_with_files(self):
"""Test viewing a directory with files."""
# Create some test files
self.handler.execute(
command="create", path="/memories/file1.txt", file_text="content1"
)
self.handler.execute(
command="create", path="/memories/file2.txt", file_text="content2"
)
result = self.handler.execute(command="view", path="/memories")
self.assertIn("success", result)
self.assertIn("file1.txt", result["success"])
self.assertIn("file2.txt", result["success"])
def test_view_file_with_line_numbers(self):
"""Test viewing a file with line numbers."""
content = "line 1\nline 2\nline 3"
self.handler.execute(
command="create", path="/memories/test.txt", file_text=content
)
result = self.handler.execute(command="view", path="/memories/test.txt")
self.assertIn("success", result)
self.assertIn(" 1: line 1", result["success"])
self.assertIn(" 2: line 2", result["success"])
self.assertIn(" 3: line 3", result["success"])
def test_view_file_with_range(self):
"""Test viewing specific line range."""
content = "line 1\nline 2\nline 3\nline 4"
self.handler.execute(
command="create", path="/memories/test.txt", file_text=content
)
result = self.handler.execute(
command="view", path="/memories/test.txt", view_range=[2, 3]
)
self.assertIn("success", result)
self.assertIn(" 2: line 2", result["success"])
self.assertIn(" 3: line 3", result["success"])
self.assertNotIn("line 1", result["success"])
self.assertNotIn("line 4", result["success"])
def test_view_nonexistent_path(self):
"""Test viewing a nonexistent path."""
result = self.handler.execute(command="view", path="/memories/notfound.txt")
self.assertIn("error", result)
self.assertIn("not found", result["error"].lower())
# Create Command Tests
def test_create_file(self):
"""Test creating a file."""
result = self.handler.execute(
command="create", path="/memories/test.txt", file_text="Hello, World!"
)
self.assertIn("success", result)
# Verify file exists
file_path = Path(self.test_dir) / "memories" / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_text(), "Hello, World!")
def test_create_file_in_subdirectory(self):
"""Test creating a file in a subdirectory (auto-creates dirs)."""
result = self.handler.execute(
command="create",
path="/memories/subdir/test.txt",
file_text="Nested content",
)
self.assertIn("success", result)
file_path = Path(self.test_dir) / "memories" / "subdir" / "test.txt"
self.assertTrue(file_path.exists())
def test_create_requires_file_extension(self):
"""Test that create only allows text file extensions."""
result = self.handler.execute(
command="create", path="/memories/noext", file_text="content"
)
self.assertIn("error", result)
self.assertIn("text files are supported", result["error"])
def test_create_overwrites_existing_file(self):
"""Test that create overwrites existing files."""
self.handler.execute(
command="create", path="/memories/test.txt", file_text="original"
)
result = self.handler.execute(
command="create", path="/memories/test.txt", file_text="updated"
)
self.assertIn("success", result)
file_path = Path(self.test_dir) / "memories" / "test.txt"
self.assertEqual(file_path.read_text(), "updated")
# String Replace Command Tests
def test_str_replace_success(self):
"""Test successful string replacement."""
self.handler.execute(
command="create",
path="/memories/test.txt",
file_text="Hello World",
)
result = self.handler.execute(
command="str_replace",
path="/memories/test.txt",
old_str="World",
new_str="Universe",
)
self.assertIn("success", result)
file_path = Path(self.test_dir) / "memories" / "test.txt"
self.assertEqual(file_path.read_text(), "Hello Universe")
def test_str_replace_string_not_found(self):
"""Test replacement when string doesn't exist."""
self.handler.execute(
command="create", path="/memories/test.txt", file_text="Hello World"
)
result = self.handler.execute(
command="str_replace",
path="/memories/test.txt",
old_str="Missing",
new_str="Text",
)
self.assertIn("error", result)
self.assertIn("not found", result["error"].lower())
def test_str_replace_multiple_occurrences(self):
"""Test that replacement fails with multiple occurrences."""
self.handler.execute(
command="create",
path="/memories/test.txt",
file_text="Hello World Hello World",
)
result = self.handler.execute(
command="str_replace",
path="/memories/test.txt",
old_str="Hello",
new_str="Hi",
)
self.assertIn("error", result)
self.assertIn("appears 2 times", result["error"])
def test_str_replace_file_not_found(self):
"""Test replacement on nonexistent file."""
result = self.handler.execute(
command="str_replace",
path="/memories/notfound.txt",
old_str="old",
new_str="new",
)
self.assertIn("error", result)
self.assertIn("not found", result["error"].lower())
# Insert Command Tests
def test_insert_at_beginning(self):
"""Test inserting at line 0 (beginning)."""
self.handler.execute(
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
)
result = self.handler.execute(
command="insert",
path="/memories/test.txt",
insert_line=0,
insert_text="new line",
)
self.assertIn("success", result)
file_path = Path(self.test_dir) / "memories" / "test.txt"
self.assertEqual(file_path.read_text(), "new line\nline 1\nline 2\n")
def test_insert_in_middle(self):
"""Test inserting in the middle."""
self.handler.execute(
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
)
result = self.handler.execute(
command="insert",
path="/memories/test.txt",
insert_line=1,
insert_text="inserted",
)
self.assertIn("success", result)
file_path = Path(self.test_dir) / "memories" / "test.txt"
self.assertEqual(file_path.read_text(), "line 1\ninserted\nline 2\n")
def test_insert_at_end(self):
"""Test inserting at the end."""
self.handler.execute(
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
)
result = self.handler.execute(
command="insert",
path="/memories/test.txt",
insert_line=2,
insert_text="last line",
)
self.assertIn("success", result)
def test_insert_invalid_line(self):
"""Test insert with invalid line number."""
self.handler.execute(
command="create", path="/memories/test.txt", file_text="line 1"
)
result = self.handler.execute(
command="insert",
path="/memories/test.txt",
insert_line=99,
insert_text="text",
)
self.assertIn("error", result)
self.assertIn("invalid", result["error"].lower())
# Delete Command Tests
def test_delete_file(self):
"""Test deleting a file."""
self.handler.execute(
command="create", path="/memories/test.txt", file_text="content"
)
result = self.handler.execute(command="delete", path="/memories/test.txt")
self.assertIn("success", result)
file_path = Path(self.test_dir) / "memories" / "test.txt"
self.assertFalse(file_path.exists())
def test_delete_directory(self):
"""Test deleting a directory."""
self.handler.execute(
command="create", path="/memories/subdir/test.txt", file_text="content"
)
result = self.handler.execute(command="delete", path="/memories/subdir")
self.assertIn("success", result)
dir_path = Path(self.test_dir) / "memories" / "subdir"
self.assertFalse(dir_path.exists())
def test_delete_cannot_delete_root(self):
"""Test that root /memories directory cannot be deleted."""
result = self.handler.execute(command="delete", path="/memories")
self.assertIn("error", result)
self.assertIn("cannot delete", result["error"].lower())
def test_delete_nonexistent_path(self):
"""Test deleting a nonexistent path."""
result = self.handler.execute(command="delete", path="/memories/notfound.txt")
self.assertIn("error", result)
self.assertIn("not found", result["error"].lower())
# Rename Command Tests
def test_rename_file(self):
"""Test renaming a file."""
self.handler.execute(
command="create", path="/memories/old.txt", file_text="content"
)
result = self.handler.execute(
command="rename", old_path="/memories/old.txt", new_path="/memories/new.txt"
)
self.assertIn("success", result)
old_path = Path(self.test_dir) / "memories" / "old.txt"
new_path = Path(self.test_dir) / "memories" / "new.txt"
self.assertFalse(old_path.exists())
self.assertTrue(new_path.exists())
def test_rename_to_subdirectory(self):
"""Test moving a file to a subdirectory."""
self.handler.execute(
command="create", path="/memories/file.txt", file_text="content"
)
result = self.handler.execute(
command="rename",
old_path="/memories/file.txt",
new_path="/memories/subdir/file.txt",
)
self.assertIn("success", result)
new_path = Path(self.test_dir) / "memories" / "subdir" / "file.txt"
self.assertTrue(new_path.exists())
def test_rename_source_not_found(self):
"""Test rename when source doesn't exist."""
result = self.handler.execute(
command="rename",
old_path="/memories/notfound.txt",
new_path="/memories/new.txt",
)
self.assertIn("error", result)
self.assertIn("not found", result["error"].lower())
def test_rename_destination_exists(self):
"""Test rename when destination already exists."""
self.handler.execute(
command="create", path="/memories/file1.txt", file_text="content1"
)
self.handler.execute(
command="create", path="/memories/file2.txt", file_text="content2"
)
result = self.handler.execute(
command="rename",
old_path="/memories/file1.txt",
new_path="/memories/file2.txt",
)
self.assertIn("error", result)
self.assertIn("already exists", result["error"].lower())
# Error Handling Tests
def test_unknown_command(self):
"""Test handling of unknown command."""
result = self.handler.execute(command="invalid", path="/memories")
self.assertIn("error", result)
self.assertIn("unknown command", result["error"].lower())
def test_missing_required_parameters(self):
"""Test error handling for missing parameters."""
result = self.handler.execute(command="view")
self.assertIn("error", result)
# Utility Tests
def test_clear_all_memory(self):
"""Test clearing all memory."""
# Create some files
self.handler.execute(
command="create", path="/memories/file1.txt", file_text="content1"
)
self.handler.execute(
command="create", path="/memories/file2.txt", file_text="content2"
)
result = self.handler.clear_all_memory()
self.assertIn("success", result)
# Verify directory exists but is empty
memory_root = Path(self.test_dir) / "memories"
self.assertTrue(memory_root.exists())
self.assertEqual(len(list(memory_root.iterdir())), 0)
if __name__ == "__main__":
unittest.main()