mirror of
https://github.com/omnara-ai/omnara.git
synced 2025-08-12 20:39:09 +03:00
Refactor (#45)
* progress * refactor * backend refactor * test * minor stdio changes * streeeeeaming * stream statuses * change required user input * something * progress * new beginnings * progress * edge case * uuhh the claude wrapper def worked before this commit * progress * ai slop that works * tests * tests * plan mode handling * progress * clean up * bump * migrations * readmes * no text truncation * consistent poll interval * fix time --------- Co-authored-by: Kartik Sarangmath <kartiksarangmath@Kartiks-MacBook-Air.local>
This commit is contained in:
@@ -57,6 +57,29 @@ def run_webhook_server(
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
def run_claude_wrapper(api_key, base_url=None, claude_args=None):
|
||||
"""Run the Claude wrapper V3 for Omnara integration"""
|
||||
# Import and run directly instead of subprocess
|
||||
from webhooks.claude_wrapper_v3 import main as claude_wrapper_main
|
||||
|
||||
# Prepare sys.argv for the claude wrapper
|
||||
original_argv = sys.argv
|
||||
new_argv = ["claude_wrapper_v3", "--api-key", api_key]
|
||||
|
||||
if base_url:
|
||||
new_argv.extend(["--base-url", base_url])
|
||||
|
||||
# Add any additional Claude arguments
|
||||
if claude_args:
|
||||
new_argv.extend(claude_args)
|
||||
|
||||
try:
|
||||
sys.argv = new_argv
|
||||
claude_wrapper_main()
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point that dispatches based on command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -79,6 +102,12 @@ Examples:
|
||||
# Run webhook server on custom port
|
||||
omnara --claude-code-webhook --port 8080
|
||||
|
||||
# Run Claude wrapper V3
|
||||
omnara --claude --api-key YOUR_API_KEY
|
||||
|
||||
# Run Claude wrapper with custom base URL
|
||||
omnara --claude --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run with custom API base URL
|
||||
omnara --stdio --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
@@ -99,6 +128,11 @@ Examples:
|
||||
action="store_true",
|
||||
help="Run the Claude Code webhook server",
|
||||
)
|
||||
mode_group.add_argument(
|
||||
"--claude",
|
||||
action="store_true",
|
||||
help="Run the Claude wrapper V3 for Omnara integration",
|
||||
)
|
||||
|
||||
# Arguments for webhook mode
|
||||
parser.add_argument(
|
||||
@@ -142,7 +176,8 @@ Examples:
|
||||
help="Pre-existing agent instance ID to use for this session (stdio mode only)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
# Use parse_known_args to capture remaining args for Claude
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
if args.cloudflare_tunnel and not args.claude_code_webhook:
|
||||
parser.error("--cloudflare-tunnel can only be used with --claude-code-webhook")
|
||||
@@ -156,6 +191,10 @@ Examples:
|
||||
dangerously_skip_permissions=args.dangerously_skip_permissions,
|
||||
port=args.port,
|
||||
)
|
||||
elif args.claude:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for --claude mode")
|
||||
run_claude_wrapper(args.api_key, args.base_url, unknown_args)
|
||||
else:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for stdio mode")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Optional, Dict, Any
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, Union, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiohttp
|
||||
@@ -11,10 +12,14 @@ from aiohttp import ClientTimeout
|
||||
|
||||
from .exceptions import AuthenticationError, TimeoutError, APIError
|
||||
from .models import (
|
||||
LogStepResponse,
|
||||
QuestionResponse,
|
||||
QuestionStatus,
|
||||
EndSessionResponse,
|
||||
CreateMessageResponse,
|
||||
PendingMessagesResponse,
|
||||
Message,
|
||||
)
|
||||
from .utils import (
|
||||
validate_agent_instance_id,
|
||||
build_message_request_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -76,6 +81,7 @@ class AsyncOmnaraClient:
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an async HTTP request to the API.
|
||||
@@ -84,6 +90,7 @@ class AsyncOmnaraClient:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint path
|
||||
json: JSON body for the request
|
||||
params: Query parameters for the request
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
@@ -104,7 +111,11 @@ class AsyncOmnaraClient:
|
||||
|
||||
try:
|
||||
async with self.session.request(
|
||||
method=method, url=url, json=json, timeout=request_timeout
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
timeout=request_timeout,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise AuthenticationError(
|
||||
@@ -128,138 +139,246 @@ class AsyncOmnaraClient:
|
||||
except aiohttp.ClientError as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
|
||||
async def log_step(
|
||||
async def send_message(
|
||||
self,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
agent_instance_id: Optional[str] = None,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
|
||||
step_description: Clear description of what the agent is doing
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
send_push: Send push notification for this step (default: False)
|
||||
send_email: Send email notification for this step (default: False)
|
||||
send_sms: Send SMS notification for this step (default: False)
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance ID, and user feedback
|
||||
"""
|
||||
data: Dict[str, Any] = {
|
||||
"agent_type": agent_type,
|
||||
"step_description": step_description,
|
||||
}
|
||||
if agent_instance_id:
|
||||
data["agent_instance_id"] = agent_instance_id
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
|
||||
response = await self._make_request("POST", "/api/v1/steps", json=data)
|
||||
|
||||
return LogStepResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
step_number=response["step_number"],
|
||||
user_feedback=response.get("user_feedback", []),
|
||||
)
|
||||
|
||||
async def ask_question(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
content: str,
|
||||
agent_type: Optional[str] = None,
|
||||
agent_instance_id: Optional[Union[str, uuid.UUID]] = None,
|
||||
requires_user_input: bool = False,
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> QuestionResponse:
|
||||
"""Ask the user a question and wait for their response.
|
||||
|
||||
This method submits the question and then polls for the answer.
|
||||
) -> CreateMessageResponse:
|
||||
"""Send a message to the dashboard.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask the user
|
||||
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification for this question (default: user preference)
|
||||
send_email: Send email notification for this question (default: user preference)
|
||||
send_sms: Send SMS notification for this question (default: user preference)
|
||||
content: The message content (step description or question text)
|
||||
agent_type: Type of agent (required if agent_instance_id not provided)
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
requires_user_input: Whether this message requires user input (default: False)
|
||||
timeout_minutes: If requires_user_input, max time to wait in minutes (default: 1440)
|
||||
poll_interval: If requires_user_input, time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification (default: False for steps, user pref for questions)
|
||||
send_email: Send email notification (default: False for steps, user pref for questions)
|
||||
send_sms: Send SMS notification (default: False for steps, user pref for questions)
|
||||
git_diff: Git diff content to include (optional)
|
||||
|
||||
Returns:
|
||||
QuestionResponse with the user's answer
|
||||
CreateMessageResponse with any queued user messages
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no answer is received within timeout
|
||||
ValueError: If neither agent_type nor agent_instance_id is provided
|
||||
TimeoutError: If requires_user_input and no answer is received within timeout
|
||||
"""
|
||||
# Submit the question
|
||||
data: Dict[str, Any] = {
|
||||
"agent_instance_id": agent_instance_id,
|
||||
"question_text": question_text,
|
||||
}
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
# If no agent_instance_id provided, generate one client-side
|
||||
if not agent_instance_id:
|
||||
if not agent_type:
|
||||
raise ValueError("agent_type is required when creating a new instance")
|
||||
agent_instance_id = uuid.uuid4()
|
||||
|
||||
# First, try the non-blocking endpoint to create the question
|
||||
response = await self._make_request(
|
||||
"POST", "/api/v1/questions", json=data, timeout=5
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
# Build request data using shared utility
|
||||
data = build_message_request_data(
|
||||
content=content,
|
||||
agent_instance_id=agent_instance_id_str,
|
||||
requires_user_input=requires_user_input,
|
||||
agent_type=agent_type,
|
||||
send_push=send_push,
|
||||
send_email=send_email,
|
||||
send_sms=send_sms,
|
||||
git_diff=git_diff,
|
||||
)
|
||||
question_id = response["question_id"]
|
||||
|
||||
# Convert timeout from minutes to seconds
|
||||
# Send the message
|
||||
response = await self._make_request("POST", "/api/v1/messages/agent", json=data)
|
||||
response_agent_instance_id = response["agent_instance_id"]
|
||||
message_id = response["message_id"]
|
||||
|
||||
queued_contents = [
|
||||
msg["content"] if isinstance(msg, dict) else msg
|
||||
for msg in response.get("queued_user_messages", [])
|
||||
]
|
||||
|
||||
create_response = CreateMessageResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response_agent_instance_id,
|
||||
message_id=message_id,
|
||||
queued_user_messages=queued_contents,
|
||||
)
|
||||
|
||||
# If it doesn't require user input, return immediately
|
||||
if not requires_user_input:
|
||||
return create_response
|
||||
|
||||
# Otherwise, poll for the answer
|
||||
# Use the message ID we just created as our starting point
|
||||
last_read_message_id = message_id
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
|
||||
# Poll for the answer
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
|
||||
status = await self.get_question_status(question_id)
|
||||
all_messages = []
|
||||
|
||||
if status.status == "answered" and status.answer:
|
||||
return QuestionResponse(answer=status.answer, question_id=question_id)
|
||||
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages
|
||||
pending_response = await self.get_pending_messages(
|
||||
agent_instance_id_str, last_read_message_id
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all messages
|
||||
all_messages.extend(pending_response.messages)
|
||||
|
||||
# Return the response with all collected messages
|
||||
create_response.queued_user_messages = [
|
||||
msg.content for msg in all_messages
|
||||
]
|
||||
return create_response
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
|
||||
|
||||
async def get_question_status(self, question_id: str) -> QuestionStatus:
|
||||
"""Get the current status of a question.
|
||||
async def get_pending_messages(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
last_read_message_id: Optional[str] = None,
|
||||
) -> PendingMessagesResponse:
|
||||
"""Get pending user messages for an agent instance.
|
||||
|
||||
Args:
|
||||
question_id: ID of the question to check
|
||||
agent_instance_id: Agent instance ID
|
||||
last_read_message_id: The last message ID that was read (optional)
|
||||
|
||||
Returns:
|
||||
QuestionStatus with current status and answer (if available)
|
||||
PendingMessagesResponse with messages and status
|
||||
"""
|
||||
response = await self._make_request("GET", f"/api/v1/questions/{question_id}")
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
return QuestionStatus(
|
||||
question_id=response["question_id"],
|
||||
status=response["status"],
|
||||
answer=response.get("answer"),
|
||||
asked_at=response["asked_at"],
|
||||
answered_at=response.get("answered_at"),
|
||||
params = {"agent_instance_id": agent_instance_id_str}
|
||||
if last_read_message_id:
|
||||
params["last_read_message_id"] = last_read_message_id
|
||||
|
||||
response = await self._make_request(
|
||||
"GET", "/api/v1/messages/pending", params=params
|
||||
)
|
||||
|
||||
async def end_session(self, agent_instance_id: str) -> EndSessionResponse:
|
||||
return PendingMessagesResponse(
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
messages=[Message(**msg) for msg in response["messages"]],
|
||||
status=response["status"],
|
||||
)
|
||||
|
||||
async def send_user_message(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
content: str,
|
||||
mark_as_read: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a user message to an agent instance.
|
||||
|
||||
Args:
|
||||
agent_instance_id: The agent instance ID to send the message to
|
||||
content: Message content
|
||||
mark_as_read: Whether to mark as read (update last_read_message_id) (default: True)
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- success: Whether the message was created
|
||||
- message_id: ID of the created message
|
||||
- marked_as_read: Whether the message was marked as read
|
||||
|
||||
Raises:
|
||||
ValueError: If agent instance not found or access denied
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Validate and convert agent_instance_id
|
||||
agent_instance_id = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data = {
|
||||
"agent_instance_id": str(agent_instance_id),
|
||||
"content": content,
|
||||
"mark_as_read": mark_as_read,
|
||||
}
|
||||
|
||||
return await self._make_request("POST", "/api/v1/messages/user", json=data)
|
||||
|
||||
async def request_user_input(
|
||||
self,
|
||||
message_id: Union[str, uuid.UUID],
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
) -> List[str]:
|
||||
"""Request user input for a previously sent agent message.
|
||||
|
||||
This method updates an agent message to require user input and polls for responses.
|
||||
It's useful when you initially send a message without requiring input, but later
|
||||
decide you need user feedback.
|
||||
|
||||
Args:
|
||||
message_id: The message ID to update (must be an agent message)
|
||||
timeout_minutes: Max time to wait for user response in minutes (default: 1440)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
|
||||
Returns:
|
||||
List of user message contents received as responses
|
||||
|
||||
Raises:
|
||||
ValueError: If message not found, already requires input, or not an agent message
|
||||
TimeoutError: If no user response is received within timeout
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Convert message_id to string if it's a UUID
|
||||
message_id_str = str(message_id)
|
||||
|
||||
# Call the endpoint to update the message
|
||||
response = await self._make_request(
|
||||
"PATCH", f"/api/v1/messages/{message_id_str}/request-input"
|
||||
)
|
||||
|
||||
agent_instance_id = response["agent_instance_id"]
|
||||
messages = response.get("messages", [])
|
||||
|
||||
if messages:
|
||||
return [msg["content"] for msg in messages]
|
||||
|
||||
# Otherwise, poll for user response
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
all_messages = []
|
||||
|
||||
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages using the message_id as last_read
|
||||
pending_response = await self.get_pending_messages(
|
||||
agent_instance_id, message_id_str
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all message contents
|
||||
all_messages.extend([msg.content for msg in pending_response.messages])
|
||||
return all_messages
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"No user response received after {timeout_minutes} minutes")
|
||||
|
||||
async def end_session(
|
||||
self, agent_instance_id: Union[str, uuid.UUID]
|
||||
) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
|
||||
Args:
|
||||
@@ -268,7 +387,10 @@ class AsyncOmnaraClient:
|
||||
Returns:
|
||||
EndSessionResponse with success status and final details
|
||||
"""
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id}
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id_str}
|
||||
response = await self._make_request("POST", "/api/v1/sessions/end", json=data)
|
||||
|
||||
return EndSessionResponse(
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Main client for interacting with the Omnara Agent Dashboard API."""
|
||||
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, Union, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
@@ -10,10 +11,14 @@ from urllib3.util.retry import Retry
|
||||
|
||||
from .exceptions import AuthenticationError, TimeoutError, APIError
|
||||
from .models import (
|
||||
LogStepResponse,
|
||||
QuestionResponse,
|
||||
QuestionStatus,
|
||||
EndSessionResponse,
|
||||
CreateMessageResponse,
|
||||
PendingMessagesResponse,
|
||||
Message,
|
||||
)
|
||||
from .utils import (
|
||||
validate_agent_instance_id,
|
||||
build_message_request_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -63,6 +68,7 @@ class OmnaraClient:
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to the API.
|
||||
@@ -71,6 +77,7 @@ class OmnaraClient:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint path
|
||||
json: JSON body for the request
|
||||
params: Query parameters for the request
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
@@ -86,7 +93,7 @@ class OmnaraClient:
|
||||
|
||||
try:
|
||||
response = self.session.request(
|
||||
method=method, url=url, json=json, timeout=timeout
|
||||
method=method, url=url, json=json, params=params, timeout=timeout
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
@@ -106,138 +113,245 @@ class OmnaraClient:
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
|
||||
def log_step(
|
||||
def send_message(
|
||||
self,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
agent_instance_id: Optional[str] = None,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
|
||||
step_description: Clear description of what the agent is doing
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
send_push: Send push notification for this step (default: False)
|
||||
send_email: Send email notification for this step (default: False)
|
||||
send_sms: Send SMS notification for this step (default: False)
|
||||
git_diff: Git diff content to include with this step (optional)
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance ID, and user feedback
|
||||
"""
|
||||
data: Dict[str, Any] = {
|
||||
"agent_type": agent_type,
|
||||
"step_description": step_description,
|
||||
}
|
||||
if agent_instance_id:
|
||||
data["agent_instance_id"] = agent_instance_id
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
|
||||
response = self._make_request("POST", "/api/v1/steps", json=data)
|
||||
|
||||
return LogStepResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
step_number=response["step_number"],
|
||||
user_feedback=response.get("user_feedback", []),
|
||||
)
|
||||
|
||||
def ask_question(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
content: str,
|
||||
agent_type: Optional[str] = None,
|
||||
agent_instance_id: Optional[Union[str, uuid.UUID]] = None,
|
||||
requires_user_input: bool = False,
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> QuestionResponse:
|
||||
"""Ask the user a question and wait for their response.
|
||||
|
||||
This method submits the question and then polls for the answer.
|
||||
) -> CreateMessageResponse:
|
||||
"""Send a message to the dashboard.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask the user
|
||||
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification for this question (default: user preference)
|
||||
send_email: Send email notification for this question (default: user preference)
|
||||
send_sms: Send SMS notification for this question (default: user preference)
|
||||
git_diff: Git diff content to include with this question (optional)
|
||||
content: The message content (step description or question text)
|
||||
agent_type: Type of agent (required if agent_instance_id not provided)
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
requires_user_input: Whether this message requires user input (default: False)
|
||||
timeout_minutes: If requires_user_input, max time to wait in minutes (default: 1440)
|
||||
poll_interval: If requires_user_input, time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification (default: False for steps, user pref for questions)
|
||||
send_email: Send email notification (default: False for steps, user pref for questions)
|
||||
send_sms: Send SMS notification (default: False for steps, user pref for questions)
|
||||
git_diff: Git diff content to include (optional)
|
||||
|
||||
Returns:
|
||||
QuestionResponse with the user's answer
|
||||
CreateMessageResponse with any queued user messages
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no answer is received within timeout
|
||||
ValueError: If neither agent_type nor agent_instance_id is provided
|
||||
TimeoutError: If requires_user_input and no answer is received within timeout
|
||||
"""
|
||||
# Submit the question
|
||||
data: Dict[str, Any] = {
|
||||
"agent_instance_id": agent_instance_id,
|
||||
"question_text": question_text,
|
||||
}
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
# If no agent_instance_id provided, generate one client-side
|
||||
if not agent_instance_id:
|
||||
if not agent_type:
|
||||
raise ValueError("agent_type is required when creating a new instance")
|
||||
agent_instance_id = uuid.uuid4()
|
||||
|
||||
# First, try the non-blocking endpoint to create the question
|
||||
response = self._make_request("POST", "/api/v1/questions", json=data, timeout=5)
|
||||
question_id = response["question_id"]
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
# Build request data using shared utility
|
||||
data = build_message_request_data(
|
||||
content=content,
|
||||
agent_instance_id=agent_instance_id_str,
|
||||
requires_user_input=requires_user_input,
|
||||
agent_type=agent_type,
|
||||
send_push=send_push,
|
||||
send_email=send_email,
|
||||
send_sms=send_sms,
|
||||
git_diff=git_diff,
|
||||
)
|
||||
|
||||
# Send the message
|
||||
response = self._make_request("POST", "/api/v1/messages/agent", json=data)
|
||||
response_agent_instance_id = response["agent_instance_id"]
|
||||
message_id = response["message_id"]
|
||||
|
||||
queued_contents = [
|
||||
msg["content"] if isinstance(msg, dict) else msg
|
||||
for msg in response.get("queued_user_messages", [])
|
||||
]
|
||||
|
||||
create_response = CreateMessageResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response_agent_instance_id,
|
||||
message_id=message_id,
|
||||
queued_user_messages=queued_contents,
|
||||
)
|
||||
|
||||
# If it doesn't require user input, return immediately with any queued messages
|
||||
if not requires_user_input:
|
||||
return create_response
|
||||
|
||||
# Otherwise, we need to poll for user response
|
||||
# Use the message ID we just created as our starting point
|
||||
last_read_message_id = message_id
|
||||
|
||||
# Convert timeout from minutes to seconds
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
|
||||
# Poll for the answer
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
status = self.get_question_status(question_id)
|
||||
all_messages = []
|
||||
|
||||
if status.status == "answered" and status.answer:
|
||||
return QuestionResponse(answer=status.answer, question_id=question_id)
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages
|
||||
pending_response = self.get_pending_messages(
|
||||
agent_instance_id_str, last_read_message_id
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all messages
|
||||
all_messages.extend(pending_response.messages)
|
||||
|
||||
# Return the response with all collected messages
|
||||
create_response.queued_user_messages = [
|
||||
msg.content for msg in all_messages
|
||||
]
|
||||
return create_response
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
|
||||
|
||||
def get_question_status(self, question_id: str) -> QuestionStatus:
|
||||
"""Get the current status of a question.
|
||||
def get_pending_messages(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
last_read_message_id: Optional[str] = None,
|
||||
) -> PendingMessagesResponse:
|
||||
"""Get pending user messages for an agent instance.
|
||||
|
||||
Args:
|
||||
question_id: ID of the question to check
|
||||
agent_instance_id: Agent instance ID
|
||||
last_read_message_id: The last message ID that was read (optional)
|
||||
|
||||
Returns:
|
||||
QuestionStatus with current status and answer (if available)
|
||||
PendingMessagesResponse with messages and status
|
||||
"""
|
||||
response = self._make_request("GET", f"/api/v1/questions/{question_id}")
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
return QuestionStatus(
|
||||
question_id=response["question_id"],
|
||||
params = {"agent_instance_id": agent_instance_id_str}
|
||||
if last_read_message_id:
|
||||
params["last_read_message_id"] = last_read_message_id
|
||||
|
||||
response = self._make_request("GET", "/api/v1/messages/pending", params=params)
|
||||
|
||||
return PendingMessagesResponse(
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
messages=[Message(**msg) for msg in response["messages"]],
|
||||
status=response["status"],
|
||||
answer=response.get("answer"),
|
||||
asked_at=response["asked_at"],
|
||||
answered_at=response.get("answered_at"),
|
||||
)
|
||||
|
||||
def end_session(self, agent_instance_id: str) -> EndSessionResponse:
|
||||
def send_user_message(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
content: str,
|
||||
mark_as_read: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a user message to an agent instance.
|
||||
|
||||
Args:
|
||||
agent_instance_id: The agent instance ID to send the message to
|
||||
content: Message content
|
||||
mark_as_read: Whether to mark as read (update last_read_message_id) (default: True)
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- success: Whether the message was created
|
||||
- message_id: ID of the created message
|
||||
- marked_as_read: Whether the message was marked as read
|
||||
|
||||
Raises:
|
||||
ValueError: If agent instance not found or access denied
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Validate and convert agent_instance_id
|
||||
agent_instance_id = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data = {
|
||||
"agent_instance_id": str(agent_instance_id),
|
||||
"content": content,
|
||||
"mark_as_read": mark_as_read,
|
||||
}
|
||||
|
||||
return self._make_request("POST", "/api/v1/messages/user", json=data)
|
||||
|
||||
def request_user_input(
|
||||
self,
|
||||
message_id: Union[str, uuid.UUID],
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
) -> List[str]:
|
||||
"""Request user input for a previously sent agent message.
|
||||
|
||||
This method updates an agent message to require user input and polls for responses.
|
||||
It's useful when you initially send a message without requiring input, but later
|
||||
decide you need user feedback.
|
||||
|
||||
Args:
|
||||
message_id: The message ID to update (must be an agent message)
|
||||
timeout_minutes: Max time to wait for user response in minutes (default: 1440)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
|
||||
Returns:
|
||||
List of user message contents received as responses
|
||||
|
||||
Raises:
|
||||
ValueError: If message not found, already requires input, or not an agent message
|
||||
TimeoutError: If no user response is received within timeout
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Convert message_id to string if it's a UUID
|
||||
message_id_str = str(message_id)
|
||||
|
||||
# Call the endpoint to update the message
|
||||
response = self._make_request(
|
||||
"PATCH", f"/api/v1/messages/{message_id_str}/request-input"
|
||||
)
|
||||
|
||||
agent_instance_id = response["agent_instance_id"]
|
||||
messages = response.get("messages", [])
|
||||
|
||||
if messages:
|
||||
return [msg["content"] for msg in messages]
|
||||
|
||||
# Otherwise, poll for user response
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
start_time = time.time()
|
||||
all_messages = []
|
||||
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages using the message_id as last_read
|
||||
pending_response = self.get_pending_messages(
|
||||
agent_instance_id, message_id_str
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all message contents
|
||||
all_messages.extend([msg.content for msg in pending_response.messages])
|
||||
return all_messages
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"No user response received after {timeout_minutes} minutes")
|
||||
|
||||
def end_session(
|
||||
self, agent_instance_id: Union[str, uuid.UUID]
|
||||
) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
|
||||
Args:
|
||||
@@ -246,7 +360,10 @@ class OmnaraClient:
|
||||
Returns:
|
||||
EndSessionResponse with success status and final details
|
||||
"""
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id}
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id_str}
|
||||
response = self._make_request("POST", "/api/v1/sessions/end", json=data)
|
||||
|
||||
return EndSessionResponse(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Data models for the Omnara SDK."""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -14,25 +14,6 @@ class LogStepResponse:
|
||||
user_feedback: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionResponse:
|
||||
"""Response from asking a question."""
|
||||
|
||||
answer: str
|
||||
question_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionStatus:
|
||||
"""Status of a question."""
|
||||
|
||||
question_id: str
|
||||
status: str # 'pending' or 'answered'
|
||||
answer: Optional[str]
|
||||
asked_at: str
|
||||
answered_at: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndSessionResponse:
|
||||
"""Response from ending a session."""
|
||||
@@ -40,3 +21,33 @@ class EndSessionResponse:
|
||||
success: bool
|
||||
agent_instance_id: str
|
||||
final_status: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateMessageResponse:
|
||||
"""Response from creating a message."""
|
||||
|
||||
success: bool
|
||||
agent_instance_id: str
|
||||
message_id: str
|
||||
queued_user_messages: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A message in the conversation."""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
sender_type: str # 'agent' or 'user'
|
||||
created_at: str
|
||||
requires_user_input: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingMessagesResponse:
|
||||
"""Response from getting pending messages."""
|
||||
|
||||
agent_instance_id: str
|
||||
messages: List[Message]
|
||||
status: str # 'ok' or 'stale'
|
||||
|
||||
78
omnara/sdk/utils.py
Normal file
78
omnara/sdk/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Utility functions for the Omnara SDK."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional, Union, Dict, Any
|
||||
|
||||
|
||||
def validate_agent_instance_id(
|
||||
agent_instance_id: Optional[Union[str, uuid.UUID]],
|
||||
) -> str:
|
||||
"""Validate and convert agent_instance_id to string.
|
||||
|
||||
Args:
|
||||
agent_instance_id: UUID string, UUID object, or None
|
||||
|
||||
Returns:
|
||||
Validated UUID string
|
||||
|
||||
Raises:
|
||||
ValueError: If agent_instance_id is not a valid UUID
|
||||
"""
|
||||
if agent_instance_id is None:
|
||||
raise ValueError("agent_instance_id cannot be None")
|
||||
|
||||
if isinstance(agent_instance_id, str):
|
||||
try:
|
||||
# Validate it's a valid UUID
|
||||
uuid.UUID(agent_instance_id)
|
||||
return agent_instance_id
|
||||
except ValueError:
|
||||
raise ValueError("agent_instance_id must be a valid UUID string")
|
||||
elif isinstance(agent_instance_id, uuid.UUID):
|
||||
return str(agent_instance_id)
|
||||
else:
|
||||
raise ValueError("agent_instance_id must be a string or UUID object")
|
||||
|
||||
|
||||
def build_message_request_data(
|
||||
content: str,
|
||||
agent_instance_id: str,
|
||||
requires_user_input: bool,
|
||||
agent_type: Optional[str] = None,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build request data for creating a message.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
agent_instance_id: Agent instance ID (already validated)
|
||||
requires_user_input: Whether message requires user input
|
||||
agent_type: Optional agent type
|
||||
send_push: Optional push notification flag
|
||||
send_email: Optional email notification flag
|
||||
send_sms: Optional SMS notification flag
|
||||
git_diff: Optional git diff content
|
||||
|
||||
Returns:
|
||||
Dictionary of request data
|
||||
"""
|
||||
data: Dict[str, Any] = {
|
||||
"content": content,
|
||||
"requires_user_input": requires_user_input,
|
||||
"agent_instance_id": agent_instance_id,
|
||||
}
|
||||
if agent_type:
|
||||
data["agent_type"] = agent_type
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
|
||||
return data
|
||||
Reference in New Issue
Block a user