Initial hooks implementation

🏠 Remote-Dev: homespace
This commit is contained in:
Dickson Tsai
2025-08-30 00:46:09 +00:00
parent f794e17e78
commit c312ab9e74
7 changed files with 927 additions and 6 deletions

View File

@@ -14,16 +14,30 @@ from .types import (
AssistantMessage,
ClaudeCodeOptions,
ContentBlock,
HookCallback,
HookCallbackMatcher,
HookEvent,
HookInput,
HookJSONOutput,
McpServerConfig,
Message,
NotificationHookInput,
PermissionMode,
PostToolUseHookInput,
PreCompactHookInput,
PreToolUseHookInput,
ResultMessage,
SessionEndHookInput,
SessionStartHookInput,
StopHookInput,
SubagentStopHookInput,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
UserPromptSubmitHookInput,
)
__version__ = "0.0.20"
@@ -48,6 +62,21 @@ __all__ = [
"ToolUseBlock",
"ToolResultBlock",
"ContentBlock",
# Hook types
"HookEvent",
"HookCallback",
"HookCallbackMatcher",
"HookInput",
"HookJSONOutput",
"PreToolUseHookInput",
"PostToolUseHookInput",
"NotificationHookInput",
"UserPromptSubmitHookInput",
"SessionStartHookInput",
"SessionEndHookInput",
"StopHookInput",
"SubagentStopHookInput",
"PreCompactHookInput",
# Errors
"ClaudeSDKError",
"CLIConnectionError",

View File

@@ -0,0 +1,143 @@
"""Control protocol types and handlers for SDK-CLI communication."""
from dataclasses import dataclass, field
from typing import Any, Literal
from ..types import HookEvent, PermissionMode
@dataclass
class SDKHookCallbackMatcher:
"""Hook callback matcher for control protocol."""
matcher: str | None = None
hook_callback_ids: list[str] = field(default_factory=list)
@dataclass
class SDKControlInitializeRequest:
"""Initialize request for control protocol."""
subtype: Literal["initialize"]
hooks: dict[HookEvent, list[SDKHookCallbackMatcher]] | None = None
@dataclass
class SDKControlInterruptRequest:
"""Interrupt request for control protocol."""
subtype: Literal["interrupt"]
@dataclass
class SDKControlPermissionRequest:
"""Permission request for control protocol."""
subtype: Literal["can_use_tool"]
tool_name: str
input: dict[str, Any]
permission_suggestions: list[dict[str, Any]] | None = None
blocked_path: str | None = None
@dataclass
class SDKControlSetPermissionModeRequest:
"""Set permission mode request for control protocol."""
subtype: Literal["set_permission_mode"]
mode: PermissionMode
@dataclass
class SDKHookCallbackRequest:
"""Hook callback request for control protocol."""
subtype: Literal["hook_callback"]
callback_id: str
input: dict[str, Any]
tool_use_id: str | None = None
@dataclass
class SDKControlMcpMessageRequest:
"""MCP message request for control protocol."""
subtype: Literal["mcp_message"]
server_name: str
message: dict[str, Any]
# Union type for all control request subtypes
SDKControlRequestSubtype = (
SDKControlInitializeRequest
| SDKControlInterruptRequest
| SDKControlPermissionRequest
| SDKControlSetPermissionModeRequest
| SDKHookCallbackRequest
| SDKControlMcpMessageRequest
)
@dataclass
class SDKControlRequest:
"""Control request wrapper."""
type: Literal["control_request"]
request_id: str
request: SDKControlRequestSubtype
@dataclass
class SDKControlCancelRequest:
"""Cancel a control request."""
type: Literal["control_cancel_request"]
request_id: str
@dataclass
class ControlSuccessResponse:
"""Successful control response."""
subtype: Literal["success"]
request_id: str
response: dict[str, Any] | None = None
@dataclass
class ControlErrorResponse:
"""Error control response."""
subtype: Literal["error"]
request_id: str
error: str
# Union type for control responses
ControlResponse = ControlSuccessResponse | ControlErrorResponse
@dataclass
class SDKControlResponse:
"""Control response wrapper."""
type: Literal["control_response"]
response: ControlResponse
@dataclass
class Command:
"""Command metadata."""
name: str
description: str
argument_hint: str
@dataclass
class SDKControlInitializeResponse:
"""Initialize response from control protocol."""
commands: list[Command]
output_style: str
available_output_styles: list[str]

View File

@@ -1,5 +1,6 @@
"""Subprocess transport implementation using Claude Code CLI."""
import asyncio
import json
import logging
import os
@@ -17,7 +18,21 @@ from anyio.streams.text import TextReceiveStream, TextSendStream
from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError
from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
from ...types import ClaudeCodeOptions
from ...types import (
ClaudeCodeOptions,
HookCallback,
HookEvent,
HookInput,
HookJSONOutput,
)
from ..control_protocol import (
ControlErrorResponse,
ControlSuccessResponse,
SDKControlInitializeRequest,
SDKControlInitializeResponse,
SDKControlResponse,
SDKHookCallbackMatcher,
)
from . import Transport
logger = logging.getLogger(__name__)
@@ -50,6 +65,13 @@ class SubprocessCLITransport(Transport):
self._task_group: anyio.abc.TaskGroup | None = None
self._stderr_file: Any = None # tempfile.NamedTemporaryFile
# Hooks support
self._hook_callbacks: dict[str, HookCallback] = {}
self._next_callback_id = 0
self._cancel_controllers: dict[str, asyncio.Event] = {}
self._initialized = False
self._hooks_enabled = options.hooks is not None
def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
if cli := shutil.which("claude"):
@@ -202,7 +224,13 @@ class SubprocessCLITransport(Transport):
# Start streaming messages to stdin in background
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
# Initialize hooks if enabled
if self._hooks_enabled:
await self._initialize_hooks()
self._task_group.start_soon(self._stream_to_stdin)
self._task_group.start_soon(self._handle_control_requests)
else:
# String mode: close stdin immediately (backward compatible)
if self._process.stdin:
@@ -334,7 +362,7 @@ class SubprocessCLITransport(Transport):
data = json.loads(json_buffer)
json_buffer = ""
# Handle control responses separately
# Handle control messages separately
if data.get("type") == "control_response":
response = data.get("response", {})
request_id = response.get("request_id")
@@ -343,6 +371,20 @@ class SubprocessCLITransport(Transport):
self._pending_control_responses[request_id] = response
continue
# Handle control requests (for hooks)
if data.get("type") == "control_request":
# Queue control request for processing
if hasattr(self, "_control_request_queue"):
await self._control_request_queue.put(data)
continue
# Handle control cancel requests
if data.get("type") == "control_cancel_request":
request_id = data.get("request_id")
if request_id in self._cancel_controllers:
self._cancel_controllers[request_id].set()
continue
try:
yield data
except GeneratorExit:
@@ -444,3 +486,210 @@ class SubprocessCLITransport(Transport):
raise CLIConnectionError(f"Control request failed: {response.get('error')}")
return response
async def _initialize_hooks(self) -> None:
"""Initialize hooks with the CLI."""
if not self._options.hooks:
return
# Convert hooks to SDK format
sdk_hooks: dict[HookEvent, list[SDKHookCallbackMatcher]] = {}
for event, matchers in self._options.hooks.items():
sdk_matchers = []
for matcher in matchers:
callback_ids = []
for callback in matcher.hooks:
callback_id = f"hook_{self._next_callback_id}"
self._next_callback_id += 1
self._hook_callbacks[callback_id] = callback
callback_ids.append(callback_id)
sdk_matchers.append(
SDKHookCallbackMatcher(
matcher=matcher.matcher, hook_callback_ids=callback_ids
)
)
sdk_hooks[event] = sdk_matchers
# Send initialize request
# sdk_hooks is dict[HookEvent, list[SDKHookCallbackMatcher]]
# but protocol expects dict[str, list[dict]]
hooks_dict: dict[str, Any] = {k: [m.__dict__ for m in v] for k, v in sdk_hooks.items()}
init_request = SDKControlInitializeRequest(
subtype="initialize",
hooks=hooks_dict # type: ignore[arg-type]
)
response = await self._send_control_request(init_request.__dict__)
self._initialized = True
# Initialize response is not returned, just stored
# return SDKControlInitializeResponse(**response.get("response", {}))
async def _handle_control_requests(self) -> None:
"""Handle incoming control requests from CLI."""
# Create a queue for control requests
self._control_request_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
try:
while True:
# Wait for control request
request_data = await self._control_request_queue.get()
# Process control request
request_id = request_data.get("request_id")
request = request_data.get("request", {})
if not request_id:
continue
# Create cancel controller for this request
cancel_event = asyncio.Event()
self._cancel_controllers[request_id] = cancel_event
try:
response = await self._process_control_request(
request, cancel_event
)
# Send success response
control_response = SDKControlResponse(
type="control_response",
response=ControlSuccessResponse(
subtype="success", request_id=request_id, response=response
),
)
if self._stdin_stream:
await self._stdin_stream.send(
json.dumps(
control_response.__dict__, default=lambda o: o.__dict__
)
+ "\n"
)
except Exception as e:
# Send error response
control_response = SDKControlResponse(
type="control_response",
response=ControlErrorResponse(
subtype="error", request_id=request_id, error=str(e)
),
)
if self._stdin_stream:
await self._stdin_stream.send(
json.dumps(
control_response.__dict__, default=lambda o: o.__dict__
)
+ "\n"
)
finally:
# Clean up cancel controller
if request_id in self._cancel_controllers:
del self._cancel_controllers[request_id]
except asyncio.CancelledError:
# Task group was cancelled
pass
except Exception as e:
logger.error(f"Error in control request handler: {e}")
async def _process_control_request(
self, request: dict[str, Any], cancel_event: asyncio.Event
) -> dict[str, Any]:
"""Process a control request from CLI."""
subtype = request.get("subtype")
if subtype == "hook_callback":
# Handle hook callback
callback_id = request.get("callback_id")
if not callback_id:
raise ValueError("Missing callback_id in hook callback request")
input_data = request.get("input", {})
tool_use_id = request.get("tool_use_id")
callback = self._hook_callbacks.get(callback_id)
if not callback:
raise ValueError(f"No hook callback found for ID: {callback_id}")
# Convert input data to appropriate HookInput type
hook_input = self._parse_hook_input(input_data)
# Create options with abort signal
options = {"signal": cancel_event}
# Call the hook
result = await callback(hook_input, tool_use_id, options)
# Convert result to dict
return self._hook_output_to_dict(result)
else:
raise ValueError(f"Unsupported control request subtype: {subtype}")
def _parse_hook_input(self, data: dict[str, Any]) -> HookInput:
"""Parse hook input data into appropriate type."""
from ...types import (
NotificationHookInput,
PostToolUseHookInput,
PreCompactHookInput,
PreToolUseHookInput,
SessionEndHookInput,
SessionStartHookInput,
StopHookInput,
SubagentStopHookInput,
UserPromptSubmitHookInput,
)
event_name = data.get("hook_event_name")
if not event_name:
raise ValueError("Missing hook_event_name in hook input")
# Map event names to input classes
input_classes: dict[str, type[HookInput]] = {
"PreToolUse": PreToolUseHookInput,
"PostToolUse": PostToolUseHookInput,
"Notification": NotificationHookInput,
"UserPromptSubmit": UserPromptSubmitHookInput,
"SessionStart": SessionStartHookInput,
"SessionEnd": SessionEndHookInput,
"Stop": StopHookInput,
"SubagentStop": SubagentStopHookInput,
"PreCompact": PreCompactHookInput,
}
input_class = input_classes.get(event_name)
if not input_class:
raise ValueError(f"Unknown hook event: {event_name}")
# Create instance from dict
return input_class(**data)
def _hook_output_to_dict(self, output: HookJSONOutput) -> dict[str, Any]:
"""Convert HookJSONOutput to dict for JSON serialization."""
result: dict[str, Any] = {}
# Map Python field names to JSON field names
if output.continue_ is not None:
result["continue"] = output.continue_
if output.suppress_output is not None:
result["suppressOutput"] = output.suppress_output
if output.stop_reason is not None:
result["stopReason"] = output.stop_reason
if output.decision is not None:
result["decision"] = output.decision
if output.system_message is not None:
result["systemMessage"] = output.system_message
if output.permission_decision is not None:
result["permissionDecision"] = output.permission_decision
if output.permission_decision_reason is not None:
result["permissionDecisionReason"] = output.permission_decision_reason
if output.reason is not None:
result["reason"] = output.reason
if output.hook_specific_output is not None:
result["hookSpecificOutput"] = output.hook_specific_output
return result

View File

@@ -5,7 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
from ._errors import CLIConnectionError
from .types import ClaudeCodeOptions, Message, ResultMessage
from .types import (
ClaudeCodeOptions,
HookCallbackMatcher,
HookEvent,
Message,
ResultMessage,
)
class ClaudeSDKClient:
@@ -88,12 +94,53 @@ class ClaudeSDKClient:
await client.disconnect()
```
Example - With hooks for tool control:
```python
from claude_code_sdk import HookCallbackMatcher, PreToolUseHookInput, HookJSONOutput
async def pre_tool_hook(input: PreToolUseHookInput, tool_use_id, options):
# Intercept and control tool execution
if input.tool_name == "Bash":
if "rm -rf" in str(input.tool_input):
return HookJSONOutput(
permission_decision="deny",
reason="Dangerous command detected"
)
return HookJSONOutput(permission_decision="allow")
hooks = {
"PreToolUse": [
HookCallbackMatcher(hooks=[pre_tool_hook])
]
}
async with ClaudeSDKClient(hooks=hooks) as client:
await client.query("Clean up temporary files")
async for message in client.receive_response():
print(message)
```
"""
def __init__(self, options: ClaudeCodeOptions | None = None):
"""Initialize Claude SDK client."""
def __init__(
self,
options: ClaudeCodeOptions | None = None,
hooks: dict[HookEvent, list[HookCallbackMatcher]] | None = None,
):
"""
Initialize Claude SDK client.
Args:
options: Configuration options for Claude Code
hooks: Optional dict of hook events to callback matchers for intercepting tool execution
"""
if options is None:
options = ClaudeCodeOptions()
# Add hooks to options if provided
if hooks is not None:
options.hooks = hooks
self.options = options
self._transport: Any | None = None
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"

View File

@@ -6,7 +6,7 @@ from typing import Any
from ._internal.client import InternalClient
from ._internal.transport import Transport
from .types import ClaudeCodeOptions, Message
from .types import ClaudeCodeOptions, HookCallbackMatcher, HookEvent, Message
async def query(
@@ -14,6 +14,7 @@ async def query(
prompt: str | AsyncIterable[dict[str, Any]],
options: ClaudeCodeOptions | None = None,
transport: Transport | None = None,
hooks: dict[HookEvent, list[HookCallbackMatcher]] | None = None,
) -> AsyncIterator[Message]:
"""
Query Claude Code for one-shot or unidirectional streaming interactions.
@@ -61,6 +62,8 @@ async def query(
transport: Optional transport implementation. If provided, this will be used
instead of the default transport selection based on options.
The transport will be automatically configured with the prompt and options.
hooks: Optional dict of hook events to callback matchers. Hooks allow you to intercept
and control tool execution. Only works in streaming mode (AsyncIterable prompt).
Yields:
Messages from the conversation
@@ -112,10 +115,45 @@ async def query(
print(message)
```
Example - With hooks (streaming mode required):
```python
from claude_code_sdk import query, HookCallbackMatcher, PreToolUseHookInput, HookJSONOutput
async def pre_tool_hook(input: PreToolUseHookInput, tool_use_id, options):
if input.tool_name == "Edit":
return HookJSONOutput(
permission_decision="ask",
reason="Edit operations require user confirmation"
)
return HookJSONOutput(permission_decision="allow")
async def prompts():
yield {"type": "user", "message": {"role": "user", "content": "Fix the bug"}}
hooks = {
"PreToolUse": [
HookCallbackMatcher(
matcher="Edit",
hooks=[pre_tool_hook]
)
]
}
async for message in query(prompt=prompts(), hooks=hooks):
print(message)
```
"""
if options is None:
options = ClaudeCodeOptions()
# Add hooks to options if provided
if hooks is not None:
# Hooks only work in streaming mode
if isinstance(prompt, str):
raise ValueError("Hooks require streaming mode (AsyncIterable prompt)")
options.hooks = hooks
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py"
client = InternalClient()

View File

@@ -1,5 +1,6 @@
"""Type definitions for Claude SDK."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, TypedDict
@@ -118,6 +119,193 @@ class ResultMessage:
Message = UserMessage | AssistantMessage | SystemMessage | ResultMessage
# Hook events
HookEvent = Literal[
"PreToolUse",
"PostToolUse",
"Notification",
"UserPromptSubmit",
"SessionStart",
"SessionEnd",
"Stop",
"SubagentStop",
"PreCompact",
]
# Base hook input with common fields
@dataclass
class BaseHookInput:
"""Base hook input with common fields."""
session_id: str
transcript_path: str
cwd: str
permission_mode: str | None = None
# Individual hook input types
@dataclass
class PreToolUseHookInput:
"""Pre-tool use hook input."""
hook_event_name: Literal["PreToolUse"]
session_id: str
transcript_path: str
cwd: str
tool_name: str
tool_input: Any
permission_mode: str | None = None
@dataclass
class PostToolUseHookInput:
"""Post-tool use hook input."""
hook_event_name: Literal["PostToolUse"]
session_id: str
transcript_path: str
cwd: str
tool_name: str
tool_input: Any
tool_response: Any
permission_mode: str | None = None
@dataclass
class NotificationHookInput:
"""Notification hook input."""
hook_event_name: Literal["Notification"]
session_id: str
transcript_path: str
cwd: str
message: str
permission_mode: str | None = None
title: str | None = None
@dataclass
class UserPromptSubmitHookInput:
"""User prompt submit hook input."""
hook_event_name: Literal["UserPromptSubmit"]
session_id: str
transcript_path: str
cwd: str
prompt: str
permission_mode: str | None = None
@dataclass
class SessionStartHookInput:
"""Session start hook input."""
hook_event_name: Literal["SessionStart"]
session_id: str
transcript_path: str
cwd: str
source: Literal["startup", "resume", "clear", "compact"]
permission_mode: str | None = None
@dataclass
class SessionEndHookInput:
"""Session end hook input."""
hook_event_name: Literal["SessionEnd"]
session_id: str
transcript_path: str
cwd: str
reason: Literal["clear", "logout", "prompt_input_exit", "other"]
permission_mode: str | None = None
@dataclass
class StopHookInput:
"""Stop hook input."""
hook_event_name: Literal["Stop"]
session_id: str
transcript_path: str
cwd: str
stop_hook_active: bool
permission_mode: str | None = None
@dataclass
class SubagentStopHookInput:
"""Subagent stop hook input."""
hook_event_name: Literal["SubagentStop"]
session_id: str
transcript_path: str
cwd: str
stop_hook_active: bool
permission_mode: str | None = None
@dataclass
class PreCompactHookInput:
"""Pre-compact hook input."""
hook_event_name: Literal["PreCompact"]
session_id: str
transcript_path: str
cwd: str
trigger: Literal["manual", "auto"]
permission_mode: str | None = None
custom_instructions: str | None = None
# Union type for all hook inputs
HookInput = (
PreToolUseHookInput
| PostToolUseHookInput
| NotificationHookInput
| UserPromptSubmitHookInput
| SessionStartHookInput
| SessionEndHookInput
| StopHookInput
| SubagentStopHookInput
| PreCompactHookInput
)
@dataclass
class HookJSONOutput:
"""Hook callback output."""
continue_: bool | None = None # 'continue' is reserved in Python
suppress_output: bool | None = None
stop_reason: str | None = None
decision: Literal["approve", "block"] | None = None
system_message: str | None = None
# PreToolUse specific
permission_decision: Literal["allow", "deny", "ask"] | None = None
permission_decision_reason: str | None = None
reason: str | None = None
# Hook-specific outputs (for future extension)
hook_specific_output: dict[str, Any] | None = None
# Hook callback signature
HookCallback = Callable[
[HookInput, str | None, dict[str, Any]], # input, tool_use_id, options
Awaitable[HookJSONOutput],
]
@dataclass
class HookCallbackMatcher:
"""Hook callback matcher with optional pattern matching."""
matcher: str | None = None
hooks: list[HookCallback] = field(default_factory=list)
@dataclass
class ClaudeCodeOptions:
"""Query options for Claude SDK."""
@@ -141,3 +329,4 @@ class ClaudeCodeOptions:
extra_args: dict[str, str | None] = field(
default_factory=dict
) # Pass arbitrary CLI flags
hooks: dict[HookEvent, list[HookCallbackMatcher]] | None = None

226
tests/test_hooks.py Normal file
View File

@@ -0,0 +1,226 @@
"""Test hooks functionality."""
import anyio
import pytest
from claude_code_sdk import (
HookCallbackMatcher,
HookJSONOutput,
PostToolUseHookInput,
PreToolUseHookInput,
UserPromptSubmitHookInput,
query,
)
def test_hook_types():
"""Test that hook types are properly defined."""
async def _test():
# Test PreToolUseHookInput
pre_tool_input = PreToolUseHookInput(
hook_event_name="PreToolUse",
session_id="test-session",
transcript_path="/tmp/test",
cwd="/home/test",
tool_name="Edit",
tool_input={"file": "test.py", "content": "print('hello')"},
)
assert pre_tool_input.hook_event_name == "PreToolUse"
assert pre_tool_input.tool_name == "Edit"
# Test PostToolUseHookInput
post_tool_input = PostToolUseHookInput(
hook_event_name="PostToolUse",
session_id="test-session",
transcript_path="/tmp/test",
cwd="/home/test",
tool_name="Edit",
tool_input={"file": "test.py", "content": "print('hello')"},
tool_response={"success": True},
)
assert post_tool_input.hook_event_name == "PostToolUse"
assert post_tool_input.tool_response == {"success": True}
# Test HookJSONOutput
output = HookJSONOutput(
continue_=True,
permission_decision="allow",
reason="Test reason",
)
assert output.continue_ is True
assert output.permission_decision == "allow"
anyio.run(_test)
def test_hook_callback_matcher():
"""Test HookCallbackMatcher structure."""
async def _test():
async def test_hook(input_data, tool_use_id, options):
return HookJSONOutput(permission_decision="allow")
matcher = HookCallbackMatcher(matcher="Edit", hooks=[test_hook])
assert matcher.matcher == "Edit"
assert len(matcher.hooks) == 1
assert callable(matcher.hooks[0])
anyio.run(_test)
def test_query_with_hooks_requires_streaming():
"""Test that hooks require streaming mode."""
async def _test():
async def test_hook(input_data, tool_use_id, options):
return HookJSONOutput(permission_decision="allow")
hooks = {"PreToolUse": [HookCallbackMatcher(hooks=[test_hook])]}
# Should raise error with string prompt
with pytest.raises(ValueError, match="Hooks require streaming mode"):
async for _ in query(prompt="test", hooks=hooks):
pass
anyio.run(_test)
def test_hook_callback_signature():
"""Test hook callback signature and execution."""
async def _test():
hook_called = False
received_input = None
received_tool_id = None
received_options = None
async def test_hook(
input_data: PreToolUseHookInput, tool_use_id: str | None, options: dict
):
nonlocal hook_called, received_input, received_tool_id, received_options
hook_called = True
received_input = input_data
received_tool_id = tool_use_id
received_options = options
# Return a proper response
return HookJSONOutput(permission_decision="allow", reason="Test hook executed")
# Create a hook input to test with
test_input = PreToolUseHookInput(
hook_event_name="PreToolUse",
session_id="test",
transcript_path="/tmp/test",
cwd="/tmp",
tool_name="Edit",
tool_input={"file": "test.py"},
)
# Call the hook directly to test signature
result = await test_hook(test_input, "tool-123", {"signal": None})
assert hook_called is True
assert received_input == test_input
assert received_tool_id == "tool-123"
assert "signal" in received_options
assert isinstance(result, HookJSONOutput)
assert result.permission_decision == "allow"
anyio.run(_test)
def test_multiple_hooks_in_matcher():
"""Test that multiple hooks can be added to a matcher."""
async def _test():
hook1_called = False
hook2_called = False
async def hook1(input_data, tool_use_id, options):
nonlocal hook1_called
hook1_called = True
return HookJSONOutput(permission_decision="allow")
async def hook2(input_data, tool_use_id, options):
nonlocal hook2_called
hook2_called = True
return HookJSONOutput(permission_decision="allow")
matcher = HookCallbackMatcher(matcher="Edit", hooks=[hook1, hook2])
assert len(matcher.hooks) == 2
# Test that both hooks are callable
test_input = PreToolUseHookInput(
hook_event_name="PreToolUse",
session_id="test",
transcript_path="/tmp/test",
cwd="/tmp",
tool_name="Edit",
tool_input={},
)
await matcher.hooks[0](test_input, None, {})
await matcher.hooks[1](test_input, None, {})
assert hook1_called is True
assert hook2_called is True
anyio.run(_test)
def test_hook_output_fields():
"""Test all HookJSONOutput fields."""
async def _test():
output = HookJSONOutput(
continue_=False,
suppress_output=True,
stop_reason="User requested stop",
decision="block",
system_message="System notification",
permission_decision="deny",
permission_decision_reason="Not allowed",
reason="General reason",
hook_specific_output={"custom": "data"},
)
assert output.continue_ is False
assert output.suppress_output is True
assert output.stop_reason == "User requested stop"
assert output.decision == "block"
assert output.system_message == "System notification"
assert output.permission_decision == "deny"
assert output.permission_decision_reason == "Not allowed"
assert output.reason == "General reason"
assert output.hook_specific_output == {"custom": "data"}
anyio.run(_test)
def test_different_hook_events():
"""Test different hook event types."""
async def _test():
events_tested = set()
async def generic_hook(input_data, tool_use_id, options):
events_tested.add(input_data.hook_event_name)
return HookJSONOutput()
# Test UserPromptSubmitHookInput
user_prompt_input = UserPromptSubmitHookInput(
hook_event_name="UserPromptSubmit",
session_id="test",
transcript_path="/tmp/test",
cwd="/tmp",
prompt="Test prompt",
)
await generic_hook(user_prompt_input, None, {})
assert "UserPromptSubmit" in events_tested
assert user_prompt_input.prompt == "Test prompt"
anyio.run(_test)