mirror of
https://github.com/anthropics/claude-cookbooks.git
synced 2025-10-06 01:00:28 +03:00
Merge pull request #1 from anthropics/zh/memory-cookbook-improvements
Add memory & context management cookbook
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -149,4 +149,8 @@ lychee-report.md
|
|||||||
# Notebook validation
|
# Notebook validation
|
||||||
.notebook_validation_state.json
|
.notebook_validation_state.json
|
||||||
.notebook_validation_checkpoint.json
|
.notebook_validation_checkpoint.json
|
||||||
validation_report_*.md
|
validation_report_*.md
|
||||||
|
# Memory tool demo artifacts
|
||||||
|
tool_use/demo_memory/
|
||||||
|
tool_use/memory_storage/
|
||||||
|
tool_use/.env
|
||||||
|
|||||||
14
tool_use/.env.example
Normal file
14
tool_use/.env.example
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Anthropic API Configuration
|
||||||
|
# Copy this file to .env and fill in your actual values
|
||||||
|
|
||||||
|
# Your Anthropic API key from https://console.anthropic.com/
|
||||||
|
ANTHROPIC_API_KEY=your_api_key_here
|
||||||
|
|
||||||
|
# Model name - Use a model that supports memory_20250818 tool
|
||||||
|
# Supported models (as of launch):
|
||||||
|
# - claude-sonnet-4-20250514
|
||||||
|
# - claude-opus-4-20250514
|
||||||
|
# - claude-opus-4-1-20250805
|
||||||
|
# - claude-sonnet-4-5-20250929
|
||||||
|
|
||||||
|
ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
|
||||||
File diff suppressed because it is too large
Load Diff
5
tool_use/memory_demo/.gitignore
vendored
Normal file
5
tool_use/memory_demo/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Ignore demo-generated directories and files
|
||||||
|
demo_memory/
|
||||||
|
memory_storage/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
339
tool_use/memory_demo/code_review_demo.py
Normal file
339
tool_use/memory_demo/code_review_demo.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
Code Review Assistant Demo - Three-session demonstration.
|
||||||
|
|
||||||
|
This demo showcases:
|
||||||
|
1. Session 1: Claude learns debugging patterns
|
||||||
|
2. Session 2: Claude applies learned patterns (faster!)
|
||||||
|
3. Session 3: Long session with context editing
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- .env file with ANTHROPIC_API_KEY and ANTHROPIC_MODEL
|
||||||
|
- memory_tool.py in the same directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from anthropic import Anthropic
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path to import memory_tool
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from memory_tool import MemoryToolHandler
|
||||||
|
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
MODEL = os.getenv("ANTHROPIC_MODEL")
|
||||||
|
|
||||||
|
if not API_KEY:
|
||||||
|
raise ValueError(
|
||||||
|
"ANTHROPIC_API_KEY not found. Copy .env.example to .env and add your API key."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not MODEL:
|
||||||
|
raise ValueError(
|
||||||
|
"ANTHROPIC_MODEL not found. Copy .env.example to .env and set the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Context management configuration
|
||||||
|
CONTEXT_MANAGEMENT = {
|
||||||
|
"edits": [
|
||||||
|
{
|
||||||
|
"type": "clear_tool_uses_20250919",
|
||||||
|
"trigger": {"type": "input_tokens", "value": 30000},
|
||||||
|
"keep": {"type": "tool_uses", "value": 3},
|
||||||
|
"clear_at_least": {"type": "input_tokens", "value": 5000},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CodeReviewAssistant:
|
||||||
|
"""
|
||||||
|
Code review assistant with memory and context editing capabilities.
|
||||||
|
|
||||||
|
This assistant:
|
||||||
|
- Checks memory for debugging patterns before reviewing code
|
||||||
|
- Stores learned patterns for future sessions
|
||||||
|
- Automatically clears old tool results when context grows large
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, memory_storage_path: str = "./memory_storage"):
|
||||||
|
"""
|
||||||
|
Initialize the code review assistant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_storage_path: Path for memory storage
|
||||||
|
"""
|
||||||
|
self.client = Anthropic(api_key=API_KEY)
|
||||||
|
self.memory_handler = MemoryToolHandler(base_path=memory_storage_path)
|
||||||
|
self.messages: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _create_system_prompt(self) -> str:
|
||||||
|
"""Create system prompt with memory instructions."""
|
||||||
|
return """You are an expert code reviewer focused on finding bugs and suggesting improvements.
|
||||||
|
|
||||||
|
MEMORY PROTOCOL:
|
||||||
|
1. Check your /memories directory for relevant debugging patterns or insights
|
||||||
|
2. When you find a bug or pattern, update your memory with what you learned
|
||||||
|
3. Keep your memory organized - use descriptive file names and clear content
|
||||||
|
|
||||||
|
When reviewing code:
|
||||||
|
- Identify bugs, security issues, and code quality problems
|
||||||
|
- Explain the issue clearly
|
||||||
|
- Provide a corrected version
|
||||||
|
- Store important patterns in memory for future reference
|
||||||
|
|
||||||
|
Remember: Your memory persists across conversations. Use it wisely."""
|
||||||
|
|
||||||
|
def _execute_tool_use(self, tool_use: Any) -> str:
|
||||||
|
"""Execute a tool use and return the result."""
|
||||||
|
if tool_use.name == "memory":
|
||||||
|
result = self.memory_handler.execute(**tool_use.input)
|
||||||
|
return result.get("success") or result.get("error", "Unknown error")
|
||||||
|
return f"Unknown tool: {tool_use.name}"
|
||||||
|
|
||||||
|
def review_code(
|
||||||
|
self, code: str, filename: str, description: str = ""
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Review code with memory-enhanced analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The code to review
|
||||||
|
filename: Name of the file being reviewed
|
||||||
|
description: Optional description of what to look for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with review results and metadata
|
||||||
|
"""
|
||||||
|
# Construct user message
|
||||||
|
user_message = f"Please review this code from {filename}"
|
||||||
|
if description:
|
||||||
|
user_message += f"\n\nContext: {description}"
|
||||||
|
user_message += f"\n\n```python\n{code}\n```"
|
||||||
|
|
||||||
|
self.messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
# Track token usage and context management
|
||||||
|
total_input_tokens = 0
|
||||||
|
context_edits_applied = []
|
||||||
|
|
||||||
|
# Conversation loop
|
||||||
|
turn = 1
|
||||||
|
while True:
|
||||||
|
print(f" 🔄 Turn {turn}: Calling Claude API...", end="", flush=True)
|
||||||
|
response = self.client.beta.messages.create(
|
||||||
|
model=MODEL,
|
||||||
|
max_tokens=4096,
|
||||||
|
system=self._create_system_prompt(),
|
||||||
|
messages=self.messages,
|
||||||
|
tools=[{"type": "memory_20250818", "name": "memory"}],
|
||||||
|
betas=["context-management-2025-06-27"],
|
||||||
|
extra_body={"context_management": CONTEXT_MANAGEMENT},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(" ✓")
|
||||||
|
|
||||||
|
# Track usage
|
||||||
|
total_input_tokens = response.usage.input_tokens
|
||||||
|
|
||||||
|
# Check for context management
|
||||||
|
if hasattr(response, "context_management") and response.context_management:
|
||||||
|
applied = response.context_management.get("applied_edits", [])
|
||||||
|
if applied:
|
||||||
|
context_edits_applied.extend(applied)
|
||||||
|
|
||||||
|
# Process response content
|
||||||
|
assistant_content = []
|
||||||
|
tool_results = []
|
||||||
|
final_text = []
|
||||||
|
|
||||||
|
for content in response.content:
|
||||||
|
if content.type == "text":
|
||||||
|
assistant_content.append({"type": "text", "text": content.text})
|
||||||
|
final_text.append(content.text)
|
||||||
|
elif content.type == "tool_use":
|
||||||
|
cmd = content.input.get('command', 'unknown')
|
||||||
|
path = content.input.get('path', '')
|
||||||
|
print(f" 🔧 Memory: {cmd} {path}")
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
|
result = self._execute_tool_use(content)
|
||||||
|
|
||||||
|
assistant_content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": content.id,
|
||||||
|
"name": content.name,
|
||||||
|
"input": content.input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_results.append(
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": content.id,
|
||||||
|
"content": result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add assistant message
|
||||||
|
self.messages.append({"role": "assistant", "content": assistant_content})
|
||||||
|
|
||||||
|
# If there are tool results, add them and continue
|
||||||
|
if tool_results:
|
||||||
|
self.messages.append({"role": "user", "content": tool_results})
|
||||||
|
turn += 1
|
||||||
|
else:
|
||||||
|
# No more tool uses, we're done
|
||||||
|
print()
|
||||||
|
break
|
||||||
|
|
||||||
|
return {
|
||||||
|
"review": "\n".join(final_text),
|
||||||
|
"input_tokens": total_input_tokens,
|
||||||
|
"context_edits": context_edits_applied,
|
||||||
|
}
|
||||||
|
|
||||||
|
def start_new_session(self) -> None:
|
||||||
|
"""Start a new conversation session (memory persists)."""
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
|
||||||
|
def run_session_1() -> None:
|
||||||
|
"""Session 1: Learn debugging patterns."""
|
||||||
|
print("=" * 80)
|
||||||
|
print("SESSION 1: Learning from First Code Review")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
assistant = CodeReviewAssistant()
|
||||||
|
|
||||||
|
# Read sample code
|
||||||
|
with open("memory_demo/sample_code/web_scraper_v1.py", "r") as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
print("\n📋 Reviewing web_scraper_v1.py...")
|
||||||
|
print("\nMulti-threaded web scraper that sometimes loses results.\n")
|
||||||
|
|
||||||
|
result = assistant.review_code(
|
||||||
|
code=code,
|
||||||
|
filename="web_scraper_v1.py",
|
||||||
|
description="This scraper sometimes returns fewer results than expected. "
|
||||||
|
"The count is inconsistent across runs. Can you find the issue?",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n🤖 Claude's Review:\n")
|
||||||
|
print(result["review"])
|
||||||
|
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
|
||||||
|
|
||||||
|
if result["context_edits"]:
|
||||||
|
print(f"\n🧹 Context edits applied: {result['context_edits']}")
|
||||||
|
|
||||||
|
print("\n✅ Session 1 complete - Claude learned debugging patterns!\n")
|
||||||
|
|
||||||
|
|
||||||
|
def run_session_2() -> None:
|
||||||
|
"""Session 2: Apply learned patterns."""
|
||||||
|
print("=" * 80)
|
||||||
|
print("SESSION 2: Applying Learned Patterns (New Conversation)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# New assistant instance (new conversation, but memory persists)
|
||||||
|
assistant = CodeReviewAssistant()
|
||||||
|
|
||||||
|
# Read different sample code with similar bug
|
||||||
|
with open("memory_demo/sample_code/api_client_v1.py", "r") as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
print("\n📋 Reviewing api_client_v1.py...")
|
||||||
|
print("\nAsync API client with concurrent requests.\n")
|
||||||
|
|
||||||
|
result = assistant.review_code(
|
||||||
|
code=code,
|
||||||
|
filename="api_client_v1.py",
|
||||||
|
description="Review this async API client. "
|
||||||
|
"It fetches multiple endpoints concurrently. Are there any issues?",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n🤖 Claude's Review:\n")
|
||||||
|
print(result["review"])
|
||||||
|
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
|
||||||
|
|
||||||
|
print("\n✅ Session 2 complete - Claude applied learned patterns faster!\n")
|
||||||
|
|
||||||
|
|
||||||
|
def run_session_3() -> None:
|
||||||
|
"""Session 3: Long session with context editing."""
|
||||||
|
print("=" * 80)
|
||||||
|
print("SESSION 3: Long Session with Context Editing")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
assistant = CodeReviewAssistant()
|
||||||
|
|
||||||
|
# Read data processor code (has multiple issues)
|
||||||
|
with open("memory_demo/sample_code/data_processor_v1.py", "r") as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
print("\n📋 Reviewing data_processor_v1.py...")
|
||||||
|
print("\nLarge file with multiple concurrent processing classes.\n")
|
||||||
|
|
||||||
|
result = assistant.review_code(
|
||||||
|
code=code,
|
||||||
|
filename="data_processor_v1.py",
|
||||||
|
description="This data processor handles files concurrently. "
|
||||||
|
"There's also a SharedCache class. Review all components for issues.",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n🤖 Claude's Review:\n")
|
||||||
|
print(result["review"])
|
||||||
|
print(f"\n📊 Input tokens used: {result['input_tokens']:,}")
|
||||||
|
|
||||||
|
if result["context_edits"]:
|
||||||
|
print("\n🧹 Context Management Applied:")
|
||||||
|
for edit in result["context_edits"]:
|
||||||
|
print(f" - Type: {edit.get('type')}")
|
||||||
|
print(f" - Cleared tool uses: {edit.get('cleared_tool_uses', 0)}")
|
||||||
|
print(f" - Tokens saved: {edit.get('cleared_input_tokens', 0):,}")
|
||||||
|
|
||||||
|
print("\n✅ Session 3 complete - Context editing kept conversation manageable!\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Run all three demo sessions."""
|
||||||
|
print("\n🚀 Code Review Assistant Demo\n")
|
||||||
|
print("This demo shows:")
|
||||||
|
print("1. Session 1: Claude learns debugging patterns")
|
||||||
|
print("2. Session 2: Claude applies learned patterns (new conversation)")
|
||||||
|
print("3. Session 3: Long session with context editing\n")
|
||||||
|
|
||||||
|
input("Press Enter to start Session 1...")
|
||||||
|
run_session_1()
|
||||||
|
|
||||||
|
input("Press Enter to start Session 2...")
|
||||||
|
run_session_2()
|
||||||
|
|
||||||
|
input("Press Enter to start Session 3...")
|
||||||
|
run_session_3()
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("🎉 Demo Complete!")
|
||||||
|
print("=" * 80)
|
||||||
|
print("\nKey Takeaways:")
|
||||||
|
print("- Memory tool enabled cross-conversation learning")
|
||||||
|
print("- Claude got faster at recognizing similar bugs")
|
||||||
|
print("- Context editing handled long sessions gracefully")
|
||||||
|
print("\n💡 For production GitHub PR reviews, check out:")
|
||||||
|
print(" https://github.com/anthropics/claude-code-action\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
194
tool_use/memory_demo/demo_helpers.py
Normal file
194
tool_use/memory_demo/demo_helpers.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""
|
||||||
|
Helper functions for memory cookbook demos.
|
||||||
|
|
||||||
|
This module provides reusable functions for running conversation loops
|
||||||
|
with Claude, handling tool execution, and managing context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from anthropic import Anthropic
|
||||||
|
from memory_tool import MemoryToolHandler
|
||||||
|
|
||||||
|
|
||||||
|
def execute_tool(tool_use: Any, memory_handler: MemoryToolHandler) -> str:
|
||||||
|
"""
|
||||||
|
Execute a tool use and return the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_use: The tool use object from Claude's response
|
||||||
|
memory_handler: The memory tool handler instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the tool execution
|
||||||
|
"""
|
||||||
|
if tool_use.name == "memory":
|
||||||
|
result = memory_handler.execute(**tool_use.input)
|
||||||
|
return result.get("success") or result.get("error", "Unknown error")
|
||||||
|
return f"Unknown tool: {tool_use.name}"
|
||||||
|
|
||||||
|
|
||||||
|
def run_conversation_turn(
|
||||||
|
client: Anthropic,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
memory_handler: MemoryToolHandler,
|
||||||
|
system: str,
|
||||||
|
context_management: dict[str, Any] | None = None,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
verbose: bool = False
|
||||||
|
) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Run a single conversation turn, handling tool uses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Anthropic client instance
|
||||||
|
model: Model to use
|
||||||
|
messages: Current conversation messages
|
||||||
|
memory_handler: Memory tool handler instance
|
||||||
|
system: System prompt
|
||||||
|
context_management: Optional context management config
|
||||||
|
max_tokens: Max tokens for response
|
||||||
|
verbose: Whether to print tool operations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response, assistant_content, tool_results)
|
||||||
|
"""
|
||||||
|
memory_tool: dict[str, Any] = {"type": "memory_20250818", "name": "memory"}
|
||||||
|
|
||||||
|
request_params: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"system": system,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": [memory_tool],
|
||||||
|
"betas": ["context-management-2025-06-27"]
|
||||||
|
}
|
||||||
|
|
||||||
|
if context_management:
|
||||||
|
request_params["extra_body"] = {"context_management": context_management}
|
||||||
|
|
||||||
|
response = client.beta.messages.create(**request_params)
|
||||||
|
|
||||||
|
assistant_content = []
|
||||||
|
tool_results = []
|
||||||
|
|
||||||
|
for content in response.content:
|
||||||
|
if content.type == "text":
|
||||||
|
if verbose:
|
||||||
|
print(f"💬 Claude: {content.text}\n")
|
||||||
|
assistant_content.append({"type": "text", "text": content.text})
|
||||||
|
elif content.type == "tool_use":
|
||||||
|
if verbose:
|
||||||
|
cmd = content.input.get('command')
|
||||||
|
path = content.input.get('path', '')
|
||||||
|
print(f" 🔧 Memory tool: {cmd} {path}")
|
||||||
|
|
||||||
|
result = execute_tool(content, memory_handler)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
result_preview = result[:80] + "..." if len(result) > 80 else result
|
||||||
|
print(f" ✓ Result: {result_preview}")
|
||||||
|
|
||||||
|
assistant_content.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": content.id,
|
||||||
|
"name": content.name,
|
||||||
|
"input": content.input
|
||||||
|
})
|
||||||
|
tool_results.append({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": content.id,
|
||||||
|
"content": result
|
||||||
|
})
|
||||||
|
|
||||||
|
return response, assistant_content, tool_results
|
||||||
|
|
||||||
|
|
||||||
|
def run_conversation_loop(
|
||||||
|
client: Anthropic,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
memory_handler: MemoryToolHandler,
|
||||||
|
system: str,
|
||||||
|
context_management: dict[str, Any] | None = None,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
max_turns: int = 5,
|
||||||
|
verbose: bool = False
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Run a complete conversation loop until Claude stops using tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Anthropic client instance
|
||||||
|
model: Model to use
|
||||||
|
messages: Current conversation messages (will be modified in-place)
|
||||||
|
memory_handler: Memory tool handler instance
|
||||||
|
system: System prompt
|
||||||
|
context_management: Optional context management config
|
||||||
|
max_tokens: Max tokens for response
|
||||||
|
max_turns: Maximum number of turns to prevent infinite loops
|
||||||
|
verbose: Whether to print progress
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final API response
|
||||||
|
"""
|
||||||
|
turn = 1
|
||||||
|
response = None
|
||||||
|
|
||||||
|
while turn <= max_turns:
|
||||||
|
if verbose:
|
||||||
|
print(f"\n🔄 Turn {turn}:")
|
||||||
|
|
||||||
|
response, assistant_content, tool_results = run_conversation_turn(
|
||||||
|
client=client,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
memory_handler=memory_handler,
|
||||||
|
system=system,
|
||||||
|
context_management=context_management,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
verbose=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": assistant_content})
|
||||||
|
|
||||||
|
if tool_results:
|
||||||
|
messages.append({"role": "user", "content": tool_results})
|
||||||
|
turn += 1
|
||||||
|
else:
|
||||||
|
# No more tool uses, conversation complete
|
||||||
|
break
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def print_context_management_info(response: Any) -> tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
Print context management information from response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: API response to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (context_cleared, saved_tokens)
|
||||||
|
"""
|
||||||
|
context_cleared = False
|
||||||
|
saved_tokens = 0
|
||||||
|
|
||||||
|
if hasattr(response, "context_management") and response.context_management:
|
||||||
|
edits = response.context_management.get("applied_edits", [])
|
||||||
|
if edits:
|
||||||
|
context_cleared = True
|
||||||
|
cleared_uses = edits[0].get('cleared_tool_uses', 0)
|
||||||
|
saved_tokens = edits[0].get('cleared_input_tokens', 0)
|
||||||
|
print(f" ✂️ Context editing triggered!")
|
||||||
|
print(f" • Cleared {cleared_uses} tool uses")
|
||||||
|
print(f" • Saved {saved_tokens:,} tokens")
|
||||||
|
print(f" • After clearing: {response.usage.input_tokens:,} tokens")
|
||||||
|
else:
|
||||||
|
print(f" ℹ️ Context below threshold - no clearing triggered")
|
||||||
|
else:
|
||||||
|
print(f" ℹ️ No context management applied")
|
||||||
|
|
||||||
|
return context_cleared, saved_tokens
|
||||||
99
tool_use/memory_demo/sample_code/api_client_v1.py
Normal file
99
tool_use/memory_demo/sample_code/api_client_v1.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
Async API client with similar concurrency issues.
|
||||||
|
This demonstrates Claude applying thread-safety patterns to async code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAPIClient:
|
||||||
|
"""Async API client for fetching data from multiple endpoints."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.responses = [] # BUG: Shared state accessed from multiple coroutines!
|
||||||
|
self.error_count = 0 # BUG: Race condition on counter increment!
|
||||||
|
|
||||||
|
async def fetch_endpoint(
|
||||||
|
self, session: aiohttp.ClientSession, endpoint: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Fetch a single endpoint."""
|
||||||
|
url = f"{self.base_url}/{endpoint}"
|
||||||
|
try:
|
||||||
|
async with session.get(
|
||||||
|
url, timeout=aiohttp.ClientTimeout(total=5)
|
||||||
|
) as response:
|
||||||
|
data = await response.json()
|
||||||
|
return {
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"status": response.status,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fetch_all(self, endpoints: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Fetch multiple endpoints concurrently.
|
||||||
|
|
||||||
|
BUG: Similar to the threading issue, multiple coroutines
|
||||||
|
modify self.responses and self.error_count without coordination!
|
||||||
|
While Python's GIL prevents some race conditions in threads,
|
||||||
|
async code can still have interleaving issues.
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
tasks = [self.fetch_endpoint(session, endpoint) for endpoint in endpoints]
|
||||||
|
|
||||||
|
for coro in asyncio.as_completed(tasks):
|
||||||
|
result = await coro
|
||||||
|
|
||||||
|
# RACE CONDITION: Multiple coroutines modify shared state
|
||||||
|
if "error" in result:
|
||||||
|
self.error_count += 1 # Not atomic!
|
||||||
|
else:
|
||||||
|
self.responses.append(result) # Not thread-safe in async context!
|
||||||
|
|
||||||
|
return self.responses
|
||||||
|
|
||||||
|
def get_summary(self) -> Dict[str, Any]:
|
||||||
|
"""Get summary statistics."""
|
||||||
|
return {
|
||||||
|
"total_responses": len(self.responses),
|
||||||
|
"errors": self.error_count,
|
||||||
|
"success_rate": (
|
||||||
|
len(self.responses) / (len(self.responses) + self.error_count)
|
||||||
|
if (len(self.responses) + self.error_count) > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Test the async API client."""
|
||||||
|
client = AsyncAPIClient("https://jsonplaceholder.typicode.com")
|
||||||
|
|
||||||
|
endpoints = [
|
||||||
|
"posts/1",
|
||||||
|
"posts/2",
|
||||||
|
"posts/3",
|
||||||
|
"users/1",
|
||||||
|
"users/2",
|
||||||
|
"invalid/endpoint", # Will error
|
||||||
|
] * 20 # 120 requests total
|
||||||
|
|
||||||
|
results = await client.fetch_all(endpoints)
|
||||||
|
|
||||||
|
print(f"Expected: ~100 successful responses")
|
||||||
|
print(f"Got: {len(results)} responses")
|
||||||
|
print(f"Summary: {client.get_summary()}")
|
||||||
|
print("\nNote: Counts may be incorrect due to race conditions!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
115
tool_use/memory_demo/sample_code/cache_manager.py
Normal file
115
tool_use/memory_demo/sample_code/cache_manager.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""
|
||||||
|
Cache manager with mutable default argument bug.
|
||||||
|
This is one of Python's most common gotchas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
"""Manage cached data with TTL support."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def add_items(
|
||||||
|
self, key: str, items: List[str] = [] # BUG: Mutable default argument!
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add items to cache.
|
||||||
|
|
||||||
|
BUG: Using [] as default creates a SHARED list across all calls!
|
||||||
|
This is one of Python's classic gotchas.
|
||||||
|
"""
|
||||||
|
# The items list is shared across ALL calls that don't provide items
|
||||||
|
items.append(f"Added at {datetime.now()}")
|
||||||
|
self.cache[key] = items
|
||||||
|
|
||||||
|
def add_items_fixed(self, key: str, items: Optional[List[str]] = None) -> None:
|
||||||
|
"""Add items with proper default handling."""
|
||||||
|
if items is None:
|
||||||
|
items = []
|
||||||
|
items = items.copy() # Also make a copy to avoid mutation
|
||||||
|
items.append(f"Added at {datetime.now()}")
|
||||||
|
self.cache[key] = items
|
||||||
|
|
||||||
|
def merge_configs(
|
||||||
|
self, name: str, overrides: Dict[str, Any] = {} # BUG: Mutable default!
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Merge configuration with overrides.
|
||||||
|
|
||||||
|
BUG: The default dict is shared across all calls!
|
||||||
|
"""
|
||||||
|
defaults = {"timeout": 30, "retries": 3, "cache_enabled": True}
|
||||||
|
|
||||||
|
# This modifies the SHARED overrides dict
|
||||||
|
overrides.update(defaults)
|
||||||
|
return overrides
|
||||||
|
|
||||||
|
def merge_configs_fixed(
|
||||||
|
self, name: str, overrides: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge configs properly."""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
|
||||||
|
defaults = {"timeout": 30, "retries": 3, "cache_enabled": True}
|
||||||
|
|
||||||
|
# Create new dict to avoid mutation
|
||||||
|
result = {**defaults, **overrides}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessor:
|
||||||
|
"""Another example of the mutable default bug."""
|
||||||
|
|
||||||
|
def process_batch(
|
||||||
|
self, data: List[int], filters: List[str] = [] # BUG: Mutable default!
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Process data with optional filters.
|
||||||
|
|
||||||
|
BUG: filters list is shared across calls!
|
||||||
|
"""
|
||||||
|
filters.append("default_filter") # Modifies shared list!
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for item in data:
|
||||||
|
if "positive" in filters and item < 0:
|
||||||
|
continue
|
||||||
|
result.append(item * 2)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cache = CacheManager()
|
||||||
|
|
||||||
|
# Demonstrate the bug
|
||||||
|
print("=== Demonstrating Mutable Default Argument Bug ===\n")
|
||||||
|
|
||||||
|
# First call with no items
|
||||||
|
cache.add_items("key1")
|
||||||
|
print(f"key1: {cache.cache['key1']}")
|
||||||
|
|
||||||
|
# Second call with no items - SURPRISE! Gets the same list
|
||||||
|
cache.add_items("key2")
|
||||||
|
print(f"key2: {cache.cache['key2']}") # Will have TWO timestamps!
|
||||||
|
|
||||||
|
# Third call - even worse!
|
||||||
|
cache.add_items("key3")
|
||||||
|
print(f"key3: {cache.cache['key3']}") # Will have THREE timestamps!
|
||||||
|
|
||||||
|
print("\nAll keys share the same list object!")
|
||||||
|
print(f"key1 is key2: {cache.cache['key1'] is cache.cache['key2']}")
|
||||||
|
|
||||||
|
print("\n=== Using Fixed Version ===\n")
|
||||||
|
cache2 = CacheManager()
|
||||||
|
cache2.add_items_fixed("key1")
|
||||||
|
cache2.add_items_fixed("key2")
|
||||||
|
cache2.add_items_fixed("key3")
|
||||||
|
print(f"key1: {cache2.cache['key1']}")
|
||||||
|
print(f"key2: {cache2.cache['key2']}")
|
||||||
|
print(f"key3: {cache2.cache['key3']}")
|
||||||
|
print(f"\nkey1 is key2: {cache2.cache['key1'] is cache2.cache['key2']}")
|
||||||
145
tool_use/memory_demo/sample_code/data_processor_v1.py
Normal file
145
tool_use/memory_demo/sample_code/data_processor_v1.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Data processor with multiple concurrency and thread-safety issues.
|
||||||
|
Used for Session 3 to demonstrate context editing with multiple bugs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessor:
|
||||||
|
"""Process data files concurrently with various thread-safety issues."""
|
||||||
|
|
||||||
|
def __init__(self, max_workers: int = 5):
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.processed_count = 0 # BUG: Race condition on counter
|
||||||
|
self.results = [] # BUG: Shared list without locking
|
||||||
|
self.errors = {} # BUG: Shared dict without locking
|
||||||
|
self.lock = threading.Lock() # Available but not used!
|
||||||
|
|
||||||
|
def process_file(self, file_path: str) -> Dict[str, Any]:
|
||||||
|
"""Process a single file."""
|
||||||
|
try:
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Simulate some processing
|
||||||
|
processed = {
|
||||||
|
"file": file_path,
|
||||||
|
"record_count": len(data) if isinstance(data, list) else 1,
|
||||||
|
"size_bytes": Path(file_path).stat().st_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"file": file_path, "error": str(e)}
|
||||||
|
|
||||||
|
def process_batch(self, file_paths: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process multiple files concurrently.
|
||||||
|
|
||||||
|
MULTIPLE BUGS:
|
||||||
|
1. self.processed_count is incremented without locking
|
||||||
|
2. self.results is appended to from multiple threads
|
||||||
|
3. self.errors is modified from multiple threads
|
||||||
|
4. We have a lock but don't use it!
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
|
futures = [executor.submit(self.process_file, fp) for fp in file_paths]
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
result = future.result()
|
||||||
|
|
||||||
|
# RACE CONDITION: Increment counter without lock
|
||||||
|
self.processed_count += 1 # BUG!
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
# RACE CONDITION: Modify dict without lock
|
||||||
|
self.errors[result["file"]] = result["error"] # BUG!
|
||||||
|
else:
|
||||||
|
# RACE CONDITION: Append to list without lock
|
||||||
|
self.results.append(result) # BUG!
|
||||||
|
|
||||||
|
return self.results
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get processing statistics.
|
||||||
|
|
||||||
|
BUG: Accessing shared state without ensuring thread-safety.
|
||||||
|
If called while processing, could get inconsistent values.
|
||||||
|
"""
|
||||||
|
total_files = self.processed_count
|
||||||
|
successful = len(self.results)
|
||||||
|
failed = len(self.errors)
|
||||||
|
|
||||||
|
# BUG: These counts might not add up correctly due to race conditions
|
||||||
|
return {
|
||||||
|
"total_processed": total_files,
|
||||||
|
"successful": successful,
|
||||||
|
"failed": failed,
|
||||||
|
"success_rate": successful / total_files if total_files > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset processor state.
|
||||||
|
|
||||||
|
BUG: No locking - if called during processing, causes corruption.
|
||||||
|
"""
|
||||||
|
self.processed_count = 0 # RACE CONDITION
|
||||||
|
self.results = [] # RACE CONDITION
|
||||||
|
self.errors = {} # RACE CONDITION
|
||||||
|
|
||||||
|
|
||||||
|
class SharedCache:
|
||||||
|
"""
|
||||||
|
A shared cache with thread-safety issues.
|
||||||
|
|
||||||
|
BUG: Classic read-modify-write race condition pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = {} # BUG: Shared dict without locking
|
||||||
|
self.hit_count = 0 # BUG: Race condition
|
||||||
|
self.miss_count = 0 # BUG: Race condition
|
||||||
|
|
||||||
|
def get(self, key: str) -> Any:
|
||||||
|
"""Get from cache - RACE CONDITION on hit/miss counts."""
|
||||||
|
if key in self.cache:
|
||||||
|
self.hit_count += 1 # BUG: Not atomic!
|
||||||
|
return self.cache[key]
|
||||||
|
else:
|
||||||
|
self.miss_count += 1 # BUG: Not atomic!
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any):
|
||||||
|
"""Set in cache - RACE CONDITION on dict modification."""
|
||||||
|
self.cache[key] = value # BUG: Dict access not synchronized!
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get cache statistics - may be inconsistent."""
|
||||||
|
total = self.hit_count + self.miss_count
|
||||||
|
return {
|
||||||
|
"hits": self.hit_count,
|
||||||
|
"misses": self.miss_count,
|
||||||
|
"hit_rate": self.hit_count / total if total > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Create some test files (not included)
|
||||||
|
processor = DataProcessor(max_workers=10)
|
||||||
|
|
||||||
|
# Simulate processing many files
|
||||||
|
file_paths = [f"data/file_{i}.json" for i in range(100)]
|
||||||
|
|
||||||
|
print("Processing files concurrently...")
|
||||||
|
results = processor.process_batch(file_paths)
|
||||||
|
|
||||||
|
print(f"\nStatistics: {processor.get_statistics()}")
|
||||||
|
print("\nNote: Counts may be inconsistent due to race conditions!")
|
||||||
105
tool_use/memory_demo/sample_code/sql_query_builder.py
Normal file
105
tool_use/memory_demo/sample_code/sql_query_builder.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
SQL query builder with SQL injection vulnerability.
|
||||||
|
Demonstrates dangerous string formatting in SQL queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class UserDatabase:
|
||||||
|
"""Simple database interface (mock)."""
|
||||||
|
|
||||||
|
def execute(self, query: str) -> List[dict]:
|
||||||
|
"""Mock execute - just returns the query for inspection."""
|
||||||
|
print(f"Executing: {query}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class QueryBuilder:
|
||||||
|
"""Build SQL queries for user operations."""
|
||||||
|
|
||||||
|
def __init__(self, db: UserDatabase):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def get_user_by_name(self, username: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get user by username.
|
||||||
|
|
||||||
|
BUG: SQL INJECTION VULNERABILITY!
|
||||||
|
Using string formatting with user input allows SQL injection.
|
||||||
|
"""
|
||||||
|
# DANGEROUS: Never use f-strings or % formatting with user input!
|
||||||
|
query = f"SELECT * FROM users WHERE username = '{username}'"
|
||||||
|
results = self.db.execute(query)
|
||||||
|
return results[0] if results else None
|
||||||
|
|
||||||
|
def get_user_by_name_safe(self, username: str) -> Optional[dict]:
|
||||||
|
"""Safe version using parameterized queries."""
|
||||||
|
# Use parameterized queries (this is pseudo-code for the concept)
|
||||||
|
query = "SELECT * FROM users WHERE username = ?"
|
||||||
|
# In real code: self.db.execute(query, (username,))
|
||||||
|
print(f"Safe query with parameter: {query}, params: ({username},)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def search_users(self, search_term: str, limit: int = 10) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Search users by term.
|
||||||
|
|
||||||
|
BUG: SQL INJECTION through LIKE clause!
|
||||||
|
"""
|
||||||
|
# DANGEROUS: User input directly in LIKE clause
|
||||||
|
query = f"SELECT * FROM users WHERE name LIKE '%{search_term}%' LIMIT {limit}"
|
||||||
|
return self.db.execute(query)
|
||||||
|
|
||||||
|
def delete_user(self, user_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a user.
|
||||||
|
|
||||||
|
BUG: SQL INJECTION in DELETE statement!
|
||||||
|
This is especially dangerous as it can lead to data loss.
|
||||||
|
"""
|
||||||
|
# DANGEROUS: Unvalidated user input in DELETE
|
||||||
|
query = f"DELETE FROM users WHERE id = {user_id}"
|
||||||
|
self.db.execute(query)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_users_by_role(self, role: str, order_by: str = "name") -> List[dict]:
|
||||||
|
"""
|
||||||
|
Get users by role.
|
||||||
|
|
||||||
|
BUG: SQL INJECTION in ORDER BY clause!
|
||||||
|
Even the ORDER BY clause can be exploited.
|
||||||
|
"""
|
||||||
|
# DANGEROUS: User-controlled ORDER BY
|
||||||
|
query = f"SELECT * FROM users WHERE role = '{role}' ORDER BY {order_by}"
|
||||||
|
return self.db.execute(query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
db = UserDatabase()
|
||||||
|
qb = QueryBuilder(db)
|
||||||
|
|
||||||
|
print("=== Demonstrating SQL Injection Vulnerabilities ===\n")
|
||||||
|
|
||||||
|
# Example 1: Basic injection
|
||||||
|
print("1. Basic username injection:")
|
||||||
|
qb.get_user_by_name("admin' OR '1'='1")
|
||||||
|
# Executes: SELECT * FROM users WHERE username = 'admin' OR '1'='1'
|
||||||
|
# Returns ALL users!
|
||||||
|
|
||||||
|
print("\n2. Search term injection:")
|
||||||
|
qb.search_users("test%' OR 1=1--")
|
||||||
|
# Can bypass the LIKE and return everything
|
||||||
|
|
||||||
|
print("\n3. DELETE injection:")
|
||||||
|
qb.delete_user("1 OR 1=1")
|
||||||
|
# Executes: DELETE FROM users WHERE id = 1 OR 1=1
|
||||||
|
# DELETES ALL USERS!
|
||||||
|
|
||||||
|
print("\n4. ORDER BY injection:")
|
||||||
|
qb.get_users_by_role("admin", "name; DROP TABLE users--")
|
||||||
|
# Can execute arbitrary SQL commands!
|
||||||
|
|
||||||
|
print("\n=== Safe Version ===")
|
||||||
|
qb.get_user_by_name_safe("admin' OR '1'='1")
|
||||||
|
# Parameters are properly escaped
|
||||||
84
tool_use/memory_demo/sample_code/web_scraper_v1.py
Normal file
84
tool_use/memory_demo/sample_code/web_scraper_v1.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
Concurrent web scraper with a race condition bug.
|
||||||
|
Multiple threads modify shared state without synchronization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class WebScraper:
|
||||||
|
"""Web scraper that fetches multiple URLs concurrently."""
|
||||||
|
|
||||||
|
def __init__(self, max_workers: int = 10):
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.results = [] # BUG: Shared mutable state accessed by multiple threads!
|
||||||
|
self.failed_urls = [] # BUG: Another race condition!
|
||||||
|
|
||||||
|
def fetch_url(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""Fetch a single URL and return the result."""
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"status": response.status_code,
|
||||||
|
"content_length": len(response.content),
|
||||||
|
}
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return {"url": url, "error": str(e)}
|
||||||
|
|
||||||
|
def scrape_urls(self, urls: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Scrape multiple URLs concurrently.
|
||||||
|
|
||||||
|
BUG: self.results is accessed from multiple threads without locking!
|
||||||
|
This causes race conditions where results can be lost or corrupted.
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
|
futures = [executor.submit(self.fetch_url, url) for url in urls]
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
result = future.result()
|
||||||
|
|
||||||
|
# RACE CONDITION: Multiple threads append to self.results simultaneously
|
||||||
|
if "error" in result:
|
||||||
|
self.failed_urls.append(result["url"]) # RACE CONDITION
|
||||||
|
else:
|
||||||
|
self.results.append(result) # RACE CONDITION
|
||||||
|
|
||||||
|
return self.results
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, int]:
|
||||||
|
"""Get scraping statistics."""
|
||||||
|
return {
|
||||||
|
"total_results": len(self.results),
|
||||||
|
"failed_urls": len(self.failed_urls),
|
||||||
|
"success_rate": (
|
||||||
|
len(self.results) / (len(self.results) + len(self.failed_urls))
|
||||||
|
if (len(self.results) + len(self.failed_urls)) > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test with multiple URLs
|
||||||
|
urls = [
|
||||||
|
"https://httpbin.org/delay/1",
|
||||||
|
"https://httpbin.org/status/200",
|
||||||
|
"https://httpbin.org/status/404",
|
||||||
|
"https://httpbin.org/delay/2",
|
||||||
|
"https://httpbin.org/status/500",
|
||||||
|
] * 10 # 50 URLs total to increase race condition probability
|
||||||
|
|
||||||
|
scraper = WebScraper(max_workers=10)
|
||||||
|
results = scraper.scrape_urls(urls)
|
||||||
|
|
||||||
|
print(f"Expected: 50 results")
|
||||||
|
print(f"Got: {len(results)} results")
|
||||||
|
print(f"Stats: {scraper.get_stats()}")
|
||||||
|
print("\nNote: Results count may be less than expected due to race condition!")
|
||||||
366
tool_use/memory_tool.py
Normal file
366
tool_use/memory_tool.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
"""
|
||||||
|
Production-ready memory tool handler for Claude's memory_20250818 tool.
|
||||||
|
|
||||||
|
This implementation provides secure, client-side execution of memory operations
|
||||||
|
with path validation, error handling, and comprehensive security measures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolHandler:
|
||||||
|
"""
|
||||||
|
Handles execution of Claude's memory tool commands.
|
||||||
|
|
||||||
|
The memory tool enables Claude to read, write, and manage files in a memory
|
||||||
|
system through a standardized tool interface. This handler provides client-side
|
||||||
|
implementation with security controls.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_path: Root directory for memory storage
|
||||||
|
memory_root: The /memories directory within base_path
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_path: str = "./memory_storage"):
|
||||||
|
"""
|
||||||
|
Initialize the memory tool handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: Root directory for all memory operations
|
||||||
|
"""
|
||||||
|
self.base_path = Path(base_path).resolve()
|
||||||
|
self.memory_root = self.base_path / "memories"
|
||||||
|
self.memory_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def _validate_path(self, path: str) -> Path:
|
||||||
|
"""
|
||||||
|
Validate and resolve memory paths to prevent directory traversal attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to validate (must start with /memories)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved absolute Path object within memory_root
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If path is invalid or attempts to escape memory directory
|
||||||
|
"""
|
||||||
|
if not path.startswith("/memories"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Path must start with /memories, got: {path}. "
|
||||||
|
"All memory operations must be confined to the /memories directory."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove /memories prefix and any leading slashes
|
||||||
|
relative_path = path[len("/memories") :].lstrip("/")
|
||||||
|
|
||||||
|
# Resolve to absolute path within memory_root
|
||||||
|
if relative_path:
|
||||||
|
full_path = (self.memory_root / relative_path).resolve()
|
||||||
|
else:
|
||||||
|
full_path = self.memory_root.resolve()
|
||||||
|
|
||||||
|
# Verify the resolved path is still within memory_root
|
||||||
|
try:
|
||||||
|
full_path.relative_to(self.memory_root.resolve())
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Path '{path}' would escape /memories directory. "
|
||||||
|
"Directory traversal attempts are not allowed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
def execute(self, **params: Any) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Execute a memory tool command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**params: Command parameters from Claude's tool use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with either 'success' or 'error' key
|
||||||
|
|
||||||
|
Supported commands:
|
||||||
|
- view: Show directory contents or file contents
|
||||||
|
- create: Create or overwrite a file
|
||||||
|
- str_replace: Replace text in a file
|
||||||
|
- insert: Insert text at a specific line
|
||||||
|
- delete: Delete a file or directory
|
||||||
|
- rename: Rename or move a file/directory
|
||||||
|
"""
|
||||||
|
command = params.get("command")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if command == "view":
|
||||||
|
return self._view(params)
|
||||||
|
elif command == "create":
|
||||||
|
return self._create(params)
|
||||||
|
elif command == "str_replace":
|
||||||
|
return self._str_replace(params)
|
||||||
|
elif command == "insert":
|
||||||
|
return self._insert(params)
|
||||||
|
elif command == "delete":
|
||||||
|
return self._delete(params)
|
||||||
|
elif command == "rename":
|
||||||
|
return self._rename(params)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": f"Unknown command: '{command}'. "
|
||||||
|
"Valid commands are: view, create, str_replace, insert, delete, rename"
|
||||||
|
}
|
||||||
|
except ValueError as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Unexpected error executing {command}: {e}"}
|
||||||
|
|
||||||
|
def _view(self, params: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""View directory contents or file contents."""
|
||||||
|
path = params.get("path")
|
||||||
|
view_range = params.get("view_range")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return {"error": "Missing required parameter: path"}
|
||||||
|
|
||||||
|
full_path = self._validate_path(path)
|
||||||
|
|
||||||
|
# Handle directory listing
|
||||||
|
if full_path.is_dir():
|
||||||
|
try:
|
||||||
|
items = []
|
||||||
|
for item in sorted(full_path.iterdir()):
|
||||||
|
if item.name.startswith("."):
|
||||||
|
continue
|
||||||
|
items.append(f"{item.name}/" if item.is_dir() else item.name)
|
||||||
|
|
||||||
|
if not items:
|
||||||
|
return {"success": f"Directory: {path}\n(empty)"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": f"Directory: {path}\n"
|
||||||
|
+ "\n".join([f"- {item}" for item in items])
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot read directory {path}: {e}"}
|
||||||
|
|
||||||
|
# Handle file reading
|
||||||
|
elif full_path.is_file():
|
||||||
|
try:
|
||||||
|
content = full_path.read_text(encoding="utf-8")
|
||||||
|
lines = content.splitlines()
|
||||||
|
|
||||||
|
# Apply view range if specified
|
||||||
|
if view_range:
|
||||||
|
start_line = max(1, view_range[0]) - 1 # Convert to 0-indexed
|
||||||
|
end_line = len(lines) if view_range[1] == -1 else view_range[1]
|
||||||
|
lines = lines[start_line:end_line]
|
||||||
|
start_num = start_line + 1
|
||||||
|
else:
|
||||||
|
start_num = 1
|
||||||
|
|
||||||
|
# Format with line numbers
|
||||||
|
numbered_lines = [
|
||||||
|
f"{i + start_num:4d}: {line}" for i, line in enumerate(lines)
|
||||||
|
]
|
||||||
|
return {"success": "\n".join(numbered_lines)}
|
||||||
|
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return {"error": f"Cannot read {path}: File is not valid UTF-8 text"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot read file {path}: {e}"}
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {"error": f"Path not found: {path}"}
|
||||||
|
|
||||||
|
def _create(self, params: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""Create or overwrite a file."""
|
||||||
|
path = params.get("path")
|
||||||
|
file_text = params.get("file_text", "")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return {"error": "Missing required parameter: path"}
|
||||||
|
|
||||||
|
full_path = self._validate_path(path)
|
||||||
|
|
||||||
|
# Don't allow creating directories directly
|
||||||
|
if not path.endswith((".txt", ".md", ".json", ".py", ".yaml", ".yml")):
|
||||||
|
return {
|
||||||
|
"error": f"Cannot create {path}: Only text files are supported. "
|
||||||
|
"Use file extensions: .txt, .md, .json, .py, .yaml, .yml"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create parent directories if needed
|
||||||
|
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Write the file
|
||||||
|
full_path.write_text(file_text, encoding="utf-8")
|
||||||
|
return {"success": f"File created successfully at {path}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot create file {path}: {e}"}
|
||||||
|
|
||||||
|
def _str_replace(self, params: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""Replace text in a file."""
|
||||||
|
path = params.get("path")
|
||||||
|
old_str = params.get("old_str")
|
||||||
|
new_str = params.get("new_str", "")
|
||||||
|
|
||||||
|
if not path or old_str is None:
|
||||||
|
return {"error": "Missing required parameters: path, old_str"}
|
||||||
|
|
||||||
|
full_path = self._validate_path(path)
|
||||||
|
|
||||||
|
if not full_path.is_file():
|
||||||
|
return {"error": f"File not found: {path}"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = full_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Check if old_str exists
|
||||||
|
count = content.count(old_str)
|
||||||
|
if count == 0:
|
||||||
|
return {
|
||||||
|
"error": f"String not found in {path}. "
|
||||||
|
"The exact text must exist in the file."
|
||||||
|
}
|
||||||
|
elif count > 1:
|
||||||
|
return {
|
||||||
|
"error": f"String appears {count} times in {path}. "
|
||||||
|
"The string must be unique. Use more specific context."
|
||||||
|
}
|
||||||
|
|
||||||
|
# Perform replacement
|
||||||
|
new_content = content.replace(old_str, new_str, 1)
|
||||||
|
full_path.write_text(new_content, encoding="utf-8")
|
||||||
|
|
||||||
|
return {"success": f"File {path} has been edited successfully"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot edit file {path}: {e}"}
|
||||||
|
|
||||||
|
def _insert(self, params: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""Insert text at a specific line."""
|
||||||
|
path = params.get("path")
|
||||||
|
insert_line = params.get("insert_line")
|
||||||
|
insert_text = params.get("insert_text", "")
|
||||||
|
|
||||||
|
if not path or insert_line is None:
|
||||||
|
return {"error": "Missing required parameters: path, insert_line"}
|
||||||
|
|
||||||
|
full_path = self._validate_path(path)
|
||||||
|
|
||||||
|
if not full_path.is_file():
|
||||||
|
return {"error": f"File not found: {path}"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
lines = full_path.read_text(encoding="utf-8").splitlines()
|
||||||
|
|
||||||
|
# Validate insert_line
|
||||||
|
if insert_line < 0 or insert_line > len(lines):
|
||||||
|
return {
|
||||||
|
"error": f"Invalid insert_line {insert_line}. "
|
||||||
|
f"Must be between 0 and {len(lines)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Insert the text
|
||||||
|
lines.insert(insert_line, insert_text.rstrip("\n"))
|
||||||
|
|
||||||
|
# Write back
|
||||||
|
full_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||||
|
|
||||||
|
return {"success": f"Text inserted at line {insert_line} in {path}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot insert into {path}: {e}"}
|
||||||
|
|
||||||
|
def _delete(self, params: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""Delete a file or directory."""
|
||||||
|
path = params.get("path")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return {"error": "Missing required parameter: path"}
|
||||||
|
|
||||||
|
# Prevent deletion of root memories directory
|
||||||
|
if path == "/memories":
|
||||||
|
return {"error": "Cannot delete the /memories directory itself"}
|
||||||
|
|
||||||
|
full_path = self._validate_path(path)
|
||||||
|
|
||||||
|
# Verify the path is within /memories to prevent accidental deletion outside the memory directory
|
||||||
|
# This provides an additional safety check beyond _validate_path
|
||||||
|
try:
|
||||||
|
full_path.relative_to(self.memory_root.resolve())
|
||||||
|
except ValueError:
|
||||||
|
return {
|
||||||
|
"error": f"Invalid operation: Path '{path}' is not within /memories directory. "
|
||||||
|
"Only paths within /memories can be deleted."
|
||||||
|
}
|
||||||
|
|
||||||
|
if not full_path.exists():
|
||||||
|
return {"error": f"Path not found: {path}"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if full_path.is_file():
|
||||||
|
full_path.unlink()
|
||||||
|
return {"success": f"File deleted: {path}"}
|
||||||
|
elif full_path.is_dir():
|
||||||
|
shutil.rmtree(full_path)
|
||||||
|
return {"success": f"Directory deleted: {path}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot delete {path}: {e}"}
|
||||||
|
|
||||||
|
def _rename(self, params: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""Rename or move a file/directory."""
|
||||||
|
old_path = params.get("old_path")
|
||||||
|
new_path = params.get("new_path")
|
||||||
|
|
||||||
|
if not old_path or not new_path:
|
||||||
|
return {"error": "Missing required parameters: old_path, new_path"}
|
||||||
|
|
||||||
|
old_full_path = self._validate_path(old_path)
|
||||||
|
new_full_path = self._validate_path(new_path)
|
||||||
|
|
||||||
|
if not old_full_path.exists():
|
||||||
|
return {"error": f"Source path not found: {old_path}"}
|
||||||
|
|
||||||
|
if new_full_path.exists():
|
||||||
|
return {
|
||||||
|
"error": f"Destination already exists: {new_path}. "
|
||||||
|
"Cannot overwrite existing files/directories."
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create parent directories if needed
|
||||||
|
new_full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Perform rename/move
|
||||||
|
old_full_path.rename(new_full_path)
|
||||||
|
|
||||||
|
return {"success": f"Renamed {old_path} to {new_path}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot rename {old_path} to {new_path}: {e}"}
|
||||||
|
|
||||||
|
def clear_all_memory(self) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Clear all memory files (useful for testing or starting fresh).
|
||||||
|
|
||||||
|
⚠️ WARNING: This method is for demonstration and testing purposes only.
|
||||||
|
In production, you should carefully consider whether you need to delete
|
||||||
|
all memory files, as this will permanently remove all learned patterns
|
||||||
|
and stored knowledge. Consider using selective deletion instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with success message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.memory_root.exists():
|
||||||
|
shutil.rmtree(self.memory_root)
|
||||||
|
self.memory_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
return {"success": "All memory cleared successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Cannot clear memory: {e}"}
|
||||||
3
tool_use/requirements.txt
Normal file
3
tool_use/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
anthropic>=0.18.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
ipykernel>=6.29.0 # For Jupyter in VSCode
|
||||||
426
tool_use/tests/test_memory_tool.py
Normal file
426
tool_use/tests/test_memory_tool.py
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the memory tool handler.
|
||||||
|
|
||||||
|
Tests security validation, command execution, and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from memory_tool import MemoryToolHandler
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryToolHandler(unittest.TestCase):
|
||||||
|
"""Test suite for MemoryToolHandler."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Create temporary directory for each test."""
|
||||||
|
self.test_dir = tempfile.mkdtemp()
|
||||||
|
self.handler = MemoryToolHandler(base_path=self.test_dir)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up temporary directory after each test."""
|
||||||
|
shutil.rmtree(self.test_dir)
|
||||||
|
|
||||||
|
# Security Tests
|
||||||
|
|
||||||
|
def test_path_validation_requires_memories_prefix(self):
|
||||||
|
"""Test that paths must start with /memories."""
|
||||||
|
result = self.handler.execute(command="view", path="/etc/passwd")
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("must start with /memories", result["error"])
|
||||||
|
|
||||||
|
def test_path_validation_prevents_traversal_dotdot(self):
|
||||||
|
"""Test that .. traversal is blocked."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="view", path="/memories/../../../etc/passwd"
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("escape", result["error"].lower())
|
||||||
|
|
||||||
|
def test_path_validation_prevents_traversal_encoded(self):
|
||||||
|
"""Test that URL-encoded traversal is blocked."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="view", path="/memories/%2e%2e/%2e%2e/etc/passwd"
|
||||||
|
)
|
||||||
|
# The path will be processed and should fail validation
|
||||||
|
self.assertIn("error", result)
|
||||||
|
|
||||||
|
def test_path_validation_allows_valid_paths(self):
|
||||||
|
"""Test that valid memory paths are accepted."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="test"
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
# View Command Tests
|
||||||
|
|
||||||
|
def test_view_empty_directory(self):
|
||||||
|
"""Test viewing an empty /memories directory."""
|
||||||
|
result = self.handler.execute(command="view", path="/memories")
|
||||||
|
self.assertIn("success", result)
|
||||||
|
self.assertIn("empty", result["success"].lower())
|
||||||
|
|
||||||
|
def test_view_directory_with_files(self):
|
||||||
|
"""Test viewing a directory with files."""
|
||||||
|
# Create some test files
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/file1.txt", file_text="content1"
|
||||||
|
)
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/file2.txt", file_text="content2"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(command="view", path="/memories")
|
||||||
|
self.assertIn("success", result)
|
||||||
|
self.assertIn("file1.txt", result["success"])
|
||||||
|
self.assertIn("file2.txt", result["success"])
|
||||||
|
|
||||||
|
def test_view_file_with_line_numbers(self):
|
||||||
|
"""Test viewing a file with line numbers."""
|
||||||
|
content = "line 1\nline 2\nline 3"
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text=content
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(command="view", path="/memories/test.txt")
|
||||||
|
self.assertIn("success", result)
|
||||||
|
self.assertIn(" 1: line 1", result["success"])
|
||||||
|
self.assertIn(" 2: line 2", result["success"])
|
||||||
|
self.assertIn(" 3: line 3", result["success"])
|
||||||
|
|
||||||
|
def test_view_file_with_range(self):
|
||||||
|
"""Test viewing specific line range."""
|
||||||
|
content = "line 1\nline 2\nline 3\nline 4"
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text=content
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="view", path="/memories/test.txt", view_range=[2, 3]
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
self.assertIn(" 2: line 2", result["success"])
|
||||||
|
self.assertIn(" 3: line 3", result["success"])
|
||||||
|
self.assertNotIn("line 1", result["success"])
|
||||||
|
self.assertNotIn("line 4", result["success"])
|
||||||
|
|
||||||
|
def test_view_nonexistent_path(self):
|
||||||
|
"""Test viewing a nonexistent path."""
|
||||||
|
result = self.handler.execute(command="view", path="/memories/notfound.txt")
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("not found", result["error"].lower())
|
||||||
|
|
||||||
|
# Create Command Tests
|
||||||
|
|
||||||
|
def test_create_file(self):
|
||||||
|
"""Test creating a file."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="Hello, World!"
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
# Verify file exists
|
||||||
|
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||||
|
self.assertTrue(file_path.exists())
|
||||||
|
self.assertEqual(file_path.read_text(), "Hello, World!")
|
||||||
|
|
||||||
|
def test_create_file_in_subdirectory(self):
|
||||||
|
"""Test creating a file in a subdirectory (auto-creates dirs)."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="create",
|
||||||
|
path="/memories/subdir/test.txt",
|
||||||
|
file_text="Nested content",
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
file_path = Path(self.test_dir) / "memories" / "subdir" / "test.txt"
|
||||||
|
self.assertTrue(file_path.exists())
|
||||||
|
|
||||||
|
def test_create_requires_file_extension(self):
|
||||||
|
"""Test that create only allows text file extensions."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="create", path="/memories/noext", file_text="content"
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("text files are supported", result["error"])
|
||||||
|
|
||||||
|
def test_create_overwrites_existing_file(self):
|
||||||
|
"""Test that create overwrites existing files."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="original"
|
||||||
|
)
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="updated"
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||||
|
self.assertEqual(file_path.read_text(), "updated")
|
||||||
|
|
||||||
|
# String Replace Command Tests
|
||||||
|
|
||||||
|
def test_str_replace_success(self):
|
||||||
|
"""Test successful string replacement."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
file_text="Hello World",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="str_replace",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
old_str="World",
|
||||||
|
new_str="Universe",
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||||
|
self.assertEqual(file_path.read_text(), "Hello Universe")
|
||||||
|
|
||||||
|
def test_str_replace_string_not_found(self):
|
||||||
|
"""Test replacement when string doesn't exist."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="Hello World"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="str_replace",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
old_str="Missing",
|
||||||
|
new_str="Text",
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("not found", result["error"].lower())
|
||||||
|
|
||||||
|
def test_str_replace_multiple_occurrences(self):
|
||||||
|
"""Test that replacement fails with multiple occurrences."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
file_text="Hello World Hello World",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="str_replace",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
old_str="Hello",
|
||||||
|
new_str="Hi",
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("appears 2 times", result["error"])
|
||||||
|
|
||||||
|
def test_str_replace_file_not_found(self):
|
||||||
|
"""Test replacement on nonexistent file."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="str_replace",
|
||||||
|
path="/memories/notfound.txt",
|
||||||
|
old_str="old",
|
||||||
|
new_str="new",
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("not found", result["error"].lower())
|
||||||
|
|
||||||
|
# Insert Command Tests
|
||||||
|
|
||||||
|
def test_insert_at_beginning(self):
|
||||||
|
"""Test inserting at line 0 (beginning)."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="insert",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
insert_line=0,
|
||||||
|
insert_text="new line",
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||||
|
self.assertEqual(file_path.read_text(), "new line\nline 1\nline 2\n")
|
||||||
|
|
||||||
|
def test_insert_in_middle(self):
|
||||||
|
"""Test inserting in the middle."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="insert",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
insert_line=1,
|
||||||
|
insert_text="inserted",
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||||
|
self.assertEqual(file_path.read_text(), "line 1\ninserted\nline 2\n")
|
||||||
|
|
||||||
|
def test_insert_at_end(self):
|
||||||
|
"""Test inserting at the end."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="line 1\nline 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="insert",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
insert_line=2,
|
||||||
|
insert_text="last line",
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
def test_insert_invalid_line(self):
|
||||||
|
"""Test insert with invalid line number."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="line 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="insert",
|
||||||
|
path="/memories/test.txt",
|
||||||
|
insert_line=99,
|
||||||
|
insert_text="text",
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("invalid", result["error"].lower())
|
||||||
|
|
||||||
|
# Delete Command Tests
|
||||||
|
|
||||||
|
def test_delete_file(self):
|
||||||
|
"""Test deleting a file."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/test.txt", file_text="content"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(command="delete", path="/memories/test.txt")
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
file_path = Path(self.test_dir) / "memories" / "test.txt"
|
||||||
|
self.assertFalse(file_path.exists())
|
||||||
|
|
||||||
|
def test_delete_directory(self):
|
||||||
|
"""Test deleting a directory."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/subdir/test.txt", file_text="content"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(command="delete", path="/memories/subdir")
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
dir_path = Path(self.test_dir) / "memories" / "subdir"
|
||||||
|
self.assertFalse(dir_path.exists())
|
||||||
|
|
||||||
|
def test_delete_cannot_delete_root(self):
|
||||||
|
"""Test that root /memories directory cannot be deleted."""
|
||||||
|
result = self.handler.execute(command="delete", path="/memories")
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("cannot delete", result["error"].lower())
|
||||||
|
|
||||||
|
def test_delete_nonexistent_path(self):
|
||||||
|
"""Test deleting a nonexistent path."""
|
||||||
|
result = self.handler.execute(command="delete", path="/memories/notfound.txt")
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("not found", result["error"].lower())
|
||||||
|
|
||||||
|
# Rename Command Tests
|
||||||
|
|
||||||
|
def test_rename_file(self):
|
||||||
|
"""Test renaming a file."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/old.txt", file_text="content"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="rename", old_path="/memories/old.txt", new_path="/memories/new.txt"
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
old_path = Path(self.test_dir) / "memories" / "old.txt"
|
||||||
|
new_path = Path(self.test_dir) / "memories" / "new.txt"
|
||||||
|
self.assertFalse(old_path.exists())
|
||||||
|
self.assertTrue(new_path.exists())
|
||||||
|
|
||||||
|
def test_rename_to_subdirectory(self):
|
||||||
|
"""Test moving a file to a subdirectory."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/file.txt", file_text="content"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="rename",
|
||||||
|
old_path="/memories/file.txt",
|
||||||
|
new_path="/memories/subdir/file.txt",
|
||||||
|
)
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
new_path = Path(self.test_dir) / "memories" / "subdir" / "file.txt"
|
||||||
|
self.assertTrue(new_path.exists())
|
||||||
|
|
||||||
|
def test_rename_source_not_found(self):
|
||||||
|
"""Test rename when source doesn't exist."""
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="rename",
|
||||||
|
old_path="/memories/notfound.txt",
|
||||||
|
new_path="/memories/new.txt",
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("not found", result["error"].lower())
|
||||||
|
|
||||||
|
def test_rename_destination_exists(self):
|
||||||
|
"""Test rename when destination already exists."""
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/file1.txt", file_text="content1"
|
||||||
|
)
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/file2.txt", file_text="content2"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.execute(
|
||||||
|
command="rename",
|
||||||
|
old_path="/memories/file1.txt",
|
||||||
|
new_path="/memories/file2.txt",
|
||||||
|
)
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("already exists", result["error"].lower())
|
||||||
|
|
||||||
|
# Error Handling Tests
|
||||||
|
|
||||||
|
def test_unknown_command(self):
|
||||||
|
"""Test handling of unknown command."""
|
||||||
|
result = self.handler.execute(command="invalid", path="/memories")
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertIn("unknown command", result["error"].lower())
|
||||||
|
|
||||||
|
def test_missing_required_parameters(self):
|
||||||
|
"""Test error handling for missing parameters."""
|
||||||
|
result = self.handler.execute(command="view")
|
||||||
|
self.assertIn("error", result)
|
||||||
|
|
||||||
|
# Utility Tests
|
||||||
|
|
||||||
|
def test_clear_all_memory(self):
|
||||||
|
"""Test clearing all memory."""
|
||||||
|
# Create some files
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/file1.txt", file_text="content1"
|
||||||
|
)
|
||||||
|
self.handler.execute(
|
||||||
|
command="create", path="/memories/file2.txt", file_text="content2"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.handler.clear_all_memory()
|
||||||
|
self.assertIn("success", result)
|
||||||
|
|
||||||
|
# Verify directory exists but is empty
|
||||||
|
memory_root = Path(self.test_dir) / "memories"
|
||||||
|
self.assertTrue(memory_root.exists())
|
||||||
|
self.assertEqual(len(list(memory_root.iterdir())), 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user