mirror of
https://github.com/anthropics/claude-agent-sdk-python.git
synced 2025-10-06 01:00:03 +03:00
Initial hooks implementation
🏠 Remote-Dev: homespace
This commit is contained in:
@@ -14,16 +14,30 @@ from .types import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ContentBlock,
|
||||
HookCallback,
|
||||
HookCallbackMatcher,
|
||||
HookEvent,
|
||||
HookInput,
|
||||
HookJSONOutput,
|
||||
McpServerConfig,
|
||||
Message,
|
||||
NotificationHookInput,
|
||||
PermissionMode,
|
||||
PostToolUseHookInput,
|
||||
PreCompactHookInput,
|
||||
PreToolUseHookInput,
|
||||
ResultMessage,
|
||||
SessionEndHookInput,
|
||||
SessionStartHookInput,
|
||||
StopHookInput,
|
||||
SubagentStopHookInput,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
UserPromptSubmitHookInput,
|
||||
)
|
||||
|
||||
__version__ = "0.0.20"
|
||||
@@ -48,6 +62,21 @@ __all__ = [
|
||||
"ToolUseBlock",
|
||||
"ToolResultBlock",
|
||||
"ContentBlock",
|
||||
# Hook types
|
||||
"HookEvent",
|
||||
"HookCallback",
|
||||
"HookCallbackMatcher",
|
||||
"HookInput",
|
||||
"HookJSONOutput",
|
||||
"PreToolUseHookInput",
|
||||
"PostToolUseHookInput",
|
||||
"NotificationHookInput",
|
||||
"UserPromptSubmitHookInput",
|
||||
"SessionStartHookInput",
|
||||
"SessionEndHookInput",
|
||||
"StopHookInput",
|
||||
"SubagentStopHookInput",
|
||||
"PreCompactHookInput",
|
||||
# Errors
|
||||
"ClaudeSDKError",
|
||||
"CLIConnectionError",
|
||||
|
||||
143
src/claude_code_sdk/_internal/control_protocol.py
Normal file
143
src/claude_code_sdk/_internal/control_protocol.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Control protocol types and handlers for SDK-CLI communication."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
from ..types import HookEvent, PermissionMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKHookCallbackMatcher:
|
||||
"""Hook callback matcher for control protocol."""
|
||||
|
||||
matcher: str | None = None
|
||||
hook_callback_ids: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlInitializeRequest:
|
||||
"""Initialize request for control protocol."""
|
||||
|
||||
subtype: Literal["initialize"]
|
||||
hooks: dict[HookEvent, list[SDKHookCallbackMatcher]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlInterruptRequest:
|
||||
"""Interrupt request for control protocol."""
|
||||
|
||||
subtype: Literal["interrupt"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlPermissionRequest:
|
||||
"""Permission request for control protocol."""
|
||||
|
||||
subtype: Literal["can_use_tool"]
|
||||
tool_name: str
|
||||
input: dict[str, Any]
|
||||
permission_suggestions: list[dict[str, Any]] | None = None
|
||||
blocked_path: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlSetPermissionModeRequest:
|
||||
"""Set permission mode request for control protocol."""
|
||||
|
||||
subtype: Literal["set_permission_mode"]
|
||||
mode: PermissionMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKHookCallbackRequest:
|
||||
"""Hook callback request for control protocol."""
|
||||
|
||||
subtype: Literal["hook_callback"]
|
||||
callback_id: str
|
||||
input: dict[str, Any]
|
||||
tool_use_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlMcpMessageRequest:
|
||||
"""MCP message request for control protocol."""
|
||||
|
||||
subtype: Literal["mcp_message"]
|
||||
server_name: str
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
# Union type for all control request subtypes
|
||||
SDKControlRequestSubtype = (
|
||||
SDKControlInitializeRequest
|
||||
| SDKControlInterruptRequest
|
||||
| SDKControlPermissionRequest
|
||||
| SDKControlSetPermissionModeRequest
|
||||
| SDKHookCallbackRequest
|
||||
| SDKControlMcpMessageRequest
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlRequest:
|
||||
"""Control request wrapper."""
|
||||
|
||||
type: Literal["control_request"]
|
||||
request_id: str
|
||||
request: SDKControlRequestSubtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlCancelRequest:
|
||||
"""Cancel a control request."""
|
||||
|
||||
type: Literal["control_cancel_request"]
|
||||
request_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlSuccessResponse:
|
||||
"""Successful control response."""
|
||||
|
||||
subtype: Literal["success"]
|
||||
request_id: str
|
||||
response: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlErrorResponse:
|
||||
"""Error control response."""
|
||||
|
||||
subtype: Literal["error"]
|
||||
request_id: str
|
||||
error: str
|
||||
|
||||
|
||||
# Union type for control responses
|
||||
ControlResponse = ControlSuccessResponse | ControlErrorResponse
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlResponse:
|
||||
"""Control response wrapper."""
|
||||
|
||||
type: Literal["control_response"]
|
||||
response: ControlResponse
|
||||
|
||||
|
||||
@dataclass
|
||||
class Command:
|
||||
"""Command metadata."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
argument_hint: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDKControlInitializeResponse:
|
||||
"""Initialize response from control protocol."""
|
||||
|
||||
commands: list[Command]
|
||||
output_style: str
|
||||
available_output_styles: list[str]
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Subprocess transport implementation using Claude Code CLI."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -17,7 +18,21 @@ from anyio.streams.text import TextReceiveStream, TextSendStream
|
||||
|
||||
from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError
|
||||
from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
|
||||
from ...types import ClaudeCodeOptions
|
||||
from ...types import (
|
||||
ClaudeCodeOptions,
|
||||
HookCallback,
|
||||
HookEvent,
|
||||
HookInput,
|
||||
HookJSONOutput,
|
||||
)
|
||||
from ..control_protocol import (
|
||||
ControlErrorResponse,
|
||||
ControlSuccessResponse,
|
||||
SDKControlInitializeRequest,
|
||||
SDKControlInitializeResponse,
|
||||
SDKControlResponse,
|
||||
SDKHookCallbackMatcher,
|
||||
)
|
||||
from . import Transport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -50,6 +65,13 @@ class SubprocessCLITransport(Transport):
|
||||
self._task_group: anyio.abc.TaskGroup | None = None
|
||||
self._stderr_file: Any = None # tempfile.NamedTemporaryFile
|
||||
|
||||
# Hooks support
|
||||
self._hook_callbacks: dict[str, HookCallback] = {}
|
||||
self._next_callback_id = 0
|
||||
self._cancel_controllers: dict[str, asyncio.Event] = {}
|
||||
self._initialized = False
|
||||
self._hooks_enabled = options.hooks is not None
|
||||
|
||||
def _find_cli(self) -> str:
|
||||
"""Find Claude Code CLI binary."""
|
||||
if cli := shutil.which("claude"):
|
||||
@@ -202,7 +224,13 @@ class SubprocessCLITransport(Transport):
|
||||
# Start streaming messages to stdin in background
|
||||
self._task_group = anyio.create_task_group()
|
||||
await self._task_group.__aenter__()
|
||||
|
||||
# Initialize hooks if enabled
|
||||
if self._hooks_enabled:
|
||||
await self._initialize_hooks()
|
||||
|
||||
self._task_group.start_soon(self._stream_to_stdin)
|
||||
self._task_group.start_soon(self._handle_control_requests)
|
||||
else:
|
||||
# String mode: close stdin immediately (backward compatible)
|
||||
if self._process.stdin:
|
||||
@@ -334,7 +362,7 @@ class SubprocessCLITransport(Transport):
|
||||
data = json.loads(json_buffer)
|
||||
json_buffer = ""
|
||||
|
||||
# Handle control responses separately
|
||||
# Handle control messages separately
|
||||
if data.get("type") == "control_response":
|
||||
response = data.get("response", {})
|
||||
request_id = response.get("request_id")
|
||||
@@ -343,6 +371,20 @@ class SubprocessCLITransport(Transport):
|
||||
self._pending_control_responses[request_id] = response
|
||||
continue
|
||||
|
||||
# Handle control requests (for hooks)
|
||||
if data.get("type") == "control_request":
|
||||
# Queue control request for processing
|
||||
if hasattr(self, "_control_request_queue"):
|
||||
await self._control_request_queue.put(data)
|
||||
continue
|
||||
|
||||
# Handle control cancel requests
|
||||
if data.get("type") == "control_cancel_request":
|
||||
request_id = data.get("request_id")
|
||||
if request_id in self._cancel_controllers:
|
||||
self._cancel_controllers[request_id].set()
|
||||
continue
|
||||
|
||||
try:
|
||||
yield data
|
||||
except GeneratorExit:
|
||||
@@ -444,3 +486,210 @@ class SubprocessCLITransport(Transport):
|
||||
raise CLIConnectionError(f"Control request failed: {response.get('error')}")
|
||||
|
||||
return response
|
||||
|
||||
async def _initialize_hooks(self) -> None:
|
||||
"""Initialize hooks with the CLI."""
|
||||
if not self._options.hooks:
|
||||
return
|
||||
|
||||
# Convert hooks to SDK format
|
||||
sdk_hooks: dict[HookEvent, list[SDKHookCallbackMatcher]] = {}
|
||||
|
||||
for event, matchers in self._options.hooks.items():
|
||||
sdk_matchers = []
|
||||
for matcher in matchers:
|
||||
callback_ids = []
|
||||
for callback in matcher.hooks:
|
||||
callback_id = f"hook_{self._next_callback_id}"
|
||||
self._next_callback_id += 1
|
||||
self._hook_callbacks[callback_id] = callback
|
||||
callback_ids.append(callback_id)
|
||||
|
||||
sdk_matchers.append(
|
||||
SDKHookCallbackMatcher(
|
||||
matcher=matcher.matcher, hook_callback_ids=callback_ids
|
||||
)
|
||||
)
|
||||
sdk_hooks[event] = sdk_matchers
|
||||
|
||||
# Send initialize request
|
||||
# sdk_hooks is dict[HookEvent, list[SDKHookCallbackMatcher]]
|
||||
# but protocol expects dict[str, list[dict]]
|
||||
hooks_dict: dict[str, Any] = {k: [m.__dict__ for m in v] for k, v in sdk_hooks.items()}
|
||||
init_request = SDKControlInitializeRequest(
|
||||
subtype="initialize",
|
||||
hooks=hooks_dict # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await self._send_control_request(init_request.__dict__)
|
||||
self._initialized = True
|
||||
|
||||
# Initialize response is not returned, just stored
|
||||
# return SDKControlInitializeResponse(**response.get("response", {}))
|
||||
|
||||
async def _handle_control_requests(self) -> None:
|
||||
"""Handle incoming control requests from CLI."""
|
||||
# Create a queue for control requests
|
||||
self._control_request_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for control request
|
||||
request_data = await self._control_request_queue.get()
|
||||
|
||||
# Process control request
|
||||
request_id = request_data.get("request_id")
|
||||
request = request_data.get("request", {})
|
||||
|
||||
if not request_id:
|
||||
continue
|
||||
|
||||
# Create cancel controller for this request
|
||||
cancel_event = asyncio.Event()
|
||||
self._cancel_controllers[request_id] = cancel_event
|
||||
|
||||
try:
|
||||
response = await self._process_control_request(
|
||||
request, cancel_event
|
||||
)
|
||||
|
||||
# Send success response
|
||||
control_response = SDKControlResponse(
|
||||
type="control_response",
|
||||
response=ControlSuccessResponse(
|
||||
subtype="success", request_id=request_id, response=response
|
||||
),
|
||||
)
|
||||
|
||||
if self._stdin_stream:
|
||||
await self._stdin_stream.send(
|
||||
json.dumps(
|
||||
control_response.__dict__, default=lambda o: o.__dict__
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Send error response
|
||||
control_response = SDKControlResponse(
|
||||
type="control_response",
|
||||
response=ControlErrorResponse(
|
||||
subtype="error", request_id=request_id, error=str(e)
|
||||
),
|
||||
)
|
||||
|
||||
if self._stdin_stream:
|
||||
await self._stdin_stream.send(
|
||||
json.dumps(
|
||||
control_response.__dict__, default=lambda o: o.__dict__
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up cancel controller
|
||||
if request_id in self._cancel_controllers:
|
||||
del self._cancel_controllers[request_id]
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task group was cancelled
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in control request handler: {e}")
|
||||
|
||||
async def _process_control_request(
|
||||
self, request: dict[str, Any], cancel_event: asyncio.Event
|
||||
) -> dict[str, Any]:
|
||||
"""Process a control request from CLI."""
|
||||
subtype = request.get("subtype")
|
||||
|
||||
if subtype == "hook_callback":
|
||||
# Handle hook callback
|
||||
callback_id = request.get("callback_id")
|
||||
if not callback_id:
|
||||
raise ValueError("Missing callback_id in hook callback request")
|
||||
input_data = request.get("input", {})
|
||||
tool_use_id = request.get("tool_use_id")
|
||||
|
||||
callback = self._hook_callbacks.get(callback_id)
|
||||
if not callback:
|
||||
raise ValueError(f"No hook callback found for ID: {callback_id}")
|
||||
|
||||
# Convert input data to appropriate HookInput type
|
||||
hook_input = self._parse_hook_input(input_data)
|
||||
|
||||
# Create options with abort signal
|
||||
options = {"signal": cancel_event}
|
||||
|
||||
# Call the hook
|
||||
result = await callback(hook_input, tool_use_id, options)
|
||||
|
||||
# Convert result to dict
|
||||
return self._hook_output_to_dict(result)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported control request subtype: {subtype}")
|
||||
|
||||
def _parse_hook_input(self, data: dict[str, Any]) -> HookInput:
|
||||
"""Parse hook input data into appropriate type."""
|
||||
from ...types import (
|
||||
NotificationHookInput,
|
||||
PostToolUseHookInput,
|
||||
PreCompactHookInput,
|
||||
PreToolUseHookInput,
|
||||
SessionEndHookInput,
|
||||
SessionStartHookInput,
|
||||
StopHookInput,
|
||||
SubagentStopHookInput,
|
||||
UserPromptSubmitHookInput,
|
||||
)
|
||||
|
||||
event_name = data.get("hook_event_name")
|
||||
if not event_name:
|
||||
raise ValueError("Missing hook_event_name in hook input")
|
||||
|
||||
# Map event names to input classes
|
||||
input_classes: dict[str, type[HookInput]] = {
|
||||
"PreToolUse": PreToolUseHookInput,
|
||||
"PostToolUse": PostToolUseHookInput,
|
||||
"Notification": NotificationHookInput,
|
||||
"UserPromptSubmit": UserPromptSubmitHookInput,
|
||||
"SessionStart": SessionStartHookInput,
|
||||
"SessionEnd": SessionEndHookInput,
|
||||
"Stop": StopHookInput,
|
||||
"SubagentStop": SubagentStopHookInput,
|
||||
"PreCompact": PreCompactHookInput,
|
||||
}
|
||||
|
||||
input_class = input_classes.get(event_name)
|
||||
if not input_class:
|
||||
raise ValueError(f"Unknown hook event: {event_name}")
|
||||
|
||||
# Create instance from dict
|
||||
return input_class(**data)
|
||||
|
||||
def _hook_output_to_dict(self, output: HookJSONOutput) -> dict[str, Any]:
|
||||
"""Convert HookJSONOutput to dict for JSON serialization."""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
# Map Python field names to JSON field names
|
||||
if output.continue_ is not None:
|
||||
result["continue"] = output.continue_
|
||||
if output.suppress_output is not None:
|
||||
result["suppressOutput"] = output.suppress_output
|
||||
if output.stop_reason is not None:
|
||||
result["stopReason"] = output.stop_reason
|
||||
if output.decision is not None:
|
||||
result["decision"] = output.decision
|
||||
if output.system_message is not None:
|
||||
result["systemMessage"] = output.system_message
|
||||
if output.permission_decision is not None:
|
||||
result["permissionDecision"] = output.permission_decision
|
||||
if output.permission_decision_reason is not None:
|
||||
result["permissionDecisionReason"] = output.permission_decision_reason
|
||||
if output.reason is not None:
|
||||
result["reason"] = output.reason
|
||||
if output.hook_specific_output is not None:
|
||||
result["hookSpecificOutput"] = output.hook_specific_output
|
||||
|
||||
return result
|
||||
|
||||
@@ -5,7 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ._errors import CLIConnectionError
|
||||
from .types import ClaudeCodeOptions, Message, ResultMessage
|
||||
from .types import (
|
||||
ClaudeCodeOptions,
|
||||
HookCallbackMatcher,
|
||||
HookEvent,
|
||||
Message,
|
||||
ResultMessage,
|
||||
)
|
||||
|
||||
|
||||
class ClaudeSDKClient:
|
||||
@@ -88,12 +94,53 @@ class ClaudeSDKClient:
|
||||
|
||||
await client.disconnect()
|
||||
```
|
||||
|
||||
Example - With hooks for tool control:
|
||||
```python
|
||||
from claude_code_sdk import HookCallbackMatcher, PreToolUseHookInput, HookJSONOutput
|
||||
|
||||
async def pre_tool_hook(input: PreToolUseHookInput, tool_use_id, options):
|
||||
# Intercept and control tool execution
|
||||
if input.tool_name == "Bash":
|
||||
if "rm -rf" in str(input.tool_input):
|
||||
return HookJSONOutput(
|
||||
permission_decision="deny",
|
||||
reason="Dangerous command detected"
|
||||
)
|
||||
return HookJSONOutput(permission_decision="allow")
|
||||
|
||||
hooks = {
|
||||
"PreToolUse": [
|
||||
HookCallbackMatcher(hooks=[pre_tool_hook])
|
||||
]
|
||||
}
|
||||
|
||||
async with ClaudeSDKClient(hooks=hooks) as client:
|
||||
await client.query("Clean up temporary files")
|
||||
async for message in client.receive_response():
|
||||
print(message)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, options: ClaudeCodeOptions | None = None):
|
||||
"""Initialize Claude SDK client."""
|
||||
def __init__(
|
||||
self,
|
||||
options: ClaudeCodeOptions | None = None,
|
||||
hooks: dict[HookEvent, list[HookCallbackMatcher]] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize Claude SDK client.
|
||||
|
||||
Args:
|
||||
options: Configuration options for Claude Code
|
||||
hooks: Optional dict of hook events to callback matchers for intercepting tool execution
|
||||
"""
|
||||
if options is None:
|
||||
options = ClaudeCodeOptions()
|
||||
|
||||
# Add hooks to options if provided
|
||||
if hooks is not None:
|
||||
options.hooks = hooks
|
||||
|
||||
self.options = options
|
||||
self._transport: Any | None = None
|
||||
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
from ._internal.client import InternalClient
|
||||
from ._internal.transport import Transport
|
||||
from .types import ClaudeCodeOptions, Message
|
||||
from .types import ClaudeCodeOptions, HookCallbackMatcher, HookEvent, Message
|
||||
|
||||
|
||||
async def query(
|
||||
@@ -14,6 +14,7 @@ async def query(
|
||||
prompt: str | AsyncIterable[dict[str, Any]],
|
||||
options: ClaudeCodeOptions | None = None,
|
||||
transport: Transport | None = None,
|
||||
hooks: dict[HookEvent, list[HookCallbackMatcher]] | None = None,
|
||||
) -> AsyncIterator[Message]:
|
||||
"""
|
||||
Query Claude Code for one-shot or unidirectional streaming interactions.
|
||||
@@ -61,6 +62,8 @@ async def query(
|
||||
transport: Optional transport implementation. If provided, this will be used
|
||||
instead of the default transport selection based on options.
|
||||
The transport will be automatically configured with the prompt and options.
|
||||
hooks: Optional dict of hook events to callback matchers. Hooks allow you to intercept
|
||||
and control tool execution. Only works in streaming mode (AsyncIterable prompt).
|
||||
|
||||
Yields:
|
||||
Messages from the conversation
|
||||
@@ -112,10 +115,45 @@ async def query(
|
||||
print(message)
|
||||
```
|
||||
|
||||
Example - With hooks (streaming mode required):
|
||||
```python
|
||||
from claude_code_sdk import query, HookCallbackMatcher, PreToolUseHookInput, HookJSONOutput
|
||||
|
||||
async def pre_tool_hook(input: PreToolUseHookInput, tool_use_id, options):
|
||||
if input.tool_name == "Edit":
|
||||
return HookJSONOutput(
|
||||
permission_decision="ask",
|
||||
reason="Edit operations require user confirmation"
|
||||
)
|
||||
return HookJSONOutput(permission_decision="allow")
|
||||
|
||||
async def prompts():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Fix the bug"}}
|
||||
|
||||
hooks = {
|
||||
"PreToolUse": [
|
||||
HookCallbackMatcher(
|
||||
matcher="Edit",
|
||||
hooks=[pre_tool_hook]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
async for message in query(prompt=prompts(), hooks=hooks):
|
||||
print(message)
|
||||
```
|
||||
|
||||
"""
|
||||
if options is None:
|
||||
options = ClaudeCodeOptions()
|
||||
|
||||
# Add hooks to options if provided
|
||||
if hooks is not None:
|
||||
# Hooks only work in streaming mode
|
||||
if isinstance(prompt, str):
|
||||
raise ValueError("Hooks require streaming mode (AsyncIterable prompt)")
|
||||
options.hooks = hooks
|
||||
|
||||
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py"
|
||||
|
||||
client = InternalClient()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Type definitions for Claude SDK."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypedDict
|
||||
@@ -118,6 +119,193 @@ class ResultMessage:
|
||||
Message = UserMessage | AssistantMessage | SystemMessage | ResultMessage
|
||||
|
||||
|
||||
# Hook events
|
||||
HookEvent = Literal[
|
||||
"PreToolUse",
|
||||
"PostToolUse",
|
||||
"Notification",
|
||||
"UserPromptSubmit",
|
||||
"SessionStart",
|
||||
"SessionEnd",
|
||||
"Stop",
|
||||
"SubagentStop",
|
||||
"PreCompact",
|
||||
]
|
||||
|
||||
|
||||
# Base hook input with common fields
|
||||
@dataclass
|
||||
class BaseHookInput:
|
||||
"""Base hook input with common fields."""
|
||||
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
# Individual hook input types
|
||||
@dataclass
|
||||
class PreToolUseHookInput:
|
||||
"""Pre-tool use hook input."""
|
||||
|
||||
hook_event_name: Literal["PreToolUse"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
tool_name: str
|
||||
tool_input: Any
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostToolUseHookInput:
|
||||
"""Post-tool use hook input."""
|
||||
|
||||
hook_event_name: Literal["PostToolUse"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
tool_name: str
|
||||
tool_input: Any
|
||||
tool_response: Any
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationHookInput:
|
||||
"""Notification hook input."""
|
||||
|
||||
hook_event_name: Literal["Notification"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
message: str
|
||||
permission_mode: str | None = None
|
||||
title: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserPromptSubmitHookInput:
|
||||
"""User prompt submit hook input."""
|
||||
|
||||
hook_event_name: Literal["UserPromptSubmit"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
prompt: str
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStartHookInput:
|
||||
"""Session start hook input."""
|
||||
|
||||
hook_event_name: Literal["SessionStart"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
source: Literal["startup", "resume", "clear", "compact"]
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionEndHookInput:
|
||||
"""Session end hook input."""
|
||||
|
||||
hook_event_name: Literal["SessionEnd"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
reason: Literal["clear", "logout", "prompt_input_exit", "other"]
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopHookInput:
|
||||
"""Stop hook input."""
|
||||
|
||||
hook_event_name: Literal["Stop"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
stop_hook_active: bool
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubagentStopHookInput:
|
||||
"""Subagent stop hook input."""
|
||||
|
||||
hook_event_name: Literal["SubagentStop"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
stop_hook_active: bool
|
||||
permission_mode: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreCompactHookInput:
|
||||
"""Pre-compact hook input."""
|
||||
|
||||
hook_event_name: Literal["PreCompact"]
|
||||
session_id: str
|
||||
transcript_path: str
|
||||
cwd: str
|
||||
trigger: Literal["manual", "auto"]
|
||||
permission_mode: str | None = None
|
||||
custom_instructions: str | None = None
|
||||
|
||||
|
||||
# Union type for all hook inputs
|
||||
HookInput = (
|
||||
PreToolUseHookInput
|
||||
| PostToolUseHookInput
|
||||
| NotificationHookInput
|
||||
| UserPromptSubmitHookInput
|
||||
| SessionStartHookInput
|
||||
| SessionEndHookInput
|
||||
| StopHookInput
|
||||
| SubagentStopHookInput
|
||||
| PreCompactHookInput
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookJSONOutput:
|
||||
"""Hook callback output."""
|
||||
|
||||
continue_: bool | None = None # 'continue' is reserved in Python
|
||||
suppress_output: bool | None = None
|
||||
stop_reason: str | None = None
|
||||
decision: Literal["approve", "block"] | None = None
|
||||
system_message: str | None = None
|
||||
|
||||
# PreToolUse specific
|
||||
permission_decision: Literal["allow", "deny", "ask"] | None = None
|
||||
permission_decision_reason: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
# Hook-specific outputs (for future extension)
|
||||
hook_specific_output: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# Hook callback signature
|
||||
HookCallback = Callable[
|
||||
[HookInput, str | None, dict[str, Any]], # input, tool_use_id, options
|
||||
Awaitable[HookJSONOutput],
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookCallbackMatcher:
|
||||
"""Hook callback matcher with optional pattern matching."""
|
||||
|
||||
matcher: str | None = None
|
||||
hooks: list[HookCallback] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaudeCodeOptions:
|
||||
"""Query options for Claude SDK."""
|
||||
@@ -141,3 +329,4 @@ class ClaudeCodeOptions:
|
||||
extra_args: dict[str, str | None] = field(
|
||||
default_factory=dict
|
||||
) # Pass arbitrary CLI flags
|
||||
hooks: dict[HookEvent, list[HookCallbackMatcher]] | None = None
|
||||
|
||||
226
tests/test_hooks.py
Normal file
226
tests/test_hooks.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Test hooks functionality."""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
HookCallbackMatcher,
|
||||
HookJSONOutput,
|
||||
PostToolUseHookInput,
|
||||
PreToolUseHookInput,
|
||||
UserPromptSubmitHookInput,
|
||||
query,
|
||||
)
|
||||
|
||||
|
||||
def test_hook_types():
|
||||
"""Test that hook types are properly defined."""
|
||||
|
||||
async def _test():
|
||||
# Test PreToolUseHookInput
|
||||
pre_tool_input = PreToolUseHookInput(
|
||||
hook_event_name="PreToolUse",
|
||||
session_id="test-session",
|
||||
transcript_path="/tmp/test",
|
||||
cwd="/home/test",
|
||||
tool_name="Edit",
|
||||
tool_input={"file": "test.py", "content": "print('hello')"},
|
||||
)
|
||||
assert pre_tool_input.hook_event_name == "PreToolUse"
|
||||
assert pre_tool_input.tool_name == "Edit"
|
||||
|
||||
# Test PostToolUseHookInput
|
||||
post_tool_input = PostToolUseHookInput(
|
||||
hook_event_name="PostToolUse",
|
||||
session_id="test-session",
|
||||
transcript_path="/tmp/test",
|
||||
cwd="/home/test",
|
||||
tool_name="Edit",
|
||||
tool_input={"file": "test.py", "content": "print('hello')"},
|
||||
tool_response={"success": True},
|
||||
)
|
||||
assert post_tool_input.hook_event_name == "PostToolUse"
|
||||
assert post_tool_input.tool_response == {"success": True}
|
||||
|
||||
# Test HookJSONOutput
|
||||
output = HookJSONOutput(
|
||||
continue_=True,
|
||||
permission_decision="allow",
|
||||
reason="Test reason",
|
||||
)
|
||||
assert output.continue_ is True
|
||||
assert output.permission_decision == "allow"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
def test_hook_callback_matcher():
|
||||
"""Test HookCallbackMatcher structure."""
|
||||
|
||||
async def _test():
|
||||
async def test_hook(input_data, tool_use_id, options):
|
||||
return HookJSONOutput(permission_decision="allow")
|
||||
|
||||
matcher = HookCallbackMatcher(matcher="Edit", hooks=[test_hook])
|
||||
|
||||
assert matcher.matcher == "Edit"
|
||||
assert len(matcher.hooks) == 1
|
||||
assert callable(matcher.hooks[0])
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
def test_query_with_hooks_requires_streaming():
|
||||
"""Test that hooks require streaming mode."""
|
||||
|
||||
async def _test():
|
||||
async def test_hook(input_data, tool_use_id, options):
|
||||
return HookJSONOutput(permission_decision="allow")
|
||||
|
||||
hooks = {"PreToolUse": [HookCallbackMatcher(hooks=[test_hook])]}
|
||||
|
||||
# Should raise error with string prompt
|
||||
with pytest.raises(ValueError, match="Hooks require streaming mode"):
|
||||
async for _ in query(prompt="test", hooks=hooks):
|
||||
pass
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
def test_hook_callback_signature():
|
||||
"""Test hook callback signature and execution."""
|
||||
|
||||
async def _test():
|
||||
hook_called = False
|
||||
received_input = None
|
||||
received_tool_id = None
|
||||
received_options = None
|
||||
|
||||
async def test_hook(
|
||||
input_data: PreToolUseHookInput, tool_use_id: str | None, options: dict
|
||||
):
|
||||
nonlocal hook_called, received_input, received_tool_id, received_options
|
||||
hook_called = True
|
||||
received_input = input_data
|
||||
received_tool_id = tool_use_id
|
||||
received_options = options
|
||||
|
||||
# Return a proper response
|
||||
return HookJSONOutput(permission_decision="allow", reason="Test hook executed")
|
||||
|
||||
# Create a hook input to test with
|
||||
test_input = PreToolUseHookInput(
|
||||
hook_event_name="PreToolUse",
|
||||
session_id="test",
|
||||
transcript_path="/tmp/test",
|
||||
cwd="/tmp",
|
||||
tool_name="Edit",
|
||||
tool_input={"file": "test.py"},
|
||||
)
|
||||
|
||||
# Call the hook directly to test signature
|
||||
result = await test_hook(test_input, "tool-123", {"signal": None})
|
||||
|
||||
assert hook_called is True
|
||||
assert received_input == test_input
|
||||
assert received_tool_id == "tool-123"
|
||||
assert "signal" in received_options
|
||||
assert isinstance(result, HookJSONOutput)
|
||||
assert result.permission_decision == "allow"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
def test_multiple_hooks_in_matcher():
|
||||
"""Test that multiple hooks can be added to a matcher."""
|
||||
|
||||
async def _test():
|
||||
hook1_called = False
|
||||
hook2_called = False
|
||||
|
||||
async def hook1(input_data, tool_use_id, options):
|
||||
nonlocal hook1_called
|
||||
hook1_called = True
|
||||
return HookJSONOutput(permission_decision="allow")
|
||||
|
||||
async def hook2(input_data, tool_use_id, options):
|
||||
nonlocal hook2_called
|
||||
hook2_called = True
|
||||
return HookJSONOutput(permission_decision="allow")
|
||||
|
||||
matcher = HookCallbackMatcher(matcher="Edit", hooks=[hook1, hook2])
|
||||
|
||||
assert len(matcher.hooks) == 2
|
||||
|
||||
# Test that both hooks are callable
|
||||
test_input = PreToolUseHookInput(
|
||||
hook_event_name="PreToolUse",
|
||||
session_id="test",
|
||||
transcript_path="/tmp/test",
|
||||
cwd="/tmp",
|
||||
tool_name="Edit",
|
||||
tool_input={},
|
||||
)
|
||||
|
||||
await matcher.hooks[0](test_input, None, {})
|
||||
await matcher.hooks[1](test_input, None, {})
|
||||
|
||||
assert hook1_called is True
|
||||
assert hook2_called is True
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
def test_hook_output_fields():
|
||||
"""Test all HookJSONOutput fields."""
|
||||
|
||||
async def _test():
|
||||
output = HookJSONOutput(
|
||||
continue_=False,
|
||||
suppress_output=True,
|
||||
stop_reason="User requested stop",
|
||||
decision="block",
|
||||
system_message="System notification",
|
||||
permission_decision="deny",
|
||||
permission_decision_reason="Not allowed",
|
||||
reason="General reason",
|
||||
hook_specific_output={"custom": "data"},
|
||||
)
|
||||
|
||||
assert output.continue_ is False
|
||||
assert output.suppress_output is True
|
||||
assert output.stop_reason == "User requested stop"
|
||||
assert output.decision == "block"
|
||||
assert output.system_message == "System notification"
|
||||
assert output.permission_decision == "deny"
|
||||
assert output.permission_decision_reason == "Not allowed"
|
||||
assert output.reason == "General reason"
|
||||
assert output.hook_specific_output == {"custom": "data"}
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
def test_different_hook_events():
|
||||
"""Test different hook event types."""
|
||||
|
||||
async def _test():
|
||||
events_tested = set()
|
||||
|
||||
async def generic_hook(input_data, tool_use_id, options):
|
||||
events_tested.add(input_data.hook_event_name)
|
||||
return HookJSONOutput()
|
||||
|
||||
# Test UserPromptSubmitHookInput
|
||||
user_prompt_input = UserPromptSubmitHookInput(
|
||||
hook_event_name="UserPromptSubmit",
|
||||
session_id="test",
|
||||
transcript_path="/tmp/test",
|
||||
cwd="/tmp",
|
||||
prompt="Test prompt",
|
||||
)
|
||||
await generic_hook(user_prompt_input, None, {})
|
||||
|
||||
assert "UserPromptSubmit" in events_tested
|
||||
assert user_prompt_input.prompt == "Test prompt"
|
||||
|
||||
anyio.run(_test)
|
||||
Reference in New Issue
Block a user