mirror of
https://github.com/omnara-ai/omnara.git
synced 2025-08-12 20:39:09 +03:00
better security (#46)
* better security * working * clean up * changes * raise exceptions * changes * fix stdio * clean up --------- Co-authored-by: Kartik Sarangmath <kartiksarangmath@Kartiks-MacBook-Air.local>
This commit is contained in:
392
omnara/cli.py
392
omnara/cli.py
@@ -8,6 +8,306 @@ This is the main entry point for the omnara command that dispatches to either:
|
||||
import argparse
|
||||
import sys
|
||||
import subprocess
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import webbrowser
|
||||
import urllib.parse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
import secrets
|
||||
import requests
|
||||
import time
|
||||
import threading
|
||||
|
||||
|
||||
def get_current_version():
|
||||
"""Get the current installed version of omnara"""
|
||||
try:
|
||||
from omnara import __version__
|
||||
|
||||
return __version__
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def check_for_updates():
|
||||
"""Check PyPI for a newer version of omnara"""
|
||||
try:
|
||||
response = requests.get("https://pypi.org/pypi/omnara/json", timeout=2)
|
||||
latest_version = response.json()["info"]["version"]
|
||||
current_version = get_current_version()
|
||||
|
||||
if latest_version != current_version and current_version != "unknown":
|
||||
print(f"\n✨ Update available: {current_version} → {latest_version}")
|
||||
print(" Run: pip install --upgrade omnara\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_credentials_path():
|
||||
"""Get the path to the credentials file"""
|
||||
config_dir = Path.home() / ".omnara"
|
||||
return config_dir / "credentials.json"
|
||||
|
||||
|
||||
def load_stored_api_key():
|
||||
"""Load API key from credentials file if it exists"""
|
||||
credentials_path = get_credentials_path()
|
||||
|
||||
if not credentials_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(credentials_path, "r") as f:
|
||||
data = json.load(f)
|
||||
api_key = data.get("write_key")
|
||||
if api_key and isinstance(api_key, str):
|
||||
return api_key
|
||||
else:
|
||||
print("Warning: Invalid API key format in credentials file.")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
"Warning: Corrupted credentials file. Please re-authenticate with --reauth."
|
||||
)
|
||||
return None
|
||||
except (KeyError, IOError) as e:
|
||||
print(f"Warning: Error reading credentials file: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def save_api_key(api_key):
|
||||
"""Save API key to credentials file"""
|
||||
credentials_path = get_credentials_path()
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
credentials_path.parent.mkdir(mode=0o700, exist_ok=True)
|
||||
|
||||
# Save the API key
|
||||
data = {"write_key": api_key}
|
||||
with open(credentials_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
# Set file permissions to 600 (read/write for owner only)
|
||||
os.chmod(credentials_path, 0o600)
|
||||
|
||||
|
||||
class AuthHTTPServer(HTTPServer):
|
||||
"""Custom HTTP server with attributes for authentication"""
|
||||
|
||||
api_key: str | None
|
||||
state: str | None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = None
|
||||
self.state = None
|
||||
|
||||
|
||||
class AuthCallbackHandler(BaseHTTPRequestHandler):
|
||||
"""HTTP handler for the OAuth callback"""
|
||||
|
||||
def log_message(self, format, *args):
|
||||
# Suppress default logging
|
||||
pass
|
||||
|
||||
def do_GET(self):
|
||||
# Parse query parameters
|
||||
if "?" in self.path:
|
||||
query_string = self.path.split("?", 1)[1]
|
||||
params = urllib.parse.parse_qs(query_string)
|
||||
|
||||
# Verify state parameter
|
||||
server: AuthHTTPServer = self.server # type: ignore
|
||||
if "state" in params and params["state"][0] == server.state:
|
||||
if "api_key" in params:
|
||||
api_key = params["api_key"][0]
|
||||
# Store the API key in the server instance
|
||||
server.api_key = api_key
|
||||
print("\n✓ Authentication successful!")
|
||||
|
||||
# Send success response with nice styling
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"""
|
||||
<html>
|
||||
<head>
|
||||
<title>Omnara CLI - Authentication Successful</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
min-height: 100vh;
|
||||
background: linear-gradient(135deg, #1a1618 0%, #2a1f3d 100%);
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: #fef3c7;
|
||||
}
|
||||
.card {
|
||||
background: rgba(26, 22, 24, 0.8);
|
||||
border: 1px solid rgba(245, 158, 11, 0.2);
|
||||
border-radius: 12px;
|
||||
padding: 48px;
|
||||
text-align: center;
|
||||
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3),
|
||||
0 0 60px rgba(245, 158, 11, 0.1);
|
||||
max-width: 400px;
|
||||
animation: fadeIn 0.5s ease-out;
|
||||
}
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
.icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 24px;
|
||||
background: rgba(134, 239, 172, 0.2);
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
animation: scaleIn 0.5s ease-out 0.2s both;
|
||||
}
|
||||
@keyframes scaleIn {
|
||||
from { transform: scale(0); }
|
||||
to { transform: scale(1); }
|
||||
}
|
||||
.checkmark {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
stroke: #86efac;
|
||||
stroke-width: 3;
|
||||
fill: none;
|
||||
stroke-dasharray: 100;
|
||||
stroke-dashoffset: 100;
|
||||
animation: draw 0.5s ease-out 0.5s forwards;
|
||||
}
|
||||
@keyframes draw {
|
||||
to { stroke-dashoffset: 0; }
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 16px;
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #86efac;
|
||||
}
|
||||
p {
|
||||
margin: 0;
|
||||
opacity: 0.8;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.close-hint {
|
||||
margin-top: 24px;
|
||||
font-size: 14px;
|
||||
opacity: 0.6;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<div class="icon">
|
||||
<svg class="checkmark" viewBox="0 0 24 24">
|
||||
<path d="M20 6L9 17l-5-5" />
|
||||
</svg>
|
||||
</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p>Your CLI has been authorized to access Omnara.</p>
|
||||
<p class="close-hint">You can now close this window and return to your terminal.</p>
|
||||
</div>
|
||||
<script>
|
||||
setTimeout(() => { window.close(); }, 2000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
return
|
||||
else:
|
||||
# Invalid or missing state parameter
|
||||
self.send_response(403)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"""
|
||||
<html>
|
||||
<head><title>Omnara CLI - Authentication Failed</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>Authentication Failed</h1>
|
||||
<p>Invalid authentication state. Please try again.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
return
|
||||
|
||||
# Send error response
|
||||
self.send_response(400)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"""
|
||||
<html>
|
||||
<head><title>Omnara CLI - Authentication Failed</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>Authentication Failed</h1>
|
||||
<p>No API key was received. Please try again.</p>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
|
||||
def authenticate_via_browser(auth_url="https://omnara.com"):
|
||||
"""Authenticate via browser and return the API key"""
|
||||
|
||||
# Generate a secure random state parameter
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Start local server to receive the callback
|
||||
server = AuthHTTPServer(("localhost", 0), AuthCallbackHandler)
|
||||
server.state = state
|
||||
server.api_key = None
|
||||
port = server.server_port
|
||||
|
||||
# Construct the auth URL
|
||||
auth_base = auth_url.rstrip("/")
|
||||
callback_url = f"http://localhost:{port}"
|
||||
auth_url = f"{auth_base}/cli-auth?callback={urllib.parse.quote(callback_url)}&state={urllib.parse.quote(state)}"
|
||||
|
||||
print("\nOpening browser for authentication...")
|
||||
print("If your browser doesn't open automatically, please click this link:")
|
||||
print(f"\n {auth_url}\n")
|
||||
print("Waiting for authentication...")
|
||||
|
||||
# Run server in a thread
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
|
||||
# Open browser
|
||||
try:
|
||||
webbrowser.open(auth_url)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Wait for authentication (with timeout)
|
||||
start_time = time.time()
|
||||
while not server.api_key and (time.time() - start_time) < 300:
|
||||
time.sleep(0.1)
|
||||
|
||||
# Shutdown server in a separate thread to avoid deadlock
|
||||
def shutdown_server():
|
||||
server.shutdown()
|
||||
|
||||
shutdown_thread = threading.Thread(target=shutdown_server)
|
||||
shutdown_thread.start()
|
||||
shutdown_thread.join(timeout=1) # Wait max 1 second for shutdown
|
||||
|
||||
server.server_close()
|
||||
|
||||
if server.api_key:
|
||||
return server.api_key
|
||||
else:
|
||||
raise Exception("Authentication failed - no API key received")
|
||||
|
||||
|
||||
def run_stdio_server(args):
|
||||
@@ -87,10 +387,13 @@ def main():
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Run MCP stdio server (default)
|
||||
# Run Claude wrapper (default)
|
||||
omnara --api-key YOUR_API_KEY
|
||||
|
||||
# Run MCP stdio server explicitly
|
||||
# Run Claude wrapper with custom base URL
|
||||
omnara --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run MCP stdio server
|
||||
omnara --stdio --api-key YOUR_API_KEY
|
||||
|
||||
# Run Claude Code webhook server
|
||||
@@ -102,17 +405,14 @@ Examples:
|
||||
# Run webhook server on custom port
|
||||
omnara --claude-code-webhook --port 8080
|
||||
|
||||
# Run Claude wrapper V3
|
||||
omnara --claude --api-key YOUR_API_KEY
|
||||
|
||||
# Run Claude wrapper with custom base URL
|
||||
omnara --claude --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run with custom API base URL
|
||||
omnara --stdio --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run with custom frontend URL for authentication
|
||||
omnara --auth-url http://localhost:3000
|
||||
|
||||
# Run with git diff capture enabled
|
||||
omnara --api-key YOUR_API_KEY --git-diff
|
||||
omnara --stdio --api-key YOUR_API_KEY --git-diff
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -121,7 +421,7 @@ Examples:
|
||||
mode_group.add_argument(
|
||||
"--stdio",
|
||||
action="store_true",
|
||||
help="Run the MCP stdio server (default if no mode specified)",
|
||||
help="Run the MCP stdio server",
|
||||
)
|
||||
mode_group.add_argument(
|
||||
"--claude-code-webhook",
|
||||
@@ -131,7 +431,7 @@ Examples:
|
||||
mode_group.add_argument(
|
||||
"--claude",
|
||||
action="store_true",
|
||||
help="Run the Claude wrapper V3 for Omnara integration",
|
||||
help="Run the Claude wrapper V3 for Omnara integration (default if no mode specified)",
|
||||
)
|
||||
|
||||
# Arguments for webhook mode
|
||||
@@ -153,13 +453,26 @@ Examples:
|
||||
|
||||
# Arguments for stdio mode
|
||||
parser.add_argument(
|
||||
"--api-key", help="API key for authentication (required for stdio mode)"
|
||||
"--api-key", help="API key for authentication (uses stored key if not provided)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reauth",
|
||||
action="store_true",
|
||||
help="Force re-authentication even if API key exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version", action="store_true", help="Show the current version of omnara"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default="https://agent-dashboard-mcp.onrender.com",
|
||||
help="Base URL of the Omnara API server (stdio mode only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auth-url",
|
||||
default="https://omnara.com",
|
||||
help="Base URL of the Omnara frontend for authentication (default: https://omnara.com)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--claude-code-permission-tool",
|
||||
action="store_true",
|
||||
@@ -179,26 +492,69 @@ Examples:
|
||||
# Use parse_known_args to capture remaining args for Claude
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
# Handle --version flag
|
||||
if args.version:
|
||||
print(f"omnara version {get_current_version()}")
|
||||
sys.exit(0)
|
||||
|
||||
# Check for updates (only when running actual commands, not --version)
|
||||
check_for_updates()
|
||||
|
||||
if args.cloudflare_tunnel and not args.claude_code_webhook:
|
||||
parser.error("--cloudflare-tunnel can only be used with --claude-code-webhook")
|
||||
|
||||
if args.port is not None and not args.claude_code_webhook:
|
||||
parser.error("--port can only be used with --claude-code-webhook")
|
||||
|
||||
# Handle re-authentication
|
||||
if args.reauth:
|
||||
try:
|
||||
print("Re-authenticating...")
|
||||
api_key = authenticate_via_browser(args.auth_url)
|
||||
save_api_key(api_key)
|
||||
args.api_key = api_key
|
||||
print("Re-authentication successful! API key saved.")
|
||||
except Exception as e:
|
||||
parser.error(f"Re-authentication failed: {str(e)}")
|
||||
else:
|
||||
# Load API key from storage if not provided
|
||||
api_key = args.api_key
|
||||
if not api_key and (args.stdio or not args.claude_code_webhook):
|
||||
api_key = load_stored_api_key()
|
||||
|
||||
# Update args with the loaded API key
|
||||
if api_key and not args.api_key:
|
||||
args.api_key = api_key
|
||||
|
||||
if args.claude_code_webhook:
|
||||
run_webhook_server(
|
||||
cloudflare_tunnel=args.cloudflare_tunnel,
|
||||
dangerously_skip_permissions=args.dangerously_skip_permissions,
|
||||
port=args.port,
|
||||
)
|
||||
elif args.claude:
|
||||
elif args.stdio:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for --claude mode")
|
||||
run_claude_wrapper(args.api_key, args.base_url, unknown_args)
|
||||
else:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for stdio mode")
|
||||
try:
|
||||
print("No API key found. Starting authentication...")
|
||||
api_key = authenticate_via_browser(args.auth_url)
|
||||
save_api_key(api_key)
|
||||
args.api_key = api_key
|
||||
print("Authentication successful! API key saved.")
|
||||
except Exception as e:
|
||||
parser.error(f"Authentication failed: {str(e)}")
|
||||
run_stdio_server(args)
|
||||
else:
|
||||
# Default to Claude mode when no mode is specified
|
||||
if not args.api_key:
|
||||
try:
|
||||
print("No API key found. Starting authentication...")
|
||||
api_key = authenticate_via_browser(args.auth_url)
|
||||
save_api_key(api_key)
|
||||
args.api_key = api_key
|
||||
print("Authentication successful! API key saved.")
|
||||
except Exception as e:
|
||||
parser.error(f"Authentication failed: {str(e)}")
|
||||
run_claude_wrapper(args.api_key, args.base_url, unknown_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -65,10 +65,17 @@ class AsyncOmnaraClient:
|
||||
# This fixes SSL verification issues with aiohttp on some systems
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
# Configure connector
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=ssl_context,
|
||||
limit=100,
|
||||
ttl_dns_cache=300,
|
||||
)
|
||||
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
connector=aiohttp.TCPConnector(ssl=ssl_context),
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
@@ -84,7 +91,7 @@ class AsyncOmnaraClient:
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an async HTTP request to the API.
|
||||
"""Make an async HTTP request to the API with retry logic.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
@@ -105,39 +112,73 @@ class AsyncOmnaraClient:
|
||||
assert self.session is not None
|
||||
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
|
||||
# Override timeout if specified
|
||||
request_timeout = ClientTimeout(total=timeout) if timeout else self.timeout
|
||||
|
||||
try:
|
||||
async with self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
timeout=request_timeout,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise AuthenticationError(
|
||||
"Invalid API key or authentication failed"
|
||||
)
|
||||
# Retry configuration to match urllib3
|
||||
max_retries = 6 # Total attempts (1 initial + 5 retries)
|
||||
backoff_factor = 1.0
|
||||
status_forcelist = {429, 500, 502, 503, 504}
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = await response.json()
|
||||
error_detail = error_data.get("detail", await response.text())
|
||||
except Exception:
|
||||
error_detail = await response.text()
|
||||
raise APIError(response.status, error_detail)
|
||||
last_error = None
|
||||
|
||||
return await response.json()
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
timeout=request_timeout,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise AuthenticationError(
|
||||
"Invalid API key or authentication failed"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Request timed out after {timeout or self.timeout.total} seconds"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = await response.json()
|
||||
error_detail = error_data.get(
|
||||
"detail", await response.text()
|
||||
)
|
||||
except Exception:
|
||||
error_detail = await response.text()
|
||||
|
||||
# Check if we should retry this status code
|
||||
if response.status in status_forcelist:
|
||||
last_error = APIError(response.status, error_detail)
|
||||
# Continue to retry logic below
|
||||
else:
|
||||
# Don't retry client errors
|
||||
raise APIError(response.status, error_detail)
|
||||
else:
|
||||
# Success!
|
||||
return await response.json()
|
||||
|
||||
except (aiohttp.ClientConnectionError, aiohttp.ClientError) as e:
|
||||
# Connection errors - retry these
|
||||
last_error = APIError(0, f"Request failed: {str(e)}")
|
||||
except asyncio.TimeoutError:
|
||||
last_error = TimeoutError(
|
||||
f"Request timed out after {timeout or self.timeout.total} seconds"
|
||||
)
|
||||
except (AuthenticationError, APIError) as e:
|
||||
if isinstance(e, APIError) and e.status_code in status_forcelist:
|
||||
last_error = e
|
||||
else:
|
||||
# Don't retry auth errors or client errors
|
||||
raise
|
||||
|
||||
# If this is not the last attempt, sleep before retrying
|
||||
if attempt < max_retries - 1 and last_error:
|
||||
sleep_time = min(backoff_factor * (2**attempt), 60.0)
|
||||
await asyncio.sleep(sleep_time)
|
||||
elif last_error:
|
||||
# Last attempt failed, raise the error
|
||||
raise last_error
|
||||
|
||||
# Should never reach here
|
||||
raise APIError(0, "Unexpected retry exhaustion")
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
|
||||
@@ -41,10 +41,20 @@ class OmnaraClient:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retries
|
||||
# Set up session with urllib3 retry strategy
|
||||
self.session = requests.Session()
|
||||
|
||||
# Configure retry strategy
|
||||
retry_strategy = Retry(
|
||||
total=3, backoff_factor=0.3, status_forcelist=[500, 502, 503, 504]
|
||||
total=5, # Total number of retries
|
||||
backoff_factor=1.0, # Exponential backoff: 1s, 2s, 4s, 8s, 16s
|
||||
status_forcelist=[429, 500, 502, 503, 504], # Retry on these HTTP codes
|
||||
allowed_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
raise_on_status=False,
|
||||
# Important: retry on connection errors
|
||||
connect=5, # Number of connection-related errors to retry
|
||||
read=5, # Number of read errors to retry
|
||||
other=5, # Number of other errors to retry
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session.mount("http://", adapter)
|
||||
|
||||
Reference in New Issue
Block a user