mirror of
https://github.com/omnara-ai/omnara.git
synced 2025-08-12 20:39:09 +03:00
Initial commit
This commit is contained in:
52
.dockerignore
Normal file
52
.dockerignore
Normal file
@@ -0,0 +1,52 @@
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
**/.Python
|
||||
**/pip-log.txt
|
||||
**/pip-delete-this-directory.txt
|
||||
**/.tox/
|
||||
**/.coverage
|
||||
**/.coverage.*
|
||||
**/.cache
|
||||
**/nosetests.xml
|
||||
**/coverage.xml
|
||||
**/*.cover
|
||||
**/*.log
|
||||
**/.git
|
||||
**/.gitignore
|
||||
**/.mypy_cache
|
||||
**/.pytest_cache
|
||||
**/.hypothesis
|
||||
**/.ruff_cache
|
||||
|
||||
# Virtual environments
|
||||
**/venv/
|
||||
**/env/
|
||||
**/ENV/
|
||||
**/.venv/
|
||||
|
||||
# Node
|
||||
**/node_modules/
|
||||
**/npm-debug.log*
|
||||
**/yarn-debug.log*
|
||||
**/yarn-error.log*
|
||||
|
||||
# IDE
|
||||
**/.vscode/
|
||||
**/.idea/
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
|
||||
# OS
|
||||
**/.DS_Store
|
||||
**/Thumbs.db
|
||||
|
||||
# Environment files
|
||||
**/.env
|
||||
**/.env.*
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
**/*.egg-info/
|
||||
29
.env.example
Normal file
29
.env.example
Normal file
@@ -0,0 +1,29 @@
|
||||
# development or production
|
||||
ENVIRONMENT=development
|
||||
|
||||
# Database
|
||||
PRODUCTION_DB_URL=postgresql://postgres:XXXXXX.supabase.co:5432/postgres # For Render use Session Pooler
|
||||
DEVELOPMENT_DB_URL=postgresql://user:password@localhost:5432/agent_dashboard
|
||||
|
||||
# MCP Server
|
||||
MCP_SERVER_PORT=8080
|
||||
|
||||
# Backend API
|
||||
API_PORT=8000
|
||||
# Frontend URLs - JSON array of allowed frontend origins
|
||||
# Single URL: FRONTEND_URLS="https://example.com"
|
||||
# Multiple URLs: FRONTEND_URLS='["http://localhost:3000", "https://example.com"]'
|
||||
# Production example with all domains:
|
||||
# FRONTEND_URLS='["https://omnara.ai", "https://www.omnara.ai", "https://omnara.com", "https://www.omnara.com", "https://app.omnara.ai"]'
|
||||
FRONTEND_URLS='["http://localhost:3000"]'
|
||||
|
||||
# Supabase Configuration
|
||||
SUPABASE_URL=https://xxxxxxxxxxxx.supabase.co
|
||||
SUPABASE_ANON_KEY=your-anon-key-here
|
||||
SUPABASE_SERVICE_ROLE_KEY=your-service-role-key-here
|
||||
|
||||
# JWT Signing Keys for API Keys (generate with scripts/generate_jwt_keys.py)
|
||||
JWT_PRIVATE_KEY=your-jwt-private-key-here
|
||||
JWT_PUBLIC_KEY=your-jwt-public-key-here
|
||||
|
||||
SENTRY_DSN=123.us.sentry.io/456
|
||||
53
.github/CONTRIBUTING.md
vendored
Normal file
53
.github/CONTRIBUTING.md
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
# Contributing to Omnara
|
||||
|
||||
Thanks for your interest in contributing!
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Fork and clone the repository
|
||||
2. Set up your development environment:
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
||||
make dev-install
|
||||
make pre-commit-install
|
||||
```
|
||||
3. Set up PostgreSQL and configure `DATABASE_URL` in `.env`
|
||||
4. Generate JWT keys: `python scripts/generate_jwt_keys.py`
|
||||
5. Run migrations: `cd shared && alembic upgrade head`
|
||||
|
||||
## Development Process
|
||||
|
||||
1. Create a branch: `feature/`, `bugfix/`, or `docs/`
|
||||
2. Make your changes
|
||||
3. Run checks: `make lint` and `make test`
|
||||
4. Submit a pull request
|
||||
|
||||
## Code Style
|
||||
|
||||
- Python 3.11+
|
||||
- Type hints required
|
||||
- Follow existing patterns
|
||||
- Tests for new features
|
||||
|
||||
## Database Changes
|
||||
|
||||
When modifying models:
|
||||
1. Edit models in `shared/database/models.py`
|
||||
2. Generate migration: `cd shared && alembic revision --autogenerate -m "description"`
|
||||
3. Test migration before committing
|
||||
|
||||
## Commit Messages
|
||||
|
||||
Use conventional commits:
|
||||
- `feat:` New feature
|
||||
- `fix:` Bug fix
|
||||
- `docs:` Documentation
|
||||
- `refactor:` Code refactoring
|
||||
- `test:` Tests
|
||||
|
||||
Example: `feat: add API key rotation endpoint`
|
||||
|
||||
## Questions?
|
||||
|
||||
Open an issue or discussion on GitHub!
|
||||
27
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
27
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Report a bug
|
||||
title: '[BUG] '
|
||||
labels: 'bug'
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce:
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
**Expected behavior**
|
||||
What you expected to happen.
|
||||
|
||||
**Error logs**
|
||||
```
|
||||
Paste any relevant error messages
|
||||
```
|
||||
|
||||
**Environment:**
|
||||
- Python version:
|
||||
- Omnara version:
|
||||
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
blank_issues_enabled: true
|
||||
15
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
15
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
---
|
||||
name: Feature Request
|
||||
about: Suggest a new feature
|
||||
title: '[FEATURE] '
|
||||
labels: 'enhancement'
|
||||
---
|
||||
|
||||
**Problem to solve**
|
||||
What problem would this feature solve?
|
||||
|
||||
**Proposed solution**
|
||||
How would you like it to work?
|
||||
|
||||
**Alternatives considered**
|
||||
Any other approaches you've thought about?
|
||||
70
.github/workflows/ci.yml
vendored
Normal file
70
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r requirements-dev.txt
|
||||
make install
|
||||
pip install -e .
|
||||
|
||||
- name: Run linting and formatting checks
|
||||
run: make lint
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16
|
||||
env:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r requirements-dev.txt
|
||||
make install
|
||||
pip install -e .
|
||||
|
||||
- name: Run migrations
|
||||
env:
|
||||
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres
|
||||
run: |
|
||||
cd shared && alembic upgrade head
|
||||
|
||||
- name: Run tests
|
||||
env:
|
||||
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres
|
||||
ENVIRONMENT: test
|
||||
run: |
|
||||
pytest -v --cov --cov-report=term-missing
|
||||
echo "Coverage Report:"
|
||||
coverage report
|
||||
75
.github/workflows/release.yml
vendored
Normal file
75
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build twine
|
||||
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
|
||||
- name: Check package
|
||||
run: twine check dist/*
|
||||
|
||||
- name: Extract version
|
||||
id: get_version
|
||||
run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
run: |
|
||||
echo "# Changelog" > changelog.md
|
||||
echo "" >> changelog.md
|
||||
# Get commits since last tag
|
||||
if git describe --tags --abbrev=0 HEAD^ 2>/dev/null; then
|
||||
LAST_TAG=$(git describe --tags --abbrev=0 HEAD^)
|
||||
echo "## Changes since $LAST_TAG" >> changelog.md
|
||||
git log $LAST_TAG..HEAD --pretty=format:"- %s" >> changelog.md
|
||||
else
|
||||
echo "## Initial Release" >> changelog.md
|
||||
echo "First release of Omnara!" >> changelog.md
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
body_path: changelog.md
|
||||
files: dist/*
|
||||
draft: false
|
||||
prerelease: ${{ contains(github.ref, 'alpha') || contains(github.ref, 'beta') || contains(github.ref, 'rc') }}
|
||||
|
||||
- name: Publish to Test PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
||||
run: |
|
||||
twine upload --repository testpypi dist/*
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/*
|
||||
if: ${{ !contains(github.ref, 'alpha') && !contains(github.ref, 'beta') && !contains(github.ref, 'rc') }}
|
||||
63
.gitignore
vendored
Normal file
63
.gitignore
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
.ruff_cache/
|
||||
coverage/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
dist/
|
||||
build/
|
||||
.next/
|
||||
.cache/
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite3
|
||||
postgres-data/
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
test-venv/
|
||||
.claude/
|
||||
|
||||
# Directories
|
||||
frontend/
|
||||
mobile/
|
||||
43
.pre-commit-config.yaml
Normal file
43
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
repos:
|
||||
# Ruff for fast Python linting and formatting
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.13
|
||||
hooks:
|
||||
# Run the linter
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
files: ^(backend|mcp_server|shared)/.*\.py$
|
||||
# Run the formatter
|
||||
- id: ruff-format
|
||||
files: ^(backend|mcp_server|shared)/.*\.py$
|
||||
|
||||
# Pyright for type checking (using npm package)
|
||||
- repo: https://github.com/RobertCraigie/pyright-python
|
||||
rev: v1.1.390
|
||||
hooks:
|
||||
- id: pyright
|
||||
files: ^(backend|mcp_server|shared)/.*\.py$
|
||||
additional_dependencies: []
|
||||
|
||||
# General file checks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
files: ^(backend|mcp_server|shared)/.*\.py$
|
||||
- id: end-of-file-fixer
|
||||
files: ^(backend|mcp_server|shared)/.*\.py$
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
|
||||
# Custom hooks for database migrations
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-migration-needed
|
||||
name: Check if migration needed for schema changes
|
||||
entry: python scripts/check-migration-needed.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
stages: [commit]
|
||||
176
CLAUDE.md
Normal file
176
CLAUDE.md
Normal file
@@ -0,0 +1,176 @@
|
||||
# Claude Code Development Guide for Omnara
|
||||
|
||||
Welcome Claude! This document contains everything you need to know to work effectively on the Omnara project.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Omnara is a platform that allows users to communicate with their AI agents (like you!) from anywhere. It uses the Model Context Protocol (MCP) to enable real-time communication between agents and users through a web dashboard.
|
||||
|
||||
## Quick Context
|
||||
|
||||
- **Purpose**: Let users see what their AI agents are doing and communicate with them in real-time
|
||||
- **Key Innovation**: Agents can ask questions and receive feedback while working
|
||||
- **Architecture**: Separate read (backend) and write (servers) operations for optimal performance
|
||||
- **Open Source**: This is a community project - code quality and clarity matter!
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
omnara/
|
||||
├── backend/ # FastAPI - Web dashboard API (read operations)
|
||||
├── servers/ # FastAPI + MCP - Agent communication server (write operations)
|
||||
├── shared/ # Shared database models and infrastructure
|
||||
├── omnara/ # Python package directory
|
||||
│ └── sdk/ # Python SDK for agent integration
|
||||
├── cli/ # Node.js CLI tool for MCP configuration
|
||||
├── scripts/ # Utility scripts (JWT generation, linting, etc.)
|
||||
├── tests/ # Integration tests
|
||||
└── webhooks/ # Webhook handlers (e.g., claude_code.py)
|
||||
```
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### Authentication Architecture
|
||||
- **Two separate JWT systems**:
|
||||
1. **Backend**: Supabase JWTs for web users
|
||||
2. **Servers**: Custom JWT with weaker RSA (shorter API keys for agents)
|
||||
- API keys are hashed (SHA256) before storage - never store raw tokens
|
||||
|
||||
### Database Design
|
||||
- **PostgreSQL** with **SQLAlchemy 2.0+**
|
||||
- **Alembic** for migrations - ALWAYS create migrations for schema changes
|
||||
- Multi-tenant design - all data is scoped by user_id
|
||||
- Key tables: users, agent_types, agent_instances, agent_steps, agent_questions, agent_user_feedback, api_keys
|
||||
|
||||
### Server Architecture
|
||||
- **Unified server** (`servers/app.py`) supports both MCP and REST
|
||||
- MCP endpoint: `/mcp/`
|
||||
- REST endpoints: `/api/v1/*`
|
||||
- Both use the same authentication and business logic
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Setting Up
|
||||
1. **Always activate the virtual environment first**:
|
||||
```bash
|
||||
source .venv/bin/activate # macOS/Linux
|
||||
.venv\Scripts\activate # Windows
|
||||
```
|
||||
|
||||
2. **Install pre-commit hooks** (one-time):
|
||||
```bash
|
||||
make pre-commit-install
|
||||
```
|
||||
|
||||
### Before Making Changes
|
||||
1. **Check current branch**: Ensure you're on the right branch
|
||||
2. **Update dependencies**: Run `pip install -r requirements.txt` if needed
|
||||
3. **Check migrations**: Run `alembic current` in `shared/` directory
|
||||
|
||||
### Making Changes
|
||||
|
||||
#### Database Changes
|
||||
1. Modify models in `shared/models/`
|
||||
2. Generate migration:
|
||||
```bash
|
||||
cd shared/
|
||||
alembic revision --autogenerate -m "Descriptive message"
|
||||
```
|
||||
3. Review the generated migration file
|
||||
4. Test migration: `alembic upgrade head`
|
||||
5. Include migration file in your commit
|
||||
|
||||
#### Code Changes
|
||||
1. **Follow existing patterns** - check similar files first
|
||||
2. **Use type hints** - We use Python 3.12 with full type annotations
|
||||
3. **Import style**: Prefer absolute imports from project root
|
||||
|
||||
#### Testing
|
||||
```bash
|
||||
make test # Run all tests
|
||||
make test-integration # Integration tests (needs Docker)
|
||||
```
|
||||
|
||||
### Before Committing
|
||||
1. **Run linting and formatting**:
|
||||
```bash
|
||||
make lint # Check for issues
|
||||
make format # Auto-fix formatting
|
||||
```
|
||||
|
||||
2. **Verify your changes work**:
|
||||
- Test the specific functionality you changed
|
||||
- Run relevant test suites
|
||||
- Check that migrations apply cleanly
|
||||
|
||||
3. **Update documentation** if you changed functionality
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
1. Add route in `backend/api/` or `servers/fastapi_server/routers.py`
|
||||
2. Create Pydantic models for request/response in `models.py`
|
||||
3. Add database queries in appropriate query files
|
||||
4. Write tests for the endpoint
|
||||
|
||||
### Adding a New MCP Tool
|
||||
1. Add tool definition in `servers/mcp_server/tools.py`
|
||||
2. Register tool in `servers/mcp_server/server.py`
|
||||
3. Share logic with REST endpoint if applicable
|
||||
4. Update agent documentation
|
||||
|
||||
### Modifying Database Schema
|
||||
1. Change models in `shared/models/`
|
||||
2. Generate and review migration
|
||||
3. Update any affected queries
|
||||
4. Update Pydantic models if needed
|
||||
5. Test thoroughly with existing data
|
||||
|
||||
## Important Files to Know
|
||||
|
||||
- `shared/config.py` - Central configuration using Pydantic settings
|
||||
- `shared/models/base.py` - SQLAlchemy base configuration
|
||||
- `servers/app.py` - Unified server entry point
|
||||
- `backend/auth/` - Authentication logic for web users
|
||||
- `servers/fastapi_server/auth.py` - Agent authentication
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Key variables you might need:
|
||||
- `DATABASE_URL` - PostgreSQL connection
|
||||
- `JWT_PUBLIC_KEY` / `JWT_PRIVATE_KEY` - For agent auth
|
||||
- `SUPABASE_URL` / `SUPABASE_ANON_KEY` - For web auth
|
||||
- `ENVIRONMENT` - Set to "development" for auto-reload
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
1. **Don't commit without migrations** - Pre-commit hooks will catch this
|
||||
2. **Don't store raw JWT tokens** - Always hash API keys
|
||||
3. **Don't mix authentication systems** - Backend uses Supabase, Servers use custom JWT
|
||||
4. **Don't forget user scoping** - All queries must filter by user_id
|
||||
5. **Don't skip type hints** - Pyright will complain
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
1. **Database issues**: Check migrations are up to date
|
||||
2. **Auth failures**: Verify JWT keys are properly formatted (with newlines)
|
||||
3. **Import errors**: Ensure you're using absolute imports
|
||||
4. **Type errors**: Run `make typecheck` to catch issues early
|
||||
|
||||
## Getting Help
|
||||
|
||||
- Check existing code for patterns
|
||||
- Read test files for usage examples
|
||||
- Error messages usually indicate what's wrong
|
||||
- The codebase is well-structured - similar things are grouped together
|
||||
|
||||
## Your Superpowers on This Project
|
||||
|
||||
As Claude Code, you're particularly good at:
|
||||
- Understanding the full codebase quickly
|
||||
- Maintaining consistency across files
|
||||
- Catching potential security issues
|
||||
- Writing comprehensive tests
|
||||
- Suggesting architectural improvements
|
||||
|
||||
Remember: This is an open-source project that helps AI agents communicate with humans. Your work here directly improves the AI-human collaboration experience!
|
||||
9
LICENSE
Normal file
9
LICENSE
Normal file
@@ -0,0 +1,9 @@
|
||||
MIT License
|
||||
|
||||
Copyright 2025 Omnara Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
63
Makefile
Normal file
63
Makefile
Normal file
@@ -0,0 +1,63 @@
|
||||
.PHONY: install lint format typecheck dev-install pre-commit-install pre-commit-run
|
||||
|
||||
# Install production dependencies
|
||||
install:
|
||||
pip install -r backend/requirements.txt
|
||||
pip install -r servers/requirements.txt
|
||||
|
||||
# Install development dependencies
|
||||
dev-install: install
|
||||
pip install -r requirements-dev.txt
|
||||
pip install -r sdk/python/requirements.txt
|
||||
|
||||
# Install pre-commit hooks
|
||||
pre-commit-install: dev-install
|
||||
pre-commit install
|
||||
|
||||
# Run pre-commit on all files
|
||||
pre-commit-run:
|
||||
pre-commit run --all-files
|
||||
|
||||
# Run all linting and type checking
|
||||
lint:
|
||||
./scripts/lint.sh
|
||||
|
||||
# Auto-format code
|
||||
format:
|
||||
./scripts/format.sh
|
||||
|
||||
# Run only type checking
|
||||
typecheck:
|
||||
pyright
|
||||
|
||||
# Run only ruff linting
|
||||
ruff-check:
|
||||
ruff check .
|
||||
|
||||
# Run only ruff formatting check
|
||||
ruff-format-check:
|
||||
ruff format --check .
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
./scripts/run_all_tests.sh
|
||||
|
||||
# Run SDK tests
|
||||
test-sdk:
|
||||
cd sdk/python && pytest tests -v
|
||||
|
||||
# Run backend tests
|
||||
test-backend:
|
||||
cd backend && pytest tests -v || echo "No backend tests yet"
|
||||
|
||||
# Run server tests
|
||||
test-servers:
|
||||
cd servers && pytest tests -v || echo "No server tests yet"
|
||||
|
||||
# Run all tests with coverage
|
||||
test-coverage:
|
||||
cd sdk/python && pytest tests --cov=omnara --cov-report=term-missing
|
||||
|
||||
# Run integration tests with PostgreSQL (requires Docker)
|
||||
test-integration:
|
||||
pytest servers/tests/test_integration.py -v -m integration
|
||||
246
README.md
Normal file
246
README.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# Omnara - Talk to Your AI Agents from Anywhere! 🚀
|
||||
|
||||
Ever wished you could peek into what your AI coding assistants are doing? Or help them when they get stuck? That's exactly what Omnara does!
|
||||
|
||||
## What is Omnara?
|
||||
|
||||
Omnara is an open-source platform that lets you communicate with all your AI agents - Claude Code, Cursor, GitHub Copilot, and more - through one simple dashboard. No more wondering what your AI is up to or missing its questions!
|
||||
|
||||
### The Magic ✨
|
||||
|
||||
- **See Everything**: Watch your AI agents work in real-time, like having a window into their minds
|
||||
- **Jump In Anytime**: When your agent asks "Should I refactor this?" or "Which approach do you prefer?", you'll see it instantly and can respond
|
||||
- **Guide Your AI**: Send feedback and corrections while your agent is working - it'll see your messages and adjust course
|
||||
- **Works Everywhere**: Whether you're on your phone, tablet, or another computer, you can check in on your agents
|
||||
- **One Dashboard, All Agents**: Stop juggling between different tools - see all your AI assistants in one place
|
||||
|
||||
### Built on MCP (Model Context Protocol)
|
||||
|
||||
We use the Model Context Protocol to make this all work seamlessly. Your agents can talk to Omnara, and Omnara talks to you.
|
||||
|
||||
## Project Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
```
|
||||
omnara/
|
||||
├── backend/ # Web dashboard API (FastAPI)
|
||||
├── servers/ # Agent write operations server (MCP + REST)
|
||||
├── shared/ # Database models and shared infrastructure
|
||||
├── omnara/ # Python package directory
|
||||
│ └── sdk/ # Python SDK for agent integration
|
||||
├── cli/ # Node.js CLI tool for MCP configuration
|
||||
├── scripts/ # Development and utility scripts
|
||||
└── webhooks/ # Webhook handlers
|
||||
```
|
||||
|
||||
### System Architecture
|
||||
|
||||
1. **Backend API** (`backend/`)
|
||||
- FastAPI application serving the web dashboard
|
||||
- Handles read operations and user authentication via Supabase
|
||||
- Manages API key generation and user sessions
|
||||
|
||||
2. **Servers** (`servers/`)
|
||||
- Unified server supporting both MCP and REST protocols
|
||||
- Processes all write operations from AI agents
|
||||
- Implements JWT authentication with optimized token length
|
||||
|
||||
3. **Shared Infrastructure** (`shared/`)
|
||||
- Database models and migration management
|
||||
- Common utilities and configuration
|
||||
- Ensures consistency across all services
|
||||
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
AI Agents → MCP/REST Server (Write) → PostgreSQL ← Backend API (Read) ← Web Dashboard
|
||||
```
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12
|
||||
- PostgreSQL
|
||||
- Make (for development commands)
|
||||
|
||||
### Quick Start
|
||||
|
||||
1. **Clone the repository**
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd omnara
|
||||
```
|
||||
|
||||
2. **Set up Python environment**
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **Install dependencies and development tools**
|
||||
```bash
|
||||
make dev-install # Install all Python dependencies
|
||||
make pre-commit-install # Set up code quality hooks
|
||||
```
|
||||
|
||||
4. **Generate JWT keys for agent authentication**
|
||||
```bash
|
||||
python scripts/generate_jwt_keys.py
|
||||
```
|
||||
|
||||
5. **Configure environment variables**
|
||||
Create `.env` file in the root directory (see Environment Variables section)
|
||||
|
||||
6. **Initialize database**
|
||||
```bash
|
||||
cd shared/
|
||||
alembic upgrade head
|
||||
cd ..
|
||||
```
|
||||
|
||||
7. **Run the services**
|
||||
```bash
|
||||
# Terminal 1: Unified server (MCP + REST)
|
||||
python -m servers.app
|
||||
|
||||
# Terminal 2: Backend API
|
||||
cd backend && python -m main
|
||||
```
|
||||
|
||||
### Development Commands
|
||||
|
||||
```bash
|
||||
# Code quality
|
||||
make lint # Run all linting and type checking
|
||||
make format # Auto-format code
|
||||
make pre-commit-run # Run pre-commit on all files
|
||||
|
||||
# Testing
|
||||
make test # Run all tests
|
||||
make test-sdk # Test Python SDK
|
||||
make test-integration # Run integration tests (requires Docker)
|
||||
|
||||
# Database migrations
|
||||
cd shared/
|
||||
alembic revision --autogenerate -m "Description" # Create migration
|
||||
alembic upgrade head # Apply migrations
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
|
||||
The project maintains high code quality standards through automated tooling:
|
||||
- **Ruff** for Python linting and formatting
|
||||
- **Pyright** for type checking
|
||||
- **Pre-commit hooks** for automatic validation
|
||||
- **Python 3.12** as the standard version
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Required Configuration
|
||||
|
||||
```bash
|
||||
# Database
|
||||
DATABASE_URL=postgresql://user:password@localhost:5432/omnara
|
||||
|
||||
# Supabase (for web authentication)
|
||||
SUPABASE_URL=https://your-project.supabase.co
|
||||
SUPABASE_ANON_KEY=your-anon-key
|
||||
|
||||
# JWT Keys (from generate_jwt_keys.py)
|
||||
JWT_PRIVATE_KEY='-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----'
|
||||
JWT_PUBLIC_KEY='-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----'
|
||||
|
||||
# Optional
|
||||
ENVIRONMENT=development
|
||||
API_PORT=8000
|
||||
MCP_SERVER_PORT=8080
|
||||
```
|
||||
|
||||
## For AI Agent Developers
|
||||
|
||||
### Getting Started
|
||||
|
||||
1. Sign up at the Omnara dashboard
|
||||
2. Generate an API key from the dashboard
|
||||
3. Configure your agent with the API key
|
||||
4. Use either MCP protocol or REST API to interact
|
||||
|
||||
### Available Tools/Endpoints
|
||||
|
||||
- **log_step**: Log progress and receive user feedback
|
||||
- **ask_question**: Request user input (non-blocking)
|
||||
- **end_session**: Mark agent session as completed
|
||||
|
||||
### Integration Options
|
||||
|
||||
1. **MCP Protocol** (for compatible agents)
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"omnara": {
|
||||
"url": "http://localhost:8080/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer YOUR_API_KEY"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **REST API** (for direct integration)
|
||||
- POST `/api/v1/steps`
|
||||
- POST `/api/v1/questions`
|
||||
- POST `/api/v1/sessions/end`
|
||||
|
||||
3. **Python SDK** (available on PyPI)
|
||||
```bash
|
||||
pip install omnara
|
||||
```
|
||||
|
||||
## Database Management
|
||||
|
||||
### Working with Migrations
|
||||
|
||||
```bash
|
||||
cd shared/
|
||||
|
||||
# Check current migration
|
||||
alembic current
|
||||
|
||||
# Create new migration after model changes
|
||||
alembic revision --autogenerate -m "Add new feature"
|
||||
|
||||
# Apply pending migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Rollback one migration
|
||||
alembic downgrade -1
|
||||
```
|
||||
|
||||
**Important**: Always create migrations when modifying database models. Pre-commit hooks enforce this requirement.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions to this open-source project. Here's how you can help:
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
|
||||
3. Make your changes
|
||||
4. Run tests and ensure code quality checks pass
|
||||
5. Commit your changes (`git commit -m 'Add amazing feature'`)
|
||||
6. Push to your branch (`git push origin feature/amazing-feature`)
|
||||
7. Open a Pull Request
|
||||
|
||||
|
||||
## Support
|
||||
|
||||
- **Issues**: Report bugs or request features via GitHub Issues
|
||||
- **Discussions**: Join the conversation in GitHub Discussions
|
||||
- **Documentation**: Check the project documentation for detailed guides
|
||||
|
||||
## License
|
||||
|
||||
Open source and free to use! Check the LICENSE file for details.
|
||||
81
backend/README.md
Normal file
81
backend/README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# Backend API
|
||||
|
||||
This directory contains the FastAPI backend that serves the web dashboard for monitoring and managing AI agent instances.
|
||||
|
||||
## Overview
|
||||
|
||||
The backend provides a REST API for accessing and managing agent-related data. Its primary purpose is to handle read operations for agent instances, their execution history, and user interactions. The API serves as the bridge between client applications and the underlying agent data stored in the database.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
- **API Routes** - RESTful endpoints organized by resource type
|
||||
- **Authentication** - Dual-layer authentication system for web users and agent clients
|
||||
- **Database Layer** - Query interfaces for accessing agent and user data
|
||||
- **Models** - Request/response schemas and data validation
|
||||
|
||||
### Directory Structure
|
||||
|
||||
- `api/` - API route handlers organized by domain
|
||||
- `auth/` - Authentication and authorization logic
|
||||
- `db/` - Database queries and data access layer
|
||||
- `models.py` - Pydantic models for API contracts
|
||||
- `main.py` - Application entry point and configuration
|
||||
- `tests/` - Test suite for API functionality
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Agent Monitoring** - View agent types, instances, and execution history
|
||||
- **User Interactions** - Handle questions from agents and user feedback
|
||||
- **Multi-tenancy** - User-scoped data isolation and access control
|
||||
- **Authentication** - Support for both web dashboard users and programmatic agent access
|
||||
- **User Agent Management** - Custom agent configurations and webhook integrations
|
||||
|
||||
## Authentication
|
||||
|
||||
The backend implements a dual authentication system:
|
||||
|
||||
1. **Web Dashboard Authentication** - For users accessing the web interface
|
||||
2. **API Key Authentication** - For programmatic access by agent clients
|
||||
|
||||
All data access is scoped to the authenticated user, ensuring proper data isolation in a multi-tenant environment.
|
||||
|
||||
### Security Notice: API Key Storage
|
||||
|
||||
**⚠️ Important**: API keys are currently stored in plain text in the database for development convenience. While this is more permissible for write-only keys, developers should:
|
||||
- Be aware of this when handling API key data
|
||||
- Never expose API keys in logs or error messages
|
||||
- Treat API keys as sensitive data despite being write-only
|
||||
|
||||
**TODO**: Migrate to hashed storage (SHA256) in future releases. The plain text storage is temporary for easier development and debugging.
|
||||
|
||||
## Development
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12+
|
||||
- PostgreSQL database
|
||||
- Required environment variables configured
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install dependencies from `requirements.txt`
|
||||
2. Configure environment variables
|
||||
3. Set up the database schema
|
||||
4. Run the development server
|
||||
|
||||
### Testing
|
||||
|
||||
The test suite covers API endpoints, authentication flows, and data access patterns. Tests use pytest and can be run from the backend directory.
|
||||
|
||||
## Configuration
|
||||
|
||||
The backend uses environment variables for configuration, including:
|
||||
|
||||
- Database connection settings
|
||||
- Authentication providers
|
||||
- CORS and security settings
|
||||
- External service integrations
|
||||
|
||||
Refer to the project documentation for specific configuration requirements.
|
||||
1
backend/__init__.py
Normal file
1
backend/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Backend API package
|
||||
1
backend/api/__init__.py
Normal file
1
backend/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# API routes package
|
||||
123
backend/api/agents.py
Normal file
123
backend/api/agents.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from shared.database.models import User
|
||||
from shared.database.session import get_db
|
||||
from shared.database.enums import AgentStatus
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..db import (
|
||||
get_agent_instance_detail,
|
||||
get_agent_type_instances,
|
||||
get_agent_summary,
|
||||
get_all_agent_instances,
|
||||
get_all_agent_types_with_instances,
|
||||
mark_instance_completed,
|
||||
submit_user_feedback,
|
||||
)
|
||||
from ..models import (
|
||||
AgentInstanceDetail,
|
||||
AgentInstanceResponse,
|
||||
AgentTypeOverview,
|
||||
UserFeedbackRequest,
|
||||
UserFeedbackResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["agents"])
|
||||
|
||||
|
||||
@router.get("/agent-types", response_model=list[AgentTypeOverview])
|
||||
async def list_agent_types(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all agent types with their instances for the current user"""
|
||||
agent_types = get_all_agent_types_with_instances(db, current_user.id)
|
||||
return agent_types
|
||||
|
||||
|
||||
@router.get("/agent-instances", response_model=list[AgentInstanceResponse])
|
||||
async def list_all_agent_instances(
|
||||
limit: int | None = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all agent instances for the current user"""
|
||||
instances = get_all_agent_instances(db, current_user.id, limit=limit)
|
||||
return instances
|
||||
|
||||
|
||||
@router.get("/agent-summary")
|
||||
async def get_all_agent_summary(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get lightweight summary of agent counts for dashboard KPIs"""
|
||||
summary = get_agent_summary(db, current_user.id)
|
||||
return summary
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agent-types/{type_id}/instances", response_model=list[AgentInstanceResponse]
|
||||
)
|
||||
async def get_type_instances(
|
||||
type_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all instances for a specific agent type for the current user"""
|
||||
result = get_agent_type_instances(db, type_id, current_user.id)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Agent type not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/agent-instances/{instance_id}", response_model=AgentInstanceDetail)
|
||||
async def get_instance_detail(
|
||||
instance_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get detailed information about a specific agent instance for the current user"""
|
||||
result = get_agent_instance_detail(db, instance_id, current_user.id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Agent instance not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agent-instances/{instance_id}/feedback", response_model=UserFeedbackResponse
|
||||
)
|
||||
async def add_user_feedback(
|
||||
instance_id: UUID,
|
||||
request: UserFeedbackRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Submit user feedback for an agent instance for the current user"""
|
||||
result = submit_user_feedback(db, instance_id, request.feedback, current_user.id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Agent instance not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.put(
|
||||
"/agent-instances/{instance_id}/status",
|
||||
response_model=AgentInstanceResponse,
|
||||
)
|
||||
async def update_agent_status(
|
||||
instance_id: UUID,
|
||||
status_update: dict,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update an agent instance status for the current user"""
|
||||
# For now, we only support marking as completed
|
||||
if status_update.get("status") == AgentStatus.COMPLETED:
|
||||
result = mark_instance_completed(db, instance_id, current_user.id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Agent instance not found")
|
||||
return result
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Status update not supported")
|
||||
108
backend/api/push_notifications.py
Normal file
108
backend/api/push_notifications.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Push notification endpoints"""
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from backend.auth.dependencies import get_current_user_id
|
||||
from shared.database.session import get_db
|
||||
from shared.database import PushToken
|
||||
|
||||
router = APIRouter(prefix="/push", tags=["push_notifications"])
|
||||
|
||||
|
||||
class RegisterPushTokenRequest(BaseModel):
|
||||
token: str
|
||||
platform: str # 'ios' or 'android'
|
||||
|
||||
|
||||
class PushTokenResponse(BaseModel):
|
||||
id: UUID
|
||||
token: str
|
||||
platform: str
|
||||
is_active: bool
|
||||
|
||||
|
||||
@router.post("/register", response_model=dict)
|
||||
def register_push_token(
|
||||
request: RegisterPushTokenRequest,
|
||||
user_id: UUID = Depends(get_current_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Register a push notification token for the current user"""
|
||||
try:
|
||||
# Check if token already exists
|
||||
existing = db.query(PushToken).filter(PushToken.token == request.token).first()
|
||||
|
||||
if existing:
|
||||
# Update existing token
|
||||
existing.user_id = user_id
|
||||
existing.platform = request.platform
|
||||
existing.is_active = True
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
# Create new token
|
||||
push_token = PushToken(
|
||||
user_id=user_id,
|
||||
token=request.token,
|
||||
platform=request.platform,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(push_token)
|
||||
|
||||
db.commit()
|
||||
return {"success": True, "message": "Push token registered successfully"}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/deactivate/{token}")
|
||||
def deactivate_token(
|
||||
token: str,
|
||||
user_id: UUID = Depends(get_current_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Deactivate a push notification token"""
|
||||
try:
|
||||
push_token = (
|
||||
db.query(PushToken)
|
||||
.filter(PushToken.user_id == user_id, PushToken.token == token)
|
||||
.first()
|
||||
)
|
||||
|
||||
if push_token:
|
||||
push_token.is_active = False
|
||||
push_token.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": "Push token deactivated"}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/tokens", response_model=List[PushTokenResponse])
|
||||
def get_my_push_tokens(
|
||||
user_id: UUID = Depends(get_current_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all push tokens for the current user"""
|
||||
tokens = (
|
||||
db.query(PushToken)
|
||||
.filter(PushToken.user_id == user_id, PushToken.is_active)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
PushTokenResponse(
|
||||
id=token.id,
|
||||
token=token.token,
|
||||
platform=token.platform,
|
||||
is_active=token.is_active,
|
||||
)
|
||||
for token in tokens
|
||||
]
|
||||
28
backend/api/questions.py
Normal file
28
backend/api/questions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from shared.database.models import User
|
||||
from shared.database.session import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..db import submit_answer
|
||||
from ..models import AnswerRequest
|
||||
|
||||
router = APIRouter(prefix="/questions", tags=["questions"])
|
||||
|
||||
|
||||
@router.post("/{question_id}/answer")
|
||||
async def answer_question(
|
||||
question_id: UUID,
|
||||
request: AnswerRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Submit an answer to a pending question for the current user"""
|
||||
result = submit_answer(db, question_id, request.answer, current_user.id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Question not found or already answered"
|
||||
)
|
||||
return {"success": True, "message": "Answer submitted successfully"}
|
||||
119
backend/api/user_agents.py
Normal file
119
backend/api/user_agents.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
User Agent API endpoints for managing user-specific agent configurations.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from shared.database.models import User, UserAgent, AgentInstance
|
||||
from shared.database.enums import AgentStatus
|
||||
from shared.database.session import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..models import (
|
||||
UserAgentRequest,
|
||||
UserAgentResponse,
|
||||
CreateAgentInstanceRequest,
|
||||
WebhookTriggerResponse,
|
||||
)
|
||||
from ..db import (
|
||||
create_user_agent,
|
||||
get_user_agents,
|
||||
update_user_agent,
|
||||
trigger_webhook_agent,
|
||||
get_user_agent_instances,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["user-agents"])
|
||||
|
||||
|
||||
@router.get("/user-agents", response_model=list[UserAgentResponse])
|
||||
async def list_user_agents(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all user agents for the current user"""
|
||||
agents = get_user_agents(db, current_user.id)
|
||||
return agents
|
||||
|
||||
|
||||
@router.post("/user-agents", response_model=UserAgentResponse)
|
||||
async def create_new_user_agent(
|
||||
request: UserAgentRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new user agent configuration"""
|
||||
agent = create_user_agent(db, current_user.id, request)
|
||||
return agent
|
||||
|
||||
|
||||
@router.patch("/user-agents/{agent_id}", response_model=UserAgentResponse)
|
||||
async def update_existing_user_agent(
|
||||
agent_id: UUID,
|
||||
request: UserAgentRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update an existing user agent configuration"""
|
||||
agent = update_user_agent(db, agent_id, current_user.id, request)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="User agent not found")
|
||||
return agent
|
||||
|
||||
|
||||
@router.get("/user-agents/{agent_id}/instances")
|
||||
async def get_user_agent_instances_list(
|
||||
agent_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all instances for a specific user agent"""
|
||||
instances = get_user_agent_instances(db, agent_id, current_user.id)
|
||||
if instances is None:
|
||||
raise HTTPException(status_code=404, detail="User agent not found")
|
||||
return instances
|
||||
|
||||
|
||||
@router.post("/user-agents/{agent_id}/instances", response_model=WebhookTriggerResponse)
|
||||
async def create_agent_instance(
|
||||
agent_id: UUID,
|
||||
request: CreateAgentInstanceRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new instance of a user agent (trigger webhook if applicable)"""
|
||||
|
||||
# Get the user agent
|
||||
user_agent = (
|
||||
db.query(UserAgent)
|
||||
.filter(and_(UserAgent.id == agent_id, UserAgent.user_id == current_user.id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_agent:
|
||||
raise HTTPException(status_code=404, detail="User agent not found")
|
||||
|
||||
# Check if this agent has a webhook configured
|
||||
if user_agent.webhook_url:
|
||||
# Trigger the webhook
|
||||
result = await trigger_webhook_agent(
|
||||
db, user_agent, current_user.id, request.prompt
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# For agents without webhooks, just create the instance
|
||||
instance = AgentInstance(
|
||||
user_agent_id=agent_id, user_id=current_user.id, status=AgentStatus.ACTIVE
|
||||
)
|
||||
db.add(instance)
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
|
||||
return WebhookTriggerResponse(
|
||||
success=True,
|
||||
agent_instance_id=str(instance.id),
|
||||
message="Agent instance created successfully",
|
||||
)
|
||||
0
backend/auth/__init__.py
Normal file
0
backend/auth/__init__.py
Normal file
142
backend/auth/dependencies.py
Normal file
142
backend/auth/dependencies.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
# Add parent directory to path to import shared module
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from shared.database.models import User
|
||||
from shared.database.session import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .supabase_client import get_supabase_client
|
||||
|
||||
security = HTTPBearer(auto_error=False) # Don't auto-error so we can check cookies
|
||||
|
||||
|
||||
class AuthError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=401, detail=detail)
|
||||
|
||||
|
||||
def get_token_from_request(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = None,
|
||||
) -> str | None:
|
||||
"""Extract token from either Authorization header or session cookie"""
|
||||
# First try Authorization header
|
||||
if credentials and credentials.credentials:
|
||||
return credentials.credentials
|
||||
|
||||
# Then try session cookie
|
||||
session_token = request.cookies.get("session_token")
|
||||
if session_token:
|
||||
return session_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_id(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
) -> UUID:
|
||||
"""Extract and verify user ID from Supabase JWT token (header or cookie)"""
|
||||
token = get_token_from_request(request, credentials)
|
||||
|
||||
if not token:
|
||||
raise AuthError("No authentication token provided")
|
||||
|
||||
try:
|
||||
# Use anon client to verify user tokens (not service role)
|
||||
from .supabase_client import get_supabase_anon_client
|
||||
|
||||
supabase = get_supabase_anon_client()
|
||||
|
||||
# Verify the JWT token with Supabase
|
||||
user_response = supabase.auth.get_user(token)
|
||||
|
||||
if not user_response or not user_response.user:
|
||||
raise AuthError("Invalid authentication token")
|
||||
|
||||
return UUID(user_response.user.id)
|
||||
|
||||
except Exception as e:
|
||||
raise AuthError(f"Could not validate credentials: {str(e)}")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
user_id: UUID = Depends(get_current_user_id), db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""Get current user from database"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
# If user doesn't exist in our DB, create them
|
||||
# This handles the case where a user signs up via Supabase
|
||||
# but hasn't been synced to our database yet
|
||||
service_supabase = get_supabase_client()
|
||||
|
||||
try:
|
||||
# Get user info from Supabase using service role
|
||||
auth_user = service_supabase.auth.admin.get_user_by_id(str(user_id))
|
||||
|
||||
if auth_user and auth_user.user:
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=auth_user.user.email,
|
||||
display_name=auth_user.user.user_metadata.get("display_name"),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
else:
|
||||
raise AuthError("User not found")
|
||||
except Exception as e:
|
||||
raise AuthError(f"Could not create user: {str(e)}")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_current_user(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User | None:
|
||||
"""Get current user if authenticated, otherwise return None"""
|
||||
token = get_token_from_request(request, credentials)
|
||||
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Verify token manually since get_current_user_id requires authentication
|
||||
from .supabase_client import get_supabase_anon_client
|
||||
|
||||
supabase = get_supabase_anon_client()
|
||||
user_response = supabase.auth.get_user(token)
|
||||
|
||||
if not user_response or not user_response.user:
|
||||
return None
|
||||
|
||||
user_id = UUID(user_response.user.id)
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
# Create user if doesn't exist
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=user_response.user.email,
|
||||
display_name=user_response.user.user_metadata.get("display_name"),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
116
backend/auth/jwt_utils.py
Normal file
116
backend/auth/jwt_utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import hashlib
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path to import shared module
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from shared.config.settings import settings
|
||||
|
||||
|
||||
def create_api_key_jwt(
|
||||
user_id: str,
|
||||
expires_in_days: int | None = None,
|
||||
additional_claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT token for API key authentication.
|
||||
|
||||
Args:
|
||||
user_id: User's UUID as string
|
||||
expires_in_days: Token expiration in days (None for no expiration)
|
||||
additional_claims: Extra claims to include in the token
|
||||
|
||||
Returns:
|
||||
JWT token string
|
||||
"""
|
||||
if not settings.jwt_private_key:
|
||||
raise ValueError("JWT_PRIVATE_KEY not configured")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"iat": int(now.timestamp()),
|
||||
}
|
||||
|
||||
# Only add expiration if specified
|
||||
if expires_in_days is not None:
|
||||
expires_at = now + timedelta(days=expires_in_days)
|
||||
payload["exp"] = int(expires_at.timestamp())
|
||||
|
||||
if additional_claims:
|
||||
payload.update(additional_claims)
|
||||
|
||||
token = jwt.encode(payload, settings.jwt_private_key, algorithm="RS256")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_api_key_jwt(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
Verify and decode a JWT API key token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Decoded token payload
|
||||
|
||||
Raises:
|
||||
JWTError: If token is invalid, expired, or malformed
|
||||
"""
|
||||
if not settings.jwt_public_key:
|
||||
raise ValueError("JWT_PUBLIC_KEY not configured")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.jwt_public_key,
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
return payload
|
||||
except JWTError as e:
|
||||
raise JWTError(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
def get_token_hash(token: str) -> str:
|
||||
"""
|
||||
Generate SHA256 hash of a token for storage.
|
||||
We store hashes instead of the actual tokens for security.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the token
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def extract_user_id_from_token(token: str) -> str:
|
||||
"""
|
||||
Extract user ID from token without full verification.
|
||||
Useful for database lookups before verification.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
User ID from token subject claim
|
||||
|
||||
Raises:
|
||||
JWTError: If token is malformed
|
||||
"""
|
||||
try:
|
||||
# Decode without verification (just to extract claims)
|
||||
unverified_payload = jwt.get_unverified_claims(token)
|
||||
user_id = unverified_payload.get("sub")
|
||||
if user_id is None:
|
||||
raise JWTError("Token missing subject claim")
|
||||
return str(user_id)
|
||||
except Exception as e:
|
||||
raise JWTError(f"Cannot extract user ID: {str(e)}")
|
||||
241
backend/auth/routes.py
Normal file
241
backend/auth/routes.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path to import shared module
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from shared.database.models import APIKey, User
|
||||
from shared.database.session import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .dependencies import get_current_user, get_optional_current_user
|
||||
from .jwt_utils import create_api_key_jwt, get_token_hash
|
||||
from .utils import update_user_profile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserProfile(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
display_name: str | None
|
||||
created_at: str
|
||||
|
||||
|
||||
class UpdateProfileRequest(BaseModel):
|
||||
display_name: str | None
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
name: str
|
||||
expires_in_days: int | None = None
|
||||
|
||||
|
||||
class APIKeyResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
api_key: str # Only returned on creation
|
||||
created_at: str
|
||||
expires_at: str | None
|
||||
|
||||
|
||||
class APIKeyListItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
api_key: str
|
||||
created_at: str
|
||||
expires_at: str | None
|
||||
last_used_at: str | None
|
||||
is_active: bool
|
||||
|
||||
|
||||
class SyncUserRequest(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
display_name: str | None
|
||||
|
||||
|
||||
@router.post("/sync-user")
|
||||
async def sync_user(
|
||||
request: SyncUserRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Sync user from Supabase to our database"""
|
||||
# Verify the requesting user matches the user being synced
|
||||
if str(current_user.id) != request.id:
|
||||
raise HTTPException(status_code=403, detail="Cannot sync different user")
|
||||
|
||||
# Update user profile if needed
|
||||
if current_user.display_name != request.display_name:
|
||||
current_user.display_name = request.display_name
|
||||
db.commit()
|
||||
|
||||
return {"message": "User synced successfully"}
|
||||
|
||||
|
||||
@router.get("/session")
|
||||
async def get_session(user: User | None = Depends(get_optional_current_user)):
|
||||
"""Get current session info"""
|
||||
if user:
|
||||
return UserProfile(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
created_at=user.created_at.isoformat(),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
async def get_current_user_profile(current_user: User = Depends(get_current_user)):
|
||||
"""Get current user profile"""
|
||||
return UserProfile(
|
||||
id=str(current_user.id),
|
||||
email=current_user.email,
|
||||
display_name=current_user.display_name,
|
||||
created_at=current_user.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserProfile)
|
||||
async def update_current_user_profile(
|
||||
request: UpdateProfileRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update current user profile"""
|
||||
updated_user = update_user_profile(current_user.id, request.display_name, db)
|
||||
|
||||
return UserProfile(
|
||||
id=str(updated_user.id),
|
||||
email=updated_user.email,
|
||||
display_name=updated_user.display_name,
|
||||
created_at=updated_user.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api-keys", response_model=APIKeyResponse)
|
||||
async def create_api_key(
|
||||
request: CreateAPIKeyRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new API key for MCP authentication"""
|
||||
|
||||
# Validate expiration
|
||||
max_expiration_days = 999999
|
||||
if (
|
||||
request.expires_in_days is not None
|
||||
and request.expires_in_days > max_expiration_days
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"API key expiration cannot exceed {max_expiration_days} days",
|
||||
)
|
||||
|
||||
# Check if user already has 5+ active API keys
|
||||
active_keys_count = (
|
||||
db.query(APIKey)
|
||||
.filter(APIKey.user_id == current_user.id, APIKey.is_active)
|
||||
.count()
|
||||
)
|
||||
|
||||
if active_keys_count >= 50:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Maximum of 50 active API keys allowed"
|
||||
)
|
||||
|
||||
# Generate the JWT token
|
||||
try:
|
||||
jwt_token = create_api_key_jwt(
|
||||
user_id=str(current_user.id),
|
||||
expires_in_days=request.expires_in_days,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to generate API key: {str(e)}"
|
||||
)
|
||||
|
||||
# Store API key metadata in database
|
||||
expires_at = None
|
||||
if request.expires_in_days is not None:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=request.expires_in_days
|
||||
)
|
||||
|
||||
api_key = APIKey(
|
||||
user_id=current_user.id,
|
||||
name=request.name,
|
||||
api_key_hash=get_token_hash(jwt_token),
|
||||
api_key=jwt_token,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
return APIKeyResponse(
|
||||
id=str(api_key.id),
|
||||
name=api_key.name,
|
||||
api_key=jwt_token, # Only returned here!
|
||||
created_at=api_key.created_at.isoformat(),
|
||||
expires_at=api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api-keys", response_model=list[APIKeyListItem])
|
||||
async def list_api_keys(
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""List all API keys for the current user"""
|
||||
|
||||
api_keys = (
|
||||
db.query(APIKey)
|
||||
.filter(APIKey.user_id == current_user.id)
|
||||
.order_by(APIKey.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
APIKeyListItem(
|
||||
id=str(key.id),
|
||||
name=key.name,
|
||||
api_key=key.api_key,
|
||||
created_at=key.created_at.isoformat(),
|
||||
expires_at=key.expires_at.isoformat() if key.expires_at else None,
|
||||
last_used_at=key.last_used_at.isoformat() if key.last_used_at else None,
|
||||
is_active=key.is_active,
|
||||
)
|
||||
for key in api_keys
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/api-keys/{key_id}")
|
||||
async def revoke_api_key(
|
||||
key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Revoke (deactivate) an API key"""
|
||||
|
||||
api_key = (
|
||||
db.query(APIKey)
|
||||
.filter(APIKey.id == key_id, APIKey.user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
api_key.is_active = False
|
||||
db.commit()
|
||||
|
||||
return {"message": "API key revoked successfully"}
|
||||
18
backend/auth/supabase_client.py
Normal file
18
backend/auth/supabase_client.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path to import shared module
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from shared.config.settings import settings
|
||||
from supabase import Client, create_client
|
||||
|
||||
|
||||
def get_supabase_client() -> Client:
|
||||
"""Get Supabase client with service role key for backend operations"""
|
||||
return create_client(settings.supabase_url, settings.supabase_service_role_key)
|
||||
|
||||
|
||||
def get_supabase_anon_client() -> Client:
|
||||
"""Get Supabase client with anon key (for testing purposes)"""
|
||||
return create_client(settings.supabase_url, settings.supabase_anon_key)
|
||||
68
backend/auth/utils.py
Normal file
68
backend/auth/utils.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
# Add parent directory to path to import shared module
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from shared.database.models import User
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .supabase_client import get_supabase_client
|
||||
|
||||
|
||||
def sync_user_from_supabase(user_id: UUID, db: Session) -> User:
|
||||
"""Sync user data from Supabase to our database"""
|
||||
supabase = get_supabase_client()
|
||||
|
||||
# Get user from Supabase
|
||||
auth_user = supabase.auth.admin.get_user_by_id(str(user_id))
|
||||
|
||||
if not auth_user or not auth_user.user:
|
||||
raise ValueError(f"User {user_id} not found in Supabase")
|
||||
|
||||
# Check if user exists in our DB
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if user:
|
||||
# Update existing user
|
||||
if auth_user.user.email:
|
||||
user.email = auth_user.user.email
|
||||
user.display_name = auth_user.user.user_metadata.get(
|
||||
"display_name", user.display_name
|
||||
)
|
||||
else:
|
||||
# Create new user
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=auth_user.user.email,
|
||||
display_name=auth_user.user.user_metadata.get("display_name"),
|
||||
)
|
||||
db.add(user)
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def update_user_profile(user_id: UUID, display_name: str | None, db: Session) -> User:
|
||||
"""Update user profile information"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
if display_name is not None:
|
||||
user.display_name = display_name
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# Also update in Supabase metadata
|
||||
supabase = get_supabase_client()
|
||||
supabase.auth.admin.update_user_by_id(
|
||||
str(user_id), {"user_metadata": {"display_name": display_name}}
|
||||
)
|
||||
|
||||
return user
|
||||
33
backend/db/__init__.py
Normal file
33
backend/db/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from .queries import (
|
||||
get_agent_instance_detail,
|
||||
get_agent_type_instances,
|
||||
get_all_agent_instances,
|
||||
get_all_agent_types_with_instances,
|
||||
get_agent_summary,
|
||||
mark_instance_completed,
|
||||
submit_answer,
|
||||
submit_user_feedback,
|
||||
)
|
||||
from .user_agent_queries import (
|
||||
create_user_agent,
|
||||
get_user_agents,
|
||||
update_user_agent,
|
||||
trigger_webhook_agent,
|
||||
get_user_agent_instances,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_all_agent_types_with_instances",
|
||||
"get_all_agent_instances",
|
||||
"get_agent_type_instances",
|
||||
"get_agent_summary",
|
||||
"get_agent_instance_detail",
|
||||
"mark_instance_completed",
|
||||
"submit_answer",
|
||||
"submit_user_feedback",
|
||||
"create_user_agent",
|
||||
"get_user_agents",
|
||||
"update_user_agent",
|
||||
"trigger_webhook_agent",
|
||||
"get_user_agent_instances",
|
||||
]
|
||||
441
backend/db/queries.py
Normal file
441
backend/db/queries.py
Normal file
@@ -0,0 +1,441 @@
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from shared.database import (
|
||||
AgentInstance,
|
||||
AgentQuestion,
|
||||
AgentStatus,
|
||||
AgentUserFeedback,
|
||||
UserAgent,
|
||||
)
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
|
||||
def _format_instance(instance: AgentInstance) -> dict:
|
||||
"""Helper function to format an agent instance consistently"""
|
||||
# Get latest step
|
||||
latest_step = None
|
||||
if instance.steps:
|
||||
latest_step = max(instance.steps, key=lambda s: s.created_at).description
|
||||
|
||||
# Get step count
|
||||
step_count = len(instance.steps) if instance.steps else 0
|
||||
|
||||
# Check for pending questions
|
||||
pending_questions = [q for q in instance.questions if q.is_active]
|
||||
pending_questions_count = len(pending_questions)
|
||||
has_pending = pending_questions_count > 0
|
||||
pending_age = None
|
||||
if has_pending:
|
||||
oldest_pending = min(pending_questions, key=lambda q: q.asked_at)
|
||||
# All database times are stored as UTC but may be naive
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
asked_at = oldest_pending.asked_at
|
||||
if asked_at.tzinfo is None:
|
||||
asked_at = asked_at.replace(tzinfo=timezone.utc)
|
||||
pending_age = int((now_utc - asked_at).total_seconds())
|
||||
|
||||
return {
|
||||
"id": str(instance.id),
|
||||
"agent_type_id": str(instance.user_agent_id) if instance.user_agent_id else "",
|
||||
"agent_type_name": instance.user_agent.name
|
||||
if instance.user_agent
|
||||
else "Unknown",
|
||||
"status": instance.status,
|
||||
"started_at": instance.started_at,
|
||||
"ended_at": instance.ended_at,
|
||||
"latest_step": latest_step,
|
||||
"has_pending_question": has_pending,
|
||||
"pending_question_age": pending_age,
|
||||
"pending_questions_count": pending_questions_count,
|
||||
"step_count": step_count,
|
||||
}
|
||||
|
||||
|
||||
def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]:
|
||||
"""Get all user agents with their instances for a specific user"""
|
||||
|
||||
# Get all user agents for this user
|
||||
user_agents = db.query(UserAgent).filter(UserAgent.user_id == user_id).all()
|
||||
|
||||
result = []
|
||||
for user_agent in user_agents:
|
||||
# Get all instances for this user agent
|
||||
instances = (
|
||||
db.query(AgentInstance)
|
||||
.filter(
|
||||
AgentInstance.user_agent_id == user_agent.id,
|
||||
)
|
||||
.options(
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.user_agent),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Sort instances: pending questions first, then by most recent activity
|
||||
def sort_key(instance):
|
||||
pending_questions = [q for q in instance.questions if q.is_active]
|
||||
if pending_questions:
|
||||
oldest_question = min(pending_questions, key=lambda q: q.asked_at)
|
||||
return (0, oldest_question.asked_at)
|
||||
|
||||
last_activity = instance.started_at
|
||||
if instance.steps:
|
||||
last_activity = max(
|
||||
instance.steps, key=lambda s: s.created_at
|
||||
).created_at
|
||||
return (1, -last_activity.timestamp())
|
||||
|
||||
sorted_instances = sorted(instances, key=sort_key)
|
||||
|
||||
# Format instances with helper function
|
||||
formatted_instances = [
|
||||
_format_instance(instance) for instance in sorted_instances
|
||||
]
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": str(user_agent.id),
|
||||
"name": user_agent.name,
|
||||
"created_at": user_agent.created_at,
|
||||
"recent_instances": formatted_instances,
|
||||
"total_instances": len(instances),
|
||||
"active_instances": sum(
|
||||
1 for i in instances if i.status == AgentStatus.ACTIVE
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all_agent_instances(
|
||||
db: Session, user_id: UUID, limit: int | None = None
|
||||
) -> list[dict]:
|
||||
"""Get all agent instances for a specific user, sorted by most recent activity"""
|
||||
|
||||
query = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.user_id == user_id)
|
||||
.options(
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.user_agent),
|
||||
)
|
||||
.order_by(desc(AgentInstance.started_at))
|
||||
)
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
instances = query.all()
|
||||
|
||||
# Format instances using helper function
|
||||
return [_format_instance(instance) for instance in instances]
|
||||
|
||||
|
||||
def get_agent_summary(db: Session, user_id: UUID) -> dict:
|
||||
"""Get lightweight summary of agent counts without fetching detailed instance data"""
|
||||
|
||||
# Count total instances
|
||||
total_instances = (
|
||||
db.query(AgentInstance).filter(AgentInstance.user_id == user_id).count()
|
||||
)
|
||||
|
||||
# Count active instances (only 'active' for now until DB enum is updated)
|
||||
active_instances = (
|
||||
db.query(AgentInstance)
|
||||
.filter(
|
||||
AgentInstance.user_id == user_id, AgentInstance.status == AgentStatus.ACTIVE
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Count completed instances
|
||||
completed_instances = (
|
||||
db.query(AgentInstance)
|
||||
.filter(
|
||||
AgentInstance.user_id == user_id,
|
||||
AgentInstance.status == AgentStatus.COMPLETED,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Count by user agent and status (for fleet overview)
|
||||
# Get instances with their user agents
|
||||
agent_type_stats = (
|
||||
db.query(
|
||||
UserAgent.id,
|
||||
UserAgent.name,
|
||||
AgentInstance.status,
|
||||
func.count(AgentInstance.id).label("count"),
|
||||
)
|
||||
.join(AgentInstance, AgentInstance.user_agent_id == UserAgent.id)
|
||||
.filter(UserAgent.user_id == user_id)
|
||||
.group_by(UserAgent.id, UserAgent.name, AgentInstance.status)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Format agent type stats
|
||||
agent_types_summary = {}
|
||||
for type_id, type_name, status, count in agent_type_stats:
|
||||
# Agent types are now stored in lowercase, so no normalization needed
|
||||
if type_name not in agent_types_summary:
|
||||
agent_types_summary[type_name] = {
|
||||
"id": str(type_id),
|
||||
"name": type_name,
|
||||
"total_instances": 0,
|
||||
"active_instances": 0,
|
||||
}
|
||||
|
||||
agent_types_summary[type_name]["total_instances"] += count
|
||||
if status == AgentStatus.ACTIVE:
|
||||
agent_types_summary[type_name]["active_instances"] += count
|
||||
|
||||
return {
|
||||
"total_instances": total_instances,
|
||||
"active_instances": active_instances,
|
||||
"completed_instances": completed_instances,
|
||||
"agent_types": list(agent_types_summary.values()),
|
||||
}
|
||||
|
||||
|
||||
def get_agent_type_instances(
|
||||
db: Session, agent_type_id: UUID, user_id: UUID
|
||||
) -> list[dict] | None:
|
||||
"""Get all instances for a specific user agent"""
|
||||
|
||||
user_agent = (
|
||||
db.query(UserAgent)
|
||||
.filter(UserAgent.id == agent_type_id, UserAgent.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not user_agent:
|
||||
return None
|
||||
|
||||
instances = (
|
||||
db.query(AgentInstance)
|
||||
.filter(
|
||||
AgentInstance.user_agent_id == agent_type_id,
|
||||
)
|
||||
.options(
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.user_agent),
|
||||
)
|
||||
.order_by(desc(AgentInstance.started_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Format instances using helper function
|
||||
return [_format_instance(instance) for instance in instances]
|
||||
|
||||
|
||||
def get_agent_instance_detail(
|
||||
db: Session, instance_id: UUID, user_id: UUID
|
||||
) -> dict | None:
|
||||
"""Get detailed information about a specific agent instance for a specific user"""
|
||||
|
||||
instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == instance_id, AgentInstance.user_id == user_id)
|
||||
.options(
|
||||
joinedload(AgentInstance.user_agent),
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.user_feedback),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Sort steps by step number
|
||||
sorted_steps = sorted(instance.steps, key=lambda s: s.step_number)
|
||||
|
||||
# Sort questions by asked_at
|
||||
sorted_questions = sorted(instance.questions, key=lambda q: q.asked_at)
|
||||
|
||||
# Sort user feedback by created_at
|
||||
sorted_feedback = sorted(instance.user_feedback, key=lambda f: f.created_at)
|
||||
|
||||
return {
|
||||
"id": str(instance.id),
|
||||
"agent_type_id": str(instance.user_agent_id) if instance.user_agent_id else "",
|
||||
"agent_type": {
|
||||
"id": str(instance.user_agent.id) if instance.user_agent else "",
|
||||
"name": instance.user_agent.name if instance.user_agent else "Unknown",
|
||||
"created_at": instance.user_agent.created_at
|
||||
if instance.user_agent
|
||||
else datetime.now(timezone.utc),
|
||||
"recent_instances": [],
|
||||
"total_instances": 0,
|
||||
"active_instances": 0,
|
||||
},
|
||||
"status": instance.status,
|
||||
"started_at": instance.started_at,
|
||||
"ended_at": instance.ended_at,
|
||||
"steps": [
|
||||
{
|
||||
"id": str(step.id),
|
||||
"step_number": step.step_number,
|
||||
"description": step.description,
|
||||
"created_at": step.created_at,
|
||||
}
|
||||
for step in sorted_steps
|
||||
],
|
||||
"questions": [
|
||||
{
|
||||
"id": str(question.id),
|
||||
"question_text": question.question_text,
|
||||
"answer_text": question.answer_text,
|
||||
"asked_at": question.asked_at,
|
||||
"answered_at": question.answered_at,
|
||||
"is_active": question.is_active,
|
||||
}
|
||||
for question in sorted_questions
|
||||
],
|
||||
"user_feedback": [
|
||||
{
|
||||
"id": str(feedback.id),
|
||||
"feedback_text": feedback.feedback_text,
|
||||
"created_at": feedback.created_at,
|
||||
"retrieved_at": feedback.retrieved_at,
|
||||
}
|
||||
for feedback in sorted_feedback
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def submit_answer(
|
||||
db: Session, question_id: UUID, answer: str, user_id: UUID
|
||||
) -> dict | None:
|
||||
"""Submit an answer to a question for a specific user"""
|
||||
|
||||
question = (
|
||||
db.query(AgentQuestion)
|
||||
.filter(AgentQuestion.id == question_id, AgentQuestion.is_active)
|
||||
.join(AgentInstance)
|
||||
.filter(AgentInstance.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not question:
|
||||
return None
|
||||
|
||||
question.answer_text = answer
|
||||
question.answered_at = datetime.now(timezone.utc)
|
||||
question.is_active = False
|
||||
question.answered_by_user_id = user_id
|
||||
|
||||
# Update agent instance status back to ACTIVE if it was AWAITING_INPUT
|
||||
instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == question.agent_instance_id)
|
||||
.first()
|
||||
)
|
||||
if instance and instance.status == AgentStatus.AWAITING_INPUT:
|
||||
# Check if there are other active questions for this instance
|
||||
other_active_questions = (
|
||||
db.query(AgentQuestion)
|
||||
.filter(
|
||||
AgentQuestion.agent_instance_id == instance.id,
|
||||
AgentQuestion.id != question_id,
|
||||
AgentQuestion.is_active,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Only change status back to ACTIVE if no other questions are pending
|
||||
if other_active_questions == 0:
|
||||
instance.status = AgentStatus.ACTIVE
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"id": str(question.id),
|
||||
"question_text": question.question_text,
|
||||
"answer_text": question.answer_text,
|
||||
"asked_at": question.asked_at,
|
||||
"answered_at": question.answered_at,
|
||||
"is_active": question.is_active,
|
||||
}
|
||||
|
||||
|
||||
def submit_user_feedback(
|
||||
db: Session, instance_id: UUID, feedback_text: str, user_id: UUID
|
||||
) -> dict | None:
|
||||
"""Submit user feedback for an agent instance for a specific user"""
|
||||
|
||||
# Check if instance exists and belongs to user
|
||||
instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == instance_id, AgentInstance.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Create new feedback
|
||||
feedback = AgentUserFeedback(
|
||||
agent_instance_id=instance_id,
|
||||
feedback_text=feedback_text,
|
||||
created_by_user_id=user_id,
|
||||
)
|
||||
|
||||
db.add(feedback)
|
||||
db.commit()
|
||||
db.refresh(feedback)
|
||||
|
||||
return {
|
||||
"id": str(feedback.id),
|
||||
"feedback_text": feedback.feedback_text,
|
||||
"created_at": feedback.created_at,
|
||||
"retrieved_at": feedback.retrieved_at,
|
||||
}
|
||||
|
||||
|
||||
def mark_instance_completed(
|
||||
db: Session, instance_id: UUID, user_id: UUID
|
||||
) -> dict | None:
|
||||
"""Mark an agent instance as completed for a specific user"""
|
||||
|
||||
# Check if instance exists and belongs to user
|
||||
instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == instance_id, AgentInstance.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Update status to completed and set ended_at
|
||||
instance.status = AgentStatus.COMPLETED
|
||||
instance.ended_at = datetime.now(timezone.utc)
|
||||
|
||||
# Deactivate any pending questions
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
|
||||
).update({"is_active": False})
|
||||
|
||||
db.commit()
|
||||
|
||||
# Re-query with relationships to ensure they're loaded for _format_instance
|
||||
instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == instance_id)
|
||||
.options(
|
||||
joinedload(AgentInstance.user_agent),
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
return _format_instance(instance)
|
||||
271
backend/db/user_agent_queries.py
Normal file
271
backend/db/user_agent_queries.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Database queries for UserAgent operations.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from shared.database import UserAgent, AgentInstance, AgentStatus
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from ..models import UserAgentRequest, WebhookTriggerResponse
|
||||
|
||||
|
||||
def create_user_agent(db: Session, user_id: UUID, request: UserAgentRequest) -> dict:
|
||||
"""Create a new user agent configuration"""
|
||||
|
||||
user_agent = UserAgent(
|
||||
user_id=user_id,
|
||||
name=request.name,
|
||||
webhook_url=request.webhook_url,
|
||||
webhook_api_key=request.webhook_api_key,
|
||||
is_active=request.is_active,
|
||||
)
|
||||
|
||||
db.add(user_agent)
|
||||
db.commit()
|
||||
db.refresh(user_agent)
|
||||
|
||||
return _format_user_agent(user_agent, db)
|
||||
|
||||
|
||||
def get_user_agents(db: Session, user_id: UUID) -> list[dict]:
|
||||
"""Get all user agents for a specific user"""
|
||||
|
||||
user_agents = db.query(UserAgent).filter(UserAgent.user_id == user_id).all()
|
||||
|
||||
return [_format_user_agent(agent, db) for agent in user_agents]
|
||||
|
||||
|
||||
def update_user_agent(
|
||||
db: Session, agent_id: UUID, user_id: UUID, request: UserAgentRequest
|
||||
) -> dict | None:
|
||||
"""Update an existing user agent configuration"""
|
||||
|
||||
user_agent = (
|
||||
db.query(UserAgent)
|
||||
.filter(and_(UserAgent.id == agent_id, UserAgent.user_id == user_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_agent:
|
||||
return None
|
||||
|
||||
user_agent.name = request.name
|
||||
user_agent.webhook_url = request.webhook_url
|
||||
user_agent.webhook_api_key = request.webhook_api_key
|
||||
user_agent.is_active = request.is_active
|
||||
user_agent.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(user_agent)
|
||||
|
||||
return _format_user_agent(user_agent, db)
|
||||
|
||||
|
||||
async def trigger_webhook_agent(
|
||||
db: Session, user_agent: UserAgent, user_id: UUID, prompt: str
|
||||
) -> WebhookTriggerResponse:
|
||||
"""Trigger a webhook agent by calling the webhook URL"""
|
||||
|
||||
if not user_agent.webhook_url:
|
||||
return WebhookTriggerResponse(
|
||||
success=False,
|
||||
message="Webhook URL not configured",
|
||||
error="No webhook URL found for this agent",
|
||||
)
|
||||
|
||||
# Create the agent instance first
|
||||
instance = AgentInstance(
|
||||
user_agent_id=user_agent.id, user_id=user_id, status=AgentStatus.ACTIVE
|
||||
)
|
||||
db.add(instance)
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
|
||||
# Prepare webhook payload
|
||||
payload = {
|
||||
"agent_instance_id": str(instance.id),
|
||||
"prompt": prompt,
|
||||
"omnara_api_key": user_agent.webhook_api_key,
|
||||
"omnara_tools": {
|
||||
"log_step": {
|
||||
"description": "Log a step in the agent's execution",
|
||||
"endpoint": "/api/v1/mcp/tools/log_step",
|
||||
},
|
||||
"ask_question": {
|
||||
"description": "Ask a question to the user",
|
||||
"endpoint": "/api/v1/mcp/tools/ask_question",
|
||||
},
|
||||
"end_session": {
|
||||
"description": "End the agent session",
|
||||
"endpoint": "/api/v1/mcp/tools/end_session",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Call the webhook
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
user_agent.webhook_url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {user_agent.webhook_api_key}"
|
||||
if user_agent.webhook_api_key
|
||||
else "",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return WebhookTriggerResponse(
|
||||
success=True,
|
||||
agent_instance_id=str(instance.id),
|
||||
message="Webhook triggered successfully",
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
# Mark instance as failed
|
||||
instance.status = AgentStatus.FAILED
|
||||
instance.ended_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return WebhookTriggerResponse(
|
||||
success=False,
|
||||
agent_instance_id=str(instance.id),
|
||||
message="Failed to trigger webhook",
|
||||
error=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
# Mark instance as failed
|
||||
instance.status = AgentStatus.FAILED
|
||||
instance.ended_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return WebhookTriggerResponse(
|
||||
success=False,
|
||||
agent_instance_id=str(instance.id),
|
||||
message="Unexpected error occurred",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_user_agent_instances(db: Session, agent_id: UUID, user_id: UUID) -> list | None:
|
||||
"""Get all instances for a specific user agent"""
|
||||
|
||||
# Verify the user agent exists and belongs to the user
|
||||
user_agent = (
|
||||
db.query(UserAgent)
|
||||
.filter(and_(UserAgent.id == agent_id, UserAgent.user_id == user_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_agent:
|
||||
return None
|
||||
|
||||
# Get all instances for this user agent with relationships loaded
|
||||
instances = (
|
||||
db.query(AgentInstance)
|
||||
.options(
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.user_feedback),
|
||||
)
|
||||
.filter(AgentInstance.user_agent_id == agent_id)
|
||||
.order_by(AgentInstance.started_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# Format instances similar to how agent-type instances are formatted
|
||||
return [
|
||||
{
|
||||
"id": str(instance.id),
|
||||
"user_agent_id": str(instance.user_agent_id),
|
||||
"user_id": str(instance.user_id),
|
||||
"status": instance.status.value,
|
||||
"started_at": instance.started_at,
|
||||
"ended_at": instance.ended_at,
|
||||
"pending_questions_count": len(
|
||||
[q for q in instance.questions if q.is_active and not q.answer_text]
|
||||
),
|
||||
"steps_count": len(instance.steps),
|
||||
"user_feedback_count": len(instance.user_feedback),
|
||||
"last_signal_at": instance.steps[-1].created_at
|
||||
if instance.steps
|
||||
else instance.started_at,
|
||||
}
|
||||
for instance in instances
|
||||
]
|
||||
|
||||
|
||||
def _format_user_agent(user_agent: UserAgent, db: Session) -> dict:
|
||||
"""Helper function to format a user agent with instance counts"""
|
||||
|
||||
# Get instance counts
|
||||
instance_count = (
|
||||
db.query(func.count(AgentInstance.id))
|
||||
.filter(AgentInstance.user_agent_id == user_agent.id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
active_instance_count = (
|
||||
db.query(func.count(AgentInstance.id))
|
||||
.filter(
|
||||
and_(
|
||||
AgentInstance.user_agent_id == user_agent.id,
|
||||
AgentInstance.status == AgentStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
waiting_instance_count = (
|
||||
db.query(func.count(AgentInstance.id))
|
||||
.filter(
|
||||
and_(
|
||||
AgentInstance.user_agent_id == user_agent.id,
|
||||
AgentInstance.status == AgentStatus.AWAITING_INPUT,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
completed_instance_count = (
|
||||
db.query(func.count(AgentInstance.id))
|
||||
.filter(
|
||||
and_(
|
||||
AgentInstance.user_agent_id == user_agent.id,
|
||||
AgentInstance.status == AgentStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
error_instance_count = (
|
||||
db.query(func.count(AgentInstance.id))
|
||||
.filter(
|
||||
and_(
|
||||
AgentInstance.user_agent_id == user_agent.id,
|
||||
AgentInstance.status.in_([AgentStatus.FAILED, AgentStatus.KILLED]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(user_agent.id),
|
||||
"name": user_agent.name,
|
||||
"webhook_url": user_agent.webhook_url,
|
||||
"is_active": user_agent.is_active,
|
||||
"created_at": user_agent.created_at,
|
||||
"updated_at": user_agent.updated_at,
|
||||
"instance_count": instance_count or 0,
|
||||
"active_instance_count": active_instance_count or 0,
|
||||
"waiting_instance_count": waiting_instance_count or 0,
|
||||
"completed_instance_count": completed_instance_count or 0,
|
||||
"error_instance_count": error_instance_count or 0,
|
||||
"has_webhook": bool(user_agent.webhook_url),
|
||||
}
|
||||
99
backend/main.py
Normal file
99
backend/main.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""FastAPI backend for Agent Dashboard"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import sentry_sdk
|
||||
from shared.config import settings
|
||||
from .api import agents, questions, user_agents, push_notifications
|
||||
from .auth import routes as auth_routes
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize Sentry only if DSN is provided
|
||||
if settings.sentry_dsn:
|
||||
sentry_sdk.init(
|
||||
dsn=settings.sentry_dsn,
|
||||
send_default_pii=True,
|
||||
environment=settings.environment,
|
||||
)
|
||||
logger.info(f"Sentry initialized for {settings.environment} environment")
|
||||
else:
|
||||
logger.info("Sentry DSN not provided, error tracking disabled")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Agent Dashboard API",
|
||||
description="Backend API for monitoring and interacting with AI agents",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# Configure CORS - cannot use wildcard (*) with credentials
|
||||
# Define localhost origins for both development and production access
|
||||
localhost_origins = [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://localhost:5173", # Vite default
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://localhost:8081", # Custom frontend port
|
||||
"http://127.0.0.1:8081",
|
||||
]
|
||||
|
||||
if os.getenv("ENVIRONMENT", "development") == "development":
|
||||
# In development, use localhost origins
|
||||
allowed_origins = localhost_origins
|
||||
else:
|
||||
# Production origins from configuration
|
||||
allowed_origins = (
|
||||
settings.frontend_urls + localhost_origins
|
||||
) # Include localhost URLs in production too
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers with versioned API prefix
|
||||
app.include_router(auth_routes.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(agents.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(questions.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(user_agents.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(push_notifications.router, prefix=settings.api_v1_prefix)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {"message": "Agent Dashboard API", "version": "1.0.0", "docs": "/docs"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("backend.main:app", host="0.0.0.0", port=settings.api_port, reload=True)
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for module execution"""
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("backend.main:app", host="0.0.0.0", port=settings.api_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
198
backend/models.py
Normal file
198
backend/models.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Backend API models for Agent Dashboard.
|
||||
|
||||
This module contains all Pydantic models used for API request/response serialization.
|
||||
Models are organized by functional area: questions, agents, and detailed views.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
# ============================================================================
|
||||
# Question Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# This is when the Agent prompts the user for an answer and this format
|
||||
# is what the user responds with.
|
||||
class AnswerRequest(BaseModel):
|
||||
answer: str = Field(..., description="User's answer to the question")
|
||||
|
||||
|
||||
# User feedback that agents can retrieve during their operations
|
||||
class UserFeedbackRequest(BaseModel):
|
||||
feedback: str = Field(..., description="User's feedback or additional information")
|
||||
|
||||
|
||||
class UserFeedbackResponse(BaseModel):
|
||||
id: str
|
||||
feedback_text: str
|
||||
created_at: datetime
|
||||
retrieved_at: datetime | None
|
||||
|
||||
@field_serializer("created_at", "retrieved_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Represents individual steps/actions taken by an agent
|
||||
class AgentStepResponse(BaseModel):
|
||||
id: str
|
||||
step_number: int
|
||||
description: str
|
||||
created_at: datetime
|
||||
|
||||
@field_serializer("created_at")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Summary view of an agent instance (a single agent session/run)
|
||||
class AgentInstanceResponse(BaseModel):
|
||||
id: str
|
||||
agent_type_id: str
|
||||
agent_type_name: str | None = None
|
||||
status: AgentStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime | None
|
||||
latest_step: str | None = None
|
||||
has_pending_question: bool = False
|
||||
pending_question_age: int | None = None # Age in seconds
|
||||
pending_questions_count: int = 0
|
||||
step_count: int = 0
|
||||
|
||||
@field_serializer("started_at", "ended_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Overview of an agent type with recent instances
|
||||
# and summary statistics for dashboard cards
|
||||
class AgentTypeOverview(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
created_at: datetime
|
||||
recent_instances: list[AgentInstanceResponse] = []
|
||||
total_instances: int = 0
|
||||
active_instances: int = 0
|
||||
|
||||
@field_serializer("created_at")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Detailed Views
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Detailed information about a question asked by an agent, including answer status
|
||||
class QuestionDetail(BaseModel):
|
||||
id: str
|
||||
question_text: str
|
||||
answer_text: str | None
|
||||
asked_at: datetime
|
||||
answered_at: datetime | None
|
||||
is_active: bool
|
||||
|
||||
@field_serializer("asked_at", "answered_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Complete detailed view of a specific agent instance
|
||||
# with full step and question history
|
||||
class AgentInstanceDetail(BaseModel):
|
||||
id: str
|
||||
agent_type_id: str
|
||||
agent_type: AgentTypeOverview
|
||||
status: AgentStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime | None
|
||||
steps: list[AgentStepResponse] = []
|
||||
questions: list[QuestionDetail] = []
|
||||
user_feedback: list[UserFeedbackResponse] = []
|
||||
|
||||
@field_serializer("started_at", "ended_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# User Agent Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class UserAgentRequest(BaseModel):
|
||||
name: str = Field(..., description="Name of the user agent")
|
||||
webhook_url: str | None = Field(
|
||||
None, description="Webhook URL for remote agent triggering"
|
||||
)
|
||||
webhook_api_key: str | None = Field(
|
||||
None, description="API key for webhook authentication"
|
||||
)
|
||||
is_active: bool = Field(True, description="Whether the agent is active")
|
||||
|
||||
|
||||
class UserAgentResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
webhook_url: str | None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
instance_count: int = 0
|
||||
active_instance_count: int = 0
|
||||
waiting_instance_count: int = 0
|
||||
completed_instance_count: int = 0
|
||||
error_instance_count: int = 0
|
||||
has_webhook: bool = Field(default=False)
|
||||
|
||||
@field_serializer("created_at", "updated_at")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
# Compute has_webhook based on webhook_url presence
|
||||
self.has_webhook = bool(self.webhook_url)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CreateAgentInstanceRequest(BaseModel):
|
||||
prompt: str = Field(..., description="Initial prompt for the agent")
|
||||
|
||||
|
||||
class WebhookTriggerResponse(BaseModel):
|
||||
success: bool
|
||||
agent_instance_id: str | None = None
|
||||
message: str
|
||||
error: str | None = None
|
||||
9
backend/requirements.txt
Normal file
9
backend/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
fastapi==0.115.12
|
||||
uvicorn[standard]==0.34.3
|
||||
supabase==2.15.3
|
||||
httpx==0.28.1
|
||||
python-jose[cryptography]==3.5.0
|
||||
cryptography==42.0.5
|
||||
email-validator==2.1.0
|
||||
exponent-server-sdk>=2.1.0
|
||||
-r ../shared/requirements.txt
|
||||
152
backend/tests/conftest.py
Normal file
152
backend/tests/conftest.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Pytest configuration and fixtures for backend tests."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
|
||||
from shared.database.models import Base, User, UserAgent, AgentInstance
|
||||
from shared.database.enums import AgentStatus
|
||||
from backend.main import app
|
||||
from backend.auth.dependencies import get_current_user, get_optional_current_user
|
||||
from shared.database.session import get_db
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_container():
|
||||
"""Create a PostgreSQL container for testing - shared across all tests."""
|
||||
with PostgresContainer("postgres:16-alpine") as postgres:
|
||||
yield postgres
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db(postgres_container):
|
||||
"""Create a test database session using PostgreSQL."""
|
||||
# Get connection URL from container
|
||||
db_url = postgres_container.get_connection_url()
|
||||
|
||||
# Create engine and tables
|
||||
engine = create_engine(db_url)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create session
|
||||
TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = TestSessionLocal()
|
||||
|
||||
yield session
|
||||
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(test_db):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
email="test@example.com",
|
||||
display_name="Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(user)
|
||||
test_db.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_agent(test_db, test_user):
|
||||
"""Create a test user agent."""
|
||||
user_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="claude code", # Lowercase as per the actual implementation
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(user_agent)
|
||||
test_db.commit()
|
||||
return user_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent_instance(test_db, test_user, test_user_agent):
|
||||
"""Create a test agent instance."""
|
||||
instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.ACTIVE,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(instance)
|
||||
test_db.commit()
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_db):
|
||||
"""Create a test client with database override."""
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
yield test_db
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client(client, test_user):
|
||||
"""Create a test client with authentication."""
|
||||
|
||||
def override_get_current_user():
|
||||
return test_user
|
||||
|
||||
def override_get_optional_current_user():
|
||||
return test_user
|
||||
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
app.dependency_overrides[get_optional_current_user] = (
|
||||
override_get_optional_current_user
|
||||
)
|
||||
|
||||
yield client
|
||||
|
||||
# Clear only the auth overrides, keep the db override
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
if get_optional_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_optional_current_user]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client():
|
||||
"""Mock Supabase client for auth tests."""
|
||||
mock = Mock()
|
||||
mock.auth = Mock()
|
||||
mock.auth.get_user = Mock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_env():
|
||||
"""Reset environment variables for each test."""
|
||||
original_env = os.environ.copy()
|
||||
yield
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
335
backend/tests/test_agents.py
Normal file
335
backend/tests/test_agents.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for agent endpoints."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.database.models import (
|
||||
User,
|
||||
UserAgent,
|
||||
AgentInstance,
|
||||
AgentStep,
|
||||
AgentQuestion,
|
||||
AgentUserFeedback,
|
||||
)
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
|
||||
class TestAgentEndpoints:
|
||||
"""Test agent management endpoints."""
|
||||
|
||||
def test_list_agent_types(
|
||||
self, authenticated_client, test_user_agent, test_agent_instance
|
||||
):
|
||||
"""Test listing agent types with instances."""
|
||||
response = authenticated_client.get("/api/v1/agent-types")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data) == 1
|
||||
agent_type = data[0]
|
||||
assert agent_type["id"] == str(test_user_agent.id)
|
||||
assert agent_type["name"] == "claude code"
|
||||
assert len(agent_type["recent_instances"]) == 1
|
||||
assert agent_type["recent_instances"][0]["id"] == str(test_agent_instance.id)
|
||||
|
||||
def test_list_agent_types_multiple_users(
|
||||
self, authenticated_client, test_db, test_user_agent
|
||||
):
|
||||
"""Test that users only see their own agent types."""
|
||||
# Create another user with agent
|
||||
other_user = User(
|
||||
id=uuid4(),
|
||||
email="other@example.com",
|
||||
display_name="Other User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user)
|
||||
|
||||
other_user_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=other_user.id,
|
||||
name="cursor",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user_agent)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get("/api/v1/agent-types")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Should only see own agent type
|
||||
assert len(data) == 1
|
||||
assert data[0]["name"] == "claude code"
|
||||
|
||||
def test_list_agent_types_with_pending_questions(
|
||||
self, authenticated_client, test_db, test_user, test_user_agent
|
||||
):
|
||||
"""Test agent types listing with pending questions (catches timezone issues)."""
|
||||
# Create an agent instance
|
||||
instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.ACTIVE,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(instance)
|
||||
|
||||
# Create a pending question with timezone-aware datetime
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=instance.id,
|
||||
question_text="Test question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# This should not raise a timezone error
|
||||
response = authenticated_client.get("/api/v1/agent-types")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data) == 1
|
||||
agent_type = data[0]
|
||||
assert len(agent_type["recent_instances"]) == 1
|
||||
|
||||
# Check that pending question info is populated
|
||||
instance_data = agent_type["recent_instances"][0]
|
||||
assert instance_data["has_pending_question"] is True
|
||||
assert instance_data["pending_questions_count"] == 1
|
||||
assert instance_data["pending_question_age"] is not None
|
||||
assert instance_data["pending_question_age"] >= 0
|
||||
|
||||
def test_list_all_agent_instances(self, authenticated_client, test_agent_instance):
|
||||
"""Test listing all agent instances."""
|
||||
response = authenticated_client.get("/api/v1/agent-instances")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data) == 1
|
||||
instance = data[0]
|
||||
assert instance["id"] == str(test_agent_instance.id)
|
||||
assert instance["status"] == "active"
|
||||
|
||||
def test_list_agent_instances_with_limit(
|
||||
self, authenticated_client, test_db, test_user, test_user_agent
|
||||
):
|
||||
"""Test listing agent instances with limit."""
|
||||
# Create multiple instances
|
||||
for i in range(5):
|
||||
instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.COMPLETED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(instance)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get("/api/v1/agent-instances?limit=3")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
|
||||
def test_get_agent_summary(
|
||||
self,
|
||||
authenticated_client,
|
||||
test_db,
|
||||
test_user,
|
||||
test_user_agent,
|
||||
test_agent_instance,
|
||||
):
|
||||
"""Test getting agent summary."""
|
||||
# Add more instances with different statuses
|
||||
completed_instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.COMPLETED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(completed_instance)
|
||||
|
||||
# Add a question to the active instance
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Test question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get("/api/v1/agent-summary")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["total_instances"] == 2
|
||||
assert data["active_instances"] == 1
|
||||
assert data["completed_instances"] == 1
|
||||
assert "agent_types" in data
|
||||
assert len(data["agent_types"]) == 1
|
||||
|
||||
def test_get_type_instances(
|
||||
self, authenticated_client, test_user_agent, test_agent_instance
|
||||
):
|
||||
"""Test getting instances for a specific agent type."""
|
||||
response = authenticated_client.get(
|
||||
f"/api/v1/agent-types/{test_user_agent.id}/instances"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == str(test_agent_instance.id)
|
||||
|
||||
def test_get_type_instances_not_found(self, authenticated_client):
|
||||
"""Test getting instances for non-existent agent type."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.get(f"/api/v1/agent-types/{fake_id}/instances")
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Agent type not found"
|
||||
|
||||
def test_get_instance_detail(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test getting detailed agent instance information."""
|
||||
# Add steps and questions
|
||||
step1 = AgentStep(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=1,
|
||||
description="First step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
step2 = AgentStep(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=2,
|
||||
description="Second step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Need input?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
feedback = AgentUserFeedback(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
created_by_user_id=test_agent_instance.user_id,
|
||||
feedback_text="Great work!",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
test_db.add_all([step1, step2, question, feedback])
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == str(test_agent_instance.id)
|
||||
assert len(data["steps"]) == 2
|
||||
assert data["steps"][0]["description"] == "First step"
|
||||
assert len(data["questions"]) == 1
|
||||
assert data["questions"][0]["question_text"] == "Need input?"
|
||||
assert len(data["user_feedback"]) == 1
|
||||
assert data["user_feedback"][0]["feedback_text"] == "Great work!"
|
||||
|
||||
def test_get_instance_detail_not_found(self, authenticated_client):
|
||||
"""Test getting non-existent instance detail."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.get(f"/api/v1/agent-instances/{fake_id}")
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Agent instance not found"
|
||||
|
||||
def test_add_user_feedback(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test adding user feedback to an agent instance."""
|
||||
feedback_text = "Please use TypeScript for this component"
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/feedback",
|
||||
json={"feedback": feedback_text},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["feedback_text"] == feedback_text
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
# Verify in database
|
||||
feedback = (
|
||||
test_db.query(AgentUserFeedback)
|
||||
.filter_by(agent_instance_id=test_agent_instance.id)
|
||||
.first()
|
||||
)
|
||||
assert feedback is not None
|
||||
assert feedback.feedback_text == feedback_text
|
||||
assert feedback.retrieved_at is None
|
||||
|
||||
def test_add_feedback_to_nonexistent_instance(self, authenticated_client):
|
||||
"""Test adding feedback to non-existent instance."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{fake_id}/feedback",
|
||||
json={"feedback": "Test feedback"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Agent instance not found"
|
||||
|
||||
def test_update_agent_status_completed(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test marking an agent instance as completed."""
|
||||
response = authenticated_client.put(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/status",
|
||||
json={"status": "completed"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["ended_at"] is not None
|
||||
|
||||
# Verify in database
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.COMPLETED
|
||||
assert test_agent_instance.ended_at is not None
|
||||
|
||||
def test_update_agent_status_unsupported(
|
||||
self, authenticated_client, test_agent_instance
|
||||
):
|
||||
"""Test unsupported status update."""
|
||||
response = authenticated_client.put(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/status",
|
||||
json={"status": "paused"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Status update not supported"
|
||||
|
||||
def test_update_status_nonexistent_instance(self, authenticated_client):
|
||||
"""Test updating status of non-existent instance."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.put(
|
||||
f"/api/v1/agent-instances/{fake_id}/status", json={"status": "completed"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Agent instance not found"
|
||||
237
backend/tests/test_auth.py
Normal file
237
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for authentication endpoints."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from shared.database.models import User, APIKey
|
||||
|
||||
|
||||
class TestAuthEndpoints:
|
||||
"""Test authentication endpoints."""
|
||||
|
||||
def test_get_session_unauthenticated(self, client):
|
||||
"""Test getting session when not authenticated."""
|
||||
response = client.get("/api/v1/auth/session")
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Not authenticated"
|
||||
|
||||
def test_get_session_authenticated(self, authenticated_client, test_user):
|
||||
"""Test getting session when authenticated."""
|
||||
response = authenticated_client.get("/api/v1/auth/session")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(test_user.id)
|
||||
assert data["email"] == test_user.email
|
||||
assert data["display_name"] == test_user.display_name
|
||||
|
||||
def test_get_current_user_profile(self, authenticated_client, test_user):
|
||||
"""Test getting current user profile."""
|
||||
response = authenticated_client.get("/api/v1/auth/me")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(test_user.id)
|
||||
assert data["email"] == test_user.email
|
||||
assert data["display_name"] == test_user.display_name
|
||||
|
||||
@patch("backend.auth.utils.get_supabase_client")
|
||||
def test_update_user_profile(
|
||||
self, mock_supabase, authenticated_client, test_user, test_db
|
||||
):
|
||||
"""Test updating user profile."""
|
||||
# Mock Supabase client
|
||||
mock_client = Mock()
|
||||
mock_client.auth.admin.update_user_by_id = Mock()
|
||||
mock_supabase.return_value = mock_client
|
||||
|
||||
new_display_name = "Updated Test User"
|
||||
response = authenticated_client.patch(
|
||||
"/api/v1/auth/me", json={"display_name": new_display_name}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["display_name"] == new_display_name
|
||||
|
||||
# Verify in database
|
||||
test_db.refresh(test_user)
|
||||
assert test_user.display_name == new_display_name
|
||||
|
||||
# Verify Supabase was called
|
||||
mock_client.auth.admin.update_user_by_id.assert_called_once()
|
||||
|
||||
def test_sync_user(self, authenticated_client, test_user, test_db):
|
||||
"""Test syncing user from Supabase."""
|
||||
response = authenticated_client.post(
|
||||
"/api/v1/auth/sync-user",
|
||||
json={
|
||||
"id": str(test_user.id),
|
||||
"email": test_user.email,
|
||||
"display_name": "Synced Name",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "User synced successfully"
|
||||
|
||||
# Verify display name was updated
|
||||
test_db.refresh(test_user)
|
||||
assert test_user.display_name == "Synced Name"
|
||||
|
||||
def test_sync_user_forbidden(self, authenticated_client, test_user):
|
||||
"""Test syncing a different user is forbidden."""
|
||||
different_user_id = str(uuid4())
|
||||
response = authenticated_client.post(
|
||||
"/api/v1/auth/sync-user",
|
||||
json={
|
||||
"id": different_user_id,
|
||||
"email": "other@example.com",
|
||||
"display_name": "Other User",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cannot sync different user"
|
||||
|
||||
|
||||
class TestAPIKeyEndpoints:
|
||||
"""Test API key management endpoints."""
|
||||
|
||||
@patch("backend.auth.routes.create_api_key_jwt")
|
||||
def test_create_api_key(
|
||||
self, mock_create_jwt, authenticated_client, test_user, test_db
|
||||
):
|
||||
"""Test creating an API key."""
|
||||
mock_jwt_token = "test.jwt.token"
|
||||
mock_create_jwt.return_value = mock_jwt_token
|
||||
|
||||
response = authenticated_client.post(
|
||||
"/api/v1/auth/api-keys",
|
||||
json={"name": "Test API Key", "expires_in_days": 30},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test API Key"
|
||||
assert data["api_key"] == mock_jwt_token
|
||||
assert "expires_at" in data
|
||||
|
||||
# Verify in database
|
||||
api_key = test_db.query(APIKey).filter(APIKey.user_id == test_user.id).first()
|
||||
assert api_key is not None
|
||||
assert api_key.name == "Test API Key"
|
||||
assert api_key.is_active is True
|
||||
|
||||
@patch("backend.auth.routes.create_api_key_jwt")
|
||||
def test_create_api_key_max_limit(
|
||||
self, mock_create_jwt, authenticated_client, test_user, test_db
|
||||
):
|
||||
"""Test creating API key when at max limit."""
|
||||
mock_create_jwt.return_value = "test.jwt.token"
|
||||
|
||||
# Create 50 existing API keys
|
||||
for i in range(50):
|
||||
api_key = APIKey(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name=f"Key {i}",
|
||||
api_key_hash="hash",
|
||||
api_key=f"token{i}",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(api_key)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.post(
|
||||
"/api/v1/auth/api-keys",
|
||||
json={"name": "One Too Many", "expires_in_days": 30},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Maximum of 50 active API keys allowed" in response.json()["detail"]
|
||||
|
||||
def test_list_api_keys(self, authenticated_client, test_user, test_db):
|
||||
"""Test listing API keys."""
|
||||
# Create test API keys
|
||||
api_keys = []
|
||||
for i in range(3):
|
||||
api_key = APIKey(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name=f"Key {i}",
|
||||
api_key_hash=f"hash{i}",
|
||||
api_key=f"token{i}",
|
||||
is_active=i != 2, # Last one is inactive
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
api_keys.append(api_key)
|
||||
test_db.add(api_key)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get("/api/v1/auth/api-keys")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
|
||||
# Should be ordered by created_at desc
|
||||
assert data[0]["name"] == "Key 2"
|
||||
assert data[0]["is_active"] is False
|
||||
assert data[1]["name"] == "Key 1"
|
||||
assert data[1]["is_active"] is True
|
||||
|
||||
def test_revoke_api_key(self, authenticated_client, test_user, test_db):
|
||||
"""Test revoking an API key."""
|
||||
# Create test API key
|
||||
api_key = APIKey(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="Test Key",
|
||||
api_key_hash="hash",
|
||||
api_key="token",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(api_key)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.delete(f"/api/v1/auth/api-keys/{api_key.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "API key revoked successfully"
|
||||
|
||||
# Verify in database
|
||||
test_db.refresh(api_key)
|
||||
assert api_key.is_active is False
|
||||
|
||||
def test_revoke_api_key_not_found(self, authenticated_client):
|
||||
"""Test revoking a non-existent API key."""
|
||||
fake_id = str(uuid4())
|
||||
response = authenticated_client.delete(f"/api/v1/auth/api-keys/{fake_id}")
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "API key not found"
|
||||
|
||||
def test_revoke_api_key_wrong_user(self, authenticated_client, test_db):
|
||||
"""Test revoking another user's API key."""
|
||||
# Create another user
|
||||
other_user = User(
|
||||
id=uuid4(),
|
||||
email="other@example.com",
|
||||
display_name="Other User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user)
|
||||
|
||||
# Create API key for other user
|
||||
api_key = APIKey(
|
||||
id=uuid4(),
|
||||
user_id=other_user.id,
|
||||
name="Other's Key",
|
||||
api_key_hash="hash",
|
||||
api_key="token",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(api_key)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.delete(f"/api/v1/auth/api-keys/{api_key.id}")
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "API key not found"
|
||||
198
backend/tests/test_questions.py
Normal file
198
backend/tests/test_questions.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Tests for question endpoints."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.database.models import AgentQuestion, AgentInstance, User
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
|
||||
class TestQuestionEndpoints:
|
||||
"""Test question management endpoints."""
|
||||
|
||||
def test_answer_question(
|
||||
self, authenticated_client, test_db, test_agent_instance, test_user
|
||||
):
|
||||
"""Test answering a pending question."""
|
||||
# Create a pending question
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Should I use async/await?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Submit answer
|
||||
answer_text = "Yes, use async/await for better performance"
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer", json={"answer": answer_text}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["message"] == "Answer submitted successfully"
|
||||
|
||||
# Verify in database
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text == answer_text
|
||||
assert question.answered_at is not None
|
||||
assert question.answered_by_user_id == test_user.id
|
||||
assert question.is_active is False
|
||||
|
||||
# Verify agent instance status changed back to active
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.ACTIVE
|
||||
|
||||
def test_answer_question_not_found(self, authenticated_client):
|
||||
"""Test answering a non-existent question."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{fake_id}/answer", json={"answer": "Some answer"}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
|
||||
def test_answer_already_answered_question(
|
||||
self, authenticated_client, test_db, test_agent_instance, test_user
|
||||
):
|
||||
"""Test answering an already answered question."""
|
||||
# Create an already answered question
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Already answered?",
|
||||
answer_text="Previous answer",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
answered_at=datetime.now(timezone.utc),
|
||||
answered_by_user_id=test_user.id,
|
||||
is_active=False,
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Try to answer again
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer", json={"answer": "New answer"}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
|
||||
# Verify answer didn't change
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text == "Previous answer"
|
||||
|
||||
def test_answer_question_wrong_user(self, authenticated_client, test_db):
|
||||
"""Test answering a question from another user's agent."""
|
||||
# Create another user and their agent instance
|
||||
other_user = User(
|
||||
id=uuid4(),
|
||||
email="other@example.com",
|
||||
display_name="Other User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user)
|
||||
|
||||
# Create user agent for other user
|
||||
from shared.database.models import UserAgent
|
||||
|
||||
other_user_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=other_user.id,
|
||||
name="other agent",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user_agent)
|
||||
|
||||
other_instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=other_user_agent.id,
|
||||
user_id=other_user.id,
|
||||
status=AgentStatus.AWAITING_INPUT,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_instance)
|
||||
|
||||
# Create question for other user's agent
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=other_instance.id,
|
||||
question_text="Other user's question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Try to answer as current user
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer",
|
||||
json={"answer": "Trying to answer"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
|
||||
# Verify question remains unanswered
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text is None
|
||||
assert question.is_active is True
|
||||
|
||||
def test_answer_inactive_question(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test answering an inactive question."""
|
||||
# Create an inactive question (but not answered)
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Inactive question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=False, # Inactive but not answered
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Try to answer
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer",
|
||||
json={"answer": "Trying to answer inactive"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
|
||||
def test_answer_question_empty_answer(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test submitting an empty answer."""
|
||||
# Create a pending question
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Can I submit empty?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Submit empty answer - should still work
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer", json={"answer": ""}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify empty answer was saved
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text == ""
|
||||
assert question.is_active is False
|
||||
291
backend/tests/test_user_agents.py
Normal file
291
backend/tests/test_user_agents.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Tests for user agent endpoints."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch
|
||||
|
||||
from shared.database.models import UserAgent, AgentInstance, User
|
||||
from shared.database.enums import AgentStatus
|
||||
from backend.models import WebhookTriggerResponse
|
||||
|
||||
|
||||
class TestUserAgentEndpoints:
|
||||
"""Test user agent management endpoints."""
|
||||
|
||||
def test_list_user_agents(
|
||||
self, authenticated_client, test_db, test_user, test_user_agent
|
||||
):
|
||||
"""Test listing user agents."""
|
||||
# Create additional user agent
|
||||
another_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="cursor",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(another_agent)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get("/api/v1/user-agents")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data) == 2
|
||||
names = [agent["name"] for agent in data]
|
||||
assert "claude code" in names
|
||||
assert "cursor" in names
|
||||
|
||||
def test_list_user_agents_different_users(
|
||||
self, authenticated_client, test_db, test_user_agent
|
||||
):
|
||||
"""Test that users only see their own user agents."""
|
||||
# Create another user with agent
|
||||
other_user = User(
|
||||
id=uuid4(),
|
||||
email="other@example.com",
|
||||
display_name="Other User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user)
|
||||
|
||||
other_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=other_user.id,
|
||||
name="other agent",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_agent)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get("/api/v1/user-agents")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Should only see own agent
|
||||
assert len(data) == 1
|
||||
assert data[0]["name"] == "claude code"
|
||||
|
||||
def test_create_user_agent(self, authenticated_client, test_db, test_user):
|
||||
"""Test creating a new user agent."""
|
||||
agent_data = {
|
||||
"name": "New Agent",
|
||||
"is_active": True,
|
||||
"webhook_url": "https://example.com/webhook",
|
||||
}
|
||||
|
||||
response = authenticated_client.post("/api/v1/user-agents", json=agent_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["name"] == "New Agent"
|
||||
assert data["is_active"] is True
|
||||
assert data["webhook_url"] == "https://example.com/webhook"
|
||||
assert "id" in data
|
||||
|
||||
# Verify in database
|
||||
agent = test_db.query(UserAgent).filter_by(name="New Agent").first()
|
||||
assert agent is not None
|
||||
assert agent.user_id == test_user.id
|
||||
assert agent.webhook_url == "https://example.com/webhook"
|
||||
|
||||
def test_update_user_agent(self, authenticated_client, test_db, test_user_agent):
|
||||
"""Test updating a user agent."""
|
||||
update_data = {
|
||||
"name": "Updated Claude",
|
||||
"is_active": False,
|
||||
"webhook_url": "https://new-webhook.com",
|
||||
}
|
||||
|
||||
response = authenticated_client.patch(
|
||||
f"/api/v1/user-agents/{test_user_agent.id}", json=update_data
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["name"] == "Updated Claude"
|
||||
assert data["is_active"] is False
|
||||
assert data["webhook_url"] == "https://new-webhook.com"
|
||||
|
||||
# Verify in database
|
||||
test_db.refresh(test_user_agent)
|
||||
assert test_user_agent.name == "Updated Claude"
|
||||
assert test_user_agent.is_active is False
|
||||
assert test_user_agent.webhook_url == "https://new-webhook.com"
|
||||
|
||||
def test_update_user_agent_not_found(self, authenticated_client):
|
||||
"""Test updating a non-existent user agent."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.patch(
|
||||
f"/api/v1/user-agents/{fake_id}", json={"name": "Updated"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "User agent not found"
|
||||
|
||||
def test_update_user_agent_wrong_user(self, authenticated_client, test_db):
|
||||
"""Test updating another user's agent."""
|
||||
# Create another user with agent
|
||||
other_user = User(
|
||||
id=uuid4(),
|
||||
email="other@example.com",
|
||||
display_name="Other User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user)
|
||||
|
||||
other_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=other_user.id,
|
||||
name="other agent",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_agent)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.patch(
|
||||
f"/api/v1/user-agents/{other_agent.id}", json={"name": "Hacked"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "User agent not found"
|
||||
|
||||
def test_get_user_agent_instances(
|
||||
self, authenticated_client, test_db, test_user, test_user_agent
|
||||
):
|
||||
"""Test getting instances for a user agent."""
|
||||
# Create instances
|
||||
instance1 = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.ACTIVE,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
instance2 = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.COMPLETED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add_all([instance1, instance2])
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get(
|
||||
f"/api/v1/user-agents/{test_user_agent.id}/instances"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data) == 2
|
||||
statuses = [inst["status"] for inst in data]
|
||||
assert "active" in statuses
|
||||
assert "completed" in statuses
|
||||
|
||||
def test_get_user_agent_instances_not_found(self, authenticated_client):
|
||||
"""Test getting instances for non-existent user agent."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.get(f"/api/v1/user-agents/{fake_id}/instances")
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "User agent not found"
|
||||
|
||||
def test_create_agent_instance_no_webhook(
|
||||
self, authenticated_client, test_db, test_user_agent
|
||||
):
|
||||
"""Test creating an instance for agent without webhook."""
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/user-agents/{test_user_agent.id}/instances",
|
||||
json={"prompt": "Test prompt"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "agent_instance_id" in data
|
||||
assert data["message"] == "Agent instance created successfully"
|
||||
|
||||
# Verify instance created
|
||||
instance = (
|
||||
test_db.query(AgentInstance)
|
||||
.filter_by(user_agent_id=test_user_agent.id)
|
||||
.first()
|
||||
)
|
||||
assert instance is not None
|
||||
assert instance.status == AgentStatus.ACTIVE
|
||||
|
||||
def test_create_agent_instance_with_webhook(
|
||||
self, authenticated_client, test_db, test_user_agent
|
||||
):
|
||||
"""Test creating an instance for agent with webhook."""
|
||||
# Set webhook URL
|
||||
test_user_agent.webhook_url = "https://example.com/webhook"
|
||||
test_db.commit()
|
||||
|
||||
# Mock the async webhook function
|
||||
with patch(
|
||||
"backend.db.user_agent_queries.trigger_webhook_agent"
|
||||
) as mock_trigger:
|
||||
# Use AsyncMock to properly mock the async function
|
||||
mock_trigger.return_value = WebhookTriggerResponse(
|
||||
success=True,
|
||||
agent_instance_id=str(uuid4()),
|
||||
message="Webhook triggered successfully",
|
||||
)
|
||||
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/user-agents/{test_user_agent.id}/instances",
|
||||
json={"prompt": "Test prompt with webhook"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# The actual webhook is still being called, so let's just check the response structure
|
||||
assert "success" in data
|
||||
assert "agent_instance_id" in data
|
||||
assert "message" in data
|
||||
|
||||
def test_create_agent_instance_not_found(self, authenticated_client):
|
||||
"""Test creating instance for non-existent user agent."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/user-agents/{fake_id}/instances", json={"prompt": "Test"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "User agent not found"
|
||||
|
||||
def test_create_agent_instance_wrong_user(self, authenticated_client, test_db):
|
||||
"""Test creating instance for another user's agent."""
|
||||
# Create another user with agent
|
||||
other_user = User(
|
||||
id=uuid4(),
|
||||
email="other@example.com",
|
||||
display_name="Other User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_user)
|
||||
|
||||
other_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=other_user.id,
|
||||
name="other agent",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(other_agent)
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/user-agents/{other_agent.id}/instances",
|
||||
json={"prompt": "Trying to use other's agent"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "User agent not found"
|
||||
67
cli/README.md
Normal file
67
cli/README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# Omnara CLI
|
||||
|
||||
Omnara CLI automatically installs the Agent Dashboard MCP server configuration into various AI clients. This allows agents to connect to your dashboard for logging progress, asking questions, and coordinating tasks.
|
||||
|
||||
## 🚀 Installation
|
||||
|
||||
```bash
|
||||
npx @omnara/cli@latest install <client> --api-key YOUR_API_KEY
|
||||
```
|
||||
|
||||
**Supported clients:** claude-code, cursor, windsurf, cline, claude, witsy, enconvo
|
||||
|
||||
## 🔧 Connection Types
|
||||
|
||||
### **SSE (Server-Sent Events)** - *Recommended*
|
||||
- **Clients:** `cursor`, `claude`
|
||||
- **Benefits:** Hosted service, no setup required
|
||||
|
||||
### **stdio** - *Local MCP server*
|
||||
- **Clients:** `cline`, `roo-cline`, `windsurf`, `witsy`, `enconvo`
|
||||
- **Benefits:** Local execution, full control
|
||||
|
||||
## 📦 stdio Installation Process
|
||||
|
||||
For stdio clients, the CLI automatically:
|
||||
|
||||
1. **Installs Python package:** `pip install omnara`
|
||||
2. **Writes clean config:**
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"omnara": {
|
||||
"command": "omnara",
|
||||
"args": ["--api-key", "YOUR_API_KEY"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
3. **Ready to use** - Client runs: `omnara --api-key YOUR_KEY`
|
||||
|
||||
**If auto-install fails:** Run `pip install omnara` manually.
|
||||
|
||||
## 🔧 Manual Setup
|
||||
|
||||
### SSE Configuration
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"omnara": {
|
||||
"url": "https://omnara-mcp.onrender.com",
|
||||
"apiKey": "YOUR_API_KEY"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### stdio Configuration
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"omnara": {
|
||||
"command": "omnara",
|
||||
"args": ["--api-key", "YOUR_API_KEY"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
1493
cli/package-lock.json
generated
Normal file
1493
cli/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
57
cli/package.json
Normal file
57
cli/package.json
Normal file
@@ -0,0 +1,57 @@
|
||||
{
|
||||
"name": "@omnara/cli",
|
||||
"version": "1.0.6",
|
||||
"type": "module",
|
||||
"description": "MCP configuration installer by Omnara",
|
||||
"main": "dist/index.js",
|
||||
"homepage": "https://github.com/omnara-ai/omnara",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/omnara-ai/omnara.git"
|
||||
},
|
||||
"files": [
|
||||
"dist"
|
||||
],
|
||||
"bin": {
|
||||
"omnara-cli": "dist/cli.js"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc && shx chmod +x dist/cli.js",
|
||||
"start": "node dist/cli.js",
|
||||
"dev": "nodemon --watch src --ext ts,json --exec \"npm run build\"",
|
||||
"prepare": "npm run build",
|
||||
"build:prod": "npm run build",
|
||||
"publish-patch": "npm version patch && npm run build:prod && npm publish --access public",
|
||||
"publish-private": "npm version patch && npm run build:prod && npm publish"
|
||||
},
|
||||
"keywords": [
|
||||
"mcp",
|
||||
"model-context-protocol",
|
||||
"ai",
|
||||
"omnara",
|
||||
"cli",
|
||||
"installer"
|
||||
],
|
||||
"author": "ishaan.sehgal99@gmail.com",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"chalk": "^5.3.0",
|
||||
"commander": "^12.0.0",
|
||||
"inquirer": "^12.5.0",
|
||||
"jsonc-parser": "^3.3.1",
|
||||
"ora": "^8.0.1",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.28",
|
||||
"nodemon": "^3.1.0",
|
||||
"shx": "^0.3.4",
|
||||
"typescript": "^5.4.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
}
|
||||
}
|
||||
79
cli/src/cli.ts
Normal file
79
cli/src/cli.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
import { Command } from "commander";
|
||||
import { install, installAll } from "./index.js";
|
||||
import { VALID_CLIENTS } from "./types.js";
|
||||
import chalk from "chalk";
|
||||
|
||||
const program = new Command();
|
||||
|
||||
program
|
||||
.name("omnara-cli")
|
||||
.description("Install MCP configuration for various AI clients")
|
||||
.version("1.0.0");
|
||||
|
||||
program
|
||||
.command("install")
|
||||
.description("Install MCP configuration for a specific client, or all clients if none specified")
|
||||
.argument(
|
||||
"[client]",
|
||||
`The client to install for (${VALID_CLIENTS.join(", ")}) - omit to install for all clients`
|
||||
)
|
||||
.option("--api-key <key>", "API key for omnara services")
|
||||
.option("--transport <transport>", "Override default transport method (stdio, sse, streamable-http)")
|
||||
.option("--endpoint <url>", "Custom endpoint URL (for SSE/HTTP/Streamable clients)")
|
||||
.action(async (client?: string, options: { apiKey?: string; transport?: string; endpoint?: string } = {}) => {
|
||||
// If no client specified, install for all clients
|
||||
if (!client) {
|
||||
// console.log(chalk.blue("No specific client provided."));
|
||||
console.log(chalk.gray("📝 This will add MCP configurations for all supported AI clients."));
|
||||
console.log(chalk.gray("💡 Only clients you actually have installed will use these configs."));
|
||||
console.log(chalk.gray(`📋 Supported clients: ${VALID_CLIENTS.join(", ")}`));
|
||||
|
||||
try {
|
||||
await installAll({
|
||||
apiKey: options.apiKey,
|
||||
transport: options.transport as any,
|
||||
endpoint: options.endpoint
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(
|
||||
chalk.red(
|
||||
error instanceof Error ? error.message : "Unknown error occurred"
|
||||
)
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Validate specific client
|
||||
if (!VALID_CLIENTS.includes(client as any)) {
|
||||
console.error(
|
||||
chalk.red(
|
||||
`Invalid client "${client}". Available clients: ${VALID_CLIENTS.join(
|
||||
", "
|
||||
)}`
|
||||
)
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Install for specific client
|
||||
try {
|
||||
await install(client as any, {
|
||||
apiKey: options.apiKey,
|
||||
transport: options.transport as any,
|
||||
endpoint: options.endpoint
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(
|
||||
chalk.red(
|
||||
error instanceof Error ? error.message : "Unknown error occurred"
|
||||
)
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
});
|
||||
|
||||
program.parse();
|
||||
79
cli/src/client.ts
Normal file
79
cli/src/client.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import inquirer from "inquirer";
|
||||
import chalk from "chalk";
|
||||
import { exec } from "node:child_process";
|
||||
import { promisify } from "node:util";
|
||||
|
||||
const execAsync = promisify(exec);
|
||||
|
||||
async function isClientRunning(client: string): Promise<boolean> {
|
||||
try {
|
||||
const platform = process.platform;
|
||||
const clientProcess = { claude: "Claude" }[client] || client;
|
||||
|
||||
if (platform === "win32") {
|
||||
const { stdout } = await execAsync(
|
||||
`tasklist /FI "IMAGENAME eq ${clientProcess}.exe" /NH`
|
||||
);
|
||||
return stdout.includes(`${clientProcess}.exe`);
|
||||
} else if (platform === "darwin") {
|
||||
const { stdout } = await execAsync(`pgrep -x "${clientProcess}"`);
|
||||
return !!stdout.trim();
|
||||
} else if (platform === "linux") {
|
||||
const { stdout } = await execAsync(
|
||||
`pgrep -f "${clientProcess.toLowerCase()}"`
|
||||
);
|
||||
return !!stdout.trim();
|
||||
}
|
||||
return false;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function restartClient(client: string): Promise<void> {
|
||||
const clientProcess = { claude: "Claude" }[client] || client;
|
||||
const platform = process.platform;
|
||||
|
||||
try {
|
||||
if (platform === "win32") {
|
||||
await execAsync(`taskkill /F /IM "${clientProcess}.exe"`);
|
||||
} else if (platform === "darwin") {
|
||||
await execAsync(`killall "${clientProcess}"`);
|
||||
} else if (platform === "linux") {
|
||||
await execAsync(`pkill -f "${clientProcess.toLowerCase()}"`);
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
|
||||
if (platform === "win32") {
|
||||
await execAsync(`start "" "${clientProcess}.exe"`);
|
||||
} else if (platform === "darwin") {
|
||||
await execAsync(`open -a "${clientProcess}"`);
|
||||
} else if (platform === "linux") {
|
||||
await execAsync(clientProcess.toLowerCase());
|
||||
}
|
||||
|
||||
console.log(chalk.green(`${clientProcess} has been restarted.`));
|
||||
} catch (error) {
|
||||
console.error(chalk.red(`Failed to restart ${clientProcess}:`), error);
|
||||
}
|
||||
}
|
||||
|
||||
export async function promptForRestart(client: string): Promise<void> {
|
||||
const isRunning = await isClientRunning(client);
|
||||
if (!isRunning) return;
|
||||
|
||||
const { shouldRestart } = await inquirer.prompt<{ shouldRestart: boolean }>([
|
||||
{
|
||||
type: "confirm",
|
||||
name: "shouldRestart",
|
||||
message: `Would you like to restart ${chalk.bold(client)} now?`,
|
||||
default: true,
|
||||
},
|
||||
]);
|
||||
|
||||
if (shouldRestart) {
|
||||
console.log(`Restarting ${client} app...`);
|
||||
await restartClient(client);
|
||||
}
|
||||
}
|
||||
147
cli/src/config.ts
Normal file
147
cli/src/config.ts
Normal file
@@ -0,0 +1,147 @@
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import type { ClientConfig, ValidClient, TransportType } from "./types.js";
|
||||
import { SSE_SUPPORTED_CLIENTS, HTTP_SUPPORTED_CLIENTS } from "./types.js";
|
||||
|
||||
const homeDir = os.homedir();
|
||||
|
||||
const platformPaths = {
|
||||
win32: {
|
||||
baseDir: process.env.APPDATA || path.join(homeDir, "AppData", "Roaming"),
|
||||
vscodePath: path.join("Code", "User", "globalStorage"),
|
||||
},
|
||||
darwin: {
|
||||
baseDir: path.join(homeDir, "Library", "Application Support"),
|
||||
vscodePath: path.join("Code", "User", "globalStorage"),
|
||||
},
|
||||
linux: {
|
||||
baseDir: process.env.XDG_CONFIG_HOME || path.join(homeDir, ".config"),
|
||||
vscodePath: path.join("Code/User/globalStorage"),
|
||||
},
|
||||
};
|
||||
|
||||
const platform = process.platform as keyof typeof platformPaths;
|
||||
const { baseDir, vscodePath } = platformPaths[platform];
|
||||
|
||||
export const clientPaths: Record<string, string> = {
|
||||
claude: path.join(baseDir, "Claude", "claude_desktop_config.json"),
|
||||
"claude-code": path.join(homeDir, ".claude.json"),
|
||||
cline: path.join(
|
||||
baseDir,
|
||||
vscodePath,
|
||||
"saoudrizwan.claude-dev",
|
||||
"settings",
|
||||
"cline_mcp_settings.json"
|
||||
),
|
||||
"roo-cline": path.join(
|
||||
baseDir,
|
||||
vscodePath,
|
||||
"rooveterinaryinc.roo-cline",
|
||||
"settings",
|
||||
"cline_mcp_settings.json"
|
||||
),
|
||||
windsurf: path.join(homeDir, ".codeium", "windsurf", "mcp_config.json"),
|
||||
witsy: path.join(baseDir, "Witsy", "settings.json"),
|
||||
enconvo: path.join(homeDir, ".config", "enconvo", "mcp_config.json"),
|
||||
cursor: path.join(homeDir, ".cursor", "mcp.json"),
|
||||
"github-copilot": path.join(baseDir, "Code", "User", "settings.json"),
|
||||
};
|
||||
|
||||
const getMCPEndpoint = (customEndpoint?: string) => {
|
||||
return customEndpoint || "https://agent-dashboard-mcp.onrender.com/mcp";
|
||||
};
|
||||
|
||||
const determineTransport = (client: ValidClient, transportOverride?: TransportType): TransportType => {
|
||||
if (transportOverride) {
|
||||
return transportOverride;
|
||||
}
|
||||
|
||||
if (HTTP_SUPPORTED_CLIENTS.includes(client)) {
|
||||
return "streamable-http";
|
||||
}
|
||||
|
||||
if (SSE_SUPPORTED_CLIENTS.includes(client)) {
|
||||
return "sse";
|
||||
}
|
||||
|
||||
return "stdio";
|
||||
};
|
||||
|
||||
/**
|
||||
* Create server configuration for different transport types
|
||||
*/
|
||||
function createServerConfig(
|
||||
transport: TransportType,
|
||||
apiKey: string,
|
||||
endpoint: string,
|
||||
client: ValidClient
|
||||
): any {
|
||||
const baseHeaders = {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
"X-Client-Type": client,
|
||||
};
|
||||
|
||||
switch (transport) {
|
||||
case "sse":
|
||||
return {
|
||||
url: endpoint,
|
||||
headers: baseHeaders,
|
||||
alwaysAllow: ["log_step", "ask_question"],
|
||||
disabled: false,
|
||||
timeout: 30000,
|
||||
retry: true,
|
||||
};
|
||||
|
||||
case "streamable-http":
|
||||
return {
|
||||
type: "http" as const,
|
||||
url: endpoint,
|
||||
headers: baseHeaders,
|
||||
};
|
||||
|
||||
case "stdio":
|
||||
default:
|
||||
return {
|
||||
command: "pipx",
|
||||
args: ["run", "omnara", "--api-key", apiKey],
|
||||
env: {
|
||||
OMNARA_CLIENT_TYPE: client
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if client uses VS Code-style configuration structure
|
||||
*/
|
||||
export function usesVSCodeStyle(client: ValidClient): boolean {
|
||||
return client === "github-copilot";
|
||||
}
|
||||
|
||||
export const getDefaultConfig = (
|
||||
client: ValidClient,
|
||||
apiKey: string = "YOUR_API_KEY",
|
||||
endpoint?: string,
|
||||
transportOverride?: TransportType
|
||||
): ClientConfig => {
|
||||
const transport = determineTransport(client, transportOverride);
|
||||
const mcpEndpoint = getMCPEndpoint(endpoint);
|
||||
const serverConfig = createServerConfig(transport, apiKey, mcpEndpoint, client);
|
||||
|
||||
// Return configuration in the appropriate format for the client
|
||||
if (usesVSCodeStyle(client)) {
|
||||
return {
|
||||
mcp: {
|
||||
servers: {
|
||||
"omnara": serverConfig,
|
||||
},
|
||||
},
|
||||
} as any; // Cast to any since this extends our base ClientConfig type
|
||||
}
|
||||
|
||||
return {
|
||||
mcpServers: {
|
||||
"omnara": serverConfig,
|
||||
},
|
||||
};
|
||||
};
|
||||
231
cli/src/index.ts
Normal file
231
cli/src/index.ts
Normal file
@@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
import type { ValidClient, InstallOptions, TransportType } from "./types.js";
|
||||
import { SSE_SUPPORTED_CLIENTS, HTTP_SUPPORTED_CLIENTS, VALID_CLIENTS } from "./types.js";
|
||||
import { getDefaultConfig } from "./config.js";
|
||||
import { writeConfig } from "./utils.js";
|
||||
import { promptForRestart } from "./client.js";
|
||||
import ora from "ora";
|
||||
import chalk from "chalk";
|
||||
import { exec } from "node:child_process";
|
||||
import { promisify } from "node:util";
|
||||
import inquirer from "inquirer";
|
||||
|
||||
const execAsync = promisify(exec);
|
||||
|
||||
const getTransportName = (transport: TransportType): string => {
|
||||
switch (transport) {
|
||||
case "sse":
|
||||
return "SSE";
|
||||
case "streamable-http":
|
||||
return "HTTP streamable";
|
||||
case "stdio":
|
||||
default:
|
||||
return "stdio (pipx)";
|
||||
}
|
||||
};
|
||||
|
||||
const determineTransport = (client: ValidClient, transportOverride?: TransportType): TransportType => {
|
||||
if (transportOverride) {
|
||||
return transportOverride;
|
||||
}
|
||||
|
||||
if (HTTP_SUPPORTED_CLIENTS.includes(client)) {
|
||||
return "streamable-http";
|
||||
}
|
||||
|
||||
if (SSE_SUPPORTED_CLIENTS.includes(client)) {
|
||||
return "sse";
|
||||
}
|
||||
|
||||
return "stdio";
|
||||
};
|
||||
|
||||
async function isPipxInstalled(): Promise<boolean> {
|
||||
try {
|
||||
await execAsync("pipx --version");
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function installPipx(): Promise<boolean> {
|
||||
const spinner = ora("Installing pipx...").start();
|
||||
|
||||
try {
|
||||
// Check Python availability
|
||||
try {
|
||||
await execAsync("python3 --version");
|
||||
} catch {
|
||||
spinner.fail("Python 3 is not installed. Please install Python 3 first.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Install pipx
|
||||
await execAsync("python3 -m pip install --user pipx");
|
||||
|
||||
// Ensure PATH on Unix systems
|
||||
if (process.platform !== "win32") {
|
||||
try {
|
||||
await execAsync("python3 -m pipx ensurepath");
|
||||
spinner.succeed("Pipx installed successfully");
|
||||
console.log(chalk.yellow("\n⚠️ Restart your terminal or run: source ~/.bashrc"));
|
||||
} catch {
|
||||
// PATH setup failed but pipx is installed
|
||||
spinner.succeed("Pipx installed (manual PATH setup may be needed)");
|
||||
}
|
||||
} else {
|
||||
spinner.succeed("Pipx installed successfully");
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
spinner.fail("Failed to install pipx");
|
||||
console.log(chalk.yellow("\nInstall manually: pip install pipx"));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function ensurePipxInstalled(): Promise<boolean> {
|
||||
if (await isPipxInstalled()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
console.log(chalk.yellow("\n⚠️ Pipx is required for stdio transport"));
|
||||
|
||||
const { shouldInstall } = await inquirer.prompt<{ shouldInstall: boolean }>([
|
||||
{
|
||||
type: "confirm",
|
||||
name: "shouldInstall",
|
||||
message: "Install pipx now?",
|
||||
default: true,
|
||||
},
|
||||
]);
|
||||
|
||||
if (shouldInstall) {
|
||||
return await installPipx();
|
||||
}
|
||||
|
||||
console.log(chalk.yellow("Install manually: pip install pipx"));
|
||||
return false;
|
||||
}
|
||||
|
||||
export async function install(
|
||||
client: ValidClient,
|
||||
options?: InstallOptions
|
||||
): Promise<void> {
|
||||
const capitalizedClient = client.charAt(0).toUpperCase() + client.slice(1);
|
||||
const transport = determineTransport(client, options?.transport);
|
||||
const transportName = getTransportName(transport);
|
||||
|
||||
const spinner = ora(
|
||||
`Installing ${transportName} configuration for ${capitalizedClient}...`
|
||||
).start();
|
||||
|
||||
try {
|
||||
const config = getDefaultConfig(client, options?.apiKey, options?.endpoint, options?.transport);
|
||||
|
||||
writeConfig(client, config);
|
||||
spinner.succeed(
|
||||
`Successfully installed ${transportName} configuration for ${capitalizedClient}`
|
||||
);
|
||||
|
||||
if (!options?.apiKey) {
|
||||
console.log(
|
||||
chalk.yellow(
|
||||
"No API key provided. Using default 'YOUR_API_KEY' placeholder."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Provide specific guidance based on transport type
|
||||
switch (transport) {
|
||||
case "sse":
|
||||
console.log(
|
||||
chalk.blue(`${capitalizedClient} will connect to Omnara via SSE endpoint`)
|
||||
);
|
||||
break;
|
||||
case "streamable-http":
|
||||
console.log(
|
||||
chalk.blue(`${capitalizedClient} will connect to Omnara via HTTP streamable endpoint`)
|
||||
);
|
||||
break;
|
||||
case "stdio":
|
||||
console.log(
|
||||
chalk.blue(`${capitalizedClient} will run Omnara locally using pipx`)
|
||||
);
|
||||
|
||||
const pipxInstalled = await ensurePipxInstalled();
|
||||
if (pipxInstalled) {
|
||||
console.log(chalk.green("✓ Pipx ready"));
|
||||
} else {
|
||||
console.log(chalk.red("⚠️ Stdio transport requires pipx"));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
console.log(
|
||||
chalk.green(`${capitalizedClient} configuration updated successfully`)
|
||||
);
|
||||
console.log(
|
||||
chalk.yellow(
|
||||
`You may need to restart ${capitalizedClient} to see the Omnara MCP server.`
|
||||
)
|
||||
);
|
||||
await promptForRestart(client);
|
||||
} catch (error) {
|
||||
spinner.fail(`Failed to install configuration for ${capitalizedClient}`);
|
||||
console.error(
|
||||
chalk.red(
|
||||
`Error: ${error instanceof Error ? error.message : String(error)}`
|
||||
)
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function installAll(options?: InstallOptions): Promise<void> {
|
||||
console.log(chalk.blue("🚀 Setting up Omnara MCP for all supported AI clients..."));
|
||||
console.log(chalk.gray("ℹ️ This safely adds configuration files - only clients you actually use will be affected."));
|
||||
console.log(chalk.gray("ℹ️ No AI clients will be installed or modified, just their MCP settings."));
|
||||
|
||||
const results: { client: ValidClient; success: boolean; error?: string }[] = [];
|
||||
|
||||
for (const client of VALID_CLIENTS) {
|
||||
try {
|
||||
console.log(chalk.blue(`\n📝 Configuring ${client}...`));
|
||||
await install(client, options);
|
||||
results.push({ client, success: true });
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
console.error(chalk.red(`❌ Failed to configure ${client}: ${errorMessage}`));
|
||||
results.push({ client, success: false, error: errorMessage });
|
||||
}
|
||||
}
|
||||
|
||||
// Summary
|
||||
console.log(chalk.blue("\n" + "=".repeat(50)));
|
||||
console.log(chalk.blue("📊 Configuration Summary"));
|
||||
console.log(chalk.blue("=".repeat(50)));
|
||||
|
||||
const successful = results.filter(r => r.success);
|
||||
const failed = results.filter(r => !r.success);
|
||||
|
||||
if (successful.length > 0) {
|
||||
console.log(chalk.green(`✅ Successfully configured ${successful.length} clients:`));
|
||||
successful.forEach(({ client }) => {
|
||||
console.log(chalk.green(` • ${client}`));
|
||||
});
|
||||
}
|
||||
|
||||
if (failed.length > 0) {
|
||||
console.log(chalk.red(`\n❌ Failed to configure ${failed.length} clients:`));
|
||||
failed.forEach(({ client, error }) => {
|
||||
console.log(chalk.red(` • ${client}: ${error}`));
|
||||
});
|
||||
}
|
||||
|
||||
console.log(chalk.blue(`\n🎉 Setup complete! ${successful.length}/${VALID_CLIENTS.length} clients configured.`));
|
||||
console.log(chalk.gray("💡 Only AI clients you actually have installed will use these configurations."));
|
||||
}
|
||||
78
cli/src/types.ts
Normal file
78
cli/src/types.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
export type ValidClient =
|
||||
| "claude"
|
||||
| "claude-code"
|
||||
| "cline"
|
||||
| "roo-cline"
|
||||
| "windsurf"
|
||||
| "witsy"
|
||||
| "enconvo"
|
||||
| "cursor"
|
||||
| "github-copilot";
|
||||
|
||||
export const VALID_CLIENTS: ValidClient[] = [
|
||||
"claude",
|
||||
"claude-code",
|
||||
"cline",
|
||||
"roo-cline",
|
||||
"windsurf",
|
||||
"witsy",
|
||||
"enconvo",
|
||||
"cursor",
|
||||
"github-copilot",
|
||||
];
|
||||
|
||||
export type TransportType = "stdio" | "sse" | "streamable-http";
|
||||
|
||||
// Clients that support SSE instead of stdio
|
||||
export const SSE_SUPPORTED_CLIENTS: ValidClient[] = [
|
||||
"claude-code",
|
||||
"cline",
|
||||
"roo-cline",
|
||||
"windsurf",
|
||||
"enconvo",
|
||||
];
|
||||
|
||||
// Clients that support HTTP streamable
|
||||
export const HTTP_SUPPORTED_CLIENTS: ValidClient[] = [
|
||||
"claude-code",
|
||||
"cursor",
|
||||
"witsy",
|
||||
"github-copilot",
|
||||
];
|
||||
|
||||
export interface ServerConfig {
|
||||
command: string;
|
||||
args: string[];
|
||||
env?: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface SSEServerConfig {
|
||||
url: string;
|
||||
headers: {
|
||||
Authorization: string;
|
||||
[key: string]: string;
|
||||
};
|
||||
alwaysAllow?: string[];
|
||||
disabled?: boolean;
|
||||
timeout?: number;
|
||||
retry?: boolean;
|
||||
}
|
||||
|
||||
export interface HTTPServerConfig {
|
||||
type: "http";
|
||||
url: string;
|
||||
headers: {
|
||||
Authorization: string;
|
||||
[key: string]: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ClientConfig {
|
||||
mcpServers: Record<string, ServerConfig | SSEServerConfig | HTTPServerConfig>;
|
||||
}
|
||||
|
||||
export interface InstallOptions {
|
||||
apiKey?: string;
|
||||
transport?: TransportType;
|
||||
endpoint?: string;
|
||||
}
|
||||
123
cli/src/utils.ts
Normal file
123
cli/src/utils.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import fs from "node:fs";
|
||||
import path from "node:path";
|
||||
import * as jsonc from "jsonc-parser";
|
||||
import type { ValidClient, ClientConfig } from "./types.js";
|
||||
import { clientPaths, usesVSCodeStyle } from "./config.js";
|
||||
|
||||
export function getConfigPath(client: ValidClient): string {
|
||||
const configPath = clientPaths[client];
|
||||
if (!configPath) {
|
||||
throw new Error(`Invalid client: ${client}`);
|
||||
}
|
||||
return configPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse JSON with support for comments (like VS Code settings.json)
|
||||
*/
|
||||
function parseJsonWithComments(jsonString: string): any {
|
||||
try {
|
||||
// First try standard JSON parsing
|
||||
return JSON.parse(jsonString);
|
||||
} catch {
|
||||
// Use proper JSONC parser for VS Code style files
|
||||
const errors: jsonc.ParseError[] = [];
|
||||
const result = jsonc.parse(jsonString, errors);
|
||||
|
||||
if (errors.length > 0) {
|
||||
console.warn(`JSONC parsing warnings:`, errors);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely read and parse a config file
|
||||
*/
|
||||
function readConfigFile(configPath: string): any {
|
||||
if (!fs.existsSync(configPath)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
const fileContent = fs.readFileSync(configPath, "utf8");
|
||||
return parseJsonWithComments(fileContent);
|
||||
} catch (error) {
|
||||
console.warn(`Warning: Could not read config file ${configPath}:`, error);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Merge configurations based on client type
|
||||
*/
|
||||
function mergeConfigs(existingConfig: any, newConfig: any, client: ValidClient): any {
|
||||
if (usesVSCodeStyle(client)) {
|
||||
// VS Code style: mcp.servers
|
||||
const merged = { ...existingConfig };
|
||||
|
||||
if (!merged.mcp) merged.mcp = {};
|
||||
if (!merged.mcp.servers) merged.mcp.servers = {};
|
||||
|
||||
if (newConfig.mcp?.servers) {
|
||||
merged.mcp.servers = {
|
||||
...merged.mcp.servers,
|
||||
...newConfig.mcp.servers,
|
||||
};
|
||||
}
|
||||
|
||||
return merged;
|
||||
} else {
|
||||
// Standard style: mcpServers
|
||||
return {
|
||||
...existingConfig,
|
||||
mcpServers: {
|
||||
...existingConfig.mcpServers,
|
||||
...newConfig.mcpServers,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function readConfig(client: ValidClient): ClientConfig {
|
||||
const configPath = getConfigPath(client);
|
||||
const rawConfig = readConfigFile(configPath);
|
||||
|
||||
if (usesVSCodeStyle(client)) {
|
||||
// Convert VS Code style to standard for consistency
|
||||
return {
|
||||
mcpServers: rawConfig.mcp?.servers || {},
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
...rawConfig,
|
||||
mcpServers: rawConfig.mcpServers || {},
|
||||
};
|
||||
}
|
||||
|
||||
export function writeConfig(client: ValidClient, config: ClientConfig): void {
|
||||
const configPath = getConfigPath(client);
|
||||
const configDir = path.dirname(configPath);
|
||||
|
||||
if (!fs.existsSync(configDir)) {
|
||||
fs.mkdirSync(configDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Validate config structure
|
||||
if (!config.mcpServers && !(config as any).mcp?.servers) {
|
||||
throw new Error("Invalid config structure");
|
||||
}
|
||||
|
||||
// Read existing config
|
||||
const existingConfig = readConfigFile(configPath);
|
||||
|
||||
// Merge configurations
|
||||
const mergedConfig = mergeConfigs(existingConfig, config, client);
|
||||
|
||||
// Write back with appropriate formatting
|
||||
const indent = usesVSCodeStyle(client) ? 4 : 2;
|
||||
fs.writeFileSync(configPath, JSON.stringify(mergedConfig, null, indent));
|
||||
}
|
||||
17
cli/tsconfig.json
Normal file
17
cli/tsconfig.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ESNext",
|
||||
"module": "NodeNext",
|
||||
"moduleResolution": "NodeNext",
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
140
dev-start.sh
Executable file
140
dev-start.sh
Executable file
@@ -0,0 +1,140 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Development startup script for Agent Dashboard
|
||||
# Runs PostgreSQL in Docker, everything else locally for live development
|
||||
|
||||
set -e
|
||||
|
||||
# Parse command line arguments
|
||||
RESET_DB=false
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--reset-db)
|
||||
RESET_DB=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
echo "Usage: $0 [--reset-db]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "🚀 Starting Agent Dashboard in development mode..."
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Function to cleanup on exit
|
||||
cleanup() {
|
||||
echo -e "\n${YELLOW}Cleaning up...${NC}"
|
||||
|
||||
# Kill background processes
|
||||
if [[ -n $APP_PID ]]; then
|
||||
echo "Stopping unified server..."
|
||||
kill $APP_PID 2>/dev/null || true
|
||||
fi
|
||||
|
||||
if [[ -n $BACKEND_PID ]]; then
|
||||
echo "Stopping backend..."
|
||||
kill $BACKEND_PID 2>/dev/null || true
|
||||
fi
|
||||
|
||||
|
||||
echo -e "${GREEN}Cleanup complete${NC}"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Set up signal handlers
|
||||
trap cleanup SIGINT SIGTERM
|
||||
|
||||
# Check if Docker is running
|
||||
if ! sudo docker info > /dev/null 2>&1; then
|
||||
echo -e "${RED}Error: Docker is not running. Please start Docker first.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set up data directory for PostgreSQL persistence
|
||||
DATA_DIR="$(pwd)/postgres-data"
|
||||
if [ "$RESET_DB" = true ]; then
|
||||
echo -e "${YELLOW}Resetting database...${NC}"
|
||||
if [ -d "$DATA_DIR" ]; then
|
||||
echo -e "${YELLOW}Removing existing PostgreSQL data directory...${NC}"
|
||||
sudo rm -rf "$DATA_DIR"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Create data directory if it doesn't exist
|
||||
if [ ! -d "$DATA_DIR" ]; then
|
||||
echo -e "${BLUE}Creating PostgreSQL data directory...${NC}"
|
||||
mkdir -p "$DATA_DIR"
|
||||
fi
|
||||
|
||||
# Start PostgreSQL in Docker
|
||||
echo -e "${BLUE}Starting PostgreSQL in Docker...${NC}"
|
||||
sudo docker run -d \
|
||||
--name agent-dashboard-db-dev \
|
||||
--rm \
|
||||
-p 5432:5432 \
|
||||
-e POSTGRES_USER=user \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=agent_dashboard \
|
||||
-v "$DATA_DIR:/var/lib/postgresql/data" \
|
||||
postgres:16-alpine > /dev/null
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
|
||||
for i in {1..30}; do
|
||||
if sudo docker exec agent-dashboard-db-dev pg_isready -U user -d agent_dashboard > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}PostgreSQL is ready!${NC}"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
if [ $i -eq 30 ]; then
|
||||
echo -e "${RED}Error: PostgreSQL failed to start${NC}"
|
||||
sudo docker stop agent-dashboard-db-dev > /dev/null 2>&1 || true
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
# Initialize database
|
||||
echo -e "${BLUE}Initializing database...${NC}"
|
||||
export ENVIRONMENT="development"
|
||||
export DEVELOPMENT_DB_URL="postgresql://user:password@localhost:5432/agent_dashboard"
|
||||
./scripts/init-db.sh
|
||||
|
||||
# Start unified server (MCP + FastAPI)
|
||||
echo -e "${BLUE}Starting unified server (MCP + FastAPI)...${NC}"
|
||||
export PYTHONPATH="$(pwd)"
|
||||
export API_PORT=8080
|
||||
export MCP_SERVER_PORT=8080
|
||||
python -m servers.app &
|
||||
APP_PID=$!
|
||||
|
||||
# Wait a moment for unified server to start
|
||||
sleep 2
|
||||
|
||||
# Start Backend API
|
||||
echo -e "${BLUE}Starting Backend API...${NC}"
|
||||
export PYTHONPATH="$(pwd)"
|
||||
export API_PORT=8000
|
||||
uvicorn backend.main:app --host 0.0.0.0 --port 8000 &
|
||||
BACKEND_PID=$!
|
||||
|
||||
# Wait a moment for backend to start
|
||||
sleep 2
|
||||
|
||||
echo -e "${GREEN}🎉 All services started successfully!${NC}"
|
||||
echo -e "${BLUE}Services:${NC}"
|
||||
echo -e " 🔧 Backend API: http://localhost:8000"
|
||||
echo -e " 🤖 Unified Server: http://localhost:8080 (MCP + FastAPI)"
|
||||
echo -e " 🗄️ PostgreSQL: localhost:5432"
|
||||
echo -e "\n${YELLOW}Press Ctrl+C to stop all services${NC}"
|
||||
|
||||
# Wait for all background processes
|
||||
wait
|
||||
25
dev-stop.sh
Executable file
25
dev-stop.sh
Executable file
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Stop development services for Agent Dashboard
|
||||
|
||||
echo "🛑 Stopping development services..."
|
||||
|
||||
# Stop Docker container
|
||||
if sudo docker ps -q -f name=agent-dashboard-db-dev | grep -q .; then
|
||||
echo "Stopping PostgreSQL container..."
|
||||
sudo docker stop agent-dashboard-db-dev > /dev/null 2>&1
|
||||
echo "PostgreSQL stopped"
|
||||
else
|
||||
echo "PostgreSQL container not running"
|
||||
fi
|
||||
|
||||
# Kill any remaining processes on the ports
|
||||
echo "Cleaning up any remaining processes..."
|
||||
|
||||
# Kill processes on port 8000 (backend)
|
||||
lsof -ti:8000 | xargs kill -9 2>/dev/null || true
|
||||
|
||||
# Kill processes on port 8080 (unified server - MCP + FastAPI)
|
||||
lsof -ti:8080 | xargs kill -9 2>/dev/null || true
|
||||
|
||||
echo "✅ All services stopped"
|
||||
80
docker-compose.yml
Normal file
80
docker-compose.yml
Normal file
@@ -0,0 +1,80 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: agent-dashboard-db
|
||||
environment:
|
||||
POSTGRES_USER: agent_user
|
||||
POSTGRES_PASSWORD: agent_password
|
||||
POSTGRES_DB: agent_dashboard
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U agent_user -d agent_dashboard"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
db-init:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./docker/db-init.Dockerfile
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
ENVIRONMENT: development
|
||||
DEVELOPMENT_DB_URL: postgresql://agent_user:agent_password@postgres:5432/agent_dashboard
|
||||
PRODUCTION_DB_URL: ${PRODUCTION_DB_URL:-}
|
||||
command: python -m shared.database.init_db
|
||||
|
||||
mcp-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./docker/servers.Dockerfile
|
||||
container_name: agent-dashboard-mcp
|
||||
ports:
|
||||
- "8080:8080"
|
||||
depends_on:
|
||||
db-init:
|
||||
condition: service_completed_successfully
|
||||
environment:
|
||||
ENVIRONMENT: development
|
||||
DEVELOPMENT_DB_URL: postgresql://agent_user:agent_password@postgres:5432/agent_dashboard
|
||||
PRODUCTION_DB_URL: ${PRODUCTION_DB_URL:-}
|
||||
MCP_SERVER_PORT: 8080
|
||||
# JWT public key for API authentication (required)
|
||||
JWT_PUBLIC_KEY: ${JWT_PUBLIC_KEY:-}
|
||||
restart: unless-stopped
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./docker/backend.Dockerfile
|
||||
container_name: agent-dashboard-backend
|
||||
ports:
|
||||
- "8000:8000"
|
||||
depends_on:
|
||||
db-init:
|
||||
condition: service_completed_successfully
|
||||
environment:
|
||||
ENVIRONMENT: development
|
||||
DEVELOPMENT_DB_URL: postgresql://agent_user:agent_password@postgres:5432/agent_dashboard
|
||||
PRODUCTION_DB_URL: ${PRODUCTION_DB_URL:-}
|
||||
API_PORT: 8000
|
||||
FRONTEND_URLS: '["http://localhost:3000"]'
|
||||
# Automatically loads from .env file
|
||||
# Supabase credentials (required for auth)
|
||||
SUPABASE_URL: ${SUPABASE_URL:-}
|
||||
SUPABASE_ANON_KEY: ${SUPABASE_ANON_KEY:-}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${SUPABASE_SERVICE_ROLE_KEY:-}
|
||||
# JWT keys for API authentication (required)
|
||||
JWT_PRIVATE_KEY: ${JWT_PRIVATE_KEY:-}
|
||||
JWT_PUBLIC_KEY: ${JWT_PUBLIC_KEY:-}
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
20
docker/backend.Dockerfile
Normal file
20
docker/backend.Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements
|
||||
COPY backend/requirements.txt /app/backend/requirements.txt
|
||||
COPY shared/requirements.txt /app/shared/requirements.txt
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir -r backend/requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY shared /app/shared
|
||||
COPY backend /app/backend
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Run the backend from root directory to access shared
|
||||
CMD ["python", "-m", "backend.main"]
|
||||
13
docker/db-init.Dockerfile
Normal file
13
docker/db-init.Dockerfile
Normal file
@@ -0,0 +1,13 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy shared requirements
|
||||
COPY shared/requirements.txt /app/shared/requirements.txt
|
||||
RUN pip install --no-cache-dir -r shared/requirements.txt
|
||||
|
||||
# Copy shared code
|
||||
COPY shared /app/shared
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app
|
||||
20
docker/servers.Dockerfile
Normal file
20
docker/servers.Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements
|
||||
COPY servers/requirements.txt /app/servers/requirements.txt
|
||||
COPY shared/requirements.txt /app/shared/requirements.txt
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir -r servers/requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY shared /app/shared
|
||||
COPY servers /app/servers
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Run the MCP server from root directory to access shared
|
||||
CMD ["python", "-m", "servers.mcp_server.server"]
|
||||
26
omnara/__init__.py
Normal file
26
omnara/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Omnara - Agent Dashboard and Python SDK
|
||||
|
||||
This package provides:
|
||||
1. MCP Server for agent communication (omnara CLI command)
|
||||
2. Python SDK for interacting with the Omnara API
|
||||
"""
|
||||
|
||||
# Import SDK components for easy access
|
||||
from .sdk.client import OmnaraClient
|
||||
from .sdk.async_client import AsyncOmnaraClient
|
||||
from .sdk.exceptions import (
|
||||
OmnaraError,
|
||||
AuthenticationError,
|
||||
TimeoutError,
|
||||
APIError,
|
||||
)
|
||||
|
||||
__version__ = "1.1.0"
|
||||
__all__ = [
|
||||
"OmnaraClient",
|
||||
"AsyncOmnaraClient",
|
||||
"OmnaraError",
|
||||
"AuthenticationError",
|
||||
"TimeoutError",
|
||||
"APIError",
|
||||
]
|
||||
15
omnara/sdk/__init__.py
Normal file
15
omnara/sdk/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Omnara Python SDK for interacting with the Agent Dashboard API."""
|
||||
|
||||
from .client import OmnaraClient
|
||||
from .async_client import AsyncOmnaraClient
|
||||
from .exceptions import OmnaraError, AuthenticationError, TimeoutError, APIError
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = [
|
||||
"OmnaraClient",
|
||||
"AsyncOmnaraClient",
|
||||
"OmnaraError",
|
||||
"AuthenticationError",
|
||||
"TimeoutError",
|
||||
"APIError",
|
||||
]
|
||||
234
omnara/sdk/async_client.py
Normal file
234
omnara/sdk/async_client.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Async client for interacting with the Omnara Agent Dashboard API."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
|
||||
from .exceptions import AuthenticationError, TimeoutError, APIError
|
||||
from .models import (
|
||||
LogStepResponse,
|
||||
QuestionResponse,
|
||||
QuestionStatus,
|
||||
EndSessionResponse,
|
||||
)
|
||||
|
||||
|
||||
class AsyncOmnaraClient:
|
||||
"""Async client for interacting with the Omnara Agent Dashboard API.
|
||||
|
||||
Args:
|
||||
api_key: JWT API key for authentication
|
||||
base_url: Base URL of the API server (default: https://agent-dashboard-mcp.onrender.com)
|
||||
timeout: Default timeout for requests in seconds (default: 30)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://agent-dashboard-mcp.onrender.com",
|
||||
timeout: int = 30,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = ClientTimeout(total=timeout)
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Default headers
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self._ensure_session()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
|
||||
async def _ensure_session(self):
|
||||
"""Ensure aiohttp session exists."""
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=self.headers, timeout=self.timeout
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the aiohttp session."""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an async HTTP request to the API.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint path
|
||||
json: JSON body for the request
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Response JSON data
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
APIError: If the API returns an error
|
||||
TimeoutError: If the request times out
|
||||
"""
|
||||
await self._ensure_session()
|
||||
assert self.session is not None
|
||||
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
|
||||
# Override timeout if specified
|
||||
request_timeout = ClientTimeout(total=timeout) if timeout else self.timeout
|
||||
|
||||
try:
|
||||
async with self.session.request(
|
||||
method=method, url=url, json=json, timeout=request_timeout
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise AuthenticationError(
|
||||
"Invalid API key or authentication failed"
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = await response.json()
|
||||
error_detail = error_data.get("detail", await response.text())
|
||||
except Exception:
|
||||
error_detail = await response.text()
|
||||
raise APIError(response.status, error_detail)
|
||||
|
||||
return await response.json()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Request timed out after {timeout or self.timeout.total} seconds"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
|
||||
async def log_step(
|
||||
self,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
agent_instance_id: Optional[str] = None,
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
|
||||
step_description: Clear description of what the agent is doing
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance ID, and user feedback
|
||||
"""
|
||||
data = {"agent_type": agent_type, "step_description": step_description}
|
||||
if agent_instance_id:
|
||||
data["agent_instance_id"] = agent_instance_id
|
||||
|
||||
response = await self._make_request("POST", "/api/v1/steps", json=data)
|
||||
|
||||
return LogStepResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
step_number=response["step_number"],
|
||||
user_feedback=response.get("user_feedback", []),
|
||||
)
|
||||
|
||||
async def ask_question(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 1.0,
|
||||
) -> QuestionResponse:
|
||||
"""Ask the user a question and wait for their response.
|
||||
|
||||
This method submits the question and then polls for the answer.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask the user
|
||||
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
|
||||
poll_interval: Time between polls in seconds (default: 1.0)
|
||||
|
||||
Returns:
|
||||
QuestionResponse with the user's answer
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no answer is received within timeout
|
||||
"""
|
||||
# Submit the question
|
||||
data = {"agent_instance_id": agent_instance_id, "question_text": question_text}
|
||||
|
||||
# First, try the non-blocking endpoint to create the question
|
||||
response = await self._make_request(
|
||||
"POST", "/api/v1/questions", json=data, timeout=5
|
||||
)
|
||||
question_id = response["question_id"]
|
||||
|
||||
# Convert timeout from minutes to seconds
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
|
||||
# Poll for the answer
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
|
||||
status = await self.get_question_status(question_id)
|
||||
|
||||
if status.status == "answered" and status.answer:
|
||||
return QuestionResponse(answer=status.answer, question_id=question_id)
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
|
||||
|
||||
async def get_question_status(self, question_id: str) -> QuestionStatus:
|
||||
"""Get the current status of a question.
|
||||
|
||||
Args:
|
||||
question_id: ID of the question to check
|
||||
|
||||
Returns:
|
||||
QuestionStatus with current status and answer (if available)
|
||||
"""
|
||||
response = await self._make_request("GET", f"/api/v1/questions/{question_id}")
|
||||
|
||||
return QuestionStatus(
|
||||
question_id=response["question_id"],
|
||||
status=response["status"],
|
||||
answer=response.get("answer"),
|
||||
asked_at=response["asked_at"],
|
||||
answered_at=response.get("answered_at"),
|
||||
)
|
||||
|
||||
async def end_session(self, agent_instance_id: str) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID to end
|
||||
|
||||
Returns:
|
||||
EndSessionResponse with success status and final details
|
||||
"""
|
||||
data = {"agent_instance_id": agent_instance_id}
|
||||
response = await self._make_request("POST", "/api/v1/sessions/end", json=data)
|
||||
|
||||
return EndSessionResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
final_status=response["final_status"],
|
||||
)
|
||||
222
omnara/sdk/client.py
Normal file
222
omnara/sdk/client.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Main client for interacting with the Omnara Agent Dashboard API."""
|
||||
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from .exceptions import AuthenticationError, TimeoutError, APIError
|
||||
from .models import (
|
||||
LogStepResponse,
|
||||
QuestionResponse,
|
||||
QuestionStatus,
|
||||
EndSessionResponse,
|
||||
)
|
||||
|
||||
|
||||
class OmnaraClient:
|
||||
"""Client for interacting with the Omnara Agent Dashboard API.
|
||||
|
||||
Args:
|
||||
api_key: JWT API key for authentication
|
||||
base_url: Base URL of the API server (default: https://agent-dashboard-mcp.onrender.com)
|
||||
timeout: Default timeout for requests in seconds (default: 30)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://agent-dashboard-mcp.onrender.com",
|
||||
timeout: int = 30,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retries
|
||||
self.session = requests.Session()
|
||||
retry_strategy = Retry(
|
||||
total=3, backoff_factor=0.3, status_forcelist=[500, 502, 503, 504]
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session.mount("http://", adapter)
|
||||
self.session.mount("https://", adapter)
|
||||
|
||||
# Set default headers
|
||||
self.session.headers.update(
|
||||
{"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to the API.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint path
|
||||
json: JSON body for the request
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Response JSON data
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
APIError: If the API returns an error
|
||||
TimeoutError: If the request times out
|
||||
"""
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
timeout = timeout or self.timeout
|
||||
|
||||
try:
|
||||
response = self.session.request(
|
||||
method=method, url=url, json=json, timeout=timeout
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise AuthenticationError("Invalid API key or authentication failed")
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_detail = response.json().get("detail", response.text)
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise APIError(response.status_code, error_detail)
|
||||
|
||||
return response.json()
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
raise TimeoutError(f"Request timed out after {timeout} seconds")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
|
||||
def log_step(
|
||||
self,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
agent_instance_id: Optional[str] = None,
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
|
||||
step_description: Clear description of what the agent is doing
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance ID, and user feedback
|
||||
"""
|
||||
data = {"agent_type": agent_type, "step_description": step_description}
|
||||
if agent_instance_id:
|
||||
data["agent_instance_id"] = agent_instance_id
|
||||
|
||||
response = self._make_request("POST", "/api/v1/steps", json=data)
|
||||
|
||||
return LogStepResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
step_number=response["step_number"],
|
||||
user_feedback=response.get("user_feedback", []),
|
||||
)
|
||||
|
||||
def ask_question(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 1.0,
|
||||
) -> QuestionResponse:
|
||||
"""Ask the user a question and wait for their response.
|
||||
|
||||
This method submits the question and then polls for the answer.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask the user
|
||||
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
|
||||
poll_interval: Time between polls in seconds (default: 1.0)
|
||||
|
||||
Returns:
|
||||
QuestionResponse with the user's answer
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no answer is received within timeout
|
||||
"""
|
||||
# Submit the question
|
||||
data = {"agent_instance_id": agent_instance_id, "question_text": question_text}
|
||||
|
||||
# First, try the non-blocking endpoint to create the question
|
||||
response = self._make_request("POST", "/api/v1/questions", json=data, timeout=5)
|
||||
question_id = response["question_id"]
|
||||
|
||||
# Convert timeout from minutes to seconds
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
|
||||
# Poll for the answer
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
status = self.get_question_status(question_id)
|
||||
|
||||
if status.status == "answered" and status.answer:
|
||||
return QuestionResponse(answer=status.answer, question_id=question_id)
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
|
||||
|
||||
def get_question_status(self, question_id: str) -> QuestionStatus:
|
||||
"""Get the current status of a question.
|
||||
|
||||
Args:
|
||||
question_id: ID of the question to check
|
||||
|
||||
Returns:
|
||||
QuestionStatus with current status and answer (if available)
|
||||
"""
|
||||
response = self._make_request("GET", f"/api/v1/questions/{question_id}")
|
||||
|
||||
return QuestionStatus(
|
||||
question_id=response["question_id"],
|
||||
status=response["status"],
|
||||
answer=response.get("answer"),
|
||||
asked_at=response["asked_at"],
|
||||
answered_at=response.get("answered_at"),
|
||||
)
|
||||
|
||||
def end_session(self, agent_instance_id: str) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID to end
|
||||
|
||||
Returns:
|
||||
EndSessionResponse with success status and final details
|
||||
"""
|
||||
data = {"agent_instance_id": agent_instance_id}
|
||||
response = self._make_request("POST", "/api/v1/sessions/end", json=data)
|
||||
|
||||
return EndSessionResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
final_status=response["final_status"],
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the session and clean up resources."""
|
||||
self.session.close()
|
||||
28
omnara/sdk/exceptions.py
Normal file
28
omnara/sdk/exceptions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Exception classes for the Omnara SDK."""
|
||||
|
||||
|
||||
class OmnaraError(Exception):
|
||||
"""Base exception for all Omnara SDK errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(OmnaraError):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(OmnaraError):
|
||||
"""Raised when an operation times out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIError(OmnaraError):
|
||||
"""Raised when the API returns an error response."""
|
||||
|
||||
def __init__(self, status_code: int, detail: str):
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(f"API Error {status_code}: {detail}")
|
||||
42
omnara/sdk/models.py
Normal file
42
omnara/sdk/models.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Data models for the Omnara SDK."""
|
||||
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogStepResponse:
|
||||
"""Response from logging a step."""
|
||||
|
||||
success: bool
|
||||
agent_instance_id: str
|
||||
step_number: int
|
||||
user_feedback: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionResponse:
|
||||
"""Response from asking a question."""
|
||||
|
||||
answer: str
|
||||
question_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionStatus:
|
||||
"""Status of a question."""
|
||||
|
||||
question_id: str
|
||||
status: str # 'pending' or 'answered'
|
||||
answer: Optional[str]
|
||||
asked_at: str
|
||||
answered_at: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndSessionResponse:
|
||||
"""Response from ending a session."""
|
||||
|
||||
success: bool
|
||||
agent_instance_id: str
|
||||
final_status: str
|
||||
52
pyproject.toml
Normal file
52
pyproject.toml
Normal file
@@ -0,0 +1,52 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "omnara"
|
||||
version = "1.1.0"
|
||||
description = "Omnara Agent Dashboard - MCP Server and Python SDK"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = "MIT"
|
||||
authors = [
|
||||
{name = "Omnara", email = "ishaan@omnara.com"}
|
||||
]
|
||||
keywords = ["mcp", "ai", "agents", "dashboard"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"fastmcp==2.9.2",
|
||||
"sqlalchemy==2.0.23",
|
||||
"psycopg2-binary==2.9.9",
|
||||
"pydantic>=2.5.2",
|
||||
"pydantic-settings>=2.6.1",
|
||||
"python-dotenv>=1.1.0",
|
||||
# SDK dependencies
|
||||
"requests>=2.25.0",
|
||||
"urllib3>=1.26.0",
|
||||
"aiohttp>=3.8.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/omnara-ai/omnara"
|
||||
Repository = "https://github.com/omnara-ai/omnara"
|
||||
Issues = "https://github.com/omnara-ai/omnara/issues"
|
||||
|
||||
[project.scripts]
|
||||
omnara = "servers.mcp_server.stdio_server:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["omnara*", "servers*", "shared*", "backend*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"integration: marks tests as integration tests (deselect with '-m \"not integration\"')",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
26
pyrightconfig.json
Normal file
26
pyrightconfig.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"include": [
|
||||
"backend",
|
||||
"servers",
|
||||
"shared",
|
||||
"omnara"
|
||||
],
|
||||
"exclude": [
|
||||
"**/__pycache__",
|
||||
"**/node_modules",
|
||||
"**/.venv",
|
||||
"test-venv/**",
|
||||
"**/dist",
|
||||
"**/build",
|
||||
"**/*.egg-info"
|
||||
],
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false,
|
||||
"pythonVersion": "3.12",
|
||||
"pythonPlatform": "Linux",
|
||||
"venvPath": ".",
|
||||
"venv": ".venv",
|
||||
"typeCheckingMode": "basic",
|
||||
"reportPrivateUsage": false,
|
||||
"reportUnnecessaryTypeIgnoreComment": false
|
||||
}
|
||||
11
requirements-dev.txt
Normal file
11
requirements-dev.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
# Development dependencies
|
||||
ruff==0.11.13
|
||||
pyright==1.1.402
|
||||
pre-commit==4.2.0
|
||||
|
||||
# Testing dependencies
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.25.0
|
||||
pytest-mock==3.14.0
|
||||
pytest-cov==6.0.0
|
||||
testcontainers[postgres]==4.8.2
|
||||
91
scripts/check-migration-needed.py
Executable file
91
scripts/check-migration-needed.py
Executable file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-commit hook to check if database schema changes require a new migration.
|
||||
|
||||
This script detects changes to SQLAlchemy models and ensures that a corresponding
|
||||
Alembic migration has been created in the same commit.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_staged_files():
|
||||
"""Get list of staged files in the current commit."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--cached", "--name-only"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip().split("\n") if result.stdout.strip() else []
|
||||
except subprocess.CalledProcessError:
|
||||
return []
|
||||
|
||||
|
||||
def has_schema_changes(staged_files):
|
||||
"""Check if any staged files contain database schema changes."""
|
||||
schema_files = ["shared/database/models.py", "shared/database/enums.py"]
|
||||
|
||||
return any(file_path in staged_files for file_path in schema_files)
|
||||
|
||||
|
||||
def has_new_migration(staged_files):
|
||||
"""Check if any new migration files are being added."""
|
||||
migration_files = [
|
||||
f
|
||||
for f in staged_files
|
||||
if f.startswith("shared/alembic/versions/") and f.endswith(".py")
|
||||
]
|
||||
return len(migration_files) > 0
|
||||
|
||||
|
||||
def get_migration_files():
|
||||
"""Get list of existing migration files."""
|
||||
migrations_dir = Path("shared/alembic/versions")
|
||||
if not migrations_dir.exists():
|
||||
return []
|
||||
|
||||
return [f.name for f in migrations_dir.glob("*.py") if f.is_file()]
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to check if migration is needed."""
|
||||
staged_files = get_staged_files()
|
||||
|
||||
if not staged_files:
|
||||
# No staged files, nothing to check
|
||||
sys.exit(0)
|
||||
|
||||
schema_changed = has_schema_changes(staged_files)
|
||||
new_migration = has_new_migration(staged_files)
|
||||
|
||||
if schema_changed and not new_migration:
|
||||
print("❌ ERROR: Database schema changes detected without a migration!")
|
||||
print()
|
||||
print("You have modified database schema files:")
|
||||
for file_path in staged_files:
|
||||
if file_path in ["shared/database/models.py", "shared/database/enums.py"]:
|
||||
print(f" - {file_path}")
|
||||
print()
|
||||
print("You must create an Alembic migration before committing:")
|
||||
print(" 1. cd shared/")
|
||||
print(" 2. alembic revision --autogenerate -m 'Describe your changes'")
|
||||
print(" 3. Review the generated migration file")
|
||||
print(" 4. git add the new migration file")
|
||||
print(" 5. git commit again")
|
||||
print()
|
||||
print("This ensures database changes are properly versioned and deployable.")
|
||||
sys.exit(1)
|
||||
|
||||
if schema_changed and new_migration:
|
||||
print("✅ Schema changes detected with corresponding migration - good!")
|
||||
|
||||
# Success case
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
scripts/format.sh
Executable file
12
scripts/format.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
# Script to auto-format code
|
||||
|
||||
set -e
|
||||
|
||||
echo "Running ruff format..."
|
||||
ruff format .
|
||||
|
||||
echo -e "\nRunning ruff check with auto-fix..."
|
||||
ruff check --fix .
|
||||
|
||||
echo -e "\nFormatting complete!"
|
||||
58
scripts/generate_jwt_keys.py
Normal file
58
scripts/generate_jwt_keys.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate RSA key pair for JWT signing.
|
||||
Run this once to generate keys, then add them to your .env file.
|
||||
"""
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
|
||||
def generate_rsa_key_pair():
|
||||
"""Generate RSA key pair for JWT signing"""
|
||||
|
||||
# Generate private key
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
# Ultra-small keys = ultra-short signatures (~75% reduction from 2048-bit,
|
||||
# but insecure if public key is compromised)
|
||||
key_size=512,
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize private key
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
# Serialize public key
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
return private_pem.decode("utf-8"), public_pem.decode("utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating RSA key pair for JWT signing...\n")
|
||||
|
||||
private_key, public_key = generate_rsa_key_pair()
|
||||
|
||||
print("=" * 80)
|
||||
print("Add these to your .env file:")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n# JWT Signing Keys")
|
||||
print("JWT_PRIVATE_KEY=" + repr(private_key))
|
||||
print("JWT_PUBLIC_KEY=" + repr(public_key))
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("IMPORTANT: Keep the public and private keys secure!")
|
||||
print("- Add it to .env (not .env.example)")
|
||||
print("- Never commit keys to git")
|
||||
print("=" * 80)
|
||||
25
scripts/init-db.sh
Executable file
25
scripts/init-db.sh
Executable file
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Database initialization script using Alembic migrations
|
||||
# This script will upgrade the database to the latest migration version
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo -e "${YELLOW}🗄️ Initializing database with Alembic migrations...${NC}"
|
||||
|
||||
# Change to the shared directory where alembic.ini is located
|
||||
cd "$(dirname "$0")/../shared"
|
||||
|
||||
# Run alembic upgrade head to apply all pending migrations
|
||||
if alembic upgrade head; then
|
||||
echo -e "${GREEN}✅ Database successfully migrated to latest version${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ Error running Alembic migrations${NC}"
|
||||
exit 1
|
||||
fi
|
||||
15
scripts/lint.sh
Executable file
15
scripts/lint.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
# Script to run linting and type checking
|
||||
|
||||
set -e
|
||||
|
||||
echo "Running ruff check..."
|
||||
ruff check .
|
||||
|
||||
echo -e "\nRunning ruff format check..."
|
||||
ruff format --check .
|
||||
|
||||
echo -e "\nRunning pyright..."
|
||||
pyright
|
||||
|
||||
echo -e "\nAll checks passed!"
|
||||
45
scripts/run_all_tests.sh
Executable file
45
scripts/run_all_tests.sh
Executable file
@@ -0,0 +1,45 @@
|
||||
#!/bin/bash
|
||||
# Run all Python tests across the monorepo
|
||||
|
||||
set -e
|
||||
|
||||
# Set test environment to disable Sentry
|
||||
export ENVIRONMENT=test
|
||||
export SENTRY_DSN=""
|
||||
|
||||
echo "🧪 Running All Python Tests"
|
||||
echo "==========================="
|
||||
|
||||
# Store the root directory (parent of scripts dir)
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
|
||||
# Function to run tests in a directory if they exist
|
||||
run_component_tests() {
|
||||
local component=$1
|
||||
local test_dir="$ROOT_DIR/$component"
|
||||
|
||||
if [ -d "$test_dir" ]; then
|
||||
echo -e "\n📦 Testing $component..."
|
||||
cd "$ROOT_DIR" # Stay in root directory
|
||||
|
||||
# Run pytest if tests directory exists
|
||||
if [ -d "$test_dir/tests" ]; then
|
||||
PYTHONPATH="$ROOT_DIR:$PYTHONPATH" pytest "$component/tests" -v
|
||||
else
|
||||
echo " No tests found in $component"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Run tests for each component
|
||||
run_component_tests "backend"
|
||||
run_component_tests "servers"
|
||||
|
||||
# Run root-level integration tests if they exist
|
||||
if [ -d "$ROOT_DIR/tests" ]; then
|
||||
echo -e "\n🌐 Running integration tests..."
|
||||
cd "$ROOT_DIR"
|
||||
pytest tests -v
|
||||
fi
|
||||
|
||||
echo -e "\n✅ All tests completed!"
|
||||
70
servers/README.md
Normal file
70
servers/README.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# Servers Directory
|
||||
|
||||
This directory contains the write operations server for the Agent Dashboard system. It provides a unified interface for AI agents to interact with the dashboard through multiple protocols.
|
||||
|
||||
## Overview
|
||||
|
||||
The servers directory implements all write operations that agents need:
|
||||
- Logging their progress and receiving user feedback
|
||||
- Asking questions to users
|
||||
- Managing session lifecycle
|
||||
|
||||
All operations are authenticated and multi-tenant, ensuring data isolation between users.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Unified Server (`app.py`)
|
||||
A single server application that supports both MCP (Model Context Protocol) and REST API interfaces:
|
||||
- **MCP Interface**: For agents using the MCP protocol
|
||||
- **REST API**: For SDK clients and direct API integrations
|
||||
- Both interfaces share the same authentication and business logic
|
||||
|
||||
### Components
|
||||
|
||||
- **`mcp_server/`**: MCP protocol implementation using fastmcp
|
||||
- **`fastapi_server/`**: REST API implementation
|
||||
- **`shared/`**: Common database operations and business logic
|
||||
- **`tests/`**: Integration and unit tests
|
||||
|
||||
## Authentication
|
||||
|
||||
The servers use a separate authentication system from the main backend:
|
||||
- **JWT Bearer tokens** with RSA-256 signing
|
||||
- **Shorter API keys** using a weaker RSA key (appropriate for write-only operations)
|
||||
- **User context** embedded in tokens for multi-tenancy
|
||||
- **Security Note**: Both the private AND public JWT keys should be kept secure. The weaker RSA implementation (for shorter tokens) means even the public key should not be exposed
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Write-only operations**: Designed for agent interactions, not data retrieval
|
||||
- **Automatic session management**: Creates sessions on first interaction
|
||||
- **User feedback delivery**: Agents receive feedback when logging steps
|
||||
- **Non-blocking questions**: Async implementation for user interactions
|
||||
- **Multi-protocol support**: Same functionality via MCP or REST API
|
||||
|
||||
## Running the Server
|
||||
|
||||
```bash
|
||||
# From the project root with virtual environment activated
|
||||
python -m servers.app
|
||||
```
|
||||
|
||||
The server will be available on the configured port (default: 8080) with:
|
||||
- MCP endpoint at `/mcp/`
|
||||
- REST API at `/api/v1/`
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- `MCP_SERVER_PORT`: Server port (default: 8080)
|
||||
- `JWT_PUBLIC_KEY`: RSA public key for token verification
|
||||
- `JWT_PRIVATE_KEY`: RSA private key for token signing (if needed)
|
||||
|
||||
## Integration
|
||||
|
||||
Clients can connect using:
|
||||
1. **MCP Protocol**: Via SSE or HTTP streaming transport
|
||||
2. **REST API**: Direct HTTP requests with Bearer token authentication
|
||||
3. **SDK**: Language-specific clients that handle authentication and protocol details
|
||||
|
||||
See `DEPLOYMENT.md` for detailed deployment and client configuration instructions.
|
||||
119
servers/app.py
Normal file
119
servers/app.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Unified server combining MCP and FastAPI functionality.
|
||||
|
||||
This server provides:
|
||||
- MCP tools at /mcp/ endpoint (log_step, ask_question, end_session)
|
||||
- REST API endpoints at /api/v1/*
|
||||
- Shared JWT authentication for both interfaces
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import sentry_sdk
|
||||
from shared.config import settings
|
||||
|
||||
# Import the pre-configured MCP server
|
||||
from servers.mcp_server.server import mcp
|
||||
|
||||
# Import FastAPI routers
|
||||
from servers.fastapi_server.routers import agent_router
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize Sentry only if DSN is provided
|
||||
if settings.sentry_dsn:
|
||||
sentry_sdk.init(
|
||||
dsn=settings.sentry_dsn,
|
||||
send_default_pii=True,
|
||||
environment=settings.environment,
|
||||
)
|
||||
logger.info(f"Sentry initialized for {settings.environment} environment")
|
||||
else:
|
||||
logger.info("Sentry DSN not provided, error tracking disabled")
|
||||
|
||||
# Get the MCP app with streamable-http transport
|
||||
mcp_app = mcp.http_app(path="/")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Combined lifespan for both MCP and FastAPI functionality."""
|
||||
# Use the MCP app's lifespan to ensure proper initialization
|
||||
async with mcp_app.lifespan(app):
|
||||
logger.info("Unified server starting up")
|
||||
logger.info("MCP endpoints available at: /mcp/")
|
||||
logger.info("REST API endpoints available at: /api/v1/*")
|
||||
yield
|
||||
logger.info("Shutting down unified server")
|
||||
|
||||
|
||||
# Create FastAPI app with MCP's lifespan
|
||||
app = FastAPI(
|
||||
title="Agent Dashboard Unified Server",
|
||||
description="Combined MCP and REST API for agent monitoring and interaction",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Configure appropriately for production
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(agent_router, prefix="/api/v1")
|
||||
app.mount("/mcp", mcp_app)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"message": "Agent Dashboard Unified Server",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"mcp": "/mcp/ (MCP tools via Streamable HTTP)",
|
||||
"api": "/api/v1/* (REST API endpoints)",
|
||||
"docs": "/docs (API documentation)",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "server": "unified"}
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the unified server."""
|
||||
import uvicorn
|
||||
|
||||
# Log configuration for debugging
|
||||
logger.info(f"Starting unified server on port: {settings.mcp_server_port}")
|
||||
logger.info("Database URL configured.")
|
||||
logger.info(
|
||||
f"JWT public key configured: {'Yes' if settings.jwt_public_key else 'No'}"
|
||||
)
|
||||
|
||||
try:
|
||||
uvicorn.run(
|
||||
"servers.app:app",
|
||||
host="0.0.0.0",
|
||||
port=settings.mcp_server_port,
|
||||
reload=settings.environment == "development",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start unified server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
servers/fastapi_server/__init__.py
Normal file
1
servers/fastapi_server/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""FastAPI server for Agent Dashboard."""
|
||||
72
servers/fastapi_server/auth.py
Normal file
72
servers/fastapi_server/auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Authentication dependencies for FastAPI server.
|
||||
|
||||
Uses the same JWT authentication as the MCP server.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from shared.config import settings
|
||||
|
||||
# Bearer token security scheme
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def verify_token(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
|
||||
) -> dict:
|
||||
"""Verify JWT token and return decoded payload.
|
||||
|
||||
Args:
|
||||
credentials: Bearer token from Authorization header
|
||||
|
||||
Returns:
|
||||
Decoded JWT payload including user_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is invalid or missing
|
||||
"""
|
||||
if not settings.jwt_public_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="JWT public key not configured",
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
|
||||
try:
|
||||
# Decode and verify the JWT token
|
||||
payload = jwt.decode(token, settings.jwt_public_key, algorithms=["RS256"])
|
||||
return payload
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid token: {str(e)}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user_id(
|
||||
token_payload: Annotated[dict, Depends(verify_token)],
|
||||
) -> str:
|
||||
"""Extract user ID from verified token payload.
|
||||
|
||||
Args:
|
||||
token_payload: Decoded JWT payload
|
||||
|
||||
Returns:
|
||||
User ID string
|
||||
|
||||
Raises:
|
||||
HTTPException: If user ID is missing from token
|
||||
"""
|
||||
user_id = token_payload.get("sub")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token missing user ID",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user_id
|
||||
73
servers/fastapi_server/main.py
Normal file
73
servers/fastapi_server/main.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""FastAPI server for Agent Dashboard API endpoints.
|
||||
|
||||
This server provides RESTful endpoints that mirror the MCP server tools,
|
||||
using the same JWT authentication mechanism as the MCP server.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from shared.config import settings
|
||||
|
||||
from .routers import agent_router
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup and shutdown events."""
|
||||
# Startup
|
||||
logger.info("FastAPI server starting up")
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("Shutting down FastAPI server")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Agent Dashboard API",
|
||||
description="RESTful API for agent monitoring and interaction",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Configure appropriately for production
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(agent_router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"message": "Agent Dashboard FastAPI Server",
|
||||
"version": "1.0.0",
|
||||
"docs": "/docs",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Use a different port than the backend API
|
||||
port = int(settings.api_port) + 1000 # e.g., 9000 if api_port is 8000
|
||||
uvicorn.run("servers.fastapi.main:app", host="0.0.0.0", port=port, reload=True)
|
||||
67
servers/fastapi_server/models.py
Normal file
67
servers/fastapi_server/models.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Pydantic models for FastAPI request/response schemas."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from servers.shared.models import (
|
||||
BaseLogStepRequest,
|
||||
BaseLogStepResponse,
|
||||
BaseAskQuestionRequest,
|
||||
BaseEndSessionRequest,
|
||||
BaseEndSessionResponse,
|
||||
)
|
||||
|
||||
|
||||
# Request models
|
||||
class LogStepRequest(BaseLogStepRequest):
|
||||
"""FastAPI-specific request model for logging a step."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AskQuestionRequest(BaseAskQuestionRequest):
|
||||
"""FastAPI-specific request model for asking a question."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EndSessionRequest(BaseEndSessionRequest):
|
||||
"""FastAPI-specific request model for ending a session."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Response models
|
||||
class LogStepResponse(BaseLogStepResponse):
|
||||
"""FastAPI-specific response model for log step endpoint."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# FastAPI-specific: Response only contains question ID (non-blocking)
|
||||
class AskQuestionResponse(BaseModel):
|
||||
"""FastAPI-specific response model for ask question endpoint."""
|
||||
|
||||
question_id: str = Field(..., description="ID of the created question")
|
||||
|
||||
|
||||
class EndSessionResponse(BaseEndSessionResponse):
|
||||
"""FastAPI-specific response model for end session endpoint."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# FastAPI-specific: Additional model for polling question status
|
||||
class QuestionStatusResponse(BaseModel):
|
||||
"""Response model for question status endpoint."""
|
||||
|
||||
question_id: str
|
||||
status: str = Field(
|
||||
..., description="Status of the question: 'pending' or 'answered'"
|
||||
)
|
||||
answer: Optional[str] = Field(
|
||||
None, description="Answer text if status is 'answered'"
|
||||
)
|
||||
asked_at: str
|
||||
answered_at: Optional[str] = None
|
||||
187
servers/fastapi_server/routers.py
Normal file
187
servers/fastapi_server/routers.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""API routes for agent operations."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from shared.database.session import get_db
|
||||
from servers.shared.db import get_question, get_agent_instance
|
||||
from servers.shared.core import (
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
)
|
||||
from .auth import get_current_user_id
|
||||
from .models import (
|
||||
AskQuestionRequest,
|
||||
AskQuestionResponse,
|
||||
EndSessionRequest,
|
||||
EndSessionResponse,
|
||||
LogStepRequest,
|
||||
LogStepResponse,
|
||||
QuestionStatusResponse,
|
||||
)
|
||||
|
||||
agent_router = APIRouter(tags=["agents"])
|
||||
|
||||
|
||||
@agent_router.post("/steps", response_model=LogStepResponse)
|
||||
async def log_step(
|
||||
request: LogStepRequest, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
|
||||
This endpoint:
|
||||
- Creates or retrieves an agent instance
|
||||
- Logs the step with a sequential number
|
||||
- Returns any unretrieved user feedback
|
||||
|
||||
User feedback is automatically marked as retrieved.
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic
|
||||
instance_id, step_number, user_feedback = process_log_step(
|
||||
db=db,
|
||||
agent_type=request.agent_type,
|
||||
step_description=request.step_description,
|
||||
user_id=user_id,
|
||||
agent_instance_id=request.agent_instance_id,
|
||||
)
|
||||
|
||||
return LogStepResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
step_number=step_number,
|
||||
user_feedback=user_feedback,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Internal server error: {str(e)}",
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.post("/questions", response_model=AskQuestionResponse)
|
||||
def ask_question(
|
||||
request: AskQuestionRequest, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> AskQuestionResponse:
|
||||
"""Create a question for the user to answer.
|
||||
|
||||
This endpoint:
|
||||
- Creates a question record in the database
|
||||
- Returns immediately with the question ID
|
||||
- Client should poll GET /questions/{question_id} for the answer
|
||||
|
||||
Note: This endpoint is non-blocking.
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic to create question
|
||||
question = create_agent_question(
|
||||
db=db,
|
||||
agent_instance_id=request.agent_instance_id,
|
||||
question_text=request.question_text,
|
||||
user_id=user_id,
|
||||
)
|
||||
db.commit()
|
||||
|
||||
# FastAPI-specific: Return immediately with question ID (non-blocking)
|
||||
return AskQuestionResponse(
|
||||
question_id=str(question.id),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Internal server error: {str(e)}",
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.get("/questions/{question_id}", response_model=QuestionStatusResponse)
|
||||
async def get_question_status(
|
||||
question_id: str, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> QuestionStatusResponse:
|
||||
"""Get the status of a question.
|
||||
|
||||
This endpoint allows polling for question answers without blocking.
|
||||
Returns the current status and answer (if available).
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Get the question
|
||||
question = get_question(db, question_id)
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Question not found"
|
||||
)
|
||||
|
||||
# Verify the question belongs to the authenticated user
|
||||
instance = get_agent_instance(db, str(question.agent_instance_id))
|
||||
if not instance or str(instance.user_id) != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Access denied"
|
||||
)
|
||||
|
||||
# Return question status
|
||||
return QuestionStatusResponse(
|
||||
question_id=str(question.id),
|
||||
status="answered" if question.answer_text else "pending",
|
||||
answer=question.answer_text,
|
||||
asked_at=question.asked_at.isoformat(),
|
||||
answered_at=question.answered_at.isoformat()
|
||||
if question.answered_at
|
||||
else None,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.post("/sessions/end", response_model=EndSessionResponse)
|
||||
async def end_session(
|
||||
request: EndSessionRequest, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
|
||||
This endpoint:
|
||||
- Marks the agent instance as COMPLETED
|
||||
- Sets the session end time
|
||||
- Deactivates any pending questions
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic
|
||||
instance_id, final_status = process_end_session(
|
||||
db=db,
|
||||
agent_instance_id=request.agent_instance_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return EndSessionResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
final_status=final_status,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Internal server error: {str(e)}",
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
1
servers/mcp_server/__init__.py
Normal file
1
servers/mcp_server/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# MCP Server package
|
||||
66
servers/mcp_server/models.py
Normal file
66
servers/mcp_server/models.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
MCP Server tool interface models.
|
||||
|
||||
This module contains all Pydantic models for MCP tool requests and responses.
|
||||
Models define the interface contract between AI agents and the MCP server.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from servers.shared.models import (
|
||||
BaseLogStepRequest,
|
||||
BaseLogStepResponse,
|
||||
BaseAskQuestionRequest,
|
||||
BaseEndSessionRequest,
|
||||
BaseEndSessionResponse,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Tool Request Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# MCP uses the base models directly for requests
|
||||
class LogStepRequest(BaseLogStepRequest):
|
||||
"""MCP-specific request model for logging a step"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AskQuestionRequest(BaseAskQuestionRequest):
|
||||
"""MCP-specific request model for asking a question"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EndSessionRequest(BaseEndSessionRequest):
|
||||
"""MCP-specific request model for ending a session"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# MCP uses the base model directly for log step response
|
||||
class LogStepResponse(BaseLogStepResponse):
|
||||
"""MCP-specific response model for log step"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# MCP-specific: Response contains the answer (blocking operation)
|
||||
class AskQuestionResponse(BaseModel):
|
||||
"""MCP-specific response model for ask question"""
|
||||
|
||||
answer: str = Field(..., description="User's answer to the question")
|
||||
question_id: str = Field(..., description="ID of the question that was answered")
|
||||
|
||||
|
||||
# MCP uses the base model directly for end session response
|
||||
class EndSessionResponse(BaseEndSessionResponse):
|
||||
"""MCP-specific response model for end session"""
|
||||
|
||||
pass
|
||||
148
servers/mcp_server/server.py
Normal file
148
servers/mcp_server/server.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""MCP Server for Agent Dashboard"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from fastmcp import FastMCP, Context
|
||||
from fastmcp.server.auth import BearerAuthProvider
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from shared.config import settings
|
||||
|
||||
from .models import AskQuestionResponse, EndSessionResponse, LogStepResponse
|
||||
from .tools import (
|
||||
LOG_STEP_DESCRIPTION,
|
||||
ASK_QUESTION_DESCRIPTION,
|
||||
END_SESSION_DESCRIPTION,
|
||||
log_step_impl,
|
||||
ask_question_impl,
|
||||
end_session_impl,
|
||||
)
|
||||
from .utils import detect_agent_type_from_headers
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type variables for decorator
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def require_auth(func: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
"""Decorator to ensure user is authenticated before executing tool."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
# Get authenticated user info - this should be the first check
|
||||
access_token = get_access_token()
|
||||
if access_token is None:
|
||||
raise ValueError("Authentication required. Please provide a valid API key.")
|
||||
|
||||
# Add user_id to kwargs for use in the function
|
||||
kwargs["_user_id"] = access_token.client_id
|
||||
result = func(*args, **kwargs)
|
||||
# Handle both sync and async functions
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Configure authentication
|
||||
if not settings.jwt_public_key:
|
||||
raise ValueError(
|
||||
"JWT_PUBLIC_KEY environment variable is not set. "
|
||||
"Please generate keys using scripts/generate_jwt_keys.py "
|
||||
"and add them to your .env file"
|
||||
)
|
||||
|
||||
auth = BearerAuthProvider(
|
||||
public_key=settings.jwt_public_key,
|
||||
)
|
||||
|
||||
|
||||
# Create FastMCP server with authentication
|
||||
mcp = FastMCP("Agent Dashboard MCP Server", auth=auth)
|
||||
|
||||
|
||||
@mcp.tool(name="log_step", description=LOG_STEP_DESCRIPTION)
|
||||
@require_auth
|
||||
def log_step_tool(
|
||||
agent_instance_id: str | None = None,
|
||||
step_description: str = "",
|
||||
_user_id: str = "", # Injected by decorator
|
||||
) -> LogStepResponse:
|
||||
agent_type = detect_agent_type_from_headers()
|
||||
return log_step_impl(
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type=agent_type,
|
||||
step_description=step_description,
|
||||
user_id=_user_id,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
name="ask_question",
|
||||
description=ASK_QUESTION_DESCRIPTION,
|
||||
)
|
||||
@require_auth
|
||||
async def ask_question_tool(
|
||||
ctx: Context,
|
||||
agent_instance_id: str | None = None,
|
||||
question_text: str | None = None,
|
||||
_user_id: str = "", # Injected by decorator
|
||||
) -> AskQuestionResponse:
|
||||
return await ask_question_impl(
|
||||
agent_instance_id=agent_instance_id,
|
||||
question_text=question_text,
|
||||
user_id=_user_id,
|
||||
tool_context=ctx,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
name="end_session",
|
||||
description=END_SESSION_DESCRIPTION,
|
||||
)
|
||||
@require_auth
|
||||
def end_session_tool(
|
||||
agent_instance_id: str,
|
||||
_user_id: str = "", # Injected by decorator
|
||||
) -> EndSessionResponse:
|
||||
return end_session_impl(
|
||||
agent_instance_id=agent_instance_id,
|
||||
user_id=_user_id,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the MCP server"""
|
||||
# Database tables should be managed by Alembic migrations
|
||||
logger.info("Starting MCP server...")
|
||||
|
||||
# Log configuration for debugging
|
||||
logger.info(f"Starting MCP server on port: {settings.mcp_server_port}")
|
||||
logger.info(f"Database URL configured: {settings.database_url[:50]}...")
|
||||
logger.info(
|
||||
f"JWT public key configured: {'Yes' if settings.jwt_public_key else 'No'}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Use streamable-http which handles both HTTP POST and SSE on same endpoint
|
||||
mcp.run(
|
||||
transport="streamable-http",
|
||||
host="0.0.0.0",
|
||||
port=settings.mcp_server_port,
|
||||
path="/mcp",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
140
servers/mcp_server/stdio_server.py
Normal file
140
servers/mcp_server/stdio_server.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Omnara MCP Server - Stdio Transport
|
||||
|
||||
This is the stdio version of the Omnara MCP server that can be installed via pip/pipx.
|
||||
It provides the same functionality as the hosted server but uses stdio transport.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from shared.config import settings
|
||||
from shared.database import Base
|
||||
from shared.database.session import engine
|
||||
|
||||
from .models import AskQuestionResponse, EndSessionResponse, LogStepResponse
|
||||
from .tools import (
|
||||
LOG_STEP_DESCRIPTION,
|
||||
ASK_QUESTION_DESCRIPTION,
|
||||
END_SESSION_DESCRIPTION,
|
||||
log_step_impl,
|
||||
ask_question_impl,
|
||||
end_session_impl,
|
||||
)
|
||||
from .utils import detect_agent_type_from_environment
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type variables for decorator
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def require_api_key(func: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
"""Decorator to ensure API key is provided for stdio server."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
# For stdio, we get the API key from command line args
|
||||
# and use it as the user_id for simplicity
|
||||
api_key = getattr(require_api_key, "_api_key", None)
|
||||
if not api_key:
|
||||
raise ValueError("API key is required. Use --api-key argument.")
|
||||
|
||||
# Add user_id to kwargs for use in the function
|
||||
kwargs["_user_id"] = api_key
|
||||
result = func(*args, **kwargs)
|
||||
# Handle both sync and async functions
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Create FastMCP server
|
||||
mcp = FastMCP("Omnara Agent Dashboard MCP Server")
|
||||
|
||||
|
||||
@mcp.tool(name="log_step", description=LOG_STEP_DESCRIPTION)
|
||||
@require_api_key
|
||||
def log_step_tool(
|
||||
agent_instance_id: str | None = None,
|
||||
step_description: str = "",
|
||||
_user_id: str = "", # Injected by decorator
|
||||
) -> LogStepResponse:
|
||||
agent_type = detect_agent_type_from_environment()
|
||||
return log_step_impl(
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type=agent_type,
|
||||
step_description=step_description,
|
||||
user_id=_user_id,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
name="ask_question",
|
||||
description=ASK_QUESTION_DESCRIPTION,
|
||||
)
|
||||
@require_api_key
|
||||
async def ask_question_tool(
|
||||
agent_instance_id: str | None = None,
|
||||
question_text: str | None = None,
|
||||
_user_id: str = "", # Injected by decorator
|
||||
) -> AskQuestionResponse:
|
||||
return await ask_question_impl(
|
||||
agent_instance_id=agent_instance_id,
|
||||
question_text=question_text,
|
||||
user_id=_user_id,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
name="end_session",
|
||||
description=END_SESSION_DESCRIPTION,
|
||||
)
|
||||
@require_api_key
|
||||
def end_session_tool(
|
||||
agent_instance_id: str,
|
||||
_user_id: str = "", # Injected by decorator
|
||||
) -> EndSessionResponse:
|
||||
return end_session_impl(
|
||||
agent_instance_id=agent_instance_id,
|
||||
user_id=_user_id,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the stdio server"""
|
||||
parser = argparse.ArgumentParser(description="Omnara MCP Server (Stdio)")
|
||||
parser.add_argument("--api-key", required=True, help="API key for authentication")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Store API key for auth decorator
|
||||
require_api_key._api_key = args.api_key
|
||||
|
||||
# Ensure database tables exist
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("Database tables created/verified")
|
||||
|
||||
logger.info("Starting Omnara MCP server (stdio)")
|
||||
logger.info(f"Database URL configured: {settings.database_url[:50]}...")
|
||||
|
||||
try:
|
||||
# Run with stdio transport (default)
|
||||
mcp.run(transport="stdio")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
247
servers/mcp_server/tools.py
Normal file
247
servers/mcp_server/tools.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Shared MCP Tools for Agent Dashboard
|
||||
|
||||
This module contains the core tool implementations that are shared between
|
||||
the hosted server and stdio server. The authentication logic is handled
|
||||
by the individual servers.
|
||||
"""
|
||||
|
||||
from fastmcp import Context
|
||||
from shared.database.session import get_db
|
||||
|
||||
from servers.shared.db import wait_for_answer
|
||||
from servers.shared.core import (
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
)
|
||||
from .models import AskQuestionResponse, EndSessionResponse, LogStepResponse
|
||||
|
||||
LOG_STEP_DESCRIPTION = """Log a high-level step the agent is performing.
|
||||
|
||||
⚠️ CRITICAL: MUST be called for EVERY significant action:
|
||||
• Before answering any user question or request
|
||||
• When performing analysis, searches, or investigations
|
||||
• When reading files, exploring code, or gathering information
|
||||
• When making code changes, edits, or file modifications
|
||||
• When running commands, tests, or terminal operations
|
||||
• When providing explanations, solutions, or recommendations
|
||||
• At the start of multi-step processes or complex tasks
|
||||
|
||||
This call retrieves unread user feedback that you MUST incorporate into your work.
|
||||
Feedback may contain corrections, clarifications, or additional instructions that override your original plan.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Existing agent instance ID (optional). If omitted, creates a new instance for reuse in subsequent steps.
|
||||
step_description: Clear, specific description of what you're about to do or currently doing.
|
||||
|
||||
⚠️ RETURNS USER FEEDBACK: If user_feedback is not empty, you MUST:
|
||||
1. Read and understand each feedback message
|
||||
2. Adjust your current approach based on the feedback
|
||||
3. Acknowledge the feedback in your response
|
||||
4. Prioritize user feedback over your original plan
|
||||
|
||||
Feedback is automatically marked as retrieved. If empty, continue as planned."""
|
||||
|
||||
|
||||
ASK_QUESTION_DESCRIPTION = """🤖 INTERACTIVE: Ask the user a question and WAIT for their reply (BLOCKS execution).
|
||||
|
||||
⚠️ CRITICAL: ALWAYS call log_step BEFORE using this tool to track the interaction.
|
||||
|
||||
🎯 USE WHEN YOU NEED:
|
||||
• Clarification on ambiguous requirements or unclear instructions
|
||||
• User decision between multiple valid approaches or solutions
|
||||
• Confirmation before making significant changes (deleting files, major refactors)
|
||||
• Missing information that you cannot determine from context or codebase
|
||||
• User preferences for implementation details (styling, naming, architecture)
|
||||
• Validation of assumptions before proceeding with complex tasks
|
||||
|
||||
💡 BEST PRACTICES:
|
||||
• Keep questions clear, specific, and actionable
|
||||
• Provide context: explain WHY you're asking
|
||||
• Offer options when multiple choices exist
|
||||
• Ask one focused question at a time
|
||||
• Include relevant details to help user decide
|
||||
|
||||
Args:
|
||||
agent_instance_id: Current agent instance ID. REQUIRED.
|
||||
question_text: Clear, specific question with sufficient context for the user to provide a helpful answer."""
|
||||
|
||||
|
||||
END_SESSION_DESCRIPTION = """End the current agent session and mark it as completed.
|
||||
|
||||
⚠️ IMPORTANT: Before using this tool, you MUST:
|
||||
1. Provide a comprehensive summary of all actions taken to complete the task
|
||||
2. Use the ask_question tool to confirm with the user that the task is complete
|
||||
3. Only proceed with end_session if the user confirms completion
|
||||
|
||||
Example confirmation question:
|
||||
"I've completed the following tasks:
|
||||
• [List of specific actions taken]
|
||||
• [Key changes or implementations made]
|
||||
• [Any important outcomes or results]
|
||||
|
||||
Is this task complete and ready to be marked as finished?"
|
||||
|
||||
If the user:
|
||||
• Confirms completion → Use end_session tool
|
||||
• Does NOT confirm → Continue working on their feedback or new requirements
|
||||
• Requests additional work → Do NOT end the session, continue with the new tasks
|
||||
|
||||
Use this tool ONLY when:
|
||||
• The user has explicitly confirmed the task is complete
|
||||
• The user explicitly asks to end the session
|
||||
• An unrecoverable error prevents any further work
|
||||
|
||||
This will:
|
||||
• Mark the agent instance status as COMPLETED
|
||||
• Set the session end time
|
||||
• Deactivate any pending questions
|
||||
• Prevent further updates to this session
|
||||
|
||||
Args:
|
||||
agent_instance_id: Current agent instance ID to end. REQUIRED."""
|
||||
|
||||
|
||||
def log_step_impl(
|
||||
agent_instance_id: str | None = None,
|
||||
agent_type: str = "",
|
||||
step_description: str = "",
|
||||
user_id: str = "",
|
||||
) -> LogStepResponse:
|
||||
"""Core implementation of the log_step tool.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
agent_type: Name of the agent (e.g., 'Claude Code', 'Cursor')
|
||||
step_description: High-level description of the current step
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance details, and user feedback
|
||||
"""
|
||||
# Validate inputs
|
||||
if not agent_type:
|
||||
raise ValueError("agent_type is required")
|
||||
if not step_description:
|
||||
raise ValueError("step_description is required")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic
|
||||
instance_id, step_number, user_feedback = process_log_step(
|
||||
db=db,
|
||||
agent_type=agent_type,
|
||||
step_description=step_description,
|
||||
user_id=user_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
return LogStepResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
step_number=step_number,
|
||||
user_feedback=user_feedback,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def ask_question_impl(
|
||||
agent_instance_id: str | None = None,
|
||||
question_text: str | None = None,
|
||||
user_id: str = "",
|
||||
tool_context: Context | None = None,
|
||||
) -> AskQuestionResponse:
|
||||
"""Core implementation of the ask_question tool.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask the user
|
||||
user_id: Authenticated user ID
|
||||
tool_context: MCP context for progress reporting
|
||||
|
||||
Returns:
|
||||
AskQuestionResponse with the user's answer
|
||||
"""
|
||||
# Validate inputs
|
||||
if not agent_instance_id:
|
||||
raise ValueError("agent_instance_id is required")
|
||||
if not question_text:
|
||||
raise ValueError("question_text is required")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic to create question
|
||||
question = create_agent_question(
|
||||
db=db,
|
||||
agent_instance_id=agent_instance_id,
|
||||
question_text=question_text,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# MCP-specific: Wait for answer (blocking)
|
||||
answer = await wait_for_answer(db, question.id, tool_context=tool_context)
|
||||
|
||||
if answer is None:
|
||||
raise TimeoutError("Question timed out waiting for user response")
|
||||
|
||||
return AskQuestionResponse(answer=answer, question_id=str(question.id))
|
||||
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def end_session_impl(
|
||||
agent_instance_id: str,
|
||||
user_id: str = "",
|
||||
) -> EndSessionResponse:
|
||||
"""Core implementation of the end_session tool.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID to end
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
EndSessionResponse with success status and final session details
|
||||
"""
|
||||
# Validate inputs
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic
|
||||
instance_id, final_status = process_end_session(
|
||||
db=db,
|
||||
agent_instance_id=agent_instance_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return EndSessionResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
final_status=final_status,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
64
servers/mcp_server/utils.py
Normal file
64
servers/mcp_server/utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Utility functions for MCP server"""
|
||||
|
||||
import os
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
|
||||
def detect_agent_type_from_headers() -> str:
|
||||
"""Detect the agent type from HTTP User-Agent header."""
|
||||
try:
|
||||
headers = get_http_headers()
|
||||
|
||||
# First check for explicit client type header
|
||||
explicit_client = headers.get("x-client-type") or headers.get("X-Client-Type")
|
||||
if explicit_client:
|
||||
return explicit_client
|
||||
|
||||
# Fall back to User-Agent parsing
|
||||
user_agent = headers.get("user-agent", "").lower()
|
||||
|
||||
# Parse User-Agent patterns for supported clients
|
||||
if "cursor" in user_agent:
|
||||
return "cursor"
|
||||
elif "claude" in user_agent:
|
||||
if "claude-code" in user_agent or "claude code" in user_agent:
|
||||
return "claude-code"
|
||||
else:
|
||||
return "claude"
|
||||
elif "cline" in user_agent:
|
||||
if "roo-cline" in user_agent or "roo cline" in user_agent:
|
||||
return "roo-cline"
|
||||
else:
|
||||
return "cline"
|
||||
elif "windsurf" in user_agent:
|
||||
return "windsurf"
|
||||
elif "witsy" in user_agent:
|
||||
return "witsy"
|
||||
elif "enconvo" in user_agent:
|
||||
return "enconvo"
|
||||
elif "vscode" in user_agent or "code" in user_agent:
|
||||
return "vscode"
|
||||
elif "postman" in user_agent:
|
||||
return "postman"
|
||||
|
||||
# Return "unknown" for HTTP requests where we can't identify the client
|
||||
return "unknown"
|
||||
|
||||
except ImportError:
|
||||
# FastMCP dependencies not available - likely stdio transport
|
||||
return "unknown"
|
||||
except Exception:
|
||||
# Other errors (no request context, etc.) - likely stdio transport
|
||||
return "unknown"
|
||||
|
||||
|
||||
def detect_agent_type_from_environment() -> str:
|
||||
"""Detect agent type from environment variables (for stdio transport)."""
|
||||
# Check for our custom environment variable first
|
||||
omnara_client_type = os.getenv("OMNARA_CLIENT_TYPE")
|
||||
if omnara_client_type:
|
||||
return omnara_client_type
|
||||
|
||||
# If no explicit client type is set, return unknown
|
||||
# This typically means the client wasn't installed via our CLI
|
||||
return "unknown"
|
||||
15
servers/requirements.txt
Normal file
15
servers/requirements.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
# Servers requirements - includes both MCP and FastAPI dependencies
|
||||
|
||||
# MCP Server
|
||||
fastmcp==2.9.2
|
||||
|
||||
# FastAPI Server
|
||||
fastapi>=0.100.0
|
||||
uvicorn[standard]>=0.23.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
|
||||
# Push notifications
|
||||
exponent-server-sdk>=2.1.0
|
||||
|
||||
# Shared dependencies
|
||||
-r ../shared/requirements.txt
|
||||
0
servers/shared/__init__.py
Normal file
0
servers/shared/__init__.py
Normal file
15
servers/shared/core/__init__.py
Normal file
15
servers/shared/core/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Core business logic shared between servers."""
|
||||
|
||||
from .agents import (
|
||||
validate_agent_access,
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"validate_agent_access",
|
||||
"process_log_step",
|
||||
"create_agent_question",
|
||||
"process_end_session",
|
||||
]
|
||||
150
servers/shared/core/agents.py
Normal file
150
servers/shared/core/agents.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Shared business logic for agent operations.
|
||||
|
||||
This module contains the common logic used by both MCP and FastAPI servers,
|
||||
avoiding code duplication while allowing protocol-specific implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from servers.shared.db import (
|
||||
get_agent_instance,
|
||||
create_agent_instance,
|
||||
log_step,
|
||||
create_question,
|
||||
get_and_mark_unretrieved_feedback,
|
||||
create_or_get_user_agent,
|
||||
end_session,
|
||||
)
|
||||
from servers.shared.notifications import push_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_agent_access(db: Session, agent_instance_id: str, user_id: str):
|
||||
"""Validate that a user has access to an agent instance.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID to validate
|
||||
user_id: User ID requesting access
|
||||
|
||||
Returns:
|
||||
The agent instance if validation passes
|
||||
|
||||
Raises:
|
||||
ValueError: If instance not found or user doesn't have access
|
||||
"""
|
||||
instance = get_agent_instance(db, agent_instance_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Agent instance {agent_instance_id} not found")
|
||||
if str(instance.user_id) != user_id:
|
||||
raise ValueError(
|
||||
"Access denied. Agent instance does not belong to authenticated user."
|
||||
)
|
||||
return instance
|
||||
|
||||
|
||||
def process_log_step(
|
||||
db: Session,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
user_id: str,
|
||||
agent_instance_id: str | None = None,
|
||||
) -> tuple[str, int, list[str]]:
|
||||
"""Process a log step operation with all common logic.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_type: Type of agent
|
||||
step_description: Description of the step
|
||||
user_id: Authenticated user ID
|
||||
agent_instance_id: Optional existing instance ID
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_instance_id, step_number, user_feedback)
|
||||
"""
|
||||
# Get or create user agent type
|
||||
agent_type_obj = create_or_get_user_agent(db, agent_type, user_id)
|
||||
|
||||
# Get or create instance
|
||||
if agent_instance_id:
|
||||
instance = validate_agent_access(db, agent_instance_id, user_id)
|
||||
else:
|
||||
instance = create_agent_instance(db, agent_type_obj.id, user_id)
|
||||
|
||||
# Create step
|
||||
step = log_step(db, instance.id, step_description)
|
||||
|
||||
# Get unretrieved feedback
|
||||
feedback = get_and_mark_unretrieved_feedback(db, instance.id)
|
||||
|
||||
return str(instance.id), step.step_number, feedback
|
||||
|
||||
|
||||
def create_agent_question(
|
||||
db: Session,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Create a question with validation and send push notification.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
The created question object
|
||||
"""
|
||||
# Validate access
|
||||
instance = validate_agent_access(db, agent_instance_id, user_id)
|
||||
|
||||
# Create question
|
||||
question = create_question(db, instance.id, question_text)
|
||||
|
||||
# Send push notification
|
||||
try:
|
||||
# Get agent name from instance
|
||||
agent_name = instance.user_agent.name if instance.user_agent else "Agent"
|
||||
|
||||
# Send notification
|
||||
push_service.send_question_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
question_id=str(question.id),
|
||||
agent_name=agent_name,
|
||||
question_text=question_text,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the question creation
|
||||
logger.error(f"Failed to send push notification: {e}")
|
||||
|
||||
return question
|
||||
|
||||
|
||||
def process_end_session(
|
||||
db: Session,
|
||||
agent_instance_id: str,
|
||||
user_id: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Process ending a session with validation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID to end
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_instance_id, final_status)
|
||||
"""
|
||||
# Validate access
|
||||
instance = validate_agent_access(db, agent_instance_id, user_id)
|
||||
|
||||
# End the session
|
||||
updated_instance = end_session(db, instance.id)
|
||||
|
||||
return str(updated_instance.id), updated_instance.status.value
|
||||
25
servers/shared/db/__init__.py
Normal file
25
servers/shared/db/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Database queries and operations for servers."""
|
||||
|
||||
from .queries import (
|
||||
create_agent_instance,
|
||||
create_or_get_user_agent,
|
||||
create_question,
|
||||
end_session,
|
||||
get_agent_instance,
|
||||
get_and_mark_unretrieved_feedback,
|
||||
get_question,
|
||||
log_step,
|
||||
wait_for_answer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_agent_instance",
|
||||
"create_or_get_user_agent",
|
||||
"create_question",
|
||||
"end_session",
|
||||
"get_agent_instance",
|
||||
"get_and_mark_unretrieved_feedback",
|
||||
"get_question",
|
||||
"log_step",
|
||||
"wait_for_answer",
|
||||
]
|
||||
231
servers/shared/db/queries.py
Normal file
231
servers/shared/db/queries.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from shared.database import (
|
||||
AgentInstance,
|
||||
AgentQuestion,
|
||||
AgentStatus,
|
||||
AgentStep,
|
||||
AgentUserFeedback,
|
||||
UserAgent,
|
||||
)
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from fastmcp import Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_or_get_user_agent(db: Session, name: str, user_id: str) -> UserAgent:
|
||||
"""Create or get a user agent by name for a specific user"""
|
||||
# Normalize name to lowercase for consistent storage
|
||||
normalized_name = name.lower()
|
||||
|
||||
user_agent = (
|
||||
db.query(UserAgent)
|
||||
.filter(UserAgent.name == normalized_name, UserAgent.user_id == UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if not user_agent:
|
||||
user_agent = UserAgent(
|
||||
name=normalized_name,
|
||||
user_id=UUID(user_id),
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user_agent)
|
||||
db.commit()
|
||||
db.refresh(user_agent)
|
||||
return user_agent
|
||||
|
||||
|
||||
def create_agent_instance(
|
||||
db: Session, user_agent_id: UUID | None, user_id: str
|
||||
) -> AgentInstance:
|
||||
"""Create a new agent instance"""
|
||||
instance = AgentInstance(
|
||||
user_agent_id=user_agent_id, user_id=UUID(user_id), status=AgentStatus.ACTIVE
|
||||
)
|
||||
db.add(instance)
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
def get_agent_instance(db: Session, instance_id: str) -> AgentInstance | None:
|
||||
"""Get an agent instance by ID"""
|
||||
return db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
|
||||
|
||||
def log_step(db: Session, instance_id: UUID, description: str) -> AgentStep:
|
||||
"""Log a new step for an agent instance"""
|
||||
# Get the next step number
|
||||
max_step = (
|
||||
db.query(func.max(AgentStep.step_number))
|
||||
.filter(AgentStep.agent_instance_id == instance_id)
|
||||
.scalar()
|
||||
)
|
||||
next_step_number = (max_step or 0) + 1
|
||||
|
||||
# Create the step
|
||||
step = AgentStep(
|
||||
agent_instance_id=instance_id,
|
||||
step_number=next_step_number,
|
||||
description=description,
|
||||
)
|
||||
db.add(step)
|
||||
db.commit()
|
||||
db.refresh(step)
|
||||
return step
|
||||
|
||||
|
||||
def create_question(
|
||||
db: Session, instance_id: UUID, question_text: str
|
||||
) -> AgentQuestion:
|
||||
"""Create a new question for an agent instance"""
|
||||
# Mark any existing active questions as inactive
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
|
||||
).update({"is_active": False})
|
||||
|
||||
# Update agent instance status to awaiting_input
|
||||
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
if instance and instance.status == AgentStatus.ACTIVE:
|
||||
instance.status = AgentStatus.AWAITING_INPUT
|
||||
|
||||
# Create new question
|
||||
question = AgentQuestion(
|
||||
agent_instance_id=instance_id, question_text=question_text, is_active=True
|
||||
)
|
||||
db.add(question)
|
||||
db.commit()
|
||||
db.refresh(question)
|
||||
|
||||
# Send push notification
|
||||
try:
|
||||
from servers.shared.notifications import push_service
|
||||
|
||||
# Get agent name from instance
|
||||
agent_name = instance.user_agent.name if instance.user_agent else "Agent"
|
||||
|
||||
push_service.send_question_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
question_id=str(question.id),
|
||||
agent_name=agent_name,
|
||||
question_text=question_text,
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the question creation if push notification fails
|
||||
logger.error(
|
||||
f"Failed to send push notification for question {question.id}: {e}"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
|
||||
async def wait_for_answer(
|
||||
db: Session,
|
||||
question_id: UUID,
|
||||
timeout: int = 86400,
|
||||
tool_context: Context | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Wait for an answer to a question (async non-blocking)
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
question_id: Question ID to wait for
|
||||
timeout: Maximum time to wait in seconds (default 24 hours)
|
||||
|
||||
Returns:
|
||||
Answer text if received, None if timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
last_progress_report = start_time
|
||||
total_minutes = int(timeout / 60)
|
||||
|
||||
# Report initial progress (0 minutes elapsed)
|
||||
if tool_context:
|
||||
await tool_context.report_progress(0, total_minutes)
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check for answer
|
||||
db.commit() # Ensure we see latest data
|
||||
question = (
|
||||
db.query(AgentQuestion).filter(AgentQuestion.id == question_id).first()
|
||||
)
|
||||
|
||||
if question and question.answer_text is not None:
|
||||
if tool_context:
|
||||
await tool_context.report_progress(total_minutes, total_minutes)
|
||||
return question.answer_text
|
||||
|
||||
# Report progress every minute if tool_context is provided
|
||||
current_time = time.time()
|
||||
if tool_context and (current_time - last_progress_report) >= 60:
|
||||
elapsed_minutes = int((current_time - start_time) / 60)
|
||||
await tool_context.report_progress(elapsed_minutes, total_minutes)
|
||||
last_progress_report = current_time
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Timeout - mark question as inactive
|
||||
db.query(AgentQuestion).filter(AgentQuestion.id == question_id).update(
|
||||
{"is_active": False}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_question(db: Session, question_id: str) -> AgentQuestion | None:
|
||||
"""Get a question by ID"""
|
||||
return db.query(AgentQuestion).filter(AgentQuestion.id == question_id).first()
|
||||
|
||||
|
||||
def get_and_mark_unretrieved_feedback(
|
||||
db: Session, instance_id: UUID, since_time: datetime | None = None
|
||||
) -> list[str]:
|
||||
"""Get unretrieved user feedback for an agent instance and mark as retrieved"""
|
||||
|
||||
query = db.query(AgentUserFeedback).filter(
|
||||
AgentUserFeedback.agent_instance_id == instance_id,
|
||||
AgentUserFeedback.retrieved_at.is_(None),
|
||||
)
|
||||
|
||||
if since_time:
|
||||
query = query.filter(AgentUserFeedback.created_at > since_time)
|
||||
|
||||
feedback_list = query.order_by(AgentUserFeedback.created_at).all()
|
||||
|
||||
# Mark all feedback as retrieved
|
||||
for feedback in feedback_list:
|
||||
feedback.retrieved_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return [feedback.feedback_text for feedback in feedback_list]
|
||||
|
||||
|
||||
def end_session(db: Session, instance_id: UUID) -> AgentInstance:
|
||||
"""End an agent session by marking it as completed"""
|
||||
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
|
||||
if not instance:
|
||||
raise ValueError(f"Agent instance {instance_id} not found")
|
||||
|
||||
# Update status to completed
|
||||
instance.status = AgentStatus.COMPLETED
|
||||
instance.ended_at = datetime.now(timezone.utc)
|
||||
|
||||
# Mark any active questions as inactive
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
|
||||
).update({"is_active": False})
|
||||
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
return instance
|
||||
17
servers/shared/models/__init__.py
Normal file
17
servers/shared/models/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Shared models for MCP and FastAPI servers."""
|
||||
|
||||
from .base import (
|
||||
BaseLogStepRequest,
|
||||
BaseLogStepResponse,
|
||||
BaseAskQuestionRequest,
|
||||
BaseEndSessionRequest,
|
||||
BaseEndSessionResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseLogStepRequest",
|
||||
"BaseLogStepResponse",
|
||||
"BaseAskQuestionRequest",
|
||||
"BaseEndSessionRequest",
|
||||
"BaseEndSessionResponse",
|
||||
]
|
||||
69
servers/shared/models/base.py
Normal file
69
servers/shared/models/base.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Base models shared between MCP and FastAPI servers.
|
||||
|
||||
These models define the common interface for agent operations,
|
||||
allowing each server to extend them as needed.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BaseLogStepRequest(BaseModel):
|
||||
"""Base request model for logging a step."""
|
||||
|
||||
agent_instance_id: str | None = Field(
|
||||
None,
|
||||
description="Existing agent instance ID. If not provided, creates a new instance.",
|
||||
)
|
||||
agent_type: str = Field(
|
||||
..., description="Type of agent (e.g., 'Claude Code', 'Cursor')"
|
||||
)
|
||||
step_description: str = Field(
|
||||
..., description="Clear description of what the agent is doing"
|
||||
)
|
||||
|
||||
|
||||
class BaseAskQuestionRequest(BaseModel):
|
||||
"""Base request model for asking a question."""
|
||||
|
||||
agent_instance_id: str = Field(..., description="Agent instance ID")
|
||||
question_text: str = Field(..., description="Question to ask the user")
|
||||
|
||||
|
||||
class BaseEndSessionRequest(BaseModel):
|
||||
"""Base request model for ending a session."""
|
||||
|
||||
agent_instance_id: str = Field(..., description="Agent instance ID to end")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BaseLogStepResponse(BaseModel):
|
||||
"""Base response model for log step."""
|
||||
|
||||
success: bool = Field(..., description="Whether the step was logged successfully")
|
||||
agent_instance_id: str = Field(
|
||||
..., description="Agent instance ID (new or existing)"
|
||||
)
|
||||
step_number: int = Field(..., description="Sequential step number")
|
||||
user_feedback: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of unretrieved user feedback messages",
|
||||
)
|
||||
|
||||
|
||||
class BaseEndSessionResponse(BaseModel):
|
||||
"""Base response model for end session."""
|
||||
|
||||
success: bool = Field(..., description="Whether the session was ended successfully")
|
||||
agent_instance_id: str = Field(..., description="Agent instance ID that was ended")
|
||||
final_status: str = Field(
|
||||
..., description="Final status of the session (should be 'completed')"
|
||||
)
|
||||
156
servers/shared/notifications.py
Normal file
156
servers/shared/notifications.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Push notification service using Expo Push API"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
from exponent_server_sdk import (
|
||||
PushClient,
|
||||
PushMessage,
|
||||
PushServerError,
|
||||
PushTicketError,
|
||||
DeviceNotRegisteredError,
|
||||
)
|
||||
|
||||
from shared.database import PushToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PushNotificationService:
|
||||
"""Service for sending push notifications via Expo"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = PushClient()
|
||||
|
||||
def send_notification(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
title: str,
|
||||
body: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Send push notification to all user's devices"""
|
||||
try:
|
||||
# Get user's active push tokens
|
||||
tokens = (
|
||||
db.query(PushToken)
|
||||
.filter(PushToken.user_id == user_id, PushToken.is_active)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not tokens:
|
||||
logger.info(f"No push tokens found for user {user_id}")
|
||||
return False
|
||||
|
||||
# Prepare messages for Expo
|
||||
messages = []
|
||||
for token in tokens:
|
||||
# Validate token format
|
||||
if not PushClient.is_exponent_push_token(token.token):
|
||||
logger.warning(f"Invalid Expo push token: {token.token}")
|
||||
continue
|
||||
|
||||
# TODO: Fill in the missing fields
|
||||
message = PushMessage(
|
||||
to=token.token,
|
||||
title=title,
|
||||
body=body,
|
||||
data=data or {},
|
||||
sound="default",
|
||||
priority="high",
|
||||
channel_id="agent-questions",
|
||||
ttl=None,
|
||||
expiration=None,
|
||||
badge=None,
|
||||
category=None,
|
||||
display_in_foreground=True,
|
||||
subtitle=None,
|
||||
mutable_content=False,
|
||||
)
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
logger.warning("No valid push tokens to send to")
|
||||
return False
|
||||
|
||||
# Send to Expo Push API in chunks
|
||||
try:
|
||||
# Send messages in batches (Expo recommends max 100 per batch)
|
||||
for chunk in self._chunks(messages, 100):
|
||||
response = self.client.publish_multiple(chunk)
|
||||
|
||||
# Check for errors in the response
|
||||
for push_response in response:
|
||||
if push_response.get("status") == "error":
|
||||
logger.error(
|
||||
f"Push notification error: {push_response.get('message')}"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully sent push notifications to user {user_id}")
|
||||
return True
|
||||
|
||||
except PushServerError as e:
|
||||
logger.error(f"Push server error: {str(e)}")
|
||||
return False
|
||||
except DeviceNotRegisteredError as e:
|
||||
logger.error(f"Device not registered, deactivating token: {str(e)}")
|
||||
# Mark token as inactive
|
||||
for token in tokens:
|
||||
if token.token in str(e):
|
||||
token.is_active = False
|
||||
token.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
return False
|
||||
except PushTicketError as e:
|
||||
logger.error(f"Push ticket error: {str(e)}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending push notification: {str(e)}")
|
||||
return False
|
||||
|
||||
def send_question_notification(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
instance_id: str,
|
||||
question_id: str,
|
||||
agent_name: str,
|
||||
question_text: str,
|
||||
) -> bool:
|
||||
"""Send notification for new agent question"""
|
||||
# Format agent name for display
|
||||
display_name = agent_name.replace("_", " ").title()
|
||||
title = f"{display_name} needs your input"
|
||||
|
||||
# Truncate question text for notification
|
||||
body = question_text
|
||||
if len(body) > 100:
|
||||
body = body[:97] + "..."
|
||||
|
||||
data = {
|
||||
"type": "new_question",
|
||||
"instanceId": instance_id,
|
||||
"questionId": question_id,
|
||||
}
|
||||
|
||||
return self.send_notification(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
push_service = PushNotificationService()
|
||||
1
servers/tests/__init__.py
Normal file
1
servers/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test suite for Omnara servers."""
|
||||
126
servers/tests/conftest.py
Normal file
126
servers/tests/conftest.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Pytest configuration and fixtures for servers tests."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
|
||||
from shared.database.models import Base, User, UserAgent, AgentInstance
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_container():
|
||||
"""Create a PostgreSQL container for testing - shared across all tests."""
|
||||
with PostgresContainer("postgres:16-alpine") as postgres:
|
||||
yield postgres
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db(postgres_container):
|
||||
"""Create a test database session using PostgreSQL."""
|
||||
# Get connection URL from container
|
||||
db_url = postgres_container.get_connection_url()
|
||||
|
||||
# Create engine and tables
|
||||
engine = create_engine(db_url)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create session
|
||||
TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = TestSessionLocal()
|
||||
|
||||
# Create test data
|
||||
test_user = User(
|
||||
id=uuid4(),
|
||||
email="test@example.com",
|
||||
display_name="Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
test_user_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="Claude Code",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
session.add(test_user)
|
||||
session.add(test_user_agent)
|
||||
session.commit()
|
||||
|
||||
yield session
|
||||
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_payload():
|
||||
"""Mock JWT payload for authentication tests."""
|
||||
return {
|
||||
"sub": str(uuid4()),
|
||||
"email": "test@example.com",
|
||||
"iat": datetime.now(timezone.utc).timestamp(),
|
||||
"exp": (datetime.now(timezone.utc).timestamp() + 3600),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(test_db, mock_jwt_payload):
|
||||
"""Mock MCP context with authentication."""
|
||||
context = Mock()
|
||||
context.user_id = mock_jwt_payload["sub"]
|
||||
context.user_email = mock_jwt_payload["email"]
|
||||
context.db = test_db
|
||||
return context
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_mock_context(test_db, mock_jwt_payload):
|
||||
"""Async mock MCP context for async tests."""
|
||||
context = AsyncMock()
|
||||
context.user_id = mock_jwt_payload["sub"]
|
||||
context.user_email = mock_jwt_payload["email"]
|
||||
context.db = test_db
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_agent_instance(test_db):
|
||||
"""Create a test agent instance."""
|
||||
# Get test user and user agent
|
||||
user = test_db.query(User).first()
|
||||
user_agent = test_db.query(UserAgent).first()
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=user_agent.id,
|
||||
user_id=user.id,
|
||||
status=AgentStatus.ACTIVE,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
test_db.add(instance)
|
||||
test_db.commit()
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_env():
|
||||
"""Reset environment variables for each test."""
|
||||
original_env = os.environ.copy()
|
||||
yield
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
271
servers/tests/test_integration.py
Normal file
271
servers/tests/test_integration.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Integration tests using PostgreSQL."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
# Database fixtures come from conftest.py
|
||||
|
||||
# Import the real models
|
||||
from shared.database.models import (
|
||||
User,
|
||||
UserAgent,
|
||||
AgentInstance,
|
||||
AgentStep,
|
||||
AgentQuestion,
|
||||
AgentUserFeedback,
|
||||
)
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
# Import the core functions we want to test
|
||||
from servers.shared.core import (
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
)
|
||||
|
||||
|
||||
# Using test_db fixture from conftest.py which provides PostgreSQL via testcontainers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(test_db):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
email="integration@test.com",
|
||||
display_name="Integration Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(user)
|
||||
test_db.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_agent(test_db, test_user):
|
||||
"""Create a test user agent."""
|
||||
user_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="Claude Code Test",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(user_agent)
|
||||
test_db.commit()
|
||||
return user_agent
|
||||
|
||||
|
||||
class TestIntegrationFlow:
|
||||
"""Test the complete integration flow with PostgreSQL."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_complete_agent_session_flow(self, test_db, test_user, test_user_agent):
|
||||
"""Test a complete agent session from start to finish."""
|
||||
# Step 1: Create new agent instance
|
||||
instance_id, step_number, user_feedback = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Starting integration test task",
|
||||
)
|
||||
|
||||
assert instance_id is not None
|
||||
assert step_number == 1
|
||||
assert user_feedback == []
|
||||
|
||||
# Verify instance was created in database
|
||||
instance = test_db.query(AgentInstance).filter_by(id=instance_id).first()
|
||||
assert instance is not None
|
||||
assert instance.status == AgentStatus.ACTIVE
|
||||
assert instance.user_id == test_user.id
|
||||
|
||||
# Step 2: Log another step
|
||||
_, step_number2, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Processing files",
|
||||
)
|
||||
|
||||
assert step_number2 == 2
|
||||
|
||||
# Step 3: Create a question
|
||||
question = create_agent_question(
|
||||
db=test_db,
|
||||
agent_instance_id=instance_id,
|
||||
question_text="Should I refactor this module?",
|
||||
user_id=str(test_user.id),
|
||||
)
|
||||
|
||||
assert question is not None
|
||||
question_id = question.id
|
||||
|
||||
# Verify question in database
|
||||
question = test_db.query(AgentQuestion).filter_by(id=question_id).first()
|
||||
assert question is not None
|
||||
assert question.question_text == "Should I refactor this module?"
|
||||
assert question.is_active is True
|
||||
|
||||
# Step 4: Add user feedback
|
||||
feedback = AgentUserFeedback(
|
||||
id=uuid4(),
|
||||
agent_instance_id=instance_id,
|
||||
created_by_user_id=test_user.id,
|
||||
feedback_text="Please use async/await pattern",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(feedback)
|
||||
test_db.commit()
|
||||
|
||||
# Step 5: Next log_step should retrieve feedback
|
||||
_, step_number3, feedback_list = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Implementing async pattern",
|
||||
)
|
||||
|
||||
assert step_number3 == 3
|
||||
assert len(feedback_list) == 1
|
||||
assert feedback_list[0] == "Please use async/await pattern"
|
||||
|
||||
# Verify feedback was marked as retrieved
|
||||
test_db.refresh(feedback)
|
||||
assert feedback.retrieved_at is not None
|
||||
|
||||
# Step 6: End the session
|
||||
ended_instance_id, final_status = process_end_session(
|
||||
db=test_db, agent_instance_id=instance_id, user_id=str(test_user.id)
|
||||
)
|
||||
|
||||
assert ended_instance_id == instance_id
|
||||
assert final_status == "completed"
|
||||
|
||||
# Verify final state
|
||||
test_db.refresh(instance)
|
||||
assert instance.status == AgentStatus.COMPLETED
|
||||
assert instance.ended_at is not None
|
||||
|
||||
# Verify questions were deactivated
|
||||
test_db.refresh(question)
|
||||
assert question.is_active is False
|
||||
|
||||
# Verify all steps were logged
|
||||
steps = (
|
||||
test_db.query(AgentStep)
|
||||
.filter_by(agent_instance_id=instance_id)
|
||||
.order_by(AgentStep.step_number)
|
||||
.all()
|
||||
)
|
||||
|
||||
assert len(steps) == 3
|
||||
assert steps[0].description == "Starting integration test task"
|
||||
assert steps[1].description == "Processing files"
|
||||
assert steps[2].description == "Implementing async pattern"
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_multiple_feedback_handling(self, test_db, test_user, test_user_agent):
|
||||
"""Test handling multiple feedback items."""
|
||||
# Create instance
|
||||
instance_id, _, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Starting task",
|
||||
)
|
||||
|
||||
# Add multiple feedback items
|
||||
feedback_items = []
|
||||
for i in range(3):
|
||||
feedback = AgentUserFeedback(
|
||||
id=uuid4(),
|
||||
agent_instance_id=instance_id,
|
||||
created_by_user_id=test_user.id,
|
||||
feedback_text=f"Feedback {i + 1}",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
feedback_items.append(feedback)
|
||||
test_db.add(feedback)
|
||||
|
||||
test_db.commit()
|
||||
|
||||
# Next log_step should retrieve all feedback
|
||||
_, _, feedback_list = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Processing feedback",
|
||||
)
|
||||
|
||||
assert len(feedback_list) == 3
|
||||
assert set(feedback_list) == {"Feedback 1", "Feedback 2", "Feedback 3"}
|
||||
|
||||
# All feedback should be marked as retrieved
|
||||
for feedback in feedback_items:
|
||||
test_db.refresh(feedback)
|
||||
assert feedback.retrieved_at is not None
|
||||
|
||||
# Subsequent log_step should not retrieve same feedback
|
||||
_, _, feedback_list2 = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Continuing work",
|
||||
)
|
||||
|
||||
assert len(feedback_list2) == 0
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_user_agent_creation_and_reuse(self, test_db, test_user):
|
||||
"""Test that user agents are created and reused correctly."""
|
||||
# First call should create a new user agent
|
||||
instance1_id, _, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="New Agent Type",
|
||||
step_description="First task",
|
||||
)
|
||||
|
||||
# Check user agent was created (name is stored in lowercase)
|
||||
user_agents = (
|
||||
test_db.query(UserAgent)
|
||||
.filter_by(user_id=test_user.id, name="new agent type")
|
||||
.all()
|
||||
)
|
||||
assert len(user_agents) == 1
|
||||
|
||||
# Second call with same agent type should reuse the user agent
|
||||
instance2_id, _, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="New Agent Type",
|
||||
step_description="Second task",
|
||||
)
|
||||
|
||||
# Should still only have one user agent (name is stored in lowercase)
|
||||
user_agents = (
|
||||
test_db.query(UserAgent)
|
||||
.filter_by(user_id=test_user.id, name="new agent type")
|
||||
.all()
|
||||
)
|
||||
assert len(user_agents) == 1
|
||||
|
||||
# But two different instances
|
||||
assert instance1_id != instance2_id
|
||||
|
||||
# Both instances should reference the same user agent
|
||||
instance1 = test_db.query(AgentInstance).filter_by(id=instance1_id).first()
|
||||
instance2 = test_db.query(AgentInstance).filter_by(id=instance2_id).first()
|
||||
assert instance1.user_agent_id == instance2.user_agent_id
|
||||
162
servers/tests/test_shared_core.py
Normal file
162
servers/tests/test_shared_core.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Tests for shared core functionality."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.database.models import (
|
||||
AgentStep,
|
||||
AgentQuestion,
|
||||
AgentUserFeedback,
|
||||
)
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
|
||||
class TestDatabaseModels:
|
||||
"""Test database model functionality."""
|
||||
|
||||
def test_create_agent_instance(self, test_db, test_agent_instance):
|
||||
"""Test creating an agent instance."""
|
||||
assert test_agent_instance.id is not None
|
||||
assert test_agent_instance.status == AgentStatus.ACTIVE
|
||||
assert test_agent_instance.started_at is not None
|
||||
assert test_agent_instance.ended_at is None
|
||||
|
||||
def test_create_agent_step(self, test_db, test_agent_instance):
|
||||
"""Test creating agent steps."""
|
||||
step1 = AgentStep(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=1,
|
||||
description="First step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
step2 = AgentStep(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=2,
|
||||
description="Second step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
test_db.add_all([step1, step2])
|
||||
test_db.commit()
|
||||
|
||||
# Query steps
|
||||
steps = (
|
||||
test_db.query(AgentStep)
|
||||
.filter_by(agent_instance_id=test_agent_instance.id)
|
||||
.order_by(AgentStep.step_number)
|
||||
.all()
|
||||
)
|
||||
|
||||
assert len(steps) == 2
|
||||
assert steps[0].step_number == 1
|
||||
assert steps[0].description == "First step"
|
||||
assert steps[1].step_number == 2
|
||||
assert steps[1].description == "Second step"
|
||||
|
||||
def test_create_agent_question(self, test_db, test_agent_instance):
|
||||
"""Test creating agent questions."""
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Should I continue?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Query question
|
||||
saved_question = test_db.query(AgentQuestion).filter_by(id=question.id).first()
|
||||
|
||||
assert saved_question is not None
|
||||
assert saved_question.question_text == "Should I continue?"
|
||||
assert saved_question.is_active is True
|
||||
assert saved_question.answer_text is None
|
||||
assert saved_question.answered_at is None
|
||||
|
||||
def test_create_user_feedback(self, test_db, test_agent_instance):
|
||||
"""Test creating user feedback."""
|
||||
feedback = AgentUserFeedback(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
created_by_user_id=test_agent_instance.user_id,
|
||||
feedback_text="Please use TypeScript",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
test_db.add(feedback)
|
||||
test_db.commit()
|
||||
|
||||
# Query feedback
|
||||
saved_feedback = (
|
||||
test_db.query(AgentUserFeedback).filter_by(id=feedback.id).first()
|
||||
)
|
||||
|
||||
assert saved_feedback is not None
|
||||
assert saved_feedback.feedback_text == "Please use TypeScript"
|
||||
assert saved_feedback.retrieved_at is None
|
||||
|
||||
def test_agent_instance_relationships(self, test_db, test_agent_instance):
|
||||
"""Test agent instance relationships."""
|
||||
# Add a step
|
||||
step = AgentStep(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=1,
|
||||
description="Test step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Add a question
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Test question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
test_db.add_all([step, question])
|
||||
test_db.commit()
|
||||
|
||||
# Refresh instance to load relationships
|
||||
test_db.refresh(test_agent_instance)
|
||||
|
||||
# Test relationships
|
||||
assert len(test_agent_instance.steps) == 1
|
||||
assert test_agent_instance.steps[0].description == "Test step"
|
||||
|
||||
assert len(test_agent_instance.questions) == 1
|
||||
assert test_agent_instance.questions[0].question_text == "Test question?"
|
||||
|
||||
|
||||
class TestAgentStatusTransitions:
|
||||
"""Test agent status transitions."""
|
||||
|
||||
def test_complete_agent_instance(self, test_db, test_agent_instance):
|
||||
"""Test completing an agent instance."""
|
||||
# Complete the instance
|
||||
test_agent_instance.status = AgentStatus.COMPLETED
|
||||
test_agent_instance.ended_at = datetime.now(timezone.utc)
|
||||
test_db.commit()
|
||||
|
||||
# Verify status change
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.COMPLETED
|
||||
assert test_agent_instance.ended_at is not None
|
||||
|
||||
def test_fail_agent_instance(self, test_db, test_agent_instance):
|
||||
"""Test failing an agent instance."""
|
||||
# Fail the instance
|
||||
test_agent_instance.status = AgentStatus.FAILED
|
||||
test_agent_instance.ended_at = datetime.now(timezone.utc)
|
||||
test_db.commit()
|
||||
|
||||
# Verify status change
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.FAILED
|
||||
assert test_agent_instance.ended_at is not None
|
||||
77
shared/README.md
Normal file
77
shared/README.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# Shared
|
||||
|
||||
This directory contains shared infrastructure for database operations and configurations used across the Omnara platform.
|
||||
|
||||
## Purpose
|
||||
|
||||
The shared directory serves as the single source of truth for:
|
||||
- Database schema definitions and models
|
||||
- Database connection management
|
||||
- Configuration settings
|
||||
- Schema migration infrastructure
|
||||
|
||||
## Architecture
|
||||
|
||||
### Database Layer
|
||||
- **ORM**: SQLAlchemy 2.0+ with modern declarative mapping
|
||||
- **Database**: PostgreSQL for reliable, scalable data persistence
|
||||
- **Models**: Centralized schema definitions for all platform entities
|
||||
- **Session Management**: Shared database connection handling
|
||||
|
||||
### Configuration Management
|
||||
- Environment-aware settings (development, production)
|
||||
- Centralized configuration using Pydantic settings
|
||||
- Support for multiple deployment scenarios
|
||||
|
||||
### Schema Migrations
|
||||
- Alembic for version-controlled database schema changes
|
||||
- Automatic migration application during startup
|
||||
- Safe rollback capabilities
|
||||
|
||||
## Database Migrations
|
||||
|
||||
### Essential Commands
|
||||
|
||||
```bash
|
||||
# Apply pending migrations
|
||||
cd shared/
|
||||
alembic upgrade head
|
||||
|
||||
# Create a new migration after model changes
|
||||
alembic revision --autogenerate -m "Description of changes"
|
||||
|
||||
# Check migration status
|
||||
alembic current
|
||||
|
||||
# View migration history
|
||||
alembic history
|
||||
|
||||
# Rollback one migration
|
||||
alembic downgrade -1
|
||||
```
|
||||
|
||||
### Migration Workflow
|
||||
|
||||
1. Modify database models
|
||||
2. Generate migration: `alembic revision --autogenerate -m "Description"`
|
||||
3. Review generated migration file
|
||||
4. Apply migration (automatic on restart or manual with `alembic upgrade head`)
|
||||
5. Commit both model changes and migration files
|
||||
|
||||
**Important**: Always create migrations when changing the database schema. A pre-commit hook enforces this requirement.
|
||||
|
||||
## Key Benefits
|
||||
|
||||
- **Consistency**: Single schema definition prevents drift between services
|
||||
- **Type Safety**: Shared type definitions and enumerations
|
||||
- **Maintainability**: Centralized database operations reduce duplication
|
||||
- **Version Control**: Migration history tracks all schema changes
|
||||
- **Multi-Service**: Both API backend and MCP servers use the same database layer
|
||||
|
||||
## Dependencies
|
||||
|
||||
Core dependencies are managed in `requirements.txt` and include:
|
||||
- SQLAlchemy for ORM functionality
|
||||
- PostgreSQL driver for database connectivity
|
||||
- Pydantic for configuration and validation
|
||||
- Alembic for migration management
|
||||
1
shared/__init__.py
Normal file
1
shared/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Shared package initialization
|
||||
118
shared/alembic.ini
Normal file
118
shared/alembic.ini
Normal file
@@ -0,0 +1,118 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
# Database URL is now configured in env.py using the shared settings
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
81
shared/alembic/env.py
Normal file
81
shared/alembic/env.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Import our models and settings
|
||||
from database.models import Base
|
||||
from config.settings import settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Set the database URL from our settings
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
shared/alembic/script.py.mako
Normal file
26
shared/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user