mirror of
https://github.com/anthropics/claude-cookbooks.git
synced 2025-10-06 01:00:28 +03:00
Add memory & context management cookbook
Interactive notebook demonstrating Claude's memory tool and context editing capabilities with code review examples. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -149,4 +149,8 @@ lychee-report.md
|
||||
# Notebook validation
|
||||
.notebook_validation_state.json
|
||||
.notebook_validation_checkpoint.json
|
||||
validation_report_*.md
|
||||
validation_report_*.md
|
||||
# Memory tool demo artifacts
|
||||
tool_use/demo_memory/
|
||||
tool_use/memory_storage/
|
||||
tool_use/.env
|
||||
|
||||
14
tool_use/.env.example
Normal file
14
tool_use/.env.example
Normal file
@@ -0,0 +1,14 @@
|
||||
# Anthropic API Configuration
|
||||
# Copy this file to .env and fill in your actual values
|
||||
|
||||
# Your Anthropic API key from https://console.anthropic.com/
|
||||
ANTHROPIC_API_KEY=your_api_key_here
|
||||
|
||||
# Model name - Use a model that supports memory_20250818 tool
|
||||
# Supported models (as of launch):
|
||||
# - claude-sonnet-4-20250514
|
||||
# - claude-opus-4-20250514
|
||||
# - claude-opus-4-1-20250805
|
||||
# - claude-sonnet-4-5-20250929
|
||||
|
||||
ANTHROPIC_MODEL=claude-sonnet-4-20250514
|
||||
File diff suppressed because it is too large
Load Diff
5
tool_use/memory_demo/.gitignore
vendored
Normal file
5
tool_use/memory_demo/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# Ignore demo-generated directories and files
|
||||
demo_memory/
|
||||
memory_storage/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
340
tool_use/memory_demo/code_review_demo.py
Normal file
340
tool_use/memory_demo/code_review_demo.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Code Review Assistant Demo - Three-session demonstration.
|
||||
|
||||
This demo showcases:
|
||||
1. Session 1: Claude learns debugging patterns
|
||||
2. Session 2: Claude applies learned patterns (faster!)
|
||||
3. Session 3: Long session with context editing
|
||||
|
||||
Requires:
|
||||
- .env file with ANTHROPIC_API_KEY and ANTHROPIC_MODEL
|
||||
- memory_tool.py in the same directory
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from anthropic import Anthropic
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path to import memory_tool
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from memory_tool import MemoryToolHandler
|
||||
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
||||
MODEL = os.getenv("ANTHROPIC_MODEL")
|
||||
|
||||
if not API_KEY:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_KEY not found. Copy .env.example to .env and add your API key."
|
||||
)
|
||||
|
||||
if not MODEL:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_MODEL not found. Copy .env.example to .env and set the model."
|
||||
)
|
||||
|
||||
|
||||
# Context management configuration
|
||||
CONTEXT_MANAGEMENT = {
|
||||
"edits": [
|
||||
{
|
||||
"type": "clear_tool_uses_20250919",
|
||||
"trigger": {"type": "input_tokens", "value": 30000},
|
||||
"keep": {"type": "tool_uses", "value": 3},
|
||||
"clear_at_least": {"type": "input_tokens", "value": 5000},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class CodeReviewAssistant:
|
||||
"""
|
||||
Code review assistant with memory and context editing capabilities.
|
||||
|
||||
This assistant:
|
||||
- Checks memory for debugging patterns before reviewing code
|
||||
- Stores learned patterns for future sessions
|
||||
- Automatically clears old tool results when context grows large
|
||||
"""
|
||||
|
||||
def __init__(self, memory_storage_path: str = "./memory_storage"):
|
||||
"""
|
||||
Initialize the code review assistant.
|
||||
|
||||
Args:
|
||||
memory_storage_path: Path for memory storage
|
||||
"""
|
||||
self.client = Anthropic(api_key=API_KEY)
|
||||
self.memory_handler = MemoryToolHandler(base_path=memory_storage_path)
|
||||
self.messages: List[Dict[str, Any]] = []
|
||||
|
||||
def _create_system_prompt(self) -> str:
|
||||
"""Create system prompt with memory instructions."""
|
||||
return """You are an expert code reviewer focused on finding bugs and suggesting improvements.
|
||||
|
||||
MEMORY PROTOCOL:
|
||||
1. ALWAYS check your /memories directory FIRST using the memory tool
|
||||
2. Look for relevant debugging patterns or insights from previous reviews
|
||||
3. When you find a bug or pattern, update your memory with what you learned
|
||||
4. Keep your memory organized - use descriptive file names and clear content
|
||||
|
||||
When reviewing code:
|
||||
- Identify bugs, security issues, and code quality problems
|
||||
- Explain the issue clearly
|
||||
- Provide a corrected version
|
||||
- Store important patterns in memory for future reference
|
||||
|
||||
Remember: Your memory persists across conversations. Use it wisely."""
|
||||
|
||||
def _execute_tool_use(self, tool_use: Any) -> str:
|
||||
"""Execute a tool use and return the result."""
|
||||
if tool_use.name == "memory":
|
||||
result = self.memory_handler.execute(**tool_use.input)
|
||||
return result.get("success") or result.get("error", "Unknown error")
|
||||
return f"Unknown tool: {tool_use.name}"
|
||||
|
||||
def review_code(
|
||||
self, code: str, filename: str, description: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Review code with memory-enhanced analysis.
|
||||
|
||||
Args:
|
||||
code: The code to review
|
||||
filename: Name of the file being reviewed
|
||||
description: Optional description of what to look for
|
||||
|
||||
Returns:
|
||||
Dict with review results and metadata
|
||||
"""
|
||||
# Construct user message
|
||||
user_message = f"Please review this code from {filename}"
|
||||
if description:
|
||||
user_message += f"\n\nContext: {description}"
|
||||
user_message += f"\n\n```python\n{code}\n```"
|
||||
|
||||
self.messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# Track token usage and context management
|
||||
total_input_tokens = 0
|
||||
context_edits_applied = []
|
||||
|
||||
# Conversation loop
|
||||
turn = 1
|
||||
while True:
|
||||
print(f" 🔄 Turn {turn}: Calling Claude API...", end="", flush=True)
|
||||
response = self.client.beta.messages.create(
|
||||
model=MODEL,
|
||||
max_tokens=4096,
|
||||
system=self._create_system_prompt(),
|
||||
messages=self.messages,
|
||||
tools=[{"type": "memory_20250818", "name": "memory"}],
|
||||
betas=["context-management-2025-06-27"],
|
||||
extra_body={"context_management": CONTEXT_MANAGEMENT},
|
||||
)
|
||||
|
||||
print(" ✓")
|
||||
|
||||
# Track usage
|
||||
total_input_tokens = response.usage.input_tokens
|
||||
|
||||
# Check for context management
|
||||
if hasattr(response, "context_management") and response.context_management:
|
||||
applied = response.context_management.get("applied_edits", [])
|
||||
if applied:
|
||||
context_edits_applied.extend(applied)
|
||||
|
||||
# Process response content
|
||||
assistant_content = []
|
||||
tool_results = []
|
||||
final_text = []
|
||||
|
||||
for content in response.content:
|
||||
if content.type == "text":
|
||||
assistant_content.append({"type": "text", "text": content.text})
|
||||
final_text.append(content.text)
|
||||
elif content.type == "tool_use":
|
||||
cmd = content.input.get('command', 'unknown')
|
||||
path = content.input.get('path', '')
|
||||
print(f" 🔧 Memory: {cmd} {path}")
|
||||
|
||||
# Execute tool
|
||||
result = self._execute_tool_use(content)
|
||||
|
||||
assistant_content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": content.id,
|
||||
"name": content.name,
|
||||
"input": content.input,
|
||||
}
|
||||
)
|
||||
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": content.id,
|
||||
"content": result,
|
||||
}
|
||||
)
|
||||
|
||||
# Add assistant message
|
||||
self.messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
# If there are tool results, add them and continue
|
||||
if tool_results:
|
||||
self.messages.append({"role": "user", "content": tool_results})
|
||||
turn += 1
|
||||
else:
|
||||
# No more tool uses, we're done
|
||||
print()
|
||||
break
|
||||
|
||||
return {
|
||||
"review": "\n".join(final_text),
|
||||
"input_tokens": total_input_tokens,
|
||||
"context_edits": context_edits_applied,
|
||||
}
|
||||
|
||||
def start_new_session(self) -> None:
|
||||
"""Start a new conversation session (memory persists)."""
|
||||
self.messages = []
|
||||
|
||||
|
||||
def run_session_1() -> None:
|
||||
"""Session 1: Learn debugging patterns."""
|
||||
print("=" * 80)
|
||||
print("SESSION 1: Learning from First Code Review")
|
||||
print("=" * 80)
|
||||
|
||||
assistant = CodeReviewAssistant()
|
||||
|
||||
# Read sample code
|
||||
with open("memory_demo/sample_code/web_scraper_v1.py", "r") as f:
|
||||
code = f.read()
|
||||
|
||||
print("\n📋 Reviewing web_scraper_v1.py...")
|
||||
print("\nMulti-threaded web scraper that sometimes loses results.\n")
|
||||
|
||||
result = assistant.review_code(
|
||||
code=code,
|
||||
filename="web_scraper_v1.py",
|
||||
description="This scraper sometimes returns fewer results than expected. "
|
||||
"The count is inconsistent across runs. Can you find the issue?",
|
||||
)
|
||||
|
||||
print("\n🤖 Claude's Review:\n")
|
||||
print(result["review"])
|
||||
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
|
||||
|
||||
if result["context_edits"]:
|
||||
print(f"\n🧹 Context edits applied: {result['context_edits']}")
|
||||
|
||||
print("\n✅ Session 1 complete - Claude learned debugging patterns!\n")
|
||||
|
||||
|
||||
def run_session_2() -> None:
|
||||
"""Session 2: Apply learned patterns."""
|
||||
print("=" * 80)
|
||||
print("SESSION 2: Applying Learned Patterns (New Conversation)")
|
||||
print("=" * 80)
|
||||
|
||||
# New assistant instance (new conversation, but memory persists)
|
||||
assistant = CodeReviewAssistant()
|
||||
|
||||
# Read different sample code with similar bug
|
||||
with open("memory_demo/sample_code/api_client_v1.py", "r") as f:
|
||||
code = f.read()
|
||||
|
||||
print("\n📋 Reviewing api_client_v1.py...")
|
||||
print("\nAsync API client with concurrent requests.\n")
|
||||
|
||||
result = assistant.review_code(
|
||||
code=code,
|
||||
filename="api_client_v1.py",
|
||||
description="Review this async API client. "
|
||||
"It fetches multiple endpoints concurrently. Are there any issues?",
|
||||
)
|
||||
|
||||
print("\n🤖 Claude's Review:\n")
|
||||
print(result["review"])
|
||||
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
|
||||
|
||||
print("\n✅ Session 2 complete - Claude applied learned patterns faster!\n")
|
||||
|
||||
|
||||
def run_session_3() -> None:
|
||||
"""Session 3: Long session with context editing."""
|
||||
print("=" * 80)
|
||||
print("SESSION 3: Long Session with Context Editing")
|
||||
print("=" * 80)
|
||||
|
||||
assistant = CodeReviewAssistant()
|
||||
|
||||
# Read data processor code (has multiple issues)
|
||||
with open("memory_demo/sample_code/data_processor_v1.py", "r") as f:
|
||||
code = f.read()
|
||||
|
||||
print("\n📋 Reviewing data_processor_v1.py...")
|
||||
print("\nLarge file with multiple concurrent processing classes.\n")
|
||||
|
||||
result = assistant.review_code(
|
||||
code=code,
|
||||
filename="data_processor_v1.py",
|
||||
description="This data processor handles files concurrently. "
|
||||
"There's also a SharedCache class. Review all components for issues.",
|
||||
)
|
||||
|
||||
print("\n🤖 Claude's Review:\n")
|
||||
print(result["review"])
|
||||
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
|
||||
|
||||
if result["context_edits"]:
|
||||
print("\n🧹 Context Management Applied:")
|
||||
for edit in result["context_edits"]:
|
||||
print(f" - Type: {edit.get('type')}")
|
||||
print(f" - Cleared tool uses: {edit.get('cleared_tool_uses', 0)}")
|
||||
print(f" - Tokens saved: {edit.get('cleared_input_tokens', 0):,}")
|
||||
|
||||
print("\n✅ Session 3 complete - Context editing kept conversation manageable!\n")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run all three demo sessions."""
|
||||
print("\n🚀 Code Review Assistant Demo\n")
|
||||
print("This demo shows:")
|
||||
print("1. Session 1: Claude learns debugging patterns")
|
||||
print("2. Session 2: Claude applies learned patterns (new conversation)")
|
||||
print("3. Session 3: Long session with context editing\n")
|
||||
|
||||
input("Press Enter to start Session 1...")
|
||||
run_session_1()
|
||||
|
||||
input("Press Enter to start Session 2...")
|
||||
run_session_2()
|
||||
|
||||
input("Press Enter to start Session 3...")
|
||||
run_session_3()
|
||||
|
||||
print("=" * 80)
|
||||
print("🎉 Demo Complete!")
|
||||
print("=" * 80)
|
||||
print("\nKey Takeaways:")
|
||||
print("- Memory tool enabled cross-conversation learning")
|
||||
print("- Claude got faster at recognizing similar bugs")
|
||||
print("- Context editing handled long sessions gracefully")
|
||||
print("\n💡 For production GitHub PR reviews, check out:")
|
||||
print(" https://github.com/anthropics/claude-code-action\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
194
tool_use/memory_demo/demo_helpers.py
Normal file
194
tool_use/memory_demo/demo_helpers.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Helper functions for memory cookbook demos.
|
||||
|
||||
This module provides reusable functions for running conversation loops
|
||||
with Claude, handling tool execution, and managing context.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic
|
||||
from memory_tool import MemoryToolHandler
|
||||
|
||||
|
||||
def execute_tool(tool_use: Any, memory_handler: MemoryToolHandler) -> str:
|
||||
"""
|
||||
Execute a tool use and return the result.
|
||||
|
||||
Args:
|
||||
tool_use: The tool use object from Claude's response
|
||||
memory_handler: The memory tool handler instance
|
||||
|
||||
Returns:
|
||||
str: The result of the tool execution
|
||||
"""
|
||||
if tool_use.name == "memory":
|
||||
result = memory_handler.execute(**tool_use.input)
|
||||
return result.get("success") or result.get("error", "Unknown error")
|
||||
return f"Unknown tool: {tool_use.name}"
|
||||
|
||||
|
||||
def run_conversation_turn(
|
||||
client: Anthropic,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
memory_handler: MemoryToolHandler,
|
||||
system: str,
|
||||
context_management: dict[str, Any] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
verbose: bool = False
|
||||
) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Run a single conversation turn, handling tool uses.
|
||||
|
||||
Args:
|
||||
client: Anthropic client instance
|
||||
model: Model to use
|
||||
messages: Current conversation messages
|
||||
memory_handler: Memory tool handler instance
|
||||
system: System prompt
|
||||
context_management: Optional context management config
|
||||
max_tokens: Max tokens for response
|
||||
verbose: Whether to print tool operations
|
||||
|
||||
Returns:
|
||||
Tuple of (response, assistant_content, tool_results)
|
||||
"""
|
||||
memory_tool: dict[str, Any] = {"type": "memory_20250818", "name": "memory"}
|
||||
|
||||
request_params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"system": system,
|
||||
"messages": messages,
|
||||
"tools": [memory_tool],
|
||||
"betas": ["context-management-2025-06-27"]
|
||||
}
|
||||
|
||||
if context_management:
|
||||
request_params["extra_body"] = {"context_management": context_management}
|
||||
|
||||
response = client.beta.messages.create(**request_params)
|
||||
|
||||
assistant_content = []
|
||||
tool_results = []
|
||||
|
||||
for content in response.content:
|
||||
if content.type == "text":
|
||||
if verbose:
|
||||
print(f"💬 Claude: {content.text}\n")
|
||||
assistant_content.append({"type": "text", "text": content.text})
|
||||
elif content.type == "tool_use":
|
||||
if verbose:
|
||||
cmd = content.input.get('command')
|
||||
path = content.input.get('path', '')
|
||||
print(f" 🔧 Memory tool: {cmd} {path}")
|
||||
|
||||
result = execute_tool(content, memory_handler)
|
||||
|
||||
if verbose:
|
||||
result_preview = result[:80] + "..." if len(result) > 80 else result
|
||||
print(f" ✓ Result: {result_preview}")
|
||||
|
||||
assistant_content.append({
|
||||
"type": "tool_use",
|
||||
"id": content.id,
|
||||
"name": content.name,
|
||||
"input": content.input
|
||||
})
|
||||
tool_results.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": content.id,
|
||||
"content": result
|
||||
})
|
||||
|
||||
return response, assistant_content, tool_results
|
||||
|
||||
|
||||
def run_conversation_loop(
|
||||
client: Anthropic,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
memory_handler: MemoryToolHandler,
|
||||
system: str,
|
||||
context_management: dict[str, Any] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
max_turns: int = 5,
|
||||
verbose: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
Run a complete conversation loop until Claude stops using tools.
|
||||
|
||||
Args:
|
||||
client: Anthropic client instance
|
||||
model: Model to use
|
||||
messages: Current conversation messages (will be modified in-place)
|
||||
memory_handler: Memory tool handler instance
|
||||
system: System prompt
|
||||
context_management: Optional context management config
|
||||
max_tokens: Max tokens for response
|
||||
max_turns: Maximum number of turns to prevent infinite loops
|
||||
verbose: Whether to print progress
|
||||
|
||||
Returns:
|
||||
The final API response
|
||||
"""
|
||||
turn = 1
|
||||
response = None
|
||||
|
||||
while turn <= max_turns:
|
||||
if verbose:
|
||||
print(f"\n🔄 Turn {turn}:")
|
||||
|
||||
response, assistant_content, tool_results = run_conversation_turn(
|
||||
client=client,
|
||||
model=model,
|
||||
messages=messages,
|
||||
memory_handler=memory_handler,
|
||||
system=system,
|
||||
context_management=context_management,
|
||||
max_tokens=max_tokens,
|
||||
verbose=verbose
|
||||
)
|
||||
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
if tool_results:
|
||||
messages.append({"role": "user", "content": tool_results})
|
||||
turn += 1
|
||||
else:
|
||||
# No more tool uses, conversation complete
|
||||
break
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def print_context_management_info(response: Any) -> tuple[bool, int]:
|
||||
"""
|
||||
Print context management information from response.
|
||||
|
||||
Args:
|
||||
response: API response to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (context_cleared, saved_tokens)
|
||||
"""
|
||||
context_cleared = False
|
||||
saved_tokens = 0
|
||||
|
||||
if hasattr(response, "context_management") and response.context_management:
|
||||
edits = response.context_management.get("applied_edits", [])
|
||||
if edits:
|
||||
context_cleared = True
|
||||
cleared_uses = edits[0].get('cleared_tool_uses', 0)
|
||||
saved_tokens = edits[0].get('cleared_input_tokens', 0)
|
||||
print(f" ✂️ Context editing triggered!")
|
||||
print(f" • Cleared {cleared_uses} tool uses")
|
||||
print(f" • Saved {saved_tokens:,} tokens")
|
||||
print(f" • After clearing: {response.usage.input_tokens:,} tokens")
|
||||
else:
|
||||
print(f" ℹ️ Context below threshold - no clearing triggered")
|
||||
else:
|
||||
print(f" ℹ️ No context management applied")
|
||||
|
||||
return context_cleared, saved_tokens
|
||||
99
tool_use/memory_demo/sample_code/api_client_v1.py
Normal file
99
tool_use/memory_demo/sample_code/api_client_v1.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Async API client with similar concurrency issues.
|
||||
This demonstrates Claude applying thread-safety patterns to async code.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class AsyncAPIClient:
|
||||
"""Async API client for fetching data from multiple endpoints."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
self.responses = [] # BUG: Shared state accessed from multiple coroutines!
|
||||
self.error_count = 0 # BUG: Race condition on counter increment!
|
||||
|
||||
async def fetch_endpoint(
|
||||
self, session: aiohttp.ClientSession, endpoint: str
|
||||
) -> Dict[str, any]:
|
||||
"""Fetch a single endpoint."""
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
try:
|
||||
async with session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=5)
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return {
|
||||
"endpoint": endpoint,
|
||||
"status": response.status,
|
||||
"data": data,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"endpoint": endpoint,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def fetch_all(self, endpoints: List[str]) -> List[Dict[str, any]]:
|
||||
"""
|
||||
Fetch multiple endpoints concurrently.
|
||||
|
||||
BUG: Similar to the threading issue, multiple coroutines
|
||||
modify self.responses and self.error_count without coordination!
|
||||
While Python's GIL prevents some race conditions in threads,
|
||||
async code can still have interleaving issues.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [self.fetch_endpoint(session, endpoint) for endpoint in endpoints]
|
||||
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
result = await coro
|
||||
|
||||
# RACE CONDITION: Multiple coroutines modify shared state
|
||||
if "error" in result:
|
||||
self.error_count += 1 # Not atomic!
|
||||
else:
|
||||
self.responses.append(result) # Not thread-safe in async context!
|
||||
|
||||
return self.responses
|
||||
|
||||
def get_summary(self) -> Dict[str, any]:
|
||||
"""Get summary statistics."""
|
||||
return {
|
||||
"total_responses": len(self.responses),
|
||||
"errors": self.error_count,
|
||||
"success_rate": (
|
||||
len(self.responses) / (len(self.responses) + self.error_count)
|
||||
if (len(self.responses) + self.error_count) > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""Test the async API client."""
|
||||
client = AsyncAPIClient("https://jsonplaceholder.typicode.com")
|
||||
|
||||
endpoints = [
|
||||
"posts/1",
|
||||
"posts/2",
|
||||
"posts/3",
|
||||
"users/1",
|
||||
"users/2",
|
||||
"invalid/endpoint", # Will error
|
||||
] * 20 # 120 requests total
|
||||
|
||||
results = await client.fetch_all(endpoints)
|
||||
|
||||
print(f"Expected: ~100 successful responses")
|
||||
print(f"Got: {len(results)} responses")
|
||||
print(f"Summary: {client.get_summary()}")
|
||||
print("\nNote: Counts may be incorrect due to race conditions!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
115
tool_use/memory_demo/sample_code/cache_manager.py
Normal file
115
tool_use/memory_demo/sample_code/cache_manager.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Cache manager with mutable default argument bug.
|
||||
This is one of Python's most common gotchas.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Manage cached data with TTL support."""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def add_items(
|
||||
self, key: str, items: List[str] = [] # BUG: Mutable default argument!
|
||||
) -> None:
|
||||
"""
|
||||
Add items to cache.
|
||||
|
||||
BUG: Using [] as default creates a SHARED list across all calls!
|
||||
This is one of Python's classic gotchas.
|
||||
"""
|
||||
# The items list is shared across ALL calls that don't provide items
|
||||
items.append(f"Added at {datetime.now()}")
|
||||
self.cache[key] = items
|
||||
|
||||
def add_items_fixed(self, key: str, items: Optional[List[str]] = None) -> None:
|
||||
"""Add items with proper default handling."""
|
||||
if items is None:
|
||||
items = []
|
||||
items = items.copy() # Also make a copy to avoid mutation
|
||||
items.append(f"Added at {datetime.now()}")
|
||||
self.cache[key] = items
|
||||
|
||||
def merge_configs(
|
||||
self, name: str, overrides: Dict[str, any] = {} # BUG: Mutable default!
|
||||
) -> Dict[str, any]:
|
||||
"""
|
||||
Merge configuration with overrides.
|
||||
|
||||
BUG: The default dict is shared across all calls!
|
||||
"""
|
||||
defaults = {"timeout": 30, "retries": 3, "cache_enabled": True}
|
||||
|
||||
# This modifies the SHARED overrides dict
|
||||
overrides.update(defaults)
|
||||
return overrides
|
||||
|
||||
def merge_configs_fixed(
|
||||
self, name: str, overrides: Optional[Dict[str, any]] = None
|
||||
) -> Dict[str, any]:
|
||||
"""Merge configs properly."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
||||
defaults = {"timeout": 30, "retries": 3, "cache_enabled": True}
|
||||
|
||||
# Create new dict to avoid mutation
|
||||
result = {**defaults, **overrides}
|
||||
return result
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""Another example of the mutable default bug."""
|
||||
|
||||
def process_batch(
|
||||
self, data: List[int], filters: List[str] = [] # BUG: Mutable default!
|
||||
) -> List[int]:
|
||||
"""
|
||||
Process data with optional filters.
|
||||
|
||||
BUG: filters list is shared across calls!
|
||||
"""
|
||||
filters.append("default_filter") # Modifies shared list!
|
||||
|
||||
result = []
|
||||
for item in data:
|
||||
if "positive" in filters and item < 0:
|
||||
continue
|
||||
result.append(item * 2)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cache = CacheManager()
|
||||
|
||||
# Demonstrate the bug
|
||||
print("=== Demonstrating Mutable Default Argument Bug ===\n")
|
||||
|
||||
# First call with no items
|
||||
cache.add_items("key1")
|
||||
print(f"key1: {cache.cache['key1']}")
|
||||
|
||||
# Second call with no items - SURPRISE! Gets the same list
|
||||
cache.add_items("key2")
|
||||
print(f"key2: {cache.cache['key2']}") # Will have TWO timestamps!
|
||||
|
||||
# Third call - even worse!
|
||||
cache.add_items("key3")
|
||||
print(f"key3: {cache.cache['key3']}") # Will have THREE timestamps!
|
||||
|
||||
print("\nAll keys share the same list object!")
|
||||
print(f"key1 is key2: {cache.cache['key1'] is cache.cache['key2']}")
|
||||
|
||||
print("\n=== Using Fixed Version ===\n")
|
||||
cache2 = CacheManager()
|
||||
cache2.add_items_fixed("key1")
|
||||
cache2.add_items_fixed("key2")
|
||||
cache2.add_items_fixed("key3")
|
||||
print(f"key1: {cache2.cache['key1']}")
|
||||
print(f"key2: {cache2.cache['key2']}")
|
||||
print(f"key3: {cache2.cache['key3']}")
|
||||
print(f"\nkey1 is key2: {cache2.cache['key1'] is cache2.cache['key2']}")
|
||||
145
tool_use/memory_demo/sample_code/data_processor_v1.py
Normal file
145
tool_use/memory_demo/sample_code/data_processor_v1.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Data processor with multiple concurrency and thread-safety issues.
|
||||
Used for Session 3 to demonstrate context editing with multiple bugs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""Process data files concurrently with various thread-safety issues."""
|
||||
|
||||
def __init__(self, max_workers: int = 5):
|
||||
self.max_workers = max_workers
|
||||
self.processed_count = 0 # BUG: Race condition on counter
|
||||
self.results = [] # BUG: Shared list without locking
|
||||
self.errors = {} # BUG: Shared dict without locking
|
||||
self.lock = threading.Lock() # Available but not used!
|
||||
|
||||
def process_file(self, file_path: str) -> Dict[str, Any]:
|
||||
"""Process a single file."""
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Simulate some processing
|
||||
processed = {
|
||||
"file": file_path,
|
||||
"record_count": len(data) if isinstance(data, list) else 1,
|
||||
"size_bytes": Path(file_path).stat().st_size,
|
||||
}
|
||||
|
||||
return processed
|
||||
|
||||
except Exception as e:
|
||||
return {"file": file_path, "error": str(e)}
|
||||
|
||||
def process_batch(self, file_paths: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process multiple files concurrently.
|
||||
|
||||
MULTIPLE BUGS:
|
||||
1. self.processed_count is incremented without locking
|
||||
2. self.results is appended to from multiple threads
|
||||
3. self.errors is modified from multiple threads
|
||||
4. We have a lock but don't use it!
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
futures = [executor.submit(self.process_file, fp) for fp in file_paths]
|
||||
|
||||
for future in futures:
|
||||
result = future.result()
|
||||
|
||||
# RACE CONDITION: Increment counter without lock
|
||||
self.processed_count += 1 # BUG!
|
||||
|
||||
if "error" in result:
|
||||
# RACE CONDITION: Modify dict without lock
|
||||
self.errors[result["file"]] = result["error"] # BUG!
|
||||
else:
|
||||
# RACE CONDITION: Append to list without lock
|
||||
self.results.append(result) # BUG!
|
||||
|
||||
return self.results
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get processing statistics.
|
||||
|
||||
BUG: Accessing shared state without ensuring thread-safety.
|
||||
If called while processing, could get inconsistent values.
|
||||
"""
|
||||
total_files = self.processed_count
|
||||
successful = len(self.results)
|
||||
failed = len(self.errors)
|
||||
|
||||
# BUG: These counts might not add up correctly due to race conditions
|
||||
return {
|
||||
"total_processed": total_files,
|
||||
"successful": successful,
|
||||
"failed": failed,
|
||||
"success_rate": successful / total_files if total_files > 0 else 0,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset processor state.
|
||||
|
||||
BUG: No locking - if called during processing, causes corruption.
|
||||
"""
|
||||
self.processed_count = 0 # RACE CONDITION
|
||||
self.results = [] # RACE CONDITION
|
||||
self.errors = {} # RACE CONDITION
|
||||
|
||||
|
||||
class SharedCache:
|
||||
"""
|
||||
A shared cache with thread-safety issues.
|
||||
|
||||
BUG: Classic read-modify-write race condition pattern.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {} # BUG: Shared dict without locking
|
||||
self.hit_count = 0 # BUG: Race condition
|
||||
self.miss_count = 0 # BUG: Race condition
|
||||
|
||||
def get(self, key: str) -> Any:
|
||||
"""Get from cache - RACE CONDITION on hit/miss counts."""
|
||||
if key in self.cache:
|
||||
self.hit_count += 1 # BUG: Not atomic!
|
||||
return self.cache[key]
|
||||
else:
|
||||
self.miss_count += 1 # BUG: Not atomic!
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any):
|
||||
"""Set in cache - RACE CONDITION on dict modification."""
|
||||
self.cache[key] = value # BUG: Dict access not synchronized!
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics - may be inconsistent."""
|
||||
total = self.hit_count + self.miss_count
|
||||
return {
|
||||
"hits": self.hit_count,
|
||||
"misses": self.miss_count,
|
||||
"hit_rate": self.hit_count / total if total > 0 else 0,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create some test files (not included)
|
||||
processor = DataProcessor(max_workers=10)
|
||||
|
||||
# Simulate processing many files
|
||||
file_paths = [f"data/file_{i}.json" for i in range(100)]
|
||||
|
||||
print("Processing files concurrently...")
|
||||
results = processor.process_batch(file_paths)
|
||||
|
||||
print(f"\nStatistics: {processor.get_statistics()}")
|
||||
print("\nNote: Counts may be inconsistent due to race conditions!")
|
||||
105
tool_use/memory_demo/sample_code/sql_query_builder.py
Normal file
105
tool_use/memory_demo/sample_code/sql_query_builder.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
SQL query builder with SQL injection vulnerability.
|
||||
Demonstrates dangerous string formatting in SQL queries.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class UserDatabase:
|
||||
"""Simple database interface (mock)."""
|
||||
|
||||
def execute(self, query: str) -> List[dict]:
|
||||
"""Mock execute - just returns the query for inspection."""
|
||||
print(f"Executing: {query}")
|
||||
return []
|
||||
|
||||
|
||||
class QueryBuilder:
|
||||
"""Build SQL queries for user operations."""
|
||||
|
||||
def __init__(self, db: UserDatabase):
|
||||
self.db = db
|
||||
|
||||
def get_user_by_name(self, username: str) -> Optional[dict]:
|
||||
"""
|
||||
Get user by username.
|
||||
|
||||
BUG: SQL INJECTION VULNERABILITY!
|
||||
Using string formatting with user input allows SQL injection.
|
||||
"""
|
||||
# DANGEROUS: Never use f-strings or % formatting with user input!
|
||||
query = f"SELECT * FROM users WHERE username = '{username}'"
|
||||
results = self.db.execute(query)
|
||||
return results[0] if results else None
|
||||
|
||||
def get_user_by_name_safe(self, username: str) -> Optional[dict]:
|
||||
"""Safe version using parameterized queries."""
|
||||
# Use parameterized queries (this is pseudo-code for the concept)
|
||||
query = "SELECT * FROM users WHERE username = ?"
|
||||
# In real code: self.db.execute(query, (username,))
|
||||
print(f"Safe query with parameter: {query}, params: ({username},)")
|
||||
return None
|
||||
|
||||
def search_users(self, search_term: str, limit: int = 10) -> List[dict]:
|
||||
"""
|
||||
Search users by term.
|
||||
|
||||
BUG: SQL INJECTION through LIKE clause!
|
||||
"""
|
||||
# DANGEROUS: User input directly in LIKE clause
|
||||
query = f"SELECT * FROM users WHERE name LIKE '%{search_term}%' LIMIT {limit}"
|
||||
return self.db.execute(query)
|
||||
|
||||
def delete_user(self, user_id: str) -> bool:
|
||||
"""
|
||||
Delete a user.
|
||||
|
||||
BUG: SQL INJECTION in DELETE statement!
|
||||
This is especially dangerous as it can lead to data loss.
|
||||
"""
|
||||
# DANGEROUS: Unvalidated user input in DELETE
|
||||
query = f"DELETE FROM users WHERE id = {user_id}"
|
||||
self.db.execute(query)
|
||||
return True
|
||||
|
||||
def get_users_by_role(self, role: str, order_by: str = "name") -> List[dict]:
|
||||
"""
|
||||
Get users by role.
|
||||
|
||||
BUG: SQL INJECTION in ORDER BY clause!
|
||||
Even the ORDER BY clause can be exploited.
|
||||
"""
|
||||
# DANGEROUS: User-controlled ORDER BY
|
||||
query = f"SELECT * FROM users WHERE role = '{role}' ORDER BY {order_by}"
|
||||
return self.db.execute(query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
db = UserDatabase()
|
||||
qb = QueryBuilder(db)
|
||||
|
||||
print("=== Demonstrating SQL Injection Vulnerabilities ===\n")
|
||||
|
||||
# Example 1: Basic injection
|
||||
print("1. Basic username injection:")
|
||||
qb.get_user_by_name("admin' OR '1'='1")
|
||||
# Executes: SELECT * FROM users WHERE username = 'admin' OR '1'='1'
|
||||
# Returns ALL users!
|
||||
|
||||
print("\n2. Search term injection:")
|
||||
qb.search_users("test%' OR 1=1--")
|
||||
# Can bypass the LIKE and return everything
|
||||
|
||||
print("\n3. DELETE injection:")
|
||||
qb.delete_user("1 OR 1=1")
|
||||
# Executes: DELETE FROM users WHERE id = 1 OR 1=1
|
||||
# DELETES ALL USERS!
|
||||
|
||||
print("\n4. ORDER BY injection:")
|
||||
qb.get_users_by_role("admin", "name; DROP TABLE users--")
|
||||
# Can execute arbitrary SQL commands!
|
||||
|
||||
print("\n=== Safe Version ===")
|
||||
qb.get_user_by_name_safe("admin' OR '1'='1")
|
||||
# Parameters are properly escaped
|
||||
84
tool_use/memory_demo/sample_code/web_scraper_v1.py
Normal file
84
tool_use/memory_demo/sample_code/web_scraper_v1.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Concurrent web scraper with a race condition bug.
|
||||
Multiple threads modify shared state without synchronization.
|
||||
"""
|
||||
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class WebScraper:
|
||||
"""Web scraper that fetches multiple URLs concurrently."""
|
||||
|
||||
def __init__(self, max_workers: int = 10):
|
||||
self.max_workers = max_workers
|
||||
self.results = [] # BUG: Shared mutable state accessed by multiple threads!
|
||||
self.failed_urls = [] # BUG: Another race condition!
|
||||
|
||||
def fetch_url(self, url: str) -> Dict[str, any]:
|
||||
"""Fetch a single URL and return the result."""
|
||||
try:
|
||||
response = requests.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
return {
|
||||
"url": url,
|
||||
"status": response.status_code,
|
||||
"content_length": len(response.content),
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {"url": url, "error": str(e)}
|
||||
|
||||
def scrape_urls(self, urls: List[str]) -> List[Dict[str, any]]:
|
||||
"""
|
||||
Scrape multiple URLs concurrently.
|
||||
|
||||
BUG: self.results is accessed from multiple threads without locking!
|
||||
This causes race conditions where results can be lost or corrupted.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
futures = [executor.submit(self.fetch_url, url) for url in urls]
|
||||
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
|
||||
# RACE CONDITION: Multiple threads append to self.results simultaneously
|
||||
if "error" in result:
|
||||
self.failed_urls.append(result["url"]) # RACE CONDITION
|
||||
else:
|
||||
self.results.append(result) # RACE CONDITION
|
||||
|
||||
return self.results
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get scraping statistics."""
|
||||
return {
|
||||
"total_results": len(self.results),
|
||||
"failed_urls": len(self.failed_urls),
|
||||
"success_rate": (
|
||||
len(self.results) / (len(self.results) + len(self.failed_urls))
|
||||
if (len(self.results) + len(self.failed_urls)) > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test with multiple URLs
|
||||
urls = [
|
||||
"https://httpbin.org/delay/1",
|
||||
"https://httpbin.org/status/200",
|
||||
"https://httpbin.org/status/404",
|
||||
"https://httpbin.org/delay/2",
|
||||
"https://httpbin.org/status/500",
|
||||
] * 10 # 50 URLs total to increase race condition probability
|
||||
|
||||
scraper = WebScraper(max_workers=10)
|
||||
results = scraper.scrape_urls(urls)
|
||||
|
||||
print(f"Expected: 50 results")
|
||||
print(f"Got: {len(results)} results")
|
||||
print(f"Stats: {scraper.get_stats()}")
|
||||
print("\nNote: Results count may be less than expected due to race condition!")
|
||||
351
tool_use/memory_tool.py
Normal file
351
tool_use/memory_tool.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Production-ready memory tool handler for Claude's memory_20250818 tool.
|
||||
|
||||
This implementation provides secure, client-side execution of memory operations
|
||||
with path validation, error handling, and comprehensive security measures.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class MemoryToolHandler:
|
||||
"""
|
||||
Handles execution of Claude's memory tool commands.
|
||||
|
||||
The memory tool enables Claude to read, write, and manage files in a memory
|
||||
system through a standardized tool interface. This handler provides client-side
|
||||
implementation with security controls.
|
||||
|
||||
Attributes:
|
||||
base_path: Root directory for memory storage
|
||||
memory_root: The /memories directory within base_path
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: str = "./memory_storage"):
|
||||
"""
|
||||
Initialize the memory tool handler.
|
||||
|
||||
Args:
|
||||
base_path: Root directory for all memory operations
|
||||
"""
|
||||
self.base_path = Path(base_path).resolve()
|
||||
self.memory_root = self.base_path / "memories"
|
||||
self.memory_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _validate_path(self, path: str) -> Path:
|
||||
"""
|
||||
Validate and resolve memory paths to prevent directory traversal attacks.
|
||||
|
||||
Args:
|
||||
path: The path to validate (must start with /memories)
|
||||
|
||||
Returns:
|
||||
Resolved absolute Path object within memory_root
|
||||
|
||||
Raises:
|
||||
ValueError: If path is invalid or attempts to escape memory directory
|
||||
"""
|
||||
if not path.startswith("/memories"):
|
||||
raise ValueError(
|
||||
f"Path must start with /memories, got: {path}. "
|
||||
"All memory operations must be confined to the /memories directory."
|
||||
)
|
||||
|
||||
# Remove /memories prefix and any leading slashes
|
||||
relative_path = path[len("/memories") :].lstrip("/")
|
||||
|
||||
# Resolve to absolute path within memory_root
|
||||
if relative_path:
|
||||
full_path = (self.memory_root / relative_path).resolve()
|
||||
else:
|
||||
full_path = self.memory_root.resolve()
|
||||
|
||||
# Verify the resolved path is still within memory_root
|
||||
try:
|
||||
full_path.relative_to(self.memory_root.resolve())
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Path '{path}' would escape /memories directory. "
|
||||
"Directory traversal attempts are not allowed."
|
||||
) from e
|
||||
|
||||
return full_path
|
||||
|
||||
def execute(self, **params: Any) -> dict[str, str]:
|
||||
"""
|
||||
Execute a memory tool command.
|
||||
|
||||
Args:
|
||||
**params: Command parameters from Claude's tool use
|
||||
|
||||
Returns:
|
||||
Dict with either 'success' or 'error' key
|
||||
|
||||
Supported commands:
|
||||
- view: Show directory contents or file contents
|
||||
- create: Create or overwrite a file
|
||||
- str_replace: Replace text in a file
|
||||
- insert: Insert text at a specific line
|
||||
- delete: Delete a file or directory
|
||||
- rename: Rename or move a file/directory
|
||||
"""
|
||||
command = params.get("command")
|
||||
|
||||
try:
|
||||
if command == "view":
|
||||
return self._view(params)
|
||||
elif command == "create":
|
||||
return self._create(params)
|
||||
elif command == "str_replace":
|
||||
return self._str_replace(params)
|
||||
elif command == "insert":
|
||||
return self._insert(params)
|
||||
elif command == "delete":
|
||||
return self._delete(params)
|
||||
elif command == "rename":
|
||||
return self._rename(params)
|
||||
else:
|
||||
return {
|
||||
"error": f"Unknown command: '{command}'. "
|
||||
"Valid commands are: view, create, str_replace, insert, delete, rename"
|
||||
}
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
except Exception as e:
|
||||
return {"error": f"Unexpected error executing {command}: {e}"}
|
||||
|
||||
def _view(self, params: dict[str, Any]) -> dict[str, str]:
|
||||
"""View directory contents or file contents."""
|
||||
path = params.get("path")
|
||||
view_range = params.get("view_range")
|
||||
|
||||
if not path:
|
||||
return {"error": "Missing required parameter: path"}
|
||||
|
||||
full_path = self._validate_path(path)
|
||||
|
||||
# Handle directory listing
|
||||
if full_path.is_dir():
|
||||
try:
|
||||
items = []
|
||||
for item in sorted(full_path.iterdir()):
|
||||
if item.name.startswith("."):
|
||||
continue
|
||||
items.append(f"{item.name}/" if item.is_dir() else item.name)
|
||||
|
||||
if not items:
|
||||
return {"success": f"Directory: {path}\n(empty)"}
|
||||
|
||||
return {
|
||||
"success": f"Directory: {path}\n"
|
||||
+ "\n".join([f"- {item}" for item in items])
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot read directory {path}: {e}"}
|
||||
|
||||
# Handle file reading
|
||||
elif full_path.is_file():
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
lines = content.splitlines()
|
||||
|
||||
# Apply view range if specified
|
||||
if view_range:
|
||||
start_line = max(1, view_range[0]) - 1 # Convert to 0-indexed
|
||||
end_line = len(lines) if view_range[1] == -1 else view_range[1]
|
||||
lines = lines[start_line:end_line]
|
||||
start_num = start_line + 1
|
||||
else:
|
||||
start_num = 1
|
||||
|
||||
# Format with line numbers
|
||||
numbered_lines = [
|
||||
f"{i + start_num:4d}: {line}" for i, line in enumerate(lines)
|
||||
]
|
||||
return {"success": "\n".join(numbered_lines)}
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return {"error": f"Cannot read {path}: File is not valid UTF-8 text"}
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot read file {path}: {e}"}
|
||||
|
||||
else:
|
||||
return {"error": f"Path not found: {path}"}
|
||||
|
||||
def _create(self, params: dict[str, Any]) -> dict[str, str]:
|
||||
"""Create or overwrite a file."""
|
||||
path = params.get("path")
|
||||
file_text = params.get("file_text", "")
|
||||
|
||||
if not path:
|
||||
return {"error": "Missing required parameter: path"}
|
||||
|
||||
full_path = self._validate_path(path)
|
||||
|
||||
# Don't allow creating directories directly
|
||||
if not path.endswith((".txt", ".md", ".json", ".py", ".yaml", ".yml")):
|
||||
return {
|
||||
"error": f"Cannot create {path}: Only text files are supported. "
|
||||
"Use file extensions: .txt, .md, .json, .py, .yaml, .yml"
|
||||
}
|
||||
|
||||
try:
|
||||
# Create parent directories if needed
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the file
|
||||
full_path.write_text(file_text, encoding="utf-8")
|
||||
return {"success": f"File created successfully at {path}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot create file {path}: {e}"}
|
||||
|
||||
def _str_replace(self, params: dict[str, Any]) -> dict[str, str]:
|
||||
"""Replace text in a file."""
|
||||
path = params.get("path")
|
||||
old_str = params.get("old_str")
|
||||
new_str = params.get("new_str", "")
|
||||
|
||||
if not path or old_str is None:
|
||||
return {"error": "Missing required parameters: path, old_str"}
|
||||
|
||||
full_path = self._validate_path(path)
|
||||
|
||||
if not full_path.is_file():
|
||||
return {"error": f"File not found: {path}"}
|
||||
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
|
||||
# Check if old_str exists
|
||||
count = content.count(old_str)
|
||||
if count == 0:
|
||||
return {
|
||||
"error": f"String not found in {path}. "
|
||||
"The exact text must exist in the file."
|
||||
}
|
||||
elif count > 1:
|
||||
return {
|
||||
"error": f"String appears {count} times in {path}. "
|
||||
"The string must be unique. Use more specific context."
|
||||
}
|
||||
|
||||
# Perform replacement
|
||||
new_content = content.replace(old_str, new_str, 1)
|
||||
full_path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
return {"success": f"File {path} has been edited successfully"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot edit file {path}: {e}"}
|
||||
|
||||
def _insert(self, params: dict[str, Any]) -> dict[str, str]:
|
||||
"""Insert text at a specific line."""
|
||||
path = params.get("path")
|
||||
insert_line = params.get("insert_line")
|
||||
insert_text = params.get("insert_text", "")
|
||||
|
||||
if not path or insert_line is None:
|
||||
return {"error": "Missing required parameters: path, insert_line"}
|
||||
|
||||
full_path = self._validate_path(path)
|
||||
|
||||
if not full_path.is_file():
|
||||
return {"error": f"File not found: {path}"}
|
||||
|
||||
try:
|
||||
lines = full_path.read_text(encoding="utf-8").splitlines()
|
||||
|
||||
# Validate insert_line
|
||||
if insert_line < 0 or insert_line > len(lines):
|
||||
return {
|
||||
"error": f"Invalid insert_line {insert_line}. "
|
||||
f"Must be between 0 and {len(lines)}"
|
||||
}
|
||||
|
||||
# Insert the text
|
||||
lines.insert(insert_line, insert_text.rstrip("\n"))
|
||||
|
||||
# Write back
|
||||
full_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
return {"success": f"Text inserted at line {insert_line} in {path}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot insert into {path}: {e}"}
|
||||
|
||||
def _delete(self, params: dict[str, Any]) -> dict[str, str]:
|
||||
"""Delete a file or directory."""
|
||||
path = params.get("path")
|
||||
|
||||
if not path:
|
||||
return {"error": "Missing required parameter: path"}
|
||||
|
||||
# Prevent deletion of root memories directory
|
||||
if path == "/memories":
|
||||
return {"error": "Cannot delete the /memories directory itself"}
|
||||
|
||||
full_path = self._validate_path(path)
|
||||
|
||||
if not full_path.exists():
|
||||
return {"error": f"Path not found: {path}"}
|
||||
|
||||
try:
|
||||
if full_path.is_file():
|
||||
full_path.unlink()
|
||||
return {"success": f"File deleted: {path}"}
|
||||
elif full_path.is_dir():
|
||||
shutil.rmtree(full_path)
|
||||
return {"success": f"Directory deleted: {path}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot delete {path}: {e}"}
|
||||
|
||||
def _rename(self, params: dict[str, Any]) -> dict[str, str]:
|
||||
"""Rename or move a file/directory."""
|
||||
old_path = params.get("old_path")
|
||||
new_path = params.get("new_path")
|
||||
|
||||
if not old_path or not new_path:
|
||||
return {"error": "Missing required parameters: old_path, new_path"}
|
||||
|
||||
old_full_path = self._validate_path(old_path)
|
||||
new_full_path = self._validate_path(new_path)
|
||||
|
||||
if not old_full_path.exists():
|
||||
return {"error": f"Source path not found: {old_path}"}
|
||||
|
||||
if new_full_path.exists():
|
||||
return {
|
||||
"error": f"Destination already exists: {new_path}. "
|
||||
"Cannot overwrite existing files/directories."
|
||||
}
|
||||
|
||||
try:
|
||||
# Create parent directories if needed
|
||||
new_full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Perform rename/move
|
||||
old_full_path.rename(new_full_path)
|
||||
|
||||
return {"success": f"Renamed {old_path} to {new_path}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot rename {old_path} to {new_path}: {e}"}
|
||||
|
||||
def clear_all_memory(self) -> dict[str, str]:
|
||||
"""
|
||||
Clear all memory files (useful for testing or starting fresh).
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
"""
|
||||
try:
|
||||
if self.memory_root.exists():
|
||||
shutil.rmtree(self.memory_root)
|
||||
self.memory_root.mkdir(parents=True, exist_ok=True)
|
||||
return {"success": "All memory cleared successfully"}
|
||||
except Exception as e:
|
||||
return {"error": f"Cannot clear memory: {e}"}
|
||||
3
tool_use/requirements.txt
Normal file
3
tool_use/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
anthropic>=0.18.0
|
||||
python-dotenv>=1.0.0
|
||||
ipykernel>=6.29.0 # For Jupyter in VSCode
|
||||
426
tool_use/tests/test_memory_tool.py
Normal file
426
tool_use/tests/test_memory_tool.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
Unit tests for the memory tool handler.
|
||||
|
||||
Tests security validation, command execution, and error handling.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from memory_tool import MemoryToolHandler
|
||||
|
||||
|
||||
class TestMemoryToolHandler(unittest.TestCase):
|
||||
"""Test suite for MemoryToolHandler."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary directory for each test."""
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.handler = MemoryToolHandler(base_path=self.test_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up temporary directory after each test."""
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
# Security Tests
|
||||
|
||||
def test_path_validation_requires_memories_prefix(self):
|
||||
"""Test that paths must start with /memories."""
|
||||
result = self.handler.execute(command="view", path="/etc/passwd")
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("must start with /memories", result["error"])
|
||||
|
||||
def test_path_validation_prevents_traversal_dotdot(self):
|
||||
"""Test that .. traversal is blocked."""
|
||||
result = self.handler.execute(
|
||||
command="view", path="/memories/../../../etc/passwd"
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("escape", result["error"].lower())
|
||||
|
||||
def test_path_validation_prevents_traversal_encoded(self):
|
||||
"""Test that URL-encoded traversal is blocked."""
|
||||
result = self.handler.execute(
|
||||
command="view", path="/memories/%2e%2e/%2e%2e/etc/passwd"
|
||||
)
|
||||
# The path will be processed and should fail validation
|
||||
self.assertIn("error", result)
|
||||
|
||||
def test_path_validation_allows_valid_paths(self):
|
||||
"""Test that valid memory paths are accepted."""
|
||||
result = self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="test"
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
# View Command Tests
|
||||
|
||||
def test_view_empty_directory(self):
|
||||
"""Test viewing an empty /memories directory."""
|
||||
result = self.handler.execute(command="view", path="/memories")
|
||||
self.assertIn("success", result)
|
||||
self.assertIn("empty", result["success"].lower())
|
||||
|
||||
def test_view_directory_with_files(self):
|
||||
"""Test viewing a directory with files."""
|
||||
# Create some test files
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/file1.txt", file_text="content1"
|
||||
)
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/file2.txt", file_text="content2"
|
||||
)
|
||||
|
||||
result = self.handler.execute(command="view", path="/memories")
|
||||
self.assertIn("success", result)
|
||||
self.assertIn("file1.txt", result["success"])
|
||||
self.assertIn("file2.txt", result["success"])
|
||||
|
||||
def test_view_file_with_line_numbers(self):
|
||||
"""Test viewing a file with line numbers."""
|
||||
content = "line 1\nline 2\nline 3"
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text=content
|
||||
)
|
||||
|
||||
result = self.handler.execute(command="view", path="/memories/test.txt")
|
||||
self.assertIn("success", result)
|
||||
self.assertIn(" 1: line 1", result["success"])
|
||||
self.assertIn(" 2: line 2", result["success"])
|
||||
self.assertIn(" 3: line 3", result["success"])
|
||||
|
||||
def test_view_file_with_range(self):
|
||||
"""Test viewing specific line range."""
|
||||
content = "line 1\nline 2\nline 3\nline 4"
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text=content
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="view", path="/memories/test.txt", view_range=[2, 3]
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
self.assertIn(" 2: line 2", result["success"])
|
||||
self.assertIn(" 3: line 3", result["success"])
|
||||
self.assertNotIn("line 1", result["success"])
|
||||
self.assertNotIn("line 4", result["success"])
|
||||
|
||||
def test_view_nonexistent_path(self):
|
||||
"""Test viewing a nonexistent path."""
|
||||
result = self.handler.execute(command="view", path="/memories/notfound.txt")
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("not found", result["error"].lower())
|
||||
|
||||
# Create Command Tests
|
||||
|
||||
def test_create_file(self):
|
||||
"""Test creating a file."""
|
||||
result = self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="Hello, World!"
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
# Verify file exists
|
||||
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||
self.assertTrue(file_path.exists())
|
||||
self.assertEqual(file_path.read_text(), "Hello, World!")
|
||||
|
||||
def test_create_file_in_subdirectory(self):
|
||||
"""Test creating a file in a subdirectory (auto-creates dirs)."""
|
||||
result = self.handler.execute(
|
||||
command="create",
|
||||
path="/memories/subdir/test.txt",
|
||||
file_text="Nested content",
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
file_path = Path(self.test_dir) / "memories" / "subdir" / "test.txt"
|
||||
self.assertTrue(file_path.exists())
|
||||
|
||||
def test_create_requires_file_extension(self):
|
||||
"""Test that create only allows text file extensions."""
|
||||
result = self.handler.execute(
|
||||
command="create", path="/memories/noext", file_text="content"
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("text files are supported", result["error"])
|
||||
|
||||
def test_create_overwrites_existing_file(self):
|
||||
"""Test that create overwrites existing files."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="original"
|
||||
)
|
||||
result = self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="updated"
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||
self.assertEqual(file_path.read_text(), "updated")
|
||||
|
||||
# String Replace Command Tests
|
||||
|
||||
def test_str_replace_success(self):
|
||||
"""Test successful string replacement."""
|
||||
self.handler.execute(
|
||||
command="create",
|
||||
path="/memories/test.txt",
|
||||
file_text="Hello World",
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="str_replace",
|
||||
path="/memories/test.txt",
|
||||
old_str="World",
|
||||
new_str="Universe",
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||
self.assertEqual(file_path.read_text(), "Hello Universe")
|
||||
|
||||
def test_str_replace_string_not_found(self):
|
||||
"""Test replacement when string doesn't exist."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="Hello World"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="str_replace",
|
||||
path="/memories/test.txt",
|
||||
old_str="Missing",
|
||||
new_str="Text",
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("not found", result["error"].lower())
|
||||
|
||||
def test_str_replace_multiple_occurrences(self):
|
||||
"""Test that replacement fails with multiple occurrences."""
|
||||
self.handler.execute(
|
||||
command="create",
|
||||
path="/memories/test.txt",
|
||||
file_text="Hello World Hello World",
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="str_replace",
|
||||
path="/memories/test.txt",
|
||||
old_str="Hello",
|
||||
new_str="Hi",
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("appears 2 times", result["error"])
|
||||
|
||||
def test_str_replace_file_not_found(self):
|
||||
"""Test replacement on nonexistent file."""
|
||||
result = self.handler.execute(
|
||||
command="str_replace",
|
||||
path="/memories/notfound.txt",
|
||||
old_str="old",
|
||||
new_str="new",
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("not found", result["error"].lower())
|
||||
|
||||
# Insert Command Tests
|
||||
|
||||
def test_insert_at_beginning(self):
|
||||
"""Test inserting at line 0 (beginning)."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="insert",
|
||||
path="/memories/test.txt",
|
||||
insert_line=0,
|
||||
insert_text="new line",
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||
self.assertEqual(file_path.read_text(), "new line\nline 1\nline 2\n")
|
||||
|
||||
def test_insert_in_middle(self):
|
||||
"""Test inserting in the middle."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="insert",
|
||||
path="/memories/test.txt",
|
||||
insert_line=1,
|
||||
insert_text="inserted",
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||
self.assertEqual(file_path.read_text(), "line 1\ninserted\nline 2\n")
|
||||
|
||||
def test_insert_at_end(self):
|
||||
"""Test inserting at the end."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="insert",
|
||||
path="/memories/test.txt",
|
||||
insert_line=2,
|
||||
insert_text="last line",
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
def test_insert_invalid_line(self):
|
||||
"""Test insert with invalid line number."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="line 1"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="insert",
|
||||
path="/memories/test.txt",
|
||||
insert_line=99,
|
||||
insert_text="text",
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("invalid", result["error"].lower())
|
||||
|
||||
# Delete Command Tests
|
||||
|
||||
def test_delete_file(self):
|
||||
"""Test deleting a file."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/test.txt", file_text="content"
|
||||
)
|
||||
|
||||
result = self.handler.execute(command="delete", path="/memories/test.txt")
|
||||
self.assertIn("success", result)
|
||||
|
||||
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||
self.assertFalse(file_path.exists())
|
||||
|
||||
def test_delete_directory(self):
|
||||
"""Test deleting a directory."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/subdir/test.txt", file_text="content"
|
||||
)
|
||||
|
||||
result = self.handler.execute(command="delete", path="/memories/subdir")
|
||||
self.assertIn("success", result)
|
||||
|
||||
dir_path = Path(self.test_dir) / "memories" / "subdir"
|
||||
self.assertFalse(dir_path.exists())
|
||||
|
||||
def test_delete_cannot_delete_root(self):
|
||||
"""Test that root /memories directory cannot be deleted."""
|
||||
result = self.handler.execute(command="delete", path="/memories")
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("cannot delete", result["error"].lower())
|
||||
|
||||
def test_delete_nonexistent_path(self):
|
||||
"""Test deleting a nonexistent path."""
|
||||
result = self.handler.execute(command="delete", path="/memories/notfound.txt")
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("not found", result["error"].lower())
|
||||
|
||||
# Rename Command Tests
|
||||
|
||||
def test_rename_file(self):
|
||||
"""Test renaming a file."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/old.txt", file_text="content"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="rename", old_path="/memories/old.txt", new_path="/memories/new.txt"
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
old_path = Path(self.test_dir) / "memories" / "old.txt"
|
||||
new_path = Path(self.test_dir) / "memories" / "new.txt"
|
||||
self.assertFalse(old_path.exists())
|
||||
self.assertTrue(new_path.exists())
|
||||
|
||||
def test_rename_to_subdirectory(self):
|
||||
"""Test moving a file to a subdirectory."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/file.txt", file_text="content"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="rename",
|
||||
old_path="/memories/file.txt",
|
||||
new_path="/memories/subdir/file.txt",
|
||||
)
|
||||
self.assertIn("success", result)
|
||||
|
||||
new_path = Path(self.test_dir) / "memories" / "subdir" / "file.txt"
|
||||
self.assertTrue(new_path.exists())
|
||||
|
||||
def test_rename_source_not_found(self):
|
||||
"""Test rename when source doesn't exist."""
|
||||
result = self.handler.execute(
|
||||
command="rename",
|
||||
old_path="/memories/notfound.txt",
|
||||
new_path="/memories/new.txt",
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("not found", result["error"].lower())
|
||||
|
||||
def test_rename_destination_exists(self):
|
||||
"""Test rename when destination already exists."""
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/file1.txt", file_text="content1"
|
||||
)
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/file2.txt", file_text="content2"
|
||||
)
|
||||
|
||||
result = self.handler.execute(
|
||||
command="rename",
|
||||
old_path="/memories/file1.txt",
|
||||
new_path="/memories/file2.txt",
|
||||
)
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("already exists", result["error"].lower())
|
||||
|
||||
# Error Handling Tests
|
||||
|
||||
def test_unknown_command(self):
|
||||
"""Test handling of unknown command."""
|
||||
result = self.handler.execute(command="invalid", path="/memories")
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("unknown command", result["error"].lower())
|
||||
|
||||
def test_missing_required_parameters(self):
|
||||
"""Test error handling for missing parameters."""
|
||||
result = self.handler.execute(command="view")
|
||||
self.assertIn("error", result)
|
||||
|
||||
# Utility Tests
|
||||
|
||||
def test_clear_all_memory(self):
|
||||
"""Test clearing all memory."""
|
||||
# Create some files
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/file1.txt", file_text="content1"
|
||||
)
|
||||
self.handler.execute(
|
||||
command="create", path="/memories/file2.txt", file_text="content2"
|
||||
)
|
||||
|
||||
result = self.handler.clear_all_memory()
|
||||
self.assertIn("success", result)
|
||||
|
||||
# Verify directory exists but is empty
|
||||
memory_root = Path(self.test_dir) / "memories"
|
||||
self.assertTrue(memory_root.exists())
|
||||
self.assertEqual(len(list(memory_root.iterdir())), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user