mirror of
https://github.com/omnara-ai/omnara.git
synced 2025-08-12 20:39:09 +03:00
better security (#46)
* better security * working * clean up * changes * raise exceptions * changes * fix stdio * clean up --------- Co-authored-by: Kartik Sarangmath <kartiksarangmath@Kartiks-MacBook-Air.local>
This commit is contained in:
@@ -243,6 +243,46 @@ async def revoke_api_key(
|
||||
return {"message": "API key revoked successfully"}
|
||||
|
||||
|
||||
@router.post("/cli-key", response_model=APIKeyResponse)
|
||||
async def create_cli_key(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new CLI-specific API key for the current user"""
|
||||
|
||||
# Always generate a new CLI key
|
||||
try:
|
||||
jwt_token = create_api_key_jwt(
|
||||
user_id=str(current_user.id),
|
||||
expires_in_days=None, # No expiration for CLI keys
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to generate API key: {str(e)}"
|
||||
)
|
||||
|
||||
# Store the new CLI key
|
||||
api_key = APIKey(
|
||||
user_id=current_user.id,
|
||||
name="CLI Key",
|
||||
api_key_hash=get_token_hash(jwt_token),
|
||||
api_key=jwt_token,
|
||||
expires_at=None, # No expiration
|
||||
)
|
||||
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
return APIKeyResponse(
|
||||
id=str(api_key.id),
|
||||
name=api_key.name,
|
||||
api_key=jwt_token,
|
||||
created_at=api_key.created_at.isoformat(),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/me")
|
||||
async def delete_user_account(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -103,7 +103,7 @@ function createServerConfig(
|
||||
default:
|
||||
return {
|
||||
command: "pipx",
|
||||
args: ["run", "--no-cache", "omnara", "--api-key", apiKey],
|
||||
args: ["run", "--no-cache", "omnara", "--stdio", "--api-key", apiKey],
|
||||
env: {
|
||||
OMNARA_CLIENT_TYPE: client
|
||||
},
|
||||
|
||||
392
omnara/cli.py
392
omnara/cli.py
@@ -8,6 +8,306 @@ This is the main entry point for the omnara command that dispatches to either:
|
||||
import argparse
|
||||
import sys
|
||||
import subprocess
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import webbrowser
|
||||
import urllib.parse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
import secrets
|
||||
import requests
|
||||
import time
|
||||
import threading
|
||||
|
||||
|
||||
def get_current_version():
|
||||
"""Get the current installed version of omnara"""
|
||||
try:
|
||||
from omnara import __version__
|
||||
|
||||
return __version__
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def check_for_updates():
|
||||
"""Check PyPI for a newer version of omnara"""
|
||||
try:
|
||||
response = requests.get("https://pypi.org/pypi/omnara/json", timeout=2)
|
||||
latest_version = response.json()["info"]["version"]
|
||||
current_version = get_current_version()
|
||||
|
||||
if latest_version != current_version and current_version != "unknown":
|
||||
print(f"\n✨ Update available: {current_version} → {latest_version}")
|
||||
print(" Run: pip install --upgrade omnara\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_credentials_path():
|
||||
"""Get the path to the credentials file"""
|
||||
config_dir = Path.home() / ".omnara"
|
||||
return config_dir / "credentials.json"
|
||||
|
||||
|
||||
def load_stored_api_key():
|
||||
"""Load API key from credentials file if it exists"""
|
||||
credentials_path = get_credentials_path()
|
||||
|
||||
if not credentials_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(credentials_path, "r") as f:
|
||||
data = json.load(f)
|
||||
api_key = data.get("write_key")
|
||||
if api_key and isinstance(api_key, str):
|
||||
return api_key
|
||||
else:
|
||||
print("Warning: Invalid API key format in credentials file.")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
"Warning: Corrupted credentials file. Please re-authenticate with --reauth."
|
||||
)
|
||||
return None
|
||||
except (KeyError, IOError) as e:
|
||||
print(f"Warning: Error reading credentials file: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def save_api_key(api_key):
|
||||
"""Save API key to credentials file"""
|
||||
credentials_path = get_credentials_path()
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
credentials_path.parent.mkdir(mode=0o700, exist_ok=True)
|
||||
|
||||
# Save the API key
|
||||
data = {"write_key": api_key}
|
||||
with open(credentials_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
# Set file permissions to 600 (read/write for owner only)
|
||||
os.chmod(credentials_path, 0o600)
|
||||
|
||||
|
||||
class AuthHTTPServer(HTTPServer):
|
||||
"""Custom HTTP server with attributes for authentication"""
|
||||
|
||||
api_key: str | None
|
||||
state: str | None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = None
|
||||
self.state = None
|
||||
|
||||
|
||||
class AuthCallbackHandler(BaseHTTPRequestHandler):
|
||||
"""HTTP handler for the OAuth callback"""
|
||||
|
||||
def log_message(self, format, *args):
|
||||
# Suppress default logging
|
||||
pass
|
||||
|
||||
def do_GET(self):
|
||||
# Parse query parameters
|
||||
if "?" in self.path:
|
||||
query_string = self.path.split("?", 1)[1]
|
||||
params = urllib.parse.parse_qs(query_string)
|
||||
|
||||
# Verify state parameter
|
||||
server: AuthHTTPServer = self.server # type: ignore
|
||||
if "state" in params and params["state"][0] == server.state:
|
||||
if "api_key" in params:
|
||||
api_key = params["api_key"][0]
|
||||
# Store the API key in the server instance
|
||||
server.api_key = api_key
|
||||
print("\n✓ Authentication successful!")
|
||||
|
||||
# Send success response with nice styling
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"""
|
||||
<html>
|
||||
<head>
|
||||
<title>Omnara CLI - Authentication Successful</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
min-height: 100vh;
|
||||
background: linear-gradient(135deg, #1a1618 0%, #2a1f3d 100%);
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: #fef3c7;
|
||||
}
|
||||
.card {
|
||||
background: rgba(26, 22, 24, 0.8);
|
||||
border: 1px solid rgba(245, 158, 11, 0.2);
|
||||
border-radius: 12px;
|
||||
padding: 48px;
|
||||
text-align: center;
|
||||
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3),
|
||||
0 0 60px rgba(245, 158, 11, 0.1);
|
||||
max-width: 400px;
|
||||
animation: fadeIn 0.5s ease-out;
|
||||
}
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
.icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 24px;
|
||||
background: rgba(134, 239, 172, 0.2);
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
animation: scaleIn 0.5s ease-out 0.2s both;
|
||||
}
|
||||
@keyframes scaleIn {
|
||||
from { transform: scale(0); }
|
||||
to { transform: scale(1); }
|
||||
}
|
||||
.checkmark {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
stroke: #86efac;
|
||||
stroke-width: 3;
|
||||
fill: none;
|
||||
stroke-dasharray: 100;
|
||||
stroke-dashoffset: 100;
|
||||
animation: draw 0.5s ease-out 0.5s forwards;
|
||||
}
|
||||
@keyframes draw {
|
||||
to { stroke-dashoffset: 0; }
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 16px;
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #86efac;
|
||||
}
|
||||
p {
|
||||
margin: 0;
|
||||
opacity: 0.8;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.close-hint {
|
||||
margin-top: 24px;
|
||||
font-size: 14px;
|
||||
opacity: 0.6;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<div class="icon">
|
||||
<svg class="checkmark" viewBox="0 0 24 24">
|
||||
<path d="M20 6L9 17l-5-5" />
|
||||
</svg>
|
||||
</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p>Your CLI has been authorized to access Omnara.</p>
|
||||
<p class="close-hint">You can now close this window and return to your terminal.</p>
|
||||
</div>
|
||||
<script>
|
||||
setTimeout(() => { window.close(); }, 2000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
return
|
||||
else:
|
||||
# Invalid or missing state parameter
|
||||
self.send_response(403)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"""
|
||||
<html>
|
||||
<head><title>Omnara CLI - Authentication Failed</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>Authentication Failed</h1>
|
||||
<p>Invalid authentication state. Please try again.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
return
|
||||
|
||||
# Send error response
|
||||
self.send_response(400)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"""
|
||||
<html>
|
||||
<head><title>Omnara CLI - Authentication Failed</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>Authentication Failed</h1>
|
||||
<p>No API key was received. Please try again.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
|
||||
def authenticate_via_browser(auth_url="https://omnara.com"):
|
||||
"""Authenticate via browser and return the API key"""
|
||||
|
||||
# Generate a secure random state parameter
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Start local server to receive the callback
|
||||
server = AuthHTTPServer(("localhost", 0), AuthCallbackHandler)
|
||||
server.state = state
|
||||
server.api_key = None
|
||||
port = server.server_port
|
||||
|
||||
# Construct the auth URL
|
||||
auth_base = auth_url.rstrip("/")
|
||||
callback_url = f"http://localhost:{port}"
|
||||
auth_url = f"{auth_base}/cli-auth?callback={urllib.parse.quote(callback_url)}&state={urllib.parse.quote(state)}"
|
||||
|
||||
print("\nOpening browser for authentication...")
|
||||
print("If your browser doesn't open automatically, please click this link:")
|
||||
print(f"\n {auth_url}\n")
|
||||
print("Waiting for authentication...")
|
||||
|
||||
# Run server in a thread
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
|
||||
# Open browser
|
||||
try:
|
||||
webbrowser.open(auth_url)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Wait for authentication (with timeout)
|
||||
start_time = time.time()
|
||||
while not server.api_key and (time.time() - start_time) < 300:
|
||||
time.sleep(0.1)
|
||||
|
||||
# Shutdown server in a separate thread to avoid deadlock
|
||||
def shutdown_server():
|
||||
server.shutdown()
|
||||
|
||||
shutdown_thread = threading.Thread(target=shutdown_server)
|
||||
shutdown_thread.start()
|
||||
shutdown_thread.join(timeout=1) # Wait max 1 second for shutdown
|
||||
|
||||
server.server_close()
|
||||
|
||||
if server.api_key:
|
||||
return server.api_key
|
||||
else:
|
||||
raise Exception("Authentication failed - no API key received")
|
||||
|
||||
|
||||
def run_stdio_server(args):
|
||||
@@ -87,10 +387,13 @@ def main():
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Run MCP stdio server (default)
|
||||
# Run Claude wrapper (default)
|
||||
omnara --api-key YOUR_API_KEY
|
||||
|
||||
# Run MCP stdio server explicitly
|
||||
# Run Claude wrapper with custom base URL
|
||||
omnara --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run MCP stdio server
|
||||
omnara --stdio --api-key YOUR_API_KEY
|
||||
|
||||
# Run Claude Code webhook server
|
||||
@@ -102,17 +405,14 @@ Examples:
|
||||
# Run webhook server on custom port
|
||||
omnara --claude-code-webhook --port 8080
|
||||
|
||||
# Run Claude wrapper V3
|
||||
omnara --claude --api-key YOUR_API_KEY
|
||||
|
||||
# Run Claude wrapper with custom base URL
|
||||
omnara --claude --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run with custom API base URL
|
||||
omnara --stdio --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run with custom frontend URL for authentication
|
||||
omnara --auth-url http://localhost:3000
|
||||
|
||||
# Run with git diff capture enabled
|
||||
omnara --api-key YOUR_API_KEY --git-diff
|
||||
omnara --stdio --api-key YOUR_API_KEY --git-diff
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -121,7 +421,7 @@ Examples:
|
||||
mode_group.add_argument(
|
||||
"--stdio",
|
||||
action="store_true",
|
||||
help="Run the MCP stdio server (default if no mode specified)",
|
||||
help="Run the MCP stdio server",
|
||||
)
|
||||
mode_group.add_argument(
|
||||
"--claude-code-webhook",
|
||||
@@ -131,7 +431,7 @@ Examples:
|
||||
mode_group.add_argument(
|
||||
"--claude",
|
||||
action="store_true",
|
||||
help="Run the Claude wrapper V3 for Omnara integration",
|
||||
help="Run the Claude wrapper V3 for Omnara integration (default if no mode specified)",
|
||||
)
|
||||
|
||||
# Arguments for webhook mode
|
||||
@@ -153,13 +453,26 @@ Examples:
|
||||
|
||||
# Arguments for stdio mode
|
||||
parser.add_argument(
|
||||
"--api-key", help="API key for authentication (required for stdio mode)"
|
||||
"--api-key", help="API key for authentication (uses stored key if not provided)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reauth",
|
||||
action="store_true",
|
||||
help="Force re-authentication even if API key exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version", action="store_true", help="Show the current version of omnara"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default="https://agent-dashboard-mcp.onrender.com",
|
||||
help="Base URL of the Omnara API server (stdio mode only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auth-url",
|
||||
default="https://omnara.com",
|
||||
help="Base URL of the Omnara frontend for authentication (default: https://omnara.com)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--claude-code-permission-tool",
|
||||
action="store_true",
|
||||
@@ -179,26 +492,69 @@ Examples:
|
||||
# Use parse_known_args to capture remaining args for Claude
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
# Handle --version flag
|
||||
if args.version:
|
||||
print(f"omnara version {get_current_version()}")
|
||||
sys.exit(0)
|
||||
|
||||
# Check for updates (only when running actual commands, not --version)
|
||||
check_for_updates()
|
||||
|
||||
if args.cloudflare_tunnel and not args.claude_code_webhook:
|
||||
parser.error("--cloudflare-tunnel can only be used with --claude-code-webhook")
|
||||
|
||||
if args.port is not None and not args.claude_code_webhook:
|
||||
parser.error("--port can only be used with --claude-code-webhook")
|
||||
|
||||
# Handle re-authentication
|
||||
if args.reauth:
|
||||
try:
|
||||
print("Re-authenticating...")
|
||||
api_key = authenticate_via_browser(args.auth_url)
|
||||
save_api_key(api_key)
|
||||
args.api_key = api_key
|
||||
print("Re-authentication successful! API key saved.")
|
||||
except Exception as e:
|
||||
parser.error(f"Re-authentication failed: {str(e)}")
|
||||
else:
|
||||
# Load API key from storage if not provided
|
||||
api_key = args.api_key
|
||||
if not api_key and (args.stdio or not args.claude_code_webhook):
|
||||
api_key = load_stored_api_key()
|
||||
|
||||
# Update args with the loaded API key
|
||||
if api_key and not args.api_key:
|
||||
args.api_key = api_key
|
||||
|
||||
if args.claude_code_webhook:
|
||||
run_webhook_server(
|
||||
cloudflare_tunnel=args.cloudflare_tunnel,
|
||||
dangerously_skip_permissions=args.dangerously_skip_permissions,
|
||||
port=args.port,
|
||||
)
|
||||
elif args.claude:
|
||||
elif args.stdio:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for --claude mode")
|
||||
run_claude_wrapper(args.api_key, args.base_url, unknown_args)
|
||||
else:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for stdio mode")
|
||||
try:
|
||||
print("No API key found. Starting authentication...")
|
||||
api_key = authenticate_via_browser(args.auth_url)
|
||||
save_api_key(api_key)
|
||||
args.api_key = api_key
|
||||
print("Authentication successful! API key saved.")
|
||||
except Exception as e:
|
||||
parser.error(f"Authentication failed: {str(e)}")
|
||||
run_stdio_server(args)
|
||||
else:
|
||||
# Default to Claude mode when no mode is specified
|
||||
if not args.api_key:
|
||||
try:
|
||||
print("No API key found. Starting authentication...")
|
||||
api_key = authenticate_via_browser(args.auth_url)
|
||||
save_api_key(api_key)
|
||||
args.api_key = api_key
|
||||
print("Authentication successful! API key saved.")
|
||||
except Exception as e:
|
||||
parser.error(f"Authentication failed: {str(e)}")
|
||||
run_claude_wrapper(args.api_key, args.base_url, unknown_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -65,10 +65,17 @@ class AsyncOmnaraClient:
|
||||
# This fixes SSL verification issues with aiohttp on some systems
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
# Configure connector
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=ssl_context,
|
||||
limit=100,
|
||||
ttl_dns_cache=300,
|
||||
)
|
||||
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
connector=aiohttp.TCPConnector(ssl=ssl_context),
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
@@ -84,7 +91,7 @@ class AsyncOmnaraClient:
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an async HTTP request to the API.
|
||||
"""Make an async HTTP request to the API with retry logic.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
@@ -105,39 +112,73 @@ class AsyncOmnaraClient:
|
||||
assert self.session is not None
|
||||
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
|
||||
# Override timeout if specified
|
||||
request_timeout = ClientTimeout(total=timeout) if timeout else self.timeout
|
||||
|
||||
try:
|
||||
async with self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
timeout=request_timeout,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise AuthenticationError(
|
||||
"Invalid API key or authentication failed"
|
||||
)
|
||||
# Retry configuration to match urllib3
|
||||
max_retries = 6 # Total attempts (1 initial + 5 retries)
|
||||
backoff_factor = 1.0
|
||||
status_forcelist = {429, 500, 502, 503, 504}
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = await response.json()
|
||||
error_detail = error_data.get("detail", await response.text())
|
||||
except Exception:
|
||||
error_detail = await response.text()
|
||||
raise APIError(response.status, error_detail)
|
||||
last_error = None
|
||||
|
||||
return await response.json()
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
timeout=request_timeout,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise AuthenticationError(
|
||||
"Invalid API key or authentication failed"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Request timed out after {timeout or self.timeout.total} seconds"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = await response.json()
|
||||
error_detail = error_data.get(
|
||||
"detail", await response.text()
|
||||
)
|
||||
except Exception:
|
||||
error_detail = await response.text()
|
||||
|
||||
# Check if we should retry this status code
|
||||
if response.status in status_forcelist:
|
||||
last_error = APIError(response.status, error_detail)
|
||||
# Continue to retry logic below
|
||||
else:
|
||||
# Don't retry client errors
|
||||
raise APIError(response.status, error_detail)
|
||||
else:
|
||||
# Success!
|
||||
return await response.json()
|
||||
|
||||
except (aiohttp.ClientConnectionError, aiohttp.ClientError) as e:
|
||||
# Connection errors - retry these
|
||||
last_error = APIError(0, f"Request failed: {str(e)}")
|
||||
except asyncio.TimeoutError:
|
||||
last_error = TimeoutError(
|
||||
f"Request timed out after {timeout or self.timeout.total} seconds"
|
||||
)
|
||||
except (AuthenticationError, APIError) as e:
|
||||
if isinstance(e, APIError) and e.status_code in status_forcelist:
|
||||
last_error = e
|
||||
else:
|
||||
# Don't retry auth errors or client errors
|
||||
raise
|
||||
|
||||
# If this is not the last attempt, sleep before retrying
|
||||
if attempt < max_retries - 1 and last_error:
|
||||
sleep_time = min(backoff_factor * (2**attempt), 60.0)
|
||||
await asyncio.sleep(sleep_time)
|
||||
elif last_error:
|
||||
# Last attempt failed, raise the error
|
||||
raise last_error
|
||||
|
||||
# Should never reach here
|
||||
raise APIError(0, "Unexpected retry exhaustion")
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
|
||||
@@ -41,10 +41,20 @@ class OmnaraClient:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retries
|
||||
# Set up session with urllib3 retry strategy
|
||||
self.session = requests.Session()
|
||||
|
||||
# Configure retry strategy
|
||||
retry_strategy = Retry(
|
||||
total=3, backoff_factor=0.3, status_forcelist=[500, 502, 503, 504]
|
||||
total=5, # Total number of retries
|
||||
backoff_factor=1.0, # Exponential backoff: 1s, 2s, 4s, 8s, 16s
|
||||
status_forcelist=[429, 500, 502, 503, 504], # Retry on these HTTP codes
|
||||
allowed_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
raise_on_status=False,
|
||||
# Important: retry on connection errors
|
||||
connect=5, # Number of connection-related errors to retry
|
||||
read=5, # Number of read errors to retry
|
||||
other=5, # Number of other errors to retry
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session.mount("http://", adapter)
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "omnara"
|
||||
version = "1.3.16"
|
||||
version = "1.3.17"
|
||||
description = "Omnara Agent Dashboard - MCP Server and Python SDK"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -8,9 +8,11 @@ This server provides:
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
import traceback
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import sentry_sdk
|
||||
from shared.config import settings
|
||||
|
||||
@@ -68,6 +70,23 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""Global exception handler that logs all unhandled exceptions."""
|
||||
# Log the error with full traceback
|
||||
logger.error(f"Unhandled exception in {request.method} {request.url.path}")
|
||||
logger.error(f"Exception: {type(exc).__name__}: {str(exc)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Re-raise HTTPExceptions to preserve their status codes
|
||||
if isinstance(exc, HTTPException):
|
||||
raise exc
|
||||
|
||||
# For all other exceptions, return 500
|
||||
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
|
||||
|
||||
|
||||
app.include_router(agent_router, prefix="/api/v1")
|
||||
app.mount("/mcp", mcp_app)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""API routes for agent operations."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
@@ -28,6 +29,7 @@ from .models import (
|
||||
)
|
||||
|
||||
agent_router = APIRouter(tags=["agents"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@agent_router.post("/messages/agent", response_model=CreateMessageResponse)
|
||||
@@ -89,11 +91,15 @@ async def create_agent_message_endpoint(
|
||||
except ValueError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except HTTPException:
|
||||
db.rollback()
|
||||
raise # Re-raise HTTPExceptions (including UsageLimitError) with their original status
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error in create_agent_message_endpoint: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Internal server error: {str(e)}",
|
||||
detail="Internal server error",
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -813,6 +813,7 @@ async def start_claude(
|
||||
"run",
|
||||
"--no-cache",
|
||||
"omnara",
|
||||
"--stdio",
|
||||
"--api-key",
|
||||
omnara_api_key,
|
||||
"--claude-code-permission-tool",
|
||||
|
||||
@@ -28,11 +28,12 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from omnara.sdk.async_client import AsyncOmnaraClient
|
||||
from omnara.sdk.client import OmnaraClient
|
||||
from omnara.sdk.exceptions import AuthenticationError, APIError
|
||||
|
||||
|
||||
# Constants
|
||||
CLAUDE_LOG_BASE = Path.home() / ".claude" / "projects"
|
||||
WRAPPER_DEBUG_LOG = Path.home() / ".claude" / "wrapper_v3_debug.log"
|
||||
OMNARA_WRAPPER_LOG_DIR = Path.home() / ".omnara" / "claude_wrapper"
|
||||
|
||||
|
||||
class MessageProcessor:
|
||||
@@ -191,8 +192,9 @@ class ClaudeWrapperV3:
|
||||
def _init_logging(self):
|
||||
"""Initialize debug logging"""
|
||||
try:
|
||||
WRAPPER_DEBUG_LOG.parent.mkdir(exist_ok=True, parents=True)
|
||||
self.debug_log_file = open(WRAPPER_DEBUG_LOG, "w")
|
||||
OMNARA_WRAPPER_LOG_DIR.mkdir(exist_ok=True, parents=True)
|
||||
log_file_path = OMNARA_WRAPPER_LOG_DIR / f"{self.session_uuid}.log"
|
||||
self.debug_log_file = open(log_file_path, "w")
|
||||
self.log(
|
||||
f"=== Claude Wrapper V3 Debug Log - {time.strftime('%Y-%m-%d %H:%M:%S')} ==="
|
||||
)
|
||||
@@ -1111,17 +1113,17 @@ class ClaudeWrapperV3:
|
||||
"""Run Claude with Omnara integration (main entry point)"""
|
||||
self.log("[INFO] Starting run() method")
|
||||
|
||||
# Initialize Omnara clients (sync)
|
||||
self.log("[INFO] Initializing Omnara clients...")
|
||||
self.init_omnara_clients()
|
||||
self.log("[INFO] Omnara clients initialized")
|
||||
|
||||
# Create initial session (sync)
|
||||
self.log("[INFO] Creating initial Omnara session...")
|
||||
try:
|
||||
# Initialize Omnara clients (sync)
|
||||
self.log("[INFO] Initializing Omnara clients...")
|
||||
self.init_omnara_clients()
|
||||
self.log("[INFO] Omnara clients initialized")
|
||||
|
||||
# Create initial session (sync)
|
||||
self.log("[INFO] Creating initial Omnara session...")
|
||||
if self.omnara_client_sync:
|
||||
response = self.omnara_client_sync.send_message(
|
||||
content="Claude wrapper V3 session started - waiting for your input...",
|
||||
content="Claude Code session started - waiting for your input...",
|
||||
agent_type="Claude Code",
|
||||
requires_user_input=False,
|
||||
)
|
||||
@@ -1132,8 +1134,72 @@ class ClaudeWrapperV3:
|
||||
if hasattr(self.message_processor, "last_message_id"):
|
||||
self.message_processor.last_message_id = response.message_id
|
||||
self.message_processor.last_message_time = time.time()
|
||||
except AuthenticationError as e:
|
||||
# Log the error
|
||||
self.log(f"[ERROR] Authentication failed: {e}")
|
||||
|
||||
# Print user-friendly error message
|
||||
print(
|
||||
"\nError: Authentication failed. Please check for valid Omnara API key in ~/.omnara/credentials.json.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Clean up and exit
|
||||
if self.omnara_client_sync:
|
||||
self.omnara_client_sync.close()
|
||||
if self.debug_log_file:
|
||||
self.debug_log_file.close()
|
||||
sys.exit(1)
|
||||
|
||||
except APIError as e:
|
||||
# Log the error
|
||||
self.log(f"[ERROR] API error: {e}")
|
||||
|
||||
# Print user-friendly error message based on status code
|
||||
if e.status_code >= 500:
|
||||
print(
|
||||
"\nError: Omnara server error. Please try again later.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
elif e.status_code == 404:
|
||||
print(
|
||||
"\nError: Omnara endpoint not found. Please check your base URL.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
print(f"\nError: Omnara API error: {e}", file=sys.stderr)
|
||||
|
||||
# Clean up and exit
|
||||
if self.omnara_client_sync:
|
||||
self.omnara_client_sync.close()
|
||||
if self.debug_log_file:
|
||||
self.debug_log_file.close()
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
self.log(f"[ERROR] Failed to create initial session: {e}")
|
||||
# Log the error
|
||||
self.log(f"[ERROR] Failed to initialize Omnara connection: {e}")
|
||||
|
||||
# Print user-friendly error message
|
||||
error_msg = str(e)
|
||||
if "connection" in error_msg.lower() or "refused" in error_msg.lower():
|
||||
print("\nError: Could not connect to Omnara server.", file=sys.stderr)
|
||||
print(
|
||||
"Please check your internet connection and try again.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\nError: Failed to connect to Omnara: {error_msg}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Clean up and exit
|
||||
if self.omnara_client_sync:
|
||||
self.omnara_client_sync.close()
|
||||
if self.debug_log_file:
|
||||
self.debug_log_file.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Start Claude in PTY (in thread)
|
||||
claude_thread = threading.Thread(target=self.run_claude_with_pty)
|
||||
@@ -1185,6 +1251,10 @@ class ClaudeWrapperV3:
|
||||
self.log("=== Claude Wrapper V3 Log Ended ===")
|
||||
self.debug_log_file.close()
|
||||
|
||||
# Only print exit message if we're exiting normally (not due to errors)
|
||||
if not sys.exc_info()[0]:
|
||||
print("\nEnded Omnara Claude Session", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
@@ -1204,10 +1274,14 @@ def main():
|
||||
wrapper = ClaudeWrapperV3(api_key=args.api_key, base_url=args.base_url)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
# Just set the flag and let the finally block handle cleanup
|
||||
wrapper.running = False
|
||||
if wrapper.original_tty_attrs:
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, wrapper.original_tty_attrs)
|
||||
sys.exit(0)
|
||||
if wrapper.child_pid:
|
||||
try:
|
||||
# Kill Claude process to trigger exit
|
||||
os.kill(wrapper.child_pid, signal.SIGTERM)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def handle_resize(sig, frame):
|
||||
"""Handle terminal resize signal"""
|
||||
@@ -1230,8 +1304,6 @@ def main():
|
||||
|
||||
try:
|
||||
wrapper.run()
|
||||
except KeyboardInterrupt:
|
||||
signal_handler(None, None)
|
||||
except Exception as e:
|
||||
# Fatal errors still go to stderr
|
||||
print(f"Fatal error: {e}", file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user