Refactor (#45)

* progress

* refactor

* backend refactor

* test

* minor stdio changes

* streeeeeaming

* stream statuses

* change required user input

* something

* progress

* new beginnings

* progress

* edge case

* uuhh the claude wrapper def worked before this commit

* progress

* ai slop that works

* tests

* tests

* plan mode handling

* progress

* clean up

* bump

* migrations

* readmes

* no text truncation

* consistent poll interval

* fix time

---------

Co-authored-by: Kartik Sarangmath <kartiksarangmath@Kartiks-MacBook-Air.local>
This commit is contained in:
ksarangmath
2025-08-04 01:44:42 -07:00
committed by GitHub
parent 140be5b512
commit eaca5a0ad0
44 changed files with 5178 additions and 1816 deletions

View File

@@ -40,7 +40,8 @@ omnara/
- **PostgreSQL** with **SQLAlchemy 2.0+**
- **Alembic** for migrations - ALWAYS create migrations for schema changes
- Multi-tenant design - all data is scoped by user_id
- Key tables: users, agent_types, agent_instances, agent_steps, agent_questions, agent_user_feedback, api_keys
- Key tables: users, user_agents, agent_instances, messages, api_keys
- **Unified messaging system**: All agent interactions (steps, questions, feedback) are now stored in the `messages` table with `sender_type` and `requires_user_input` fields
### Server Architecture
- **Unified server** (`servers/app.py`) supports both MCP and REST
@@ -107,6 +108,13 @@ make test-integration # Integration tests (needs Docker)
## Common Tasks
### Working with Messages
The unified messaging system uses a single `messages` table:
- **Agent messages**: Set `sender_type=AGENT`, use `requires_user_input=True` for questions
- **User messages**: Set `sender_type=USER` for feedback/responses
- **Reading messages**: Use `last_read_message_id` to track reading progress
- **Queued messages**: Agent receives unread user messages when sending new messages
### Adding a New API Endpoint
1. Add route in `backend/api/` or `servers/fastapi_server/routers.py`
2. Create Pydantic models for request/response in `models.py`

View File

@@ -186,25 +186,27 @@ client = OmnaraClient(api_key="your-api-key")
instance_id = str(uuid.uuid4())
# Log progress and check for user feedback
response = client.log_step(
response = client.send_message(
agent_type="claude-code",
step_description="Analyzing codebase structure",
agent_instance_id=instance_id
content="Analyzing codebase structure",
agent_instance_id=instance_id,
requires_user_input=False
)
# Ask for user input when needed
answer = client.ask_question(
question="Should I refactor this legacy module?",
agent_instance_id=instance_id
answer = client.send_message(
content="Should I refactor this legacy module?",
agent_instance_id=instance_id,
requires_user_input=True
)
```
### Method 3: REST API
```bash
curl -X POST https://api.omnara.ai/api/v1/steps \
curl -X POST https://api.omnara.ai/api/v1/messages/agent \
-H "Authorization: Bearer YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{"step_description": "Starting deployment process"}'
-d '{"content": "Starting deployment process", "agent_type": "claude-code", "requires_user_input": false}'
```
## 🤝 Contributing

View File

@@ -27,7 +27,7 @@ The backend provides a REST API for accessing and managing agent-related data. I
## Key Features
- **Agent Monitoring** - View agent types, instances, and execution history
- **User Interactions** - Handle questions from agents and user feedback
- **Unified Messaging** - All agent interactions (steps, questions, feedback) through a single messaging system
- **Multi-tenancy** - User-scoped data isolation and access control
- **Authentication** - Support for both web dashboard users and programmatic agent access
- **User Agent Management** - Custom agent configurations and webhook integrations

View File

@@ -1,10 +1,15 @@
from uuid import UUID
import asyncio
import json
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from shared.database.models import User
from shared.database.session import get_db
from shared.database.enums import AgentStatus
from sqlalchemy.orm import Session
import asyncpg
from ..auth.dependencies import get_current_user
from ..db import (
@@ -14,14 +19,14 @@ from ..db import (
get_all_agent_instances,
get_all_agent_types_with_instances,
mark_instance_completed,
submit_user_feedback,
submit_user_message,
)
from ..models import (
AgentInstanceDetail,
AgentInstanceResponse,
AgentTypeOverview,
UserFeedbackRequest,
UserFeedbackResponse,
UserMessageRequest,
UserMessageResponse,
)
router = APIRouter(tags=["agents"])
@@ -87,21 +92,129 @@ async def get_instance_detail(
@router.post(
"/agent-instances/{instance_id}/feedback", response_model=UserFeedbackResponse
"/agent-instances/{instance_id}/messages", response_model=UserMessageResponse
)
async def add_user_feedback(
async def create_user_message(
instance_id: UUID,
request: UserFeedbackRequest,
request: UserMessageRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Submit user feedback for an agent instance for the current user"""
result = submit_user_feedback(db, instance_id, request.feedback, current_user.id)
"""Send a message to an agent instance (answers questions or provides feedback)"""
result = submit_user_message(db, instance_id, request.content, current_user.id)
if not result:
raise HTTPException(status_code=404, detail="Agent instance not found")
return result
@router.get("/agent-instances/{instance_id}/messages/stream")
async def stream_messages(
request: Request,
instance_id: UUID,
token: str | None = None,
db: Session = Depends(get_db),
):
"""Stream new messages for an agent instance using Server-Sent Events"""
# Handle SSE authentication - token comes from query param
if not token:
raise HTTPException(status_code=401, detail="Token required for SSE")
try:
# Verify token and get user
from ..auth.supabase_client import get_supabase_anon_client
supabase = get_supabase_anon_client()
user_response = supabase.auth.get_user(token)
if not user_response or not user_response.user:
raise HTTPException(status_code=401, detail="Invalid token")
user_id = UUID(user_response.user.id)
except Exception as e:
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
# Verify the user has access to this instance
instance = get_agent_instance_detail(db, instance_id, user_id)
if not instance:
raise HTTPException(status_code=404, detail="Agent instance not found")
async def message_generator() -> AsyncGenerator[str, None]:
# Import settings here to avoid circular imports
from shared.config.settings import settings
# Create connection to PostgreSQL for LISTEN/NOTIFY
conn = await asyncpg.connect(settings.database_url)
try:
# Listen to the channel for this instance
channel_name = f"message_channel_{instance_id}"
# Execute LISTEN command (quote channel name for UUIDs with hyphens)
await conn.execute(f'LISTEN "{channel_name}"')
# Create a queue to receive notifications
notification_queue = asyncio.Queue()
# Define callback to put notifications in queue
def notification_callback(connection, pid, channel, payload):
asyncio.create_task(notification_queue.put(payload))
# Add listener with callback
await conn.add_listener(channel_name, notification_callback)
# Send initial connection event
yield f"event: connected\ndata: {json.dumps({'instance_id': str(instance_id)})}\n\n"
while True:
# Check if client disconnected
if await request.is_disconnected():
break
try:
# Wait for notification with timeout for heartbeat
payload = await asyncio.wait_for(
notification_queue.get(), timeout=30.0
)
# Parse the JSON payload
data = json.loads(payload)
# Check event type and send appropriate SSE event
event_type = data.get("event_type")
if event_type == "status_update":
# Send status_update event
yield f"event: status_update\ndata: {json.dumps(data)}\n\n"
elif event_type == "message_update":
# Send message_update event for frontend to handle
yield f"event: message_update\ndata: {json.dumps(data)}\n\n"
else:
# Regular message event (either message_insert or legacy without event_type)
yield f"event: message\ndata: {json.dumps(data)}\n\n"
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield f"event: heartbeat\ndata: {json.dumps({'timestamp': asyncio.get_event_loop().time()})}\n\n"
except Exception as e:
# Send error event
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
break
finally:
# Clean up listener and connection
await conn.remove_listener(channel_name, notification_callback)
await conn.close()
return StreamingResponse(
message_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable Nginx buffering
},
)
@router.put(
"/agent-instances/{instance_id}/status",
response_model=AgentInstanceResponse,

View File

@@ -1,28 +0,0 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from shared.database.models import User
from shared.database.session import get_db
from sqlalchemy.orm import Session
from ..auth.dependencies import get_current_user
from ..db import submit_answer
from ..models import AnswerRequest
router = APIRouter(prefix="/questions", tags=["questions"])
@router.post("/{question_id}/answer")
async def answer_question(
question_id: UUID,
request: AnswerRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Submit an answer to a pending question for the current user"""
result = submit_answer(db, question_id, request.answer, current_user.id)
if not result:
raise HTTPException(
status_code=404, detail="Question not found or already answered"
)
return {"success": True, "message": "Answer submitted successfully"}

View File

@@ -5,8 +5,7 @@ from .queries import (
get_all_agent_types_with_instances,
get_agent_summary,
mark_instance_completed,
submit_answer,
submit_user_feedback,
submit_user_message,
)
from .user_agent_queries import (
create_user_agent,
@@ -24,8 +23,7 @@ __all__ = [
"get_agent_summary",
"get_agent_instance_detail",
"mark_instance_completed",
"submit_answer",
"submit_user_feedback",
"submit_user_message",
"create_user_agent",
"get_user_agents",
"update_user_agent",

View File

@@ -5,12 +5,11 @@ from uuid import UUID
from shared.config import settings
from shared.database import (
AgentInstance,
AgentQuestion,
AgentStatus,
AgentStep,
AgentUserFeedback,
APIKey,
Message,
PushToken,
SenderType,
User,
UserAgent,
)
@@ -19,52 +18,48 @@ from shared.database.subscription_models import BillingEvent, Subscription
from sqlalchemy import desc, func
from sqlalchemy.orm import Session, joinedload
# Import Pydantic models for type-safe returns
from backend.models import (
AgentInstanceResponse,
AgentInstanceDetail,
MessageResponse,
UserMessageResponse,
AgentTypeOverview,
)
def _format_instance(instance: AgentInstance) -> dict:
def _format_instance(instance: AgentInstance) -> AgentInstanceResponse:
"""Helper function to format an agent instance consistently"""
# Get latest step
latest_step = None
if instance.steps:
latest_step = max(instance.steps, key=lambda s: s.created_at).description
# Get all messages for this instance
messages = instance.messages if hasattr(instance, "messages") else []
# Get step count
step_count = len(instance.steps) if instance.steps else 0
# Get latest message and its timestamp
latest_message = None
latest_message_at = None
if messages:
last_msg = max(messages, key=lambda m: m.created_at)
latest_message = last_msg.content
latest_message_at = last_msg.created_at
# Check for pending questions
pending_questions = [q for q in instance.questions if q.is_active]
pending_questions_count = len(pending_questions)
has_pending = pending_questions_count > 0
pending_age = None
if has_pending:
oldest_pending = min(pending_questions, key=lambda q: q.asked_at)
# All database times are stored as UTC but may be naive
now_utc = datetime.now(timezone.utc)
asked_at = oldest_pending.asked_at
if asked_at.tzinfo is None:
asked_at = asked_at.replace(tzinfo=timezone.utc)
pending_age = int((now_utc - asked_at).total_seconds())
# Get total message count (chat length)
chat_length = len(messages)
return {
"id": str(instance.id),
"agent_type_id": str(instance.user_agent_id) if instance.user_agent_id else "",
"agent_type_name": instance.user_agent.name
if instance.user_agent
else "Unknown",
"status": instance.status,
"started_at": instance.started_at,
"ended_at": instance.ended_at,
"latest_step": latest_step,
"has_pending_question": has_pending,
"pending_question_age": pending_age,
"pending_questions_count": pending_questions_count,
"step_count": step_count,
"last_signal_at": instance.steps[-1].created_at
if instance.steps
else instance.started_at,
}
return AgentInstanceResponse(
id=str(instance.id),
agent_type_id=str(instance.user_agent_id) if instance.user_agent_id else "",
agent_type_name=instance.user_agent.name if instance.user_agent else "Unknown",
status=instance.status,
started_at=instance.started_at,
ended_at=instance.ended_at,
latest_message=latest_message,
latest_message_at=latest_message_at,
chat_length=chat_length,
)
def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]:
def get_all_agent_types_with_instances(
db: Session, user_id: UUID
) -> list[AgentTypeOverview]:
"""Get all user agents with their instances for a specific user"""
# Get all user agents for this user
@@ -79,25 +74,29 @@ def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]
AgentInstance.user_agent_id == user_agent.id,
)
.options(
joinedload(AgentInstance.steps),
joinedload(AgentInstance.questions),
joinedload(AgentInstance.messages),
joinedload(AgentInstance.user_agent),
)
.all()
)
# Sort instances: pending questions first, then by most recent activity
# Sort instances: AWAITING_INPUT instances first, then by most recent activity
def sort_key(instance):
pending_questions = [q for q in instance.questions if q.is_active]
if pending_questions:
oldest_question = min(pending_questions, key=lambda q: q.asked_at)
return (0, oldest_question.asked_at)
messages = instance.messages if hasattr(instance, "messages") else []
# If instance is awaiting input, prioritize it
if instance.status == AgentStatus.AWAITING_INPUT:
# Sort by when the question was asked (last message time)
if messages:
last_msg_time = max(messages, key=lambda m: m.created_at).created_at
return (0, last_msg_time)
else:
return (0, instance.started_at)
# Otherwise sort by last activity
last_activity = instance.started_at
if instance.steps:
last_activity = max(
instance.steps, key=lambda s: s.created_at
).created_at
if messages:
last_activity = max(messages, key=lambda m: m.created_at).created_at
return (1, -last_activity.timestamp())
sorted_instances = sorted(instances, key=sort_key)
@@ -108,16 +107,16 @@ def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]
]
result.append(
{
"id": str(user_agent.id),
"name": user_agent.name,
"created_at": user_agent.created_at,
"recent_instances": formatted_instances,
"total_instances": len(instances),
"active_instances": sum(
AgentTypeOverview(
id=str(user_agent.id),
name=user_agent.name,
created_at=user_agent.created_at,
recent_instances=formatted_instances,
total_instances=len(instances),
active_instances=sum(
1 for i in instances if i.status == AgentStatus.ACTIVE
),
}
)
)
return result
@@ -125,15 +124,14 @@ def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]
def get_all_agent_instances(
db: Session, user_id: UUID, limit: int | None = None
) -> list[dict]:
) -> list[AgentInstanceResponse]:
"""Get all agent instances for a specific user, sorted by most recent activity"""
query = (
db.query(AgentInstance)
.filter(AgentInstance.user_id == user_id)
.options(
joinedload(AgentInstance.steps),
joinedload(AgentInstance.questions),
joinedload(AgentInstance.messages),
joinedload(AgentInstance.user_agent),
)
.order_by(desc(AgentInstance.started_at))
@@ -216,7 +214,7 @@ def get_agent_summary(db: Session, user_id: UUID) -> dict:
def get_agent_type_instances(
db: Session, agent_type_id: UUID, user_id: UUID
) -> list[dict] | None:
) -> list[AgentInstanceResponse] | None:
"""Get all instances for a specific user agent"""
user_agent = (
@@ -233,8 +231,7 @@ def get_agent_type_instances(
AgentInstance.user_agent_id == agent_type_id,
)
.options(
joinedload(AgentInstance.steps),
joinedload(AgentInstance.questions),
joinedload(AgentInstance.messages),
joinedload(AgentInstance.user_agent),
)
.order_by(desc(AgentInstance.started_at))
@@ -247,7 +244,7 @@ def get_agent_type_instances(
def get_agent_instance_detail(
db: Session, instance_id: UUID, user_id: UUID
) -> dict | None:
) -> AgentInstanceDetail | None:
"""Get detailed information about a specific agent instance for a specific user"""
instance = (
@@ -255,9 +252,7 @@ def get_agent_instance_detail(
.filter(AgentInstance.id == instance_id, AgentInstance.user_id == user_id)
.options(
joinedload(AgentInstance.user_agent),
joinedload(AgentInstance.steps),
joinedload(AgentInstance.questions),
joinedload(AgentInstance.user_feedback),
joinedload(AgentInstance.messages),
)
.first()
)
@@ -265,122 +260,45 @@ def get_agent_instance_detail(
if not instance:
return None
# Sort steps by step number
sorted_steps = sorted(instance.steps, key=lambda s: s.step_number)
# Sort questions by asked_at
sorted_questions = sorted(instance.questions, key=lambda q: q.asked_at)
# Sort user feedback by created_at
sorted_feedback = sorted(instance.user_feedback, key=lambda f: f.created_at)
return {
"id": str(instance.id),
"agent_type_id": str(instance.user_agent_id) if instance.user_agent_id else "",
"agent_type": {
"id": str(instance.user_agent.id) if instance.user_agent else "",
"name": instance.user_agent.name if instance.user_agent else "Unknown",
"created_at": instance.user_agent.created_at
if instance.user_agent
else datetime.now(timezone.utc),
"recent_instances": [],
"total_instances": 0,
"active_instances": 0,
},
"status": instance.status,
"started_at": instance.started_at,
"ended_at": instance.ended_at,
"git_diff": instance.git_diff,
"steps": [
{
"id": str(step.id),
"step_number": step.step_number,
"description": step.description,
"created_at": step.created_at,
}
for step in sorted_steps
],
"questions": [
{
"id": str(question.id),
"question_text": question.question_text,
"answer_text": question.answer_text,
"asked_at": question.asked_at,
"answered_at": question.answered_at,
"is_active": question.is_active,
}
for question in sorted_questions
],
"user_feedback": [
{
"id": str(feedback.id),
"feedback_text": feedback.feedback_text,
"created_at": feedback.created_at,
"retrieved_at": feedback.retrieved_at,
}
for feedback in sorted_feedback
],
}
def submit_answer(
db: Session, question_id: UUID, answer: str, user_id: UUID
) -> dict | None:
"""Submit an answer to a question for a specific user"""
question = (
db.query(AgentQuestion)
.filter(AgentQuestion.id == question_id, AgentQuestion.is_active)
.join(AgentInstance)
.filter(AgentInstance.user_id == user_id)
.first()
# Get all messages and sort by created_at
messages = (
sorted(instance.messages, key=lambda m: m.created_at)
if hasattr(instance, "messages")
else []
)
if not question:
return None
question.answer_text = answer
question.answered_at = datetime.now(timezone.utc)
question.is_active = False
question.answered_by_user_id = user_id
# Update agent instance status back to ACTIVE if it was AWAITING_INPUT
instance = (
db.query(AgentInstance)
.filter(AgentInstance.id == question.agent_instance_id)
.first()
)
if instance and instance.status == AgentStatus.AWAITING_INPUT:
# Check if there are other active questions for this instance
other_active_questions = (
db.query(AgentQuestion)
.filter(
AgentQuestion.agent_instance_id == instance.id,
AgentQuestion.id != question_id,
AgentQuestion.is_active,
# Format messages for chat display
formatted_messages = []
for msg in messages:
formatted_messages.append(
MessageResponse(
id=str(msg.id),
content=msg.content,
sender_type=msg.sender_type.value, # "agent" or "user"
created_at=msg.created_at,
requires_user_input=msg.requires_user_input,
)
.count()
)
# Only change status back to ACTIVE if no other questions are pending
if other_active_questions == 0:
instance.status = AgentStatus.ACTIVE
db.commit()
return {
"id": str(question.id),
"question_text": question.question_text,
"answer_text": question.answer_text,
"asked_at": question.asked_at,
"answered_at": question.answered_at,
"is_active": question.is_active,
}
return AgentInstanceDetail(
id=str(instance.id),
agent_type_id=str(instance.user_agent_id) if instance.user_agent_id else "",
agent_type_name=instance.user_agent.name if instance.user_agent else "Unknown",
status=instance.status,
started_at=instance.started_at,
ended_at=instance.ended_at,
git_diff=instance.git_diff,
messages=formatted_messages,
last_read_message_id=str(instance.last_read_message_id)
if instance.last_read_message_id
else None,
)
def submit_user_feedback(
db: Session, instance_id: UUID, feedback_text: str, user_id: UUID
) -> dict | None:
"""Submit user feedback for an agent instance for a specific user"""
def submit_user_message(
db: Session, instance_id: UUID, content: str, user_id: UUID
) -> UserMessageResponse | None:
"""Submit a user message to an agent instance"""
# Check if instance exists and belongs to user
instance = (
@@ -391,28 +309,33 @@ def submit_user_feedback(
if not instance:
return None
# Create new feedback
feedback = AgentUserFeedback(
# Create new user message
user_message = Message(
agent_instance_id=instance_id,
feedback_text=feedback_text,
created_by_user_id=user_id,
sender_type=SenderType.USER,
content=content,
requires_user_input=False,
)
db.add(user_message)
if instance.status != AgentStatus.COMPLETED:
instance.status = AgentStatus.ACTIVE
db.add(feedback)
db.commit()
db.refresh(feedback)
db.refresh(user_message)
return {
"id": str(feedback.id),
"feedback_text": feedback.feedback_text,
"created_at": feedback.created_at,
"retrieved_at": feedback.retrieved_at,
}
return UserMessageResponse(
id=str(user_message.id),
content=user_message.content,
sender_type=user_message.sender_type.value,
created_at=user_message.created_at,
requires_user_input=user_message.requires_user_input,
)
def mark_instance_completed(
db: Session, instance_id: UUID, user_id: UUID
) -> dict | None:
) -> AgentInstanceResponse | None:
"""Mark an agent instance as completed for a specific user"""
# Check if instance exists and belongs to user
@@ -428,10 +351,7 @@ def mark_instance_completed(
instance.status = AgentStatus.COMPLETED
instance.ended_at = datetime.now(timezone.utc)
# Deactivate any pending questions
db.query(AgentQuestion).filter(
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
).update({"is_active": False})
# No need to deactivate questions - they're handled by checking for user responses
db.commit()
@@ -441,8 +361,7 @@ def mark_instance_completed(
.filter(AgentInstance.id == instance_id)
.options(
joinedload(AgentInstance.user_agent),
joinedload(AgentInstance.steps),
joinedload(AgentInstance.questions),
joinedload(AgentInstance.messages),
)
.first()
)
@@ -483,16 +402,6 @@ def delete_user_account(db: Session, user_id: UUID) -> None:
)
# Delete in order of foreign key dependencies
# 1. Delete AgentUserFeedback (depends on AgentInstance and User)
db.query(AgentUserFeedback).filter(
AgentUserFeedback.created_by_user_id == user_id
).delete(synchronize_session=False)
# 2. Delete AgentQuestions (depends on AgentInstance and User)
db.query(AgentQuestion).filter(
AgentQuestion.answered_by_user_id == user_id
).delete(synchronize_session=False)
# Get all agent instances for this user to delete their related data
instance_ids = [
instance.id
@@ -502,19 +411,9 @@ def delete_user_account(db: Session, user_id: UUID) -> None:
]
if instance_ids:
# Delete questions for user's instances
db.query(AgentQuestion).filter(
AgentQuestion.agent_instance_id.in_(instance_ids)
).delete(synchronize_session=False)
# Delete steps for user's instances
db.query(AgentStep).filter(
AgentStep.agent_instance_id.in_(instance_ids)
).delete(synchronize_session=False)
# Delete feedback for user's instances
db.query(AgentUserFeedback).filter(
AgentUserFeedback.agent_instance_id.in_(instance_ids)
# Delete messages for user's instances
db.query(Message).filter(
Message.agent_instance_id.in_(instance_ids)
).delete(synchronize_session=False)
# 3. Delete AgentInstances (depends on UserAgent and User)

View File

@@ -12,9 +12,7 @@ from shared.database import (
AgentInstance,
AgentStatus,
APIKey,
AgentStep,
AgentQuestion,
AgentUserFeedback,
Message,
)
from shared.database.billing_operations import check_agent_limit
from sqlalchemy import and_, func
@@ -157,6 +155,7 @@ async def trigger_webhook_agent(
payload = {
"agent_instance_id": str(agent_instance_id),
"prompt": prompt,
"agent_type": user_agent.name,
}
if name is not None:
@@ -281,9 +280,7 @@ def get_user_agent_instances(db: Session, agent_id: UUID, user_id: UUID) -> list
instances = (
db.query(AgentInstance)
.options(
joinedload(AgentInstance.questions),
joinedload(AgentInstance.steps),
joinedload(AgentInstance.user_feedback),
joinedload(AgentInstance.messages),
joinedload(AgentInstance.user_agent),
)
.filter(AgentInstance.user_agent_id == agent_id)
@@ -315,18 +312,8 @@ def delete_user_agent(db: Session, agent_id: UUID, user_id: UUID) -> bool:
# For each agent instance, delete all related data
for instance in agent_instances:
# Delete agent steps
db.query(AgentStep).filter(AgentStep.agent_instance_id == instance.id).delete()
# Delete agent questions
db.query(AgentQuestion).filter(
AgentQuestion.agent_instance_id == instance.id
).delete()
# Delete user feedback
db.query(AgentUserFeedback).filter(
AgentUserFeedback.agent_instance_id == instance.id
).delete()
# Delete messages
db.query(Message).filter(Message.agent_instance_id == instance.id).delete()
# Delete all agent instances
db.query(AgentInstance).filter(AgentInstance.user_agent_id == agent_id).delete()

View File

@@ -9,7 +9,6 @@ import sentry_sdk
from shared.config import settings
from .api import (
agents,
questions,
user_agents,
push_notifications,
billing,
@@ -73,7 +72,6 @@ app.add_middleware(
# Include routers with versioned API prefix
app.include_router(auth_routes.router, prefix=settings.api_v1_prefix)
app.include_router(agents.router, prefix=settings.api_v1_prefix)
app.include_router(questions.router, prefix=settings.api_v1_prefix)
app.include_router(user_agents.router, prefix=settings.api_v1_prefix)
app.include_router(push_notifications.router, prefix=settings.api_v1_prefix)
app.include_router(user_settings.router, prefix=settings.api_v1_prefix)

View File

@@ -13,31 +13,24 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer
from shared.database.enums import AgentStatus
# ============================================================================
# Question Models
# Message Models
# ============================================================================
# This is when the Agent prompts the user for an answer and this format
# is what the user responds with.
class AnswerRequest(BaseModel):
answer: str = Field(..., description="User's answer to the question")
# Unified message models
class UserMessageRequest(BaseModel):
content: str = Field(..., description="Message content from the user")
# User feedback that agents can retrieve during their operations
class UserFeedbackRequest(BaseModel):
feedback: str = Field(..., description="User's feedback or additional information")
class UserFeedbackResponse(BaseModel):
class UserMessageResponse(BaseModel):
id: str
feedback_text: str
content: str
sender_type: str
created_at: datetime
retrieved_at: datetime | None
requires_user_input: bool
@field_serializer("created_at", "retrieved_at")
def serialize_datetime(self, dt: datetime | None, _info):
if dt is None:
return None
@field_serializer("created_at")
def serialize_datetime(self, dt: datetime, _info):
return dt.isoformat() + "Z"
model_config = ConfigDict(from_attributes=True)
@@ -75,20 +68,6 @@ class UserNotificationSettingsResponse(BaseModel):
# ============================================================================
# Represents individual steps/actions taken by an agent
class AgentStepResponse(BaseModel):
id: str
step_number: int
description: str
created_at: datetime
@field_serializer("created_at")
def serialize_datetime(self, dt: datetime, _info):
return dt.isoformat() + "Z"
model_config = ConfigDict(from_attributes=True)
# Summary view of an agent instance (a single agent session/run)
class AgentInstanceResponse(BaseModel):
id: str
@@ -97,13 +76,11 @@ class AgentInstanceResponse(BaseModel):
status: AgentStatus
started_at: datetime
ended_at: datetime | None
latest_step: str | None = None
has_pending_question: bool = False
pending_question_age: int | None = None # Age in seconds
pending_questions_count: int = 0
step_count: int = 0
latest_message: str | None = None
latest_message_at: datetime | None = None # Timestamp of the latest message
chat_length: int = 0 # Total message count
@field_serializer("started_at", "ended_at")
@field_serializer("started_at", "ended_at", "latest_message_at")
def serialize_datetime(self, dt: datetime | None, _info):
if dt is None:
return None
@@ -134,37 +111,33 @@ class AgentTypeOverview(BaseModel):
# ============================================================================
# Detailed information about a question asked by an agent, including answer status
class QuestionDetail(BaseModel):
# Message model for the chat interface
class MessageResponse(BaseModel):
id: str
question_text: str
answer_text: str | None
asked_at: datetime
answered_at: datetime | None
is_active: bool
content: str
sender_type: str
created_at: datetime
requires_user_input: bool
@field_serializer("asked_at", "answered_at")
def serialize_datetime(self, dt: datetime | None, _info):
if dt is None:
return None
@field_serializer("created_at")
def serialize_datetime(self, dt: datetime, _info):
return dt.isoformat() + "Z"
model_config = ConfigDict(from_attributes=True)
# Complete detailed view of a specific agent instance
# with full step and question history
# with full message history
class AgentInstanceDetail(BaseModel):
id: str
agent_type_id: str
agent_type: AgentTypeOverview
agent_type_name: str
status: AgentStatus
started_at: datetime
ended_at: datetime | None
git_diff: str | None = None
steps: list[AgentStepResponse] = []
questions: list[QuestionDetail] = []
user_feedback: list[UserFeedbackResponse] = []
messages: list[MessageResponse] = []
last_read_message_id: str | None = None
@field_serializer("started_at", "ended_at")
def serialize_datetime(self, dt: datetime | None, _info):

View File

@@ -7,4 +7,5 @@ cryptography==42.0.5
email-validator==2.1.0
exponent-server-sdk>=2.1.0
stripe==10.8.0
asyncpg==0.29.0
-r ../shared/requirements.txt

View File

@@ -7,9 +7,6 @@ from shared.database.models import (
User,
UserAgent,
AgentInstance,
AgentStep,
AgentQuestion,
AgentUserFeedback,
)
from shared.database.enums import AgentStatus
@@ -74,20 +71,23 @@ class TestAgentEndpoints:
id=uuid4(),
user_agent_id=test_user_agent.id,
user_id=test_user.id,
status=AgentStatus.ACTIVE,
status=AgentStatus.AWAITING_INPUT,
started_at=datetime.now(timezone.utc),
)
test_db.add(instance)
# Create a pending question with timezone-aware datetime
question = AgentQuestion(
# Create a message with requires_user_input=True to simulate a question
from shared.database import Message, SenderType
question_msg = Message(
id=uuid4(),
agent_instance_id=instance.id,
question_text="Test question?",
asked_at=datetime.now(timezone.utc),
is_active=True,
sender_type=SenderType.AGENT,
content="Test question?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_db.add(question_msg)
test_db.commit()
# This should not raise a timezone error
@@ -99,12 +99,10 @@ class TestAgentEndpoints:
agent_type = data[0]
assert len(agent_type["recent_instances"]) == 1
# Check that pending question info is populated
# Check that the instance has AWAITING_INPUT status
instance_data = agent_type["recent_instances"][0]
assert instance_data["has_pending_question"] is True
assert instance_data["pending_questions_count"] == 1
assert instance_data["pending_question_age"] is not None
assert instance_data["pending_question_age"] >= 0
assert instance_data["status"] == "AWAITING_INPUT"
assert instance_data["latest_message"] == "Test question?"
def test_list_all_agent_instances(self, authenticated_client, test_agent_instance):
"""Test listing all agent instances."""
@@ -115,7 +113,7 @@ class TestAgentEndpoints:
assert len(data) == 1
instance = data[0]
assert instance["id"] == str(test_agent_instance.id)
assert instance["status"] == "active"
assert instance["status"] == "ACTIVE"
def test_list_agent_instances_with_limit(
self, authenticated_client, test_db, test_user, test_user_agent
@@ -158,15 +156,19 @@ class TestAgentEndpoints:
)
test_db.add(completed_instance)
# Add a question to the active instance
question = AgentQuestion(
# Add a message with requires_user_input to the active instance
from shared.database import Message, SenderType
question_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Test question?",
asked_at=datetime.now(timezone.utc),
is_active=True,
sender_type=SenderType.AGENT,
content="Test question?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_db.add(question_msg)
test_agent_instance.status = AgentStatus.AWAITING_INPUT
test_db.commit()
response = authenticated_client.get("/api/v1/agent-summary")
@@ -174,7 +176,7 @@ class TestAgentEndpoints:
data = response.json()
assert data["total_instances"] == 2
assert data["active_instances"] == 1
assert data["active_instances"] == 0 # AWAITING_INPUT doesn't count as active
assert data["completed_instances"] == 1
assert "agent_types" in data
assert len(data["agent_types"]) == 1
@@ -203,39 +205,35 @@ class TestAgentEndpoints:
self, authenticated_client, test_db, test_agent_instance
):
"""Test getting detailed agent instance information."""
# Add steps and questions
step1 = AgentStep(
# Add messages to simulate conversation
from shared.database import Message, SenderType
msg1 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
step_number=1,
description="First step",
sender_type=SenderType.AGENT,
content="First step completed",
requires_user_input=False,
created_at=datetime.now(timezone.utc),
)
step2 = AgentStep(
msg2 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
step_number=2,
description="Second step",
sender_type=SenderType.AGENT,
content="Need input?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
msg3 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.USER,
content="Great work!",
requires_user_input=False,
created_at=datetime.now(timezone.utc),
)
question = AgentQuestion(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Need input?",
asked_at=datetime.now(timezone.utc),
is_active=True,
)
feedback = AgentUserFeedback(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
created_by_user_id=test_agent_instance.user_id,
feedback_text="Great work!",
created_at=datetime.now(timezone.utc),
)
test_db.add_all([step1, step2, question, feedback])
test_db.add_all([msg1, msg2, msg3])
test_db.commit()
response = authenticated_client.get(
@@ -245,12 +243,14 @@ class TestAgentEndpoints:
data = response.json()
assert data["id"] == str(test_agent_instance.id)
assert len(data["steps"]) == 2
assert data["steps"][0]["description"] == "First step"
assert len(data["questions"]) == 1
assert data["questions"][0]["question_text"] == "Need input?"
assert len(data["user_feedback"]) == 1
assert data["user_feedback"][0]["feedback_text"] == "Great work!"
assert "messages" in data
assert len(data["messages"]) == 3
assert data["messages"][0]["content"] == "First step completed"
assert data["messages"][0]["sender_type"] == "AGENT"
assert data["messages"][1]["content"] == "Need input?"
assert data["messages"][1]["requires_user_input"] is True
assert data["messages"][2]["content"] == "Great work!"
assert data["messages"][2]["sender_type"] == "USER"
def test_get_instance_detail_not_found(self, authenticated_client):
"""Test getting non-existent instance detail."""
@@ -265,71 +265,140 @@ class TestAgentEndpoints:
"""Test adding user feedback to an agent instance."""
feedback_text = "Please use TypeScript for this component"
response = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/feedback",
json={"feedback": feedback_text},
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": feedback_text},
)
assert response.status_code == 200
data = response.json()
assert data["feedback_text"] == feedback_text
assert data["content"] == feedback_text
assert data["sender_type"] == "USER"
assert "id" in data
assert "created_at" in data
assert data["requires_user_input"] is False
# Verify in database
feedback = (
test_db.query(AgentUserFeedback)
.filter_by(agent_instance_id=test_agent_instance.id)
from shared.database import Message, SenderType
message = (
test_db.query(Message)
.filter_by(
agent_instance_id=test_agent_instance.id, sender_type=SenderType.USER
)
.first()
)
assert feedback is not None
assert feedback.feedback_text == feedback_text
assert feedback.retrieved_at is None
assert message is not None
assert message.content == feedback_text
assert message.sender_type == SenderType.USER
def test_add_feedback_to_nonexistent_instance(self, authenticated_client):
"""Test adding feedback to non-existent instance."""
fake_id = uuid4()
response = authenticated_client.post(
f"/api/v1/agent-instances/{fake_id}/feedback",
json={"feedback": "Test feedback"},
f"/api/v1/agent-instances/{fake_id}/messages",
json={"content": "Test feedback"},
)
assert response.status_code == 404
assert response.json()["detail"] == "Agent instance not found"
def test_update_agent_status_completed(
def test_instance_status_changes_with_messages(
self, authenticated_client, test_db, test_agent_instance
):
"""Test marking an agent instance as completed."""
response = authenticated_client.put(
f"/api/v1/agent-instances/{test_agent_instance.id}/status",
json={"status": "completed"},
)
"""Test that instance status changes based on message flow."""
from shared.database import Message, SenderType
# Initially instance should be ACTIVE
assert test_agent_instance.status == AgentStatus.ACTIVE
# Agent sends a message requiring user input
question_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.AGENT,
content="Should I use TypeScript or JavaScript?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question_msg)
test_db.commit()
# Status should change to AWAITING_INPUT (this would be done by the agent)
test_agent_instance.status = AgentStatus.AWAITING_INPUT
test_db.commit()
# User responds
response = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "Use TypeScript for better type safety"},
)
assert response.status_code == 200
# Check that status changed back to ACTIVE
test_db.refresh(test_agent_instance)
assert test_agent_instance.status == AgentStatus.ACTIVE
def test_message_creates_status_update_notification(
self, authenticated_client, test_db, test_agent_instance
):
"""Test that sending messages triggers appropriate notifications."""
# This would test the notification system if implemented
# For now, just verify message creation works
response = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "Test message for notifications"},
)
assert response.status_code == 200
# Verify message was created
from shared.database import Message
messages = (
test_db.query(Message)
.filter_by(agent_instance_id=test_agent_instance.id)
.all()
)
assert len(messages) >= 1
assert any(msg.content == "Test message for notifications" for msg in messages)
def test_agent_instance_latest_message_tracking(
self, authenticated_client, test_db, test_agent_instance
):
"""Test that latest_message is properly tracked in instance listing."""
from shared.database import Message, SenderType
import time
# Create messages with different timestamps
msg1 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.AGENT,
content="First message",
requires_user_input=False,
created_at=datetime.now(timezone.utc),
)
test_db.add(msg1)
test_db.commit()
# Small delay to ensure different timestamps
time.sleep(0.1)
msg2 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.USER,
content="Latest message",
requires_user_input=False,
created_at=datetime.now(timezone.utc),
)
test_db.add(msg2)
test_db.commit()
# Get instance list
response = authenticated_client.get("/api/v1/agent-instances")
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"
assert data["ended_at"] is not None
# Verify in database
test_db.refresh(test_agent_instance)
assert test_agent_instance.status == AgentStatus.COMPLETED
assert test_agent_instance.ended_at is not None
def test_update_agent_status_unsupported(
self, authenticated_client, test_agent_instance
):
"""Test unsupported status update."""
response = authenticated_client.put(
f"/api/v1/agent-instances/{test_agent_instance.id}/status",
json={"status": "paused"},
)
assert response.status_code == 400
assert response.json()["detail"] == "Status update not supported"
def test_update_status_nonexistent_instance(self, authenticated_client):
"""Test updating status of non-existent instance."""
fake_id = uuid4()
response = authenticated_client.put(
f"/api/v1/agent-instances/{fake_id}/status", json={"status": "completed"}
)
assert response.status_code == 404
assert response.json()["detail"] == "Agent instance not found"
assert len(data) == 1
instance = data[0]
assert instance["latest_message"] == "Latest message"
assert instance["chat_length"] == 2

View File

@@ -0,0 +1,298 @@
"""Comprehensive tests for the unified message system."""
from datetime import datetime, timezone
from uuid import uuid4
import time
from shared.database.models import Message, AgentInstance
from shared.database.enums import AgentStatus, SenderType
class TestMessageSystem:
"""Test the core message system functionality."""
def test_message_flow_creates_conversation(
self, authenticated_client, test_db, test_agent_instance
):
"""Test that messages create a proper conversation flow."""
# Agent sends initial message
agent_msg1 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.AGENT,
content="I'm starting to work on your request",
requires_user_input=False,
created_at=datetime.now(timezone.utc),
)
test_db.add(agent_msg1)
test_db.commit()
# Agent asks a question
agent_question = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.AGENT,
content="Which framework would you prefer: React or Vue?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(agent_question)
test_agent_instance.status = AgentStatus.AWAITING_INPUT
test_db.commit()
# User responds
response = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "Let's use React with TypeScript"},
)
assert response.status_code == 200
user_response = response.json()
assert user_response["sender_type"] == "USER"
assert user_response["requires_user_input"] is False
# Verify conversation in detail endpoint
detail_response = authenticated_client.get(
f"/api/v1/agent-instances/{test_agent_instance.id}"
)
assert detail_response.status_code == 200
detail = detail_response.json()
assert len(detail["messages"]) == 3
assert (
detail["messages"][0]["content"] == "I'm starting to work on your request"
)
assert (
detail["messages"][1]["content"]
== "Which framework would you prefer: React or Vue?"
)
assert detail["messages"][1]["requires_user_input"] is True
assert detail["messages"][2]["content"] == "Let's use React with TypeScript"
assert detail["messages"][2]["sender_type"] == "USER"
def test_multiple_user_messages_allowed(
self, authenticated_client, test_db, test_agent_instance
):
"""Test that users can send multiple messages in a row."""
# Send first user message
response1 = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "First instruction"},
)
assert response1.status_code == 200
# Send second user message immediately
response2 = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "Additional instruction"},
)
assert response2.status_code == 200
# Send third user message
response3 = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "One more thing..."},
)
assert response3.status_code == 200
# Verify all messages were created
messages = (
test_db.query(Message)
.filter_by(
agent_instance_id=test_agent_instance.id, sender_type=SenderType.USER
)
.all()
)
assert len(messages) == 3
assert [msg.content for msg in messages] == [
"First instruction",
"Additional instruction",
"One more thing...",
]
def test_message_ordering_preserved(
self, authenticated_client, test_db, test_agent_instance
):
"""Test that message ordering is preserved correctly."""
# Create messages with specific order
messages_data = [
(SenderType.AGENT, "Starting task", False),
(SenderType.AGENT, "Found an issue", False),
(SenderType.AGENT, "How should I proceed?", True),
(SenderType.USER, "Fix it with option A", False),
(SenderType.AGENT, "Implementing option A", False),
]
created_messages = []
for i, (sender_type, content, requires_input) in enumerate(messages_data):
time.sleep(0.01) # Ensure different timestamps
if sender_type == SenderType.AGENT:
msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=sender_type,
content=content,
requires_user_input=requires_input,
created_at=datetime.now(timezone.utc),
)
test_db.add(msg)
test_db.commit()
created_messages.append(msg)
else:
# User message via API
response = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": content},
)
assert response.status_code == 200
# Get instance detail
response = authenticated_client.get(
f"/api/v1/agent-instances/{test_agent_instance.id}"
)
assert response.status_code == 200
data = response.json()
# Verify order is preserved
assert len(data["messages"]) == 5
for i, (sender_type, content, requires_input) in enumerate(messages_data):
msg = data["messages"][i]
assert msg["content"] == content
assert msg["sender_type"] == sender_type.value
if sender_type == SenderType.AGENT:
assert msg["requires_user_input"] == requires_input
def test_agent_instance_summary_includes_message_stats(
self, authenticated_client, test_db, test_user, test_user_agent
):
"""Test that agent instance summaries include message statistics."""
# Create multiple instances with different message counts
instances = []
for i in range(3):
instance = AgentInstance(
id=uuid4(),
user_agent_id=test_user_agent.id,
user_id=test_user.id,
status=AgentStatus.ACTIVE,
started_at=datetime.now(timezone.utc),
)
test_db.add(instance)
instances.append(instance)
# Add different number of messages to each instance
for j in range(i + 1):
msg = Message(
id=uuid4(),
agent_instance_id=instance.id,
sender_type=SenderType.AGENT,
content=f"Message {j} for instance {i}",
requires_user_input=False,
created_at=datetime.now(timezone.utc),
)
test_db.add(msg)
test_db.commit()
# Get agent types overview
response = authenticated_client.get("/api/v1/agent-types")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
agent_type = data[0]
assert len(agent_type["recent_instances"]) >= 3
# Verify chat_length is included
for instance in agent_type["recent_instances"]:
assert "chat_length" in instance
assert instance["chat_length"] >= 0
def test_message_with_empty_content_allowed(
self, authenticated_client, test_agent_instance
):
"""Test that empty messages are allowed."""
response = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": ""},
)
assert response.status_code == 200
data = response.json()
assert data["content"] == ""
assert data["sender_type"] == "USER"
def test_message_status_transitions(
self, authenticated_client, test_db, test_agent_instance
):
"""Test that message flow properly triggers status transitions."""
# Start with ACTIVE
assert test_agent_instance.status == AgentStatus.ACTIVE
# Agent asks question -> AWAITING_INPUT
question = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.AGENT,
content="Should I continue?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_agent_instance.status = AgentStatus.AWAITING_INPUT
test_db.commit()
# User responds -> back to ACTIVE
response = authenticated_client.post(
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "Yes, continue"},
)
assert response.status_code == 200
test_db.refresh(test_agent_instance)
assert test_agent_instance.status == AgentStatus.ACTIVE
def test_awaiting_input_instances_prioritized(
self, authenticated_client, test_db, test_user, test_user_agent
):
"""Test that instances awaiting input are prioritized in listings."""
# Create active instance
active_instance = AgentInstance(
id=uuid4(),
user_agent_id=test_user_agent.id,
user_id=test_user.id,
status=AgentStatus.ACTIVE,
started_at=datetime.now(timezone.utc),
)
test_db.add(active_instance)
# Create awaiting input instance (older)
awaiting_instance = AgentInstance(
id=uuid4(),
user_agent_id=test_user_agent.id,
user_id=test_user.id,
status=AgentStatus.AWAITING_INPUT,
started_at=datetime.now(timezone.utc),
)
test_db.add(awaiting_instance)
# Add question message to awaiting instance
question = Message(
id=uuid4(),
agent_instance_id=awaiting_instance.id,
sender_type=SenderType.AGENT,
content="Need your input",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_db.commit()
# Get agent types
response = authenticated_client.get("/api/v1/agent-types")
assert response.status_code == 200
data = response.json()
# Awaiting input instance should be first
instances = data[0]["recent_instances"]
assert len(instances) >= 2
assert instances[0]["status"] == "AWAITING_INPUT"
assert instances[0]["latest_message"] == "Need your input"

View File

@@ -3,8 +3,8 @@
from datetime import datetime, timezone
from uuid import uuid4
from shared.database.models import AgentQuestion, AgentInstance, User
from shared.database.enums import AgentStatus
from shared.database.models import Message, AgentInstance, User
from shared.database.enums import AgentStatus, SenderType
class TestQuestionEndpoints:
@@ -14,78 +14,90 @@ class TestQuestionEndpoints:
self, authenticated_client, test_db, test_agent_instance, test_user
):
"""Test answering a pending question."""
# Create a pending question
question = AgentQuestion(
# Create a message with requires_user_input=True (question)
from shared.database import Message, SenderType
question_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Should I use async/await?",
asked_at=datetime.now(timezone.utc),
is_active=True,
sender_type=SenderType.AGENT,
content="Should I use async/await?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_db.add(question_msg)
# Set instance status to AWAITING_INPUT
test_agent_instance.status = AgentStatus.AWAITING_INPUT
test_db.commit()
# Submit answer
# Submit answer as a new message
answer_text = "Yes, use async/await for better performance"
response = authenticated_client.post(
f"/api/v1/questions/{question.id}/answer", json={"answer": answer_text}
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": answer_text},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "Answer submitted successfully"
# Verify in database
test_db.refresh(question)
assert question.answer_text == answer_text
assert question.answered_at is not None
assert question.answered_by_user_id == test_user.id
assert question.is_active is False
assert data["content"] == answer_text
assert data["sender_type"] == "USER"
assert data["requires_user_input"] is False
# Verify agent instance status changed back to active
test_db.refresh(test_agent_instance)
assert test_agent_instance.status == AgentStatus.ACTIVE
def test_answer_question_not_found(self, authenticated_client):
"""Test answering a non-existent question."""
"""Test sending message to non-existent agent instance."""
fake_id = uuid4()
response = authenticated_client.post(
f"/api/v1/questions/{fake_id}/answer", json={"answer": "Some answer"}
f"/api/v1/agent-instances/{fake_id}/messages",
json={"content": "Some answer"},
)
assert response.status_code == 404
assert response.json()["detail"] == "Question not found or already answered"
assert response.json()["detail"] == "Agent instance not found"
def test_answer_already_answered_question(
self, authenticated_client, test_db, test_agent_instance, test_user
):
"""Test answering an already answered question."""
# Create an already answered question
question = AgentQuestion(
"""Test that you can continue sending messages after answering a question."""
# In the new system, you can always send more messages
# This test verifies that behavior
from shared.database import Message, SenderType
# Create a question message
question_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Already answered?",
answer_text="Previous answer",
asked_at=datetime.now(timezone.utc),
answered_at=datetime.now(timezone.utc),
answered_by_user_id=test_user.id,
is_active=False,
sender_type=SenderType.AGENT,
content="Already answered?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_db.add(question_msg)
# Create answer message
answer_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
sender_type=SenderType.USER,
content="Previous answer",
requires_user_input=False,
created_at=datetime.now(timezone.utc),
)
test_db.add(answer_msg)
test_db.commit()
# Try to answer again
# Send another message - should work fine
response = authenticated_client.post(
f"/api/v1/questions/{question.id}/answer", json={"answer": "New answer"}
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "New message"},
)
assert response.status_code == 404
assert response.json()["detail"] == "Question not found or already answered"
# Verify answer didn't change
test_db.refresh(question)
assert question.answer_text == "Previous answer"
assert response.status_code == 200
assert response.json()["content"] == "New message"
def test_answer_question_wrong_user(self, authenticated_client, test_db):
"""Test answering a question from another user's agent."""
@@ -121,78 +133,67 @@ class TestQuestionEndpoints:
)
test_db.add(other_instance)
# Create question for other user's agent
question = AgentQuestion(
# Create question message for other user's agent
question_msg = Message(
id=uuid4(),
agent_instance_id=other_instance.id,
question_text="Other user's question?",
asked_at=datetime.now(timezone.utc),
is_active=True,
sender_type=SenderType.AGENT,
content="Other user's question?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_db.add(question_msg)
test_db.commit()
# Try to answer as current user
# Try to send message as current user to other user's instance
response = authenticated_client.post(
f"/api/v1/questions/{question.id}/answer",
json={"answer": "Trying to answer"},
f"/api/v1/agent-instances/{other_instance.id}/messages",
json={"content": "Trying to answer"},
)
assert response.status_code == 404
assert response.json()["detail"] == "Question not found or already answered"
# Verify question remains unanswered
test_db.refresh(question)
assert question.answer_text is None
assert question.is_active is True
assert response.json()["detail"] == "Agent instance not found"
def test_answer_inactive_question(
self, authenticated_client, test_db, test_agent_instance
):
"""Test answering an inactive question."""
# Create an inactive question (but not answered)
question = AgentQuestion(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Inactive question?",
asked_at=datetime.now(timezone.utc),
is_active=False, # Inactive but not answered
)
test_db.add(question)
"""Test sending message to completed agent instance."""
# Set instance to COMPLETED status
test_agent_instance.status = AgentStatus.COMPLETED
test_agent_instance.ended_at = datetime.now(timezone.utc)
test_db.commit()
# Try to answer
# Try to send message - should still work since status updates happen server-side
response = authenticated_client.post(
f"/api/v1/questions/{question.id}/answer",
json={"answer": "Trying to answer inactive"},
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": "Message to completed instance"},
)
assert response.status_code == 404
assert response.json()["detail"] == "Question not found or already answered"
# Based on the backend code, messages can still be sent to completed instances
assert response.status_code == 200
def test_answer_question_empty_answer(
self, authenticated_client, test_db, test_agent_instance
):
"""Test submitting an empty answer."""
# Create a pending question
question = AgentQuestion(
"""Test submitting an empty message."""
# Create a question message
question_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Can I submit empty?",
asked_at=datetime.now(timezone.utc),
is_active=True,
sender_type=SenderType.AGENT,
content="Can I submit empty?",
requires_user_input=True,
created_at=datetime.now(timezone.utc),
)
test_db.add(question)
test_db.add(question_msg)
test_agent_instance.status = AgentStatus.AWAITING_INPUT
test_db.commit()
# Submit empty answer - should still work
# Submit empty message - should still work
response = authenticated_client.post(
f"/api/v1/questions/{question.id}/answer", json={"answer": ""}
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
json={"content": ""},
)
assert response.status_code == 200
# Verify empty answer was saved
test_db.refresh(question)
assert question.answer_text == ""
assert question.is_active is False
assert response.json()["content"] == ""

View File

@@ -186,8 +186,8 @@ class TestUserAgentEndpoints:
assert len(data) == 2
statuses = [inst["status"] for inst in data]
assert "active" in statuses
assert "completed" in statuses
assert "ACTIVE" in statuses
assert "COMPLETED" in statuses
def test_get_user_agent_instances_not_found(self, authenticated_client):
"""Test getting instances for non-existent user agent."""

View File

@@ -57,6 +57,29 @@ def run_webhook_server(
subprocess.run(cmd)
def run_claude_wrapper(api_key, base_url=None, claude_args=None):
"""Run the Claude wrapper V3 for Omnara integration"""
# Import and run directly instead of subprocess
from webhooks.claude_wrapper_v3 import main as claude_wrapper_main
# Prepare sys.argv for the claude wrapper
original_argv = sys.argv
new_argv = ["claude_wrapper_v3", "--api-key", api_key]
if base_url:
new_argv.extend(["--base-url", base_url])
# Add any additional Claude arguments
if claude_args:
new_argv.extend(claude_args)
try:
sys.argv = new_argv
claude_wrapper_main()
finally:
sys.argv = original_argv
def main():
"""Main entry point that dispatches based on command line arguments"""
parser = argparse.ArgumentParser(
@@ -79,6 +102,12 @@ Examples:
# Run webhook server on custom port
omnara --claude-code-webhook --port 8080
# Run Claude wrapper V3
omnara --claude --api-key YOUR_API_KEY
# Run Claude wrapper with custom base URL
omnara --claude --api-key YOUR_API_KEY --base-url http://localhost:8000
# Run with custom API base URL
omnara --stdio --api-key YOUR_API_KEY --base-url http://localhost:8000
@@ -99,6 +128,11 @@ Examples:
action="store_true",
help="Run the Claude Code webhook server",
)
mode_group.add_argument(
"--claude",
action="store_true",
help="Run the Claude wrapper V3 for Omnara integration",
)
# Arguments for webhook mode
parser.add_argument(
@@ -142,7 +176,8 @@ Examples:
help="Pre-existing agent instance ID to use for this session (stdio mode only)",
)
args = parser.parse_args()
# Use parse_known_args to capture remaining args for Claude
args, unknown_args = parser.parse_known_args()
if args.cloudflare_tunnel and not args.claude_code_webhook:
parser.error("--cloudflare-tunnel can only be used with --claude-code-webhook")
@@ -156,6 +191,10 @@ Examples:
dangerously_skip_permissions=args.dangerously_skip_permissions,
port=args.port,
)
elif args.claude:
if not args.api_key:
parser.error("--api-key is required for --claude mode")
run_claude_wrapper(args.api_key, args.base_url, unknown_args)
else:
if not args.api_key:
parser.error("--api-key is required for stdio mode")

View File

@@ -2,7 +2,8 @@
import asyncio
import ssl
from typing import Optional, Dict, Any
import uuid
from typing import Optional, Dict, Any, Union, List
from urllib.parse import urljoin
import aiohttp
@@ -11,10 +12,14 @@ from aiohttp import ClientTimeout
from .exceptions import AuthenticationError, TimeoutError, APIError
from .models import (
LogStepResponse,
QuestionResponse,
QuestionStatus,
EndSessionResponse,
CreateMessageResponse,
PendingMessagesResponse,
Message,
)
from .utils import (
validate_agent_instance_id,
build_message_request_data,
)
@@ -76,6 +81,7 @@ class AsyncOmnaraClient:
method: str,
endpoint: str,
json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""Make an async HTTP request to the API.
@@ -84,6 +90,7 @@ class AsyncOmnaraClient:
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint path
json: JSON body for the request
params: Query parameters for the request
timeout: Request timeout in seconds
Returns:
@@ -104,7 +111,11 @@ class AsyncOmnaraClient:
try:
async with self.session.request(
method=method, url=url, json=json, timeout=request_timeout
method=method,
url=url,
json=json,
params=params,
timeout=request_timeout,
) as response:
if response.status == 401:
raise AuthenticationError(
@@ -128,138 +139,246 @@ class AsyncOmnaraClient:
except aiohttp.ClientError as e:
raise APIError(0, f"Request failed: {str(e)}")
async def log_step(
async def send_message(
self,
agent_type: str,
step_description: str,
agent_instance_id: Optional[str] = None,
send_push: Optional[bool] = None,
send_email: Optional[bool] = None,
send_sms: Optional[bool] = None,
git_diff: Optional[str] = None,
) -> LogStepResponse:
"""Log a high-level step the agent is performing.
Args:
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
step_description: Clear description of what the agent is doing
agent_instance_id: Existing agent instance ID (optional)
send_push: Send push notification for this step (default: False)
send_email: Send email notification for this step (default: False)
send_sms: Send SMS notification for this step (default: False)
Returns:
LogStepResponse with success status, instance ID, and user feedback
"""
data: Dict[str, Any] = {
"agent_type": agent_type,
"step_description": step_description,
}
if agent_instance_id:
data["agent_instance_id"] = agent_instance_id
if send_push is not None:
data["send_push"] = send_push
if send_email is not None:
data["send_email"] = send_email
if send_sms is not None:
data["send_sms"] = send_sms
if git_diff is not None:
data["git_diff"] = git_diff
response = await self._make_request("POST", "/api/v1/steps", json=data)
return LogStepResponse(
success=response["success"],
agent_instance_id=response["agent_instance_id"],
step_number=response["step_number"],
user_feedback=response.get("user_feedback", []),
)
async def ask_question(
self,
agent_instance_id: str,
question_text: str,
content: str,
agent_type: Optional[str] = None,
agent_instance_id: Optional[Union[str, uuid.UUID]] = None,
requires_user_input: bool = False,
timeout_minutes: int = 1440,
poll_interval: float = 10.0,
send_push: Optional[bool] = None,
send_email: Optional[bool] = None,
send_sms: Optional[bool] = None,
git_diff: Optional[str] = None,
) -> QuestionResponse:
"""Ask the user a question and wait for their response.
This method submits the question and then polls for the answer.
) -> CreateMessageResponse:
"""Send a message to the dashboard.
Args:
agent_instance_id: Agent instance ID
question_text: Question to ask the user
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
poll_interval: Time between polls in seconds (default: 10.0)
send_push: Send push notification for this question (default: user preference)
send_email: Send email notification for this question (default: user preference)
send_sms: Send SMS notification for this question (default: user preference)
content: The message content (step description or question text)
agent_type: Type of agent (required if agent_instance_id not provided)
agent_instance_id: Existing agent instance ID (optional)
requires_user_input: Whether this message requires user input (default: False)
timeout_minutes: If requires_user_input, max time to wait in minutes (default: 1440)
poll_interval: If requires_user_input, time between polls in seconds (default: 10.0)
send_push: Send push notification (default: False for steps, user pref for questions)
send_email: Send email notification (default: False for steps, user pref for questions)
send_sms: Send SMS notification (default: False for steps, user pref for questions)
git_diff: Git diff content to include (optional)
Returns:
QuestionResponse with the user's answer
CreateMessageResponse with any queued user messages
Raises:
TimeoutError: If no answer is received within timeout
ValueError: If neither agent_type nor agent_instance_id is provided
TimeoutError: If requires_user_input and no answer is received within timeout
"""
# Submit the question
data: Dict[str, Any] = {
"agent_instance_id": agent_instance_id,
"question_text": question_text,
}
if send_push is not None:
data["send_push"] = send_push
if send_email is not None:
data["send_email"] = send_email
if send_sms is not None:
data["send_sms"] = send_sms
if git_diff is not None:
data["git_diff"] = git_diff
# If no agent_instance_id provided, generate one client-side
if not agent_instance_id:
if not agent_type:
raise ValueError("agent_type is required when creating a new instance")
agent_instance_id = uuid.uuid4()
# First, try the non-blocking endpoint to create the question
response = await self._make_request(
"POST", "/api/v1/questions", json=data, timeout=5
# Validate and convert agent_instance_id to string
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
# Build request data using shared utility
data = build_message_request_data(
content=content,
agent_instance_id=agent_instance_id_str,
requires_user_input=requires_user_input,
agent_type=agent_type,
send_push=send_push,
send_email=send_email,
send_sms=send_sms,
git_diff=git_diff,
)
question_id = response["question_id"]
# Convert timeout from minutes to seconds
# Send the message
response = await self._make_request("POST", "/api/v1/messages/agent", json=data)
response_agent_instance_id = response["agent_instance_id"]
message_id = response["message_id"]
queued_contents = [
msg["content"] if isinstance(msg, dict) else msg
for msg in response.get("queued_user_messages", [])
]
create_response = CreateMessageResponse(
success=response["success"],
agent_instance_id=response_agent_instance_id,
message_id=message_id,
queued_user_messages=queued_contents,
)
# If it doesn't require user input, return immediately
if not requires_user_input:
return create_response
# Otherwise, poll for the answer
# Use the message ID we just created as our starting point
last_read_message_id = message_id
timeout_seconds = timeout_minutes * 60
# Poll for the answer
start_time = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
status = await self.get_question_status(question_id)
all_messages = []
if status.status == "answered" and status.answer:
return QuestionResponse(answer=status.answer, question_id=question_id)
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
# Poll for pending messages
pending_response = await self.get_pending_messages(
agent_instance_id_str, last_read_message_id
)
# If status is "stale", another process has read the messages
if pending_response.status == "stale":
raise TimeoutError("Another process has read the messages")
# Check if we got any messages
if pending_response.messages:
# Collect all messages
all_messages.extend(pending_response.messages)
# Return the response with all collected messages
create_response.queued_user_messages = [
msg.content for msg in all_messages
]
return create_response
await asyncio.sleep(poll_interval)
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
async def get_question_status(self, question_id: str) -> QuestionStatus:
"""Get the current status of a question.
async def get_pending_messages(
self,
agent_instance_id: Union[str, uuid.UUID],
last_read_message_id: Optional[str] = None,
) -> PendingMessagesResponse:
"""Get pending user messages for an agent instance.
Args:
question_id: ID of the question to check
agent_instance_id: Agent instance ID
last_read_message_id: The last message ID that was read (optional)
Returns:
QuestionStatus with current status and answer (if available)
PendingMessagesResponse with messages and status
"""
response = await self._make_request("GET", f"/api/v1/questions/{question_id}")
# Validate and convert agent_instance_id to string
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
return QuestionStatus(
question_id=response["question_id"],
status=response["status"],
answer=response.get("answer"),
asked_at=response["asked_at"],
answered_at=response.get("answered_at"),
params = {"agent_instance_id": agent_instance_id_str}
if last_read_message_id:
params["last_read_message_id"] = last_read_message_id
response = await self._make_request(
"GET", "/api/v1/messages/pending", params=params
)
async def end_session(self, agent_instance_id: str) -> EndSessionResponse:
return PendingMessagesResponse(
agent_instance_id=response["agent_instance_id"],
messages=[Message(**msg) for msg in response["messages"]],
status=response["status"],
)
async def send_user_message(
self,
agent_instance_id: Union[str, uuid.UUID],
content: str,
mark_as_read: bool = True,
) -> Dict[str, Any]:
"""Send a user message to an agent instance.
Args:
agent_instance_id: The agent instance ID to send the message to
content: Message content
mark_as_read: Whether to mark as read (update last_read_message_id) (default: True)
Returns:
Dict containing:
- success: Whether the message was created
- message_id: ID of the created message
- marked_as_read: Whether the message was marked as read
Raises:
ValueError: If agent instance not found or access denied
APIError: If the API request fails
"""
# Validate and convert agent_instance_id
agent_instance_id = validate_agent_instance_id(agent_instance_id)
data = {
"agent_instance_id": str(agent_instance_id),
"content": content,
"mark_as_read": mark_as_read,
}
return await self._make_request("POST", "/api/v1/messages/user", json=data)
async def request_user_input(
self,
message_id: Union[str, uuid.UUID],
timeout_minutes: int = 1440,
poll_interval: float = 10.0,
) -> List[str]:
"""Request user input for a previously sent agent message.
This method updates an agent message to require user input and polls for responses.
It's useful when you initially send a message without requiring input, but later
decide you need user feedback.
Args:
message_id: The message ID to update (must be an agent message)
timeout_minutes: Max time to wait for user response in minutes (default: 1440)
poll_interval: Time between polls in seconds (default: 10.0)
Returns:
List of user message contents received as responses
Raises:
ValueError: If message not found, already requires input, or not an agent message
TimeoutError: If no user response is received within timeout
APIError: If the API request fails
"""
# Convert message_id to string if it's a UUID
message_id_str = str(message_id)
# Call the endpoint to update the message
response = await self._make_request(
"PATCH", f"/api/v1/messages/{message_id_str}/request-input"
)
agent_instance_id = response["agent_instance_id"]
messages = response.get("messages", [])
if messages:
return [msg["content"] for msg in messages]
# Otherwise, poll for user response
timeout_seconds = timeout_minutes * 60
start_time = asyncio.get_event_loop().time()
all_messages = []
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
# Poll for pending messages using the message_id as last_read
pending_response = await self.get_pending_messages(
agent_instance_id, message_id_str
)
# If status is "stale", another process has read the messages
if pending_response.status == "stale":
raise TimeoutError("Another process has read the messages")
# Check if we got any messages
if pending_response.messages:
# Collect all message contents
all_messages.extend([msg.content for msg in pending_response.messages])
return all_messages
await asyncio.sleep(poll_interval)
raise TimeoutError(f"No user response received after {timeout_minutes} minutes")
async def end_session(
self, agent_instance_id: Union[str, uuid.UUID]
) -> EndSessionResponse:
"""End an agent session and mark it as completed.
Args:
@@ -268,7 +387,10 @@ class AsyncOmnaraClient:
Returns:
EndSessionResponse with success status and final details
"""
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id}
# Validate and convert agent_instance_id to string
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id_str}
response = await self._make_request("POST", "/api/v1/sessions/end", json=data)
return EndSessionResponse(

View File

@@ -1,7 +1,8 @@
"""Main client for interacting with the Omnara Agent Dashboard API."""
import time
from typing import Optional, Dict, Any
import uuid
from typing import Optional, Dict, Any, Union, List
from urllib.parse import urljoin
import requests
@@ -10,10 +11,14 @@ from urllib3.util.retry import Retry
from .exceptions import AuthenticationError, TimeoutError, APIError
from .models import (
LogStepResponse,
QuestionResponse,
QuestionStatus,
EndSessionResponse,
CreateMessageResponse,
PendingMessagesResponse,
Message,
)
from .utils import (
validate_agent_instance_id,
build_message_request_data,
)
@@ -63,6 +68,7 @@ class OmnaraClient:
method: str,
endpoint: str,
json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""Make an HTTP request to the API.
@@ -71,6 +77,7 @@ class OmnaraClient:
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint path
json: JSON body for the request
params: Query parameters for the request
timeout: Request timeout in seconds
Returns:
@@ -86,7 +93,7 @@ class OmnaraClient:
try:
response = self.session.request(
method=method, url=url, json=json, timeout=timeout
method=method, url=url, json=json, params=params, timeout=timeout
)
if response.status_code == 401:
@@ -106,138 +113,245 @@ class OmnaraClient:
except requests.exceptions.RequestException as e:
raise APIError(0, f"Request failed: {str(e)}")
def log_step(
def send_message(
self,
agent_type: str,
step_description: str,
agent_instance_id: Optional[str] = None,
send_push: Optional[bool] = None,
send_email: Optional[bool] = None,
send_sms: Optional[bool] = None,
git_diff: Optional[str] = None,
) -> LogStepResponse:
"""Log a high-level step the agent is performing.
Args:
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
step_description: Clear description of what the agent is doing
agent_instance_id: Existing agent instance ID (optional)
send_push: Send push notification for this step (default: False)
send_email: Send email notification for this step (default: False)
send_sms: Send SMS notification for this step (default: False)
git_diff: Git diff content to include with this step (optional)
Returns:
LogStepResponse with success status, instance ID, and user feedback
"""
data: Dict[str, Any] = {
"agent_type": agent_type,
"step_description": step_description,
}
if agent_instance_id:
data["agent_instance_id"] = agent_instance_id
if send_push is not None:
data["send_push"] = send_push
if send_email is not None:
data["send_email"] = send_email
if send_sms is not None:
data["send_sms"] = send_sms
if git_diff is not None:
data["git_diff"] = git_diff
response = self._make_request("POST", "/api/v1/steps", json=data)
return LogStepResponse(
success=response["success"],
agent_instance_id=response["agent_instance_id"],
step_number=response["step_number"],
user_feedback=response.get("user_feedback", []),
)
def ask_question(
self,
agent_instance_id: str,
question_text: str,
content: str,
agent_type: Optional[str] = None,
agent_instance_id: Optional[Union[str, uuid.UUID]] = None,
requires_user_input: bool = False,
timeout_minutes: int = 1440,
poll_interval: float = 10.0,
send_push: Optional[bool] = None,
send_email: Optional[bool] = None,
send_sms: Optional[bool] = None,
git_diff: Optional[str] = None,
) -> QuestionResponse:
"""Ask the user a question and wait for their response.
This method submits the question and then polls for the answer.
) -> CreateMessageResponse:
"""Send a message to the dashboard.
Args:
agent_instance_id: Agent instance ID
question_text: Question to ask the user
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
poll_interval: Time between polls in seconds (default: 10.0)
send_push: Send push notification for this question (default: user preference)
send_email: Send email notification for this question (default: user preference)
send_sms: Send SMS notification for this question (default: user preference)
git_diff: Git diff content to include with this question (optional)
content: The message content (step description or question text)
agent_type: Type of agent (required if agent_instance_id not provided)
agent_instance_id: Existing agent instance ID (optional)
requires_user_input: Whether this message requires user input (default: False)
timeout_minutes: If requires_user_input, max time to wait in minutes (default: 1440)
poll_interval: If requires_user_input, time between polls in seconds (default: 10.0)
send_push: Send push notification (default: False for steps, user pref for questions)
send_email: Send email notification (default: False for steps, user pref for questions)
send_sms: Send SMS notification (default: False for steps, user pref for questions)
git_diff: Git diff content to include (optional)
Returns:
QuestionResponse with the user's answer
CreateMessageResponse with any queued user messages
Raises:
TimeoutError: If no answer is received within timeout
ValueError: If neither agent_type nor agent_instance_id is provided
TimeoutError: If requires_user_input and no answer is received within timeout
"""
# Submit the question
data: Dict[str, Any] = {
"agent_instance_id": agent_instance_id,
"question_text": question_text,
}
if send_push is not None:
data["send_push"] = send_push
if send_email is not None:
data["send_email"] = send_email
if send_sms is not None:
data["send_sms"] = send_sms
if git_diff is not None:
data["git_diff"] = git_diff
# If no agent_instance_id provided, generate one client-side
if not agent_instance_id:
if not agent_type:
raise ValueError("agent_type is required when creating a new instance")
agent_instance_id = uuid.uuid4()
# First, try the non-blocking endpoint to create the question
response = self._make_request("POST", "/api/v1/questions", json=data, timeout=5)
question_id = response["question_id"]
# Validate and convert agent_instance_id to string
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
# Build request data using shared utility
data = build_message_request_data(
content=content,
agent_instance_id=agent_instance_id_str,
requires_user_input=requires_user_input,
agent_type=agent_type,
send_push=send_push,
send_email=send_email,
send_sms=send_sms,
git_diff=git_diff,
)
# Send the message
response = self._make_request("POST", "/api/v1/messages/agent", json=data)
response_agent_instance_id = response["agent_instance_id"]
message_id = response["message_id"]
queued_contents = [
msg["content"] if isinstance(msg, dict) else msg
for msg in response.get("queued_user_messages", [])
]
create_response = CreateMessageResponse(
success=response["success"],
agent_instance_id=response_agent_instance_id,
message_id=message_id,
queued_user_messages=queued_contents,
)
# If it doesn't require user input, return immediately with any queued messages
if not requires_user_input:
return create_response
# Otherwise, we need to poll for user response
# Use the message ID we just created as our starting point
last_read_message_id = message_id
# Convert timeout from minutes to seconds
timeout_seconds = timeout_minutes * 60
# Poll for the answer
start_time = time.time()
while time.time() - start_time < timeout_seconds:
status = self.get_question_status(question_id)
all_messages = []
if status.status == "answered" and status.answer:
return QuestionResponse(answer=status.answer, question_id=question_id)
while time.time() - start_time < timeout_seconds:
# Poll for pending messages
pending_response = self.get_pending_messages(
agent_instance_id_str, last_read_message_id
)
# If status is "stale", another process has read the messages
if pending_response.status == "stale":
raise TimeoutError("Another process has read the messages")
# Check if we got any messages
if pending_response.messages:
# Collect all messages
all_messages.extend(pending_response.messages)
# Return the response with all collected messages
create_response.queued_user_messages = [
msg.content for msg in all_messages
]
return create_response
time.sleep(poll_interval)
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
def get_question_status(self, question_id: str) -> QuestionStatus:
"""Get the current status of a question.
def get_pending_messages(
self,
agent_instance_id: Union[str, uuid.UUID],
last_read_message_id: Optional[str] = None,
) -> PendingMessagesResponse:
"""Get pending user messages for an agent instance.
Args:
question_id: ID of the question to check
agent_instance_id: Agent instance ID
last_read_message_id: The last message ID that was read (optional)
Returns:
QuestionStatus with current status and answer (if available)
PendingMessagesResponse with messages and status
"""
response = self._make_request("GET", f"/api/v1/questions/{question_id}")
# Validate and convert agent_instance_id to string
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
return QuestionStatus(
question_id=response["question_id"],
params = {"agent_instance_id": agent_instance_id_str}
if last_read_message_id:
params["last_read_message_id"] = last_read_message_id
response = self._make_request("GET", "/api/v1/messages/pending", params=params)
return PendingMessagesResponse(
agent_instance_id=response["agent_instance_id"],
messages=[Message(**msg) for msg in response["messages"]],
status=response["status"],
answer=response.get("answer"),
asked_at=response["asked_at"],
answered_at=response.get("answered_at"),
)
def end_session(self, agent_instance_id: str) -> EndSessionResponse:
def send_user_message(
self,
agent_instance_id: Union[str, uuid.UUID],
content: str,
mark_as_read: bool = True,
) -> Dict[str, Any]:
"""Send a user message to an agent instance.
Args:
agent_instance_id: The agent instance ID to send the message to
content: Message content
mark_as_read: Whether to mark as read (update last_read_message_id) (default: True)
Returns:
Dict containing:
- success: Whether the message was created
- message_id: ID of the created message
- marked_as_read: Whether the message was marked as read
Raises:
ValueError: If agent instance not found or access denied
APIError: If the API request fails
"""
# Validate and convert agent_instance_id
agent_instance_id = validate_agent_instance_id(agent_instance_id)
data = {
"agent_instance_id": str(agent_instance_id),
"content": content,
"mark_as_read": mark_as_read,
}
return self._make_request("POST", "/api/v1/messages/user", json=data)
def request_user_input(
self,
message_id: Union[str, uuid.UUID],
timeout_minutes: int = 1440,
poll_interval: float = 10.0,
) -> List[str]:
"""Request user input for a previously sent agent message.
This method updates an agent message to require user input and polls for responses.
It's useful when you initially send a message without requiring input, but later
decide you need user feedback.
Args:
message_id: The message ID to update (must be an agent message)
timeout_minutes: Max time to wait for user response in minutes (default: 1440)
poll_interval: Time between polls in seconds (default: 10.0)
Returns:
List of user message contents received as responses
Raises:
ValueError: If message not found, already requires input, or not an agent message
TimeoutError: If no user response is received within timeout
APIError: If the API request fails
"""
# Convert message_id to string if it's a UUID
message_id_str = str(message_id)
# Call the endpoint to update the message
response = self._make_request(
"PATCH", f"/api/v1/messages/{message_id_str}/request-input"
)
agent_instance_id = response["agent_instance_id"]
messages = response.get("messages", [])
if messages:
return [msg["content"] for msg in messages]
# Otherwise, poll for user response
timeout_seconds = timeout_minutes * 60
start_time = time.time()
all_messages = []
while time.time() - start_time < timeout_seconds:
# Poll for pending messages using the message_id as last_read
pending_response = self.get_pending_messages(
agent_instance_id, message_id_str
)
# If status is "stale", another process has read the messages
if pending_response.status == "stale":
raise TimeoutError("Another process has read the messages")
# Check if we got any messages
if pending_response.messages:
# Collect all message contents
all_messages.extend([msg.content for msg in pending_response.messages])
return all_messages
time.sleep(poll_interval)
raise TimeoutError(f"No user response received after {timeout_minutes} minutes")
def end_session(
self, agent_instance_id: Union[str, uuid.UUID]
) -> EndSessionResponse:
"""End an agent session and mark it as completed.
Args:
@@ -246,7 +360,10 @@ class OmnaraClient:
Returns:
EndSessionResponse with success status and final details
"""
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id}
# Validate and convert agent_instance_id to string
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id_str}
response = self._make_request("POST", "/api/v1/sessions/end", json=data)
return EndSessionResponse(

View File

@@ -1,6 +1,6 @@
"""Data models for the Omnara SDK."""
from typing import List, Optional
from typing import List
from dataclasses import dataclass
@@ -14,25 +14,6 @@ class LogStepResponse:
user_feedback: List[str]
@dataclass
class QuestionResponse:
"""Response from asking a question."""
answer: str
question_id: str
@dataclass
class QuestionStatus:
"""Status of a question."""
question_id: str
status: str # 'pending' or 'answered'
answer: Optional[str]
asked_at: str
answered_at: Optional[str]
@dataclass
class EndSessionResponse:
"""Response from ending a session."""
@@ -40,3 +21,33 @@ class EndSessionResponse:
success: bool
agent_instance_id: str
final_status: str
@dataclass
class CreateMessageResponse:
"""Response from creating a message."""
success: bool
agent_instance_id: str
message_id: str
queued_user_messages: List[str]
@dataclass
class Message:
"""A message in the conversation."""
id: str
content: str
sender_type: str # 'agent' or 'user'
created_at: str
requires_user_input: bool
@dataclass
class PendingMessagesResponse:
"""Response from getting pending messages."""
agent_instance_id: str
messages: List[Message]
status: str # 'ok' or 'stale'

78
omnara/sdk/utils.py Normal file
View File

@@ -0,0 +1,78 @@
"""Utility functions for the Omnara SDK."""
import uuid
from typing import Optional, Union, Dict, Any
def validate_agent_instance_id(
agent_instance_id: Optional[Union[str, uuid.UUID]],
) -> str:
"""Validate and convert agent_instance_id to string.
Args:
agent_instance_id: UUID string, UUID object, or None
Returns:
Validated UUID string
Raises:
ValueError: If agent_instance_id is not a valid UUID
"""
if agent_instance_id is None:
raise ValueError("agent_instance_id cannot be None")
if isinstance(agent_instance_id, str):
try:
# Validate it's a valid UUID
uuid.UUID(agent_instance_id)
return agent_instance_id
except ValueError:
raise ValueError("agent_instance_id must be a valid UUID string")
elif isinstance(agent_instance_id, uuid.UUID):
return str(agent_instance_id)
else:
raise ValueError("agent_instance_id must be a string or UUID object")
def build_message_request_data(
content: str,
agent_instance_id: str,
requires_user_input: bool,
agent_type: Optional[str] = None,
send_push: Optional[bool] = None,
send_email: Optional[bool] = None,
send_sms: Optional[bool] = None,
git_diff: Optional[str] = None,
) -> Dict[str, Any]:
"""Build request data for creating a message.
Args:
content: Message content
agent_instance_id: Agent instance ID (already validated)
requires_user_input: Whether message requires user input
agent_type: Optional agent type
send_push: Optional push notification flag
send_email: Optional email notification flag
send_sms: Optional SMS notification flag
git_diff: Optional git diff content
Returns:
Dictionary of request data
"""
data: Dict[str, Any] = {
"content": content,
"requires_user_input": requires_user_input,
"agent_instance_id": agent_instance_id,
}
if agent_type:
data["agent_type"] = agent_type
if send_push is not None:
data["send_push"] = send_push
if send_email is not None:
data["send_email"] = send_email
if send_sms is not None:
data["send_sms"] = send_sms
if git_diff is not None:
data["git_diff"] = git_diff
return data

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "omnara"
version = "1.3.15"
version = "1.3.16"
description = "Omnara Agent Dashboard - MCP Server and Python SDK"
readme = "README.md"
requires-python = ">=3.10"

View File

@@ -4,9 +4,9 @@ This directory contains the write operations server for the Agent Dashboard syst
## Overview
The servers directory implements all write operations that agents need:
- Logging their progress and receiving user feedback
- Asking questions to users
The servers directory implements all write operations that agents need through a unified messaging system:
- Sending messages (both informational steps and questions requiring user input)
- Receiving user responses and feedback
- Managing session lifecycle
All operations are authenticated and multi-tenant, ensuring data isolation between users.
@@ -36,10 +36,10 @@ The servers use a separate authentication system from the main backend:
## Key Features
- **Write-only operations**: Designed for agent interactions, not data retrieval
- **Unified messaging**: All agent interactions use the same message-based API
- **Automatic session management**: Creates sessions on first interaction
- **User feedback delivery**: Agents receive feedback when logging steps
- **Non-blocking questions**: Async implementation for user interactions
- **Message queuing**: Agents receive unread user messages when sending new messages
- **Flexible interactions**: Messages can be informational or require user input
- **Multi-protocol support**: Same functionality via MCP or REST API
## Running the Server

View File

@@ -1,29 +1,58 @@
"""Pydantic models for FastAPI request/response schemas."""
from typing import Optional
from pydantic import BaseModel, Field
from servers.shared.models import (
BaseLogStepRequest,
BaseLogStepResponse,
BaseAskQuestionRequest,
BaseEndSessionRequest,
BaseEndSessionResponse,
)
# Request models
class LogStepRequest(BaseLogStepRequest):
"""FastAPI-specific request model for logging a step."""
class CreateMessageRequest(BaseModel):
"""Request model for creating a new message."""
pass
agent_instance_id: str = Field(
...,
description="Existing agent instance ID. Creates a new agent instance if ID doesn't exist.",
)
agent_type: str | None = Field(
None, description="Type of agent (e.g., 'claude_code', 'cursor')"
)
content: str = Field(
..., description="Message content (step description or question text)"
)
requires_user_input: bool = Field(
False, description="Whether this message requires user input (is a question)"
)
send_email: bool | None = Field(
None,
description="Whether to send email notification (overrides user preference)",
)
send_sms: bool | None = Field(
None, description="Whether to send SMS notification (overrides user preference)"
)
send_push: bool | None = Field(
None,
description="Whether to send push notification (overrides user preference)",
)
git_diff: str | None = Field(
None,
description="Git diff content to store with the instance",
)
class AskQuestionRequest(BaseAskQuestionRequest):
"""FastAPI-specific request model for asking a question."""
class CreateUserMessageRequest(BaseModel):
"""Request model for creating a user message."""
pass
agent_instance_id: str = Field(
..., description="Agent instance ID to send the message to"
)
content: str = Field(..., description="Message content")
mark_as_read: bool = Field(
True,
description="Whether to mark this message as read (update last_read_message_id)",
)
class EndSessionRequest(BaseEndSessionRequest):
@@ -33,35 +62,60 @@ class EndSessionRequest(BaseEndSessionRequest):
# Response models
class LogStepResponse(BaseLogStepResponse):
"""FastAPI-specific response model for log step endpoint."""
class MessageResponse(BaseModel):
"""Response model for individual messages."""
pass
id: str = Field(..., description="Message ID")
content: str = Field(..., description="Message content")
sender_type: str = Field(..., description="Sender type: 'agent' or 'user'")
created_at: str = Field(..., description="ISO timestamp when message was created")
requires_user_input: bool = Field(
..., description="Whether this message requires user input"
)
# FastAPI-specific: Response only contains question ID (non-blocking)
class AskQuestionResponse(BaseModel):
"""FastAPI-specific response model for ask question endpoint."""
class CreateMessageResponse(BaseModel):
"""Response model for create message endpoint."""
question_id: str = Field(..., description="ID of the created question")
success: bool = Field(
..., description="Whether the message was created successfully"
)
agent_instance_id: str = Field(
..., description="Agent instance ID (new or existing)"
)
message_id: str = Field(..., description="ID of the message that was created")
queued_user_messages: list[MessageResponse] = Field(
default_factory=list,
description="List of queued user messages with full metadata",
)
class CreateUserMessageResponse(BaseModel):
"""Response model for create user message endpoint."""
success: bool = Field(
..., description="Whether the message was created successfully"
)
message_id: str = Field(..., description="ID of the created message")
marked_as_read: bool = Field(
..., description="Whether the message was marked as read"
)
class GetMessagesResponse(BaseModel):
"""Response model for get messages endpoint."""
agent_instance_id: str = Field(..., description="Agent instance ID")
messages: list[MessageResponse] = Field(
default_factory=list, description="List of messages"
)
status: str = Field(
"ok",
description="Status: 'ok' if messages retrieved, 'stale' if last_read_message_id is outdated",
)
class EndSessionResponse(BaseEndSessionResponse):
"""FastAPI-specific response model for end session endpoint."""
pass
# FastAPI-specific: Additional model for polling question status
class QuestionStatusResponse(BaseModel):
"""Response model for question status endpoint."""
question_id: str
status: str = Field(
..., description="Status of the question: 'pending' or 'answered'"
)
answer: Optional[str] = Field(
None, description="Answer text if status is 'answered'"
)
asked_at: str
answered_at: Optional[str] = None

View File

@@ -1,66 +1,93 @@
"""API routes for agent operations."""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from shared.database.session import get_db
from servers.shared.db import get_question, get_agent_instance
from servers.shared.core import (
process_log_step,
create_agent_question,
process_end_session,
from shared.database import Message, AgentInstance, SenderType, AgentStatus
from servers.shared.db import (
send_agent_message,
end_session,
get_or_create_agent_instance,
get_queued_user_messages,
create_user_message,
)
from servers.shared.notification_utils import send_message_notifications
from .auth import get_current_user_id
from .models import (
AskQuestionRequest,
AskQuestionResponse,
CreateMessageRequest,
CreateMessageResponse,
CreateUserMessageRequest,
CreateUserMessageResponse,
EndSessionRequest,
EndSessionResponse,
LogStepRequest,
LogStepResponse,
QuestionStatusResponse,
GetMessagesResponse,
MessageResponse,
)
agent_router = APIRouter(tags=["agents"])
@agent_router.post("/steps", response_model=LogStepResponse)
def log_step(
request: LogStepRequest, user_id: Annotated[str, Depends(get_current_user_id)]
) -> LogStepResponse:
"""Log a high-level step the agent is performing.
@agent_router.post("/messages/agent", response_model=CreateMessageResponse)
async def create_agent_message_endpoint(
request: CreateMessageRequest, user_id: Annotated[str, Depends(get_current_user_id)]
) -> CreateMessageResponse:
"""Create a new agent message.
This endpoint:
- Creates or retrieves an agent instance
- Logs the step with a sequential number
- Returns any unretrieved user feedback
User feedback is automatically marked as retrieved.
- Creates a new message
- Returns the message ID and any queued user messages
- Sends notifications if requested
"""
db = next(get_db())
try:
# Use shared business logic
instance_id, step_number, user_feedback = process_log_step(
# Use the unified send_agent_message function
instance_id, message_id, queued_messages = await send_agent_message(
db=db,
agent_type=request.agent_type,
step_description=request.step_description,
user_id=user_id,
agent_instance_id=request.agent_instance_id,
send_email=request.send_email,
send_sms=request.send_sms,
send_push=request.send_push,
content=request.content,
user_id=user_id,
agent_type=request.agent_type,
requires_user_input=request.requires_user_input,
git_diff=request.git_diff,
)
return LogStepResponse(
# Send notifications if requested
await send_message_notifications(
db=db,
instance_id=UUID(instance_id),
content=request.content,
requires_user_input=request.requires_user_input,
send_email=request.send_email,
send_sms=request.send_sms,
send_push=request.send_push,
)
db.commit()
message_responses = [
MessageResponse(
id=str(msg.id),
content=msg.content,
sender_type=msg.sender_type.value,
created_at=msg.created_at.isoformat(),
requires_user_input=msg.requires_user_input,
)
for msg in queued_messages
]
return CreateMessageResponse(
success=True,
agent_instance_id=instance_id,
step_number=step_number,
user_feedback=user_feedback,
message_id=message_id,
queued_user_messages=message_responses,
)
except ValueError as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
db.rollback()
@@ -72,40 +99,202 @@ def log_step(
db.close()
@agent_router.post("/questions", response_model=AskQuestionResponse)
async def ask_question(
request: AskQuestionRequest, user_id: Annotated[str, Depends(get_current_user_id)]
) -> AskQuestionResponse:
"""Create a question for the user to answer.
@agent_router.post("/messages/user", response_model=CreateUserMessageResponse)
async def create_user_message_endpoint(
request: CreateUserMessageRequest,
user_id: Annotated[str, Depends(get_current_user_id)],
) -> CreateUserMessageResponse:
"""Create a user message.
This endpoint:
- Creates a question record in the database
- Returns immediately with the question ID
- Client should poll GET /questions/{question_id} for the answer
Note: This endpoint is non-blocking.
- Creates a user message for an existing agent instance
- Optionally marks it as read (updates last_read_message_id)
- Returns the message ID
"""
db = next(get_db())
try:
# Use shared business logic to create question
question = await create_agent_question(
# Create the user message
message_id, marked_as_read = create_user_message(
db=db,
agent_instance_id=request.agent_instance_id,
question_text=request.question_text,
content=request.content,
user_id=user_id,
send_email=request.send_email,
send_sms=request.send_sms,
send_push=request.send_push,
git_diff=request.git_diff,
mark_as_read=request.mark_as_read,
)
# Commit the transaction
db.commit()
# FastAPI-specific: Return immediately with question ID (non-blocking)
return AskQuestionResponse(
question_id=str(question.id),
return CreateUserMessageResponse(
success=True,
message_id=message_id,
marked_as_read=marked_as_read,
)
except ValueError as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Internal server error: {str(e)}",
)
finally:
db.close()
@agent_router.get("/messages/pending", response_model=GetMessagesResponse)
async def get_pending_messages(
agent_instance_id: str,
last_read_message_id: str | None,
user_id: Annotated[str, Depends(get_current_user_id)],
) -> GetMessagesResponse:
"""Get pending user messages for an agent instance.
This endpoint:
- Returns all user messages since the provided last_read_message_id
- Updates the last_read_message_id to the latest message
- Returns None status if another process has already read the messages
"""
db = next(get_db())
try:
# Validate access (agent_instance_id is required here)
instance = get_or_create_agent_instance(db, agent_instance_id, user_id)
# Parse last_read_message_id if provided
last_read_uuid = UUID(last_read_message_id) if last_read_message_id else None
# Get queued messages
messages = get_queued_user_messages(db, instance.id, last_read_uuid)
# If messages is None, another process has read the messages
if messages is None:
return GetMessagesResponse(
agent_instance_id=agent_instance_id,
messages=[],
status="stale", # Indicate that the last_read_message_id is stale
)
db.commit()
# Convert to response format
message_responses = [
MessageResponse(
id=str(msg.id),
content=msg.content,
sender_type=msg.sender_type.value,
created_at=msg.created_at.isoformat(),
requires_user_input=msg.requires_user_input,
)
for msg in messages
]
return GetMessagesResponse(
agent_instance_id=agent_instance_id,
messages=message_responses,
status="ok",
)
except ValueError as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Internal server error: {str(e)}",
)
finally:
db.close()
@agent_router.patch("/messages/{message_id}/request-input")
async def request_user_input_endpoint(
message_id: UUID,
user_id: Annotated[str, Depends(get_current_user_id)],
) -> dict:
"""Update an agent message to request user input.
This endpoint:
- Updates the requires_user_input field from false to true
- Only works on agent messages that don't already require input
- Returns any queued user messages since this message
- Triggers a notification via the database trigger
"""
db = next(get_db())
try:
# Find the message and verify it's an agent message belonging to the user
message = (
db.query(Message)
.join(AgentInstance, Message.agent_instance_id == AgentInstance.id)
.filter(
Message.id == message_id,
Message.sender_type == SenderType.AGENT,
AgentInstance.user_id == user_id,
)
.first()
)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Agent message not found or access denied",
)
# Check if it already requires user input
if message.requires_user_input:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Message already requires user input",
)
# Update the field
message.requires_user_input = True
queued_messages = get_queued_user_messages(
db, message.agent_instance_id, message_id
)
if not queued_messages:
agent_instance = (
db.query(AgentInstance)
.filter(AgentInstance.id == message.agent_instance_id)
.first()
)
if agent_instance:
agent_instance.status = AgentStatus.AWAITING_INPUT
await send_message_notifications(
db=db,
instance_id=message.agent_instance_id,
content=message.content,
requires_user_input=True,
)
db.commit()
message_responses = [
MessageResponse(
id=str(msg.id),
content=msg.content,
sender_type=msg.sender_type.value,
created_at=msg.created_at.isoformat(),
requires_user_input=msg.requires_user_input,
)
for msg in (queued_messages or [])
]
return {
"success": True,
"message_id": str(message_id),
"agent_instance_id": str(message.agent_instance_id),
"messages": message_responses,
"status": "ok" if queued_messages is not None else "stale",
}
except HTTPException:
db.rollback()
raise
except Exception as e:
db.rollback()
@@ -117,48 +306,8 @@ async def ask_question(
db.close()
@agent_router.get("/questions/{question_id}", response_model=QuestionStatusResponse)
async def get_question_status(
question_id: str, user_id: Annotated[str, Depends(get_current_user_id)]
) -> QuestionStatusResponse:
"""Get the status of a question.
This endpoint allows polling for question answers without blocking.
Returns the current status and answer (if available).
"""
db = next(get_db())
try:
# Get the question
question = get_question(db, question_id)
if not question:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Question not found"
)
# Verify the question belongs to the authenticated user
instance = get_agent_instance(db, str(question.agent_instance_id))
if not instance or str(instance.user_id) != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Access denied"
)
# Return question status
return QuestionStatusResponse(
question_id=str(question.id),
status="answered" if question.answer_text else "pending",
answer=question.answer_text,
asked_at=question.asked_at.isoformat(),
answered_at=question.answered_at.isoformat()
if question.answered_at
else None,
)
finally:
db.close()
@agent_router.post("/sessions/end", response_model=EndSessionResponse)
async def end_session(
async def end_session_endpoint(
request: EndSessionRequest, user_id: Annotated[str, Depends(get_current_user_id)]
) -> EndSessionResponse:
"""End an agent session and mark it as completed.
@@ -166,24 +315,26 @@ async def end_session(
This endpoint:
- Marks the agent instance as COMPLETED
- Sets the session end time
- Deactivates any pending questions
"""
db = next(get_db())
try:
# Use shared business logic
instance_id, final_status = process_end_session(
# Use the end_session function from queries
instance_id, final_status = end_session(
db=db,
agent_instance_id=request.agent_instance_id,
user_id=user_id,
)
db.commit()
return EndSessionResponse(
success=True,
agent_instance_id=instance_id,
final_status=final_status,
)
except ValueError as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
db.rollback()

View File

@@ -73,13 +73,13 @@ mcp = FastMCP("Agent Dashboard MCP Server", auth=auth)
@mcp.tool(name="log_step", description=LOG_STEP_DESCRIPTION)
@require_auth
def log_step_tool(
async def log_step_tool(
agent_instance_id: str | None = None,
step_description: str = "",
_user_id: str = "", # Injected by decorator
) -> LogStepResponse:
agent_type = detect_agent_type_from_headers()
return log_step_impl(
return await log_step_impl(
agent_instance_id=agent_instance_id,
agent_type=agent_type,
step_description=step_description,

View File

@@ -211,21 +211,23 @@ async def log_step_tool(
# Get git diff if enabled
git_diff = get_git_diff()
response = await client.log_step(
response = await client.send_message(
agent_type=agent_type,
step_description=step_description,
content=step_description,
agent_instance_id=agent_instance_id,
requires_user_input=False, # Log steps don't require user input
git_diff=git_diff,
)
if current_agent_instance_id is None:
current_agent_instance_id = response.agent_instance_id
# Return LogStepResponse with queued messages as user_feedback
return LogStepResponse(
success=response.success,
agent_instance_id=response.agent_instance_id,
step_number=response.step_number,
user_feedback=response.user_feedback,
step_number=1, # We don't track step numbers anymore
user_feedback=response.queued_user_messages,
)
@@ -244,6 +246,7 @@ async def ask_question_tool(
if not question_text:
raise ValueError("question_text is required")
agent_type = detect_agent_type_from_environment()
client = get_client()
# Get git diff if enabled
@@ -253,17 +256,24 @@ async def ask_question_tool(
current_agent_instance_id = agent_instance_id
try:
response = await client.ask_question(
response = await client.send_message(
agent_type=agent_type,
agent_instance_id=agent_instance_id,
question_text=question_text,
content=question_text,
requires_user_input=True, # Questions require user input
timeout_minutes=1440, # 24 hours default
poll_interval=10.0,
git_diff=git_diff,
)
# Get the answer from queued_user_messages
answer = (
response.queued_user_messages[0] if response.queued_user_messages else ""
)
return AskQuestionResponse(
answer=response.answer,
question_id=response.question_id,
answer=answer,
question_id=response.message_id,
)
except OmnaraTimeoutError:
raise TimeoutError("Question timed out waiting for user response")
@@ -316,10 +326,11 @@ async def approve_tool(
instance_id = current_agent_instance_id
else:
# Only create a new instance if we don't have one
response = await client.log_step(
response = await client.send_message(
agent_type="Claude Code",
step_description="Permission request",
content="Permission request",
agent_instance_id=None,
requires_user_input=False,
)
instance_id = response.agent_instance_id
current_agent_instance_id = instance_id
@@ -384,15 +395,21 @@ async def approve_tool(
try:
# Ask the permission question
answer_response = await client.ask_question(
response = await client.send_message(
agent_instance_id=instance_id,
question_text=question_text,
content=question_text,
requires_user_input=True,
timeout_minutes=1440,
poll_interval=10.0,
)
# Parse the answer to determine approval
answer = answer_response.answer.strip()
# Get the answer from queued_user_messages
answer = (
response.queued_user_messages[0].strip()
if response.queued_user_messages
else ""
)
# Handle option selections by comparing with actual option text
if answer == option_yes:
@@ -432,7 +449,7 @@ async def approve_tool(
# Custom text response - treat as denial with message
return {
"behavior": "deny",
"message": f"Permission denied by user: {answer_response.answer}",
"message": f"Permission denied by user: {answer}",
}
except OmnaraTimeoutError:

View File

@@ -5,21 +5,23 @@ the hosted server and stdio server. The authentication logic is handled
by the individual servers.
"""
import uuid
from uuid import UUID
from fastmcp import Context
from shared.database.session import get_db
from servers.shared.db import wait_for_answer
from servers.shared.core import (
process_log_step,
create_agent_question,
process_end_session,
from servers.shared.db import (
send_agent_message,
end_session,
wait_for_answer,
create_agent_message,
get_or_create_agent_instance,
)
from .models import AskQuestionResponse, EndSessionResponse, LogStepResponse
def log_step_impl(
async def log_step_impl(
agent_instance_id: str | None = None,
agent_type: str = "",
step_description: str = "",
@@ -29,20 +31,25 @@ def log_step_impl(
Args:
agent_instance_id: Existing agent instance ID (optional)
agent_type: Name of the agent (e.g., 'Claude Code', 'Cursor')
agent_type: Name of the agent (e.g., 'claude_code', 'cursor')
step_description: High-level description of the current step
user_id: Authenticated user ID
Returns:
LogStepResponse with success status, instance details, and user feedback
"""
if agent_instance_id:
# Generate a new UUID if agent_instance_id is not provided
if not agent_instance_id:
agent_instance_id = str(uuid.uuid4())
else:
# Validate the provided UUID
try:
UUID(agent_instance_id)
except ValueError:
raise ValueError(
f"Invalid agent_instance_id format: must be a valid UUID, got '{agent_instance_id}'"
)
if not agent_type:
raise ValueError("agent_type is required")
if not step_description:
@@ -53,19 +60,37 @@ def log_step_impl(
db = next(get_db())
try:
instance_id, step_number, user_feedback = process_log_step(
# Use send_agent_message for steps (requires_user_input=False)
instance_id, message_id, queued_messages = await send_agent_message(
db=db,
agent_type=agent_type,
step_description=step_description,
user_id=user_id,
agent_instance_id=agent_instance_id,
content=step_description,
user_id=user_id,
agent_type=agent_type,
requires_user_input=False,
)
# For backward compatibility, we need to return a step number
# Count the number of agent messages (steps) for this instance
from shared.database import Message, SenderType
step_count = (
db.query(Message)
.filter(
Message.agent_instance_id == UUID(instance_id),
Message.sender_type == SenderType.AGENT,
Message.requires_user_input.is_(False),
)
.count()
)
db.commit()
return LogStepResponse(
success=True,
agent_instance_id=instance_id,
step_number=step_number,
user_feedback=user_feedback,
step_number=step_count,
user_feedback=[msg.content for msg in queued_messages],
)
except Exception:
@@ -108,13 +133,21 @@ async def ask_question_impl(
db = next(get_db())
try:
question = await create_agent_question(
# Validate access first (agent_instance_id is required here)
instance = get_or_create_agent_instance(db, agent_instance_id, user_id)
# Create question message (requires_user_input=True)
question = create_agent_message(
db=db,
agent_instance_id=agent_instance_id,
question_text=question_text,
user_id=user_id,
instance_id=instance.id,
content=question_text,
requires_user_input=True,
)
# Commit to make the question visible
db.commit()
# Wait for answer
answer = await wait_for_answer(db, question.id, tool_context=tool_context)
if answer is None:
@@ -156,12 +189,15 @@ def end_session_impl(
db = next(get_db())
try:
instance_id, final_status = process_end_session(
instance_id, final_status = end_session(
db=db,
agent_instance_id=agent_instance_id,
user_id=user_id,
)
# Commit the transaction
db.commit()
return EndSessionResponse(
success=True,
agent_instance_id=instance_id,

View File

@@ -1,15 +0,0 @@
"""Core business logic shared between servers."""
from .agents import (
validate_agent_access,
process_log_step,
create_agent_question,
process_end_session,
)
__all__ = [
"validate_agent_access",
"process_log_step",
"create_agent_question",
"process_end_session",
]

View File

@@ -1,167 +0,0 @@
"""Shared business logic for agent operations.
This module contains the common logic used by both MCP and FastAPI servers,
avoiding code duplication while allowing protocol-specific implementations.
"""
import logging
from sqlalchemy.orm import Session
from servers.shared.db import (
get_agent_instance,
create_agent_instance,
log_step,
create_question,
get_and_mark_unretrieved_feedback,
create_or_get_user_agent,
end_session,
)
from shared.database.utils import sanitize_git_diff
logger = logging.getLogger(__name__)
def validate_agent_access(db: Session, agent_instance_id: str, user_id: str):
"""Validate that a user has access to an agent instance.
Args:
db: Database session
agent_instance_id: Agent instance ID to validate
user_id: User ID requesting access
Returns:
The agent instance if validation passes
Raises:
ValueError: If instance not found or user doesn't have access
"""
instance = get_agent_instance(db, agent_instance_id)
if not instance:
raise ValueError(f"Agent instance {agent_instance_id} not found")
if str(instance.user_id) != user_id:
raise ValueError(
"Access denied. Agent instance does not belong to authenticated user."
)
return instance
def process_log_step(
db: Session,
agent_type: str,
step_description: str,
user_id: str,
agent_instance_id: str | None = None,
send_email: bool | None = None,
send_sms: bool | None = None,
send_push: bool | None = None,
git_diff: str | None = None,
) -> tuple[str, int, list[str]]:
"""Process a log step operation with all common logic.
Args:
db: Database session
agent_type: Type of agent
step_description: Description of the step
user_id: Authenticated user ID
agent_instance_id: Optional existing instance ID
Returns:
Tuple of (agent_instance_id, step_number, user_feedback)
"""
# Get or create user agent type
agent_type_obj = create_or_get_user_agent(db, agent_type, user_id)
# Get or create instance
if agent_instance_id:
instance = validate_agent_access(db, agent_instance_id, user_id)
else:
instance = create_agent_instance(db, agent_type_obj.id, user_id)
# Create step with notification preferences
step = log_step(db, instance.id, step_description, send_email, send_sms, send_push)
# Update git diff if provided
if git_diff is not None:
# Validate and sanitize git diff
sanitized_diff = sanitize_git_diff(git_diff)
if sanitized_diff is not None: # Allow empty string (cleared diff)
instance.git_diff = sanitized_diff
db.commit()
else:
logger.warning(
f"Invalid git diff format for instance {instance.id}, skipping git diff update"
)
# Get unretrieved feedback
feedback = get_and_mark_unretrieved_feedback(db, instance.id)
return str(instance.id), step.step_number, feedback
async def create_agent_question(
db: Session,
agent_instance_id: str,
question_text: str,
user_id: str,
send_email: bool | None = None,
send_sms: bool | None = None,
send_push: bool | None = None,
git_diff: str | None = None,
):
"""Create a question with validation and send push notification.
Args:
db: Database session
agent_instance_id: Agent instance ID
question_text: Question to ask
user_id: Authenticated user ID
Returns:
The created question object
"""
# Validate access
instance = validate_agent_access(db, agent_instance_id, user_id)
# Update git diff if provided
if git_diff is not None:
# Validate and sanitize git diff
sanitized_diff = sanitize_git_diff(git_diff)
if sanitized_diff is not None: # Allow empty string (cleared diff)
instance.git_diff = sanitized_diff
db.commit()
else:
logger.warning(
f"Invalid git diff format for instance {instance.id}, skipping git diff update"
)
# Create question
# Note: Notifications sent by create_question() function based on parameters
question = await create_question(
db, instance.id, question_text, send_email, send_sms, send_push
)
return question
def process_end_session(
db: Session,
agent_instance_id: str,
user_id: str,
) -> tuple[str, str]:
"""Process ending a session with validation.
Args:
db: Database session
agent_instance_id: Agent instance ID to end
user_id: Authenticated user ID
Returns:
Tuple of (agent_instance_id, final_status)
"""
# Validate access
instance = validate_agent_access(db, agent_instance_id, user_id)
# End the session
updated_instance = end_session(db, instance.id)
return str(updated_instance.id), updated_instance.status.value

View File

@@ -1,25 +1,31 @@
"""Database queries and operations for servers."""
from .queries import (
# Low-level functions
create_agent_instance,
create_or_get_user_agent,
create_question,
create_agent_message,
create_user_message,
end_session,
get_agent_instance,
get_and_mark_unretrieved_feedback,
get_question,
log_step,
get_queued_user_messages,
get_or_create_agent_instance,
wait_for_answer,
# High-level functions
send_agent_message,
)
__all__ = [
# Low-level functions
"create_agent_instance",
"create_or_get_user_agent",
"create_question",
"create_agent_message",
"create_user_message",
"end_session",
"get_agent_instance",
"get_and_mark_unretrieved_feedback",
"get_question",
"log_step",
"get_queued_user_messages",
"get_or_create_agent_instance",
"wait_for_answer",
# High-level functions
"send_agent_message",
]

View File

@@ -6,19 +6,15 @@ from uuid import UUID
from shared.database import (
AgentInstance,
AgentQuestion,
AgentStatus,
AgentStep,
AgentUserFeedback,
Message,
SenderType,
UserAgent,
User,
)
from shared.database.billing_operations import check_agent_limit
from sqlalchemy import func
from shared.database.utils import sanitize_git_diff
from sqlalchemy.orm import Session
from fastmcp import Context
from servers.shared.notifications import push_service
from servers.shared.twilio_service import twilio_service
logger = logging.getLogger(__name__)
@@ -40,8 +36,7 @@ def create_or_get_user_agent(db: Session, name: str, user_id: str) -> UserAgent:
is_active=True,
)
db.add(user_agent)
db.commit()
db.refresh(user_agent)
db.flush() # Flush to get the user_agent ID
return user_agent
@@ -56,8 +51,6 @@ def create_agent_instance(
user_agent_id=user_agent_id, user_id=UUID(user_id), status=AgentStatus.ACTIVE
)
db.add(instance)
db.commit()
db.refresh(instance)
return instance
@@ -66,215 +59,153 @@ def get_agent_instance(db: Session, instance_id: str) -> AgentInstance | None:
return db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
def log_step(
db: Session,
instance_id: UUID,
description: str,
send_email: bool | None = None,
send_sms: bool | None = None,
send_push: bool | None = None,
) -> AgentStep:
"""Log a new step for an agent instance"""
# Get the next step number
max_step = (
db.query(func.max(AgentStep.step_number))
.filter(AgentStep.agent_instance_id == instance_id)
.scalar()
)
next_step_number = (max_step or 0) + 1
def get_or_create_agent_instance(
db: Session, agent_instance_id: str, user_id: str, agent_type: str | None = None
) -> AgentInstance:
"""Get an existing agent instance or create a new one.
# Create the step
step = AgentStep(
agent_instance_id=instance_id,
step_number=next_step_number,
description=description,
)
db.add(step)
db.commit()
db.refresh(step)
Args:
db: Database session
agent_instance_id: Agent instance ID (always required)
user_id: User ID requesting access
agent_type: Agent type name (required only when creating new instance)
# Send notifications if requested (all default to False for log steps)
if send_email or send_sms or send_push:
# Get instance details for notifications
instance = (
db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
)
if instance:
user = db.query(User).filter(User.id == instance.user_id).first()
Returns:
The agent instance (existing or newly created)
if user:
agent_name = (
instance.user_agent.name if instance.user_agent else "Agent"
)
Raises:
ValueError: If instance not found, user doesn't have access, or agent_type missing when creating
"""
# Try to get existing instance
instance = get_agent_instance(db, agent_instance_id)
# Override defaults - for log steps, all notifications default to False
should_send_push = send_push if send_push is not None else False
should_send_email = send_email if send_email is not None else False
should_send_sms = send_sms if send_sms is not None else False
# Send push notification if explicitly enabled
if should_send_push:
try:
asyncio.create_task(
push_service.send_step_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
step_number=step.step_number,
agent_name=agent_name,
step_description=description,
)
)
except Exception as e:
logger.error(
f"Failed to send push notification for step {step.id}: {e}"
)
# Send Twilio notifications if explicitly enabled
if should_send_email or should_send_sms:
try:
asyncio.create_task(
twilio_service.send_step_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
step_number=step.step_number,
agent_name=agent_name,
step_description=description,
send_email=should_send_email,
send_sms=should_send_sms,
)
)
except Exception as e:
logger.error(
f"Failed to send Twilio notification for step {step.id}: {e}"
)
return step
async def create_question(
db: Session,
instance_id: UUID,
question_text: str,
send_email: bool | None = None,
send_sms: bool | None = None,
send_push: bool | None = None,
) -> AgentQuestion:
"""Create a new question for an agent instance"""
# Mark any existing active questions as inactive
db.query(AgentQuestion).filter(
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
).update({"is_active": False})
# Update agent instance status to awaiting_input
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
if instance and instance.status == AgentStatus.ACTIVE:
instance.status = AgentStatus.AWAITING_INPUT
# Create new question
question = AgentQuestion(
agent_instance_id=instance_id, question_text=question_text, is_active=True
)
db.add(question)
db.commit()
db.refresh(question)
# Send notifications based on user preferences
if instance:
# Get user for checking preferences
user = db.query(User).filter(User.id == instance.user_id).first()
if user:
agent_name = instance.user_agent.name if instance.user_agent else "Agent"
# Determine notification preferences
# For questions: push defaults to True (or user preference), email/SMS default to False
should_send_push = (
send_push if send_push is not None else user.push_notifications_enabled
)
should_send_email = (
send_email
if send_email is not None
else user.email_notifications_enabled
)
should_send_sms = (
send_sms if send_sms is not None else user.sms_notifications_enabled
# Validate access to existing instance
if str(instance.user_id) != user_id:
raise ValueError(
"Access denied. Agent instance does not belong to authenticated user."
)
return instance
else:
# Create new instance with the provided ID
if not agent_type:
raise ValueError("agent_type is required when creating new instance")
# Send push notification if enabled
if should_send_push:
try:
await push_service.send_question_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
question_id=str(question.id),
agent_name=agent_name,
question_text=question_text,
)
except Exception as e:
logger.error(
f"Failed to send push notification for question {question.id}: {e}"
)
agent_type_obj = create_or_get_user_agent(db, agent_type, user_id)
# Send Twilio notification if enabled (email and/or SMS)
if should_send_email or should_send_sms:
try:
twilio_service.send_question_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
question_id=str(question.id),
agent_name=agent_name,
question_text=question_text,
send_email=should_send_email,
send_sms=should_send_sms,
)
except Exception as e:
logger.error(
f"Failed to send Twilio notification for question {question.id}: {e}"
)
# Check usage limits if billing is enabled
check_agent_limit(UUID(user_id), db)
return question
# Create instance with the specific ID
instance = AgentInstance(
id=UUID(agent_instance_id),
user_agent_id=agent_type_obj.id,
user_id=UUID(user_id),
status=AgentStatus.ACTIVE,
)
db.add(instance)
db.flush() # Flush to ensure the instance is in the session with its ID
return instance
def end_session(db: Session, agent_instance_id: str, user_id: str) -> tuple[str, str]:
"""End an agent session by marking it as completed.
Args:
db: Database session
agent_instance_id: Agent instance ID to end
user_id: Authenticated user ID
Returns:
Tuple of (agent_instance_id, final_status)
"""
instance = get_or_create_agent_instance(db, agent_instance_id, user_id)
instance.status = AgentStatus.COMPLETED
instance.ended_at = datetime.now(timezone.utc)
return str(instance.id), instance.status.value
def create_agent_message(
db: Session,
instance_id: UUID,
content: str,
requires_user_input: bool = False,
) -> Message:
"""Create a new agent message without committing"""
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
if instance and instance.status != AgentStatus.COMPLETED:
if requires_user_input:
instance.status = AgentStatus.AWAITING_INPUT
else:
instance.status = AgentStatus.ACTIVE
message = Message(
agent_instance_id=instance_id,
sender_type=SenderType.AGENT,
content=content,
requires_user_input=requires_user_input,
)
db.add(message)
db.flush() # Flush to get the message ID
# Update last read message
if instance:
instance.last_read_message_id = message.id
return message
async def wait_for_answer(
db: Session,
question_id: UUID,
timeout: int = 86400,
timeout_seconds: int = 86400, # 24 hours default
tool_context: Context | None = None,
) -> str | None:
"""
Wait for an answer to a question (async non-blocking)
Args:
db: Database session
question_id: Question ID to wait for
timeout: Maximum time to wait in seconds (default 24 hours)
Returns:
Answer text if received, None if timeout
"""
"""Wait for an answer to a question using polling"""
start_time = time.time()
last_progress_report = start_time
total_minutes = int(timeout / 60)
total_minutes = timeout_seconds // 60
# Report initial progress (0 minutes elapsed)
if tool_context:
await tool_context.report_progress(0, total_minutes)
# Get the question message
question = db.query(Message).filter(Message.id == question_id).first()
if not question or not question.requires_user_input:
return None
while time.time() - start_time < timeout:
# Check for answer
db.commit() # Ensure we see latest data
question = (
db.query(AgentQuestion).filter(AgentQuestion.id == question_id).first()
while time.time() - start_time < timeout_seconds:
# Check if agent has moved on (last read message changed)
instance = (
db.query(AgentInstance)
.filter(AgentInstance.id == question.agent_instance_id)
.first()
)
if question and question.answer_text is not None:
# If last_read_message_id has changed from our question, agent has moved on
if instance and instance.last_read_message_id != question_id:
return None
# Check for a user message after this question
answer = (
db.query(Message)
.filter(
Message.agent_instance_id == question.agent_instance_id,
Message.sender_type == SenderType.USER,
Message.created_at > question.created_at,
)
.order_by(Message.created_at)
.first()
)
if answer:
# Update last read message to this answer
if instance:
instance.last_read_message_id = answer.id
if tool_context:
await tool_context.report_progress(total_minutes, total_minutes)
return question.answer_text
return answer.content
# Report progress every minute if tool_context is provided
current_time = time.time()
@@ -285,59 +216,178 @@ async def wait_for_answer(
await asyncio.sleep(1)
# Timeout - mark question as inactive
db.query(AgentQuestion).filter(AgentQuestion.id == question_id).update(
{"is_active": False}
)
db.commit()
return None
def get_question(db: Session, question_id: str) -> AgentQuestion | None:
"""Get a question by ID"""
return db.query(AgentQuestion).filter(AgentQuestion.id == question_id).first()
def get_queued_user_messages(
db: Session, instance_id: UUID, last_read_message_id: UUID | None = None
) -> list[Message] | None:
"""Get all user messages since the agent last read them.
Args:
db: Database session
instance_id: Agent instance ID
last_read_message_id: The message ID the agent last read (optional)
Returns:
- None if last_read_message_id doesn't match the instance's current last_read_message_id
- Empty list if no new messages
- List of messages if there are new user messages
"""
# Get the agent instance to check last read message
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
if not instance:
return []
if (
last_read_message_id is not None
and instance.last_read_message_id != last_read_message_id
):
return None
# If no last read message, get all user messages
if not instance.last_read_message_id:
messages = (
db.query(Message)
.filter(
Message.agent_instance_id == instance_id,
Message.sender_type == SenderType.USER,
)
.order_by(Message.created_at)
.all()
)
else:
last_read_message = (
db.query(Message)
.filter(Message.id == instance.last_read_message_id)
.first()
)
if not last_read_message:
return []
# Get all user messages after the last read message
messages = (
db.query(Message)
.filter(
Message.agent_instance_id == instance_id,
Message.sender_type == SenderType.USER,
Message.created_at > last_read_message.created_at,
)
.order_by(Message.created_at)
.all()
)
if messages and last_read_message_id is not None:
instance.last_read_message_id = messages[-1].id
return messages
def get_and_mark_unretrieved_feedback(
db: Session, instance_id: UUID, since_time: datetime | None = None
) -> list[str]:
"""Get unretrieved user feedback for an agent instance and mark as retrieved"""
async def send_agent_message(
db: Session,
agent_instance_id: str,
content: str,
user_id: str,
agent_type: str | None = None,
requires_user_input: bool = False,
git_diff: str | None = None,
) -> tuple[str, str, list[Message]]:
"""High-level function to send an agent message and get queued user messages.
query = db.query(AgentUserFeedback).filter(
AgentUserFeedback.agent_instance_id == instance_id,
AgentUserFeedback.retrieved_at.is_(None),
This combines the common pattern of:
1. Getting or creating an agent instance
2. Validating access (if existing instance)
3. Creating a message
4. Updating git diff if provided
5. Getting any queued user messages
Args:
db: Database session
agent_instance_id: Agent instance ID (pass None to create new)
content: Message content
user_id: Authenticated user ID
agent_type: Type of agent (required if creating new instance)
requires_user_input: Whether this is a question requiring response
git_diff: Optional git diff to update on the instance
Returns:
Tuple of (agent_instance_id, message_id, list of queued user message contents)
"""
# Get or create instance using the unified function
instance = get_or_create_agent_instance(db, agent_instance_id, user_id, agent_type)
# Update git diff if provided (but don't commit yet)
if git_diff is not None:
sanitized_diff = sanitize_git_diff(git_diff)
if sanitized_diff is not None: # Allow empty string (cleared diff)
instance.git_diff = sanitized_diff
else:
logger.warning(
f"Invalid git diff format for instance {instance.id}, skipping git diff update"
)
queued_messages = get_queued_user_messages(db, instance.id, None)
# Create the message (this will update last_read_message_id)
message = create_agent_message(
db=db,
instance_id=instance.id,
content=content,
requires_user_input=requires_user_input,
)
if since_time:
query = query.filter(AgentUserFeedback.created_at > since_time)
# Handle the None case (shouldn't happen here since we just created the message)
if queued_messages is None:
queued_messages = []
feedback_list = query.order_by(AgentUserFeedback.created_at).all()
# Mark all feedback as retrieved
for feedback in feedback_list:
feedback.retrieved_at = datetime.now(timezone.utc)
db.commit()
return [feedback.feedback_text for feedback in feedback_list]
return str(instance.id), str(message.id), queued_messages
def end_session(db: Session, instance_id: UUID) -> AgentInstance:
"""End an agent session by marking it as completed"""
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
def create_user_message(
db: Session,
agent_instance_id: str,
content: str,
user_id: str,
mark_as_read: bool = True,
) -> tuple[str, bool]:
"""Create a user message for an agent instance.
Args:
db: Database session
agent_instance_id: Agent instance ID to send the message to
content: Message content
user_id: Authenticated user ID
mark_as_read: Whether to update last_read_message_id (default: True)
Returns:
Tuple of (message_id, marked_as_read)
Raises:
ValueError: If instance not found or user doesn't have access
"""
# Get the instance and validate access
instance = get_agent_instance(db, agent_instance_id)
if not instance:
raise ValueError(f"Agent instance {instance_id} not found")
raise ValueError("Agent instance not found")
# Update status to completed
instance.status = AgentStatus.COMPLETED
instance.ended_at = datetime.now(timezone.utc)
if str(instance.user_id) != user_id:
raise ValueError(
"Access denied. Agent instance does not belong to authenticated user."
)
# Mark any active questions as inactive
db.query(AgentQuestion).filter(
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
).update({"is_active": False})
# Create the user message
message = Message(
agent_instance_id=UUID(agent_instance_id),
sender_type=SenderType.USER,
content=content,
requires_user_input=False,
)
db.add(message)
db.flush() # Get the message ID
db.commit()
db.refresh(instance)
return instance
# Update last_read_message_id if requested
if mark_as_read:
instance.last_read_message_id = message.id
return str(message.id), mark_as_read

View File

@@ -27,7 +27,6 @@ class NotificationServiceBase(ABC):
db: Session,
user_id: UUID,
instance_id: str,
question_id: str,
agent_name: str,
question_text: str,
**kwargs,
@@ -41,7 +40,6 @@ class NotificationServiceBase(ABC):
db: Session,
user_id: UUID,
instance_id: str,
step_number: int,
agent_name: str,
step_description: str,
**kwargs,

View File

@@ -0,0 +1,111 @@
"""Notification utilities for sending push, email, and SMS notifications."""
import logging
from uuid import UUID
from sqlalchemy.orm import Session
from shared.database import User, AgentInstance
from .notifications import push_service
from .twilio_service import twilio_service
logger = logging.getLogger(__name__)
async def send_message_notifications(
db: Session,
instance_id: UUID,
content: str,
requires_user_input: bool,
send_email: bool | None = None,
send_sms: bool | None = None,
send_push: bool | None = None,
) -> None:
"""Send notifications for a message (either step or question).
Args:
db: Database session
instance_id: Agent instance ID
content: Message content
requires_user_input: Whether this message requires user input
send_email: Override email notification preference
send_sms: Override SMS notification preference
send_push: Override push notification preference
"""
# Get instance and user
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
if not instance:
logger.warning(f"Instance {instance_id} not found for notifications")
return
user = db.query(User).filter(User.id == instance.user_id).first()
if not user:
logger.warning(f"User {instance.user_id} not found for notifications")
return
agent_name = instance.user_agent.name if instance.user_agent else "Agent"
# Determine notification preferences based on message type
if requires_user_input:
# For questions: respect user preferences
should_send_push = (
send_push if send_push is not None else user.push_notifications_enabled
)
should_send_email = (
send_email if send_email is not None else user.email_notifications_enabled
)
should_send_sms = (
send_sms if send_sms is not None else user.sms_notifications_enabled
)
else:
# For steps: notifications default to False unless explicitly enabled
should_send_push = send_push if send_push is not None else False
should_send_email = send_email if send_email is not None else False
should_send_sms = send_sms if send_sms is not None else False
# Send push notification if enabled
if should_send_push:
try:
if requires_user_input:
await push_service.send_question_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
agent_name=agent_name,
question_text=content,
)
else:
await push_service.send_step_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
agent_name=agent_name,
step_description=content,
)
except Exception as e:
logger.error(f"Failed to send push notification: {e}")
# Send Twilio notifications if enabled
if should_send_email or should_send_sms:
try:
if requires_user_input:
await twilio_service.send_question_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
agent_name=agent_name,
question_text=content,
send_email=should_send_email,
send_sms=should_send_sms,
)
else:
await twilio_service.send_step_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),
agent_name=agent_name,
step_description=content,
send_email=should_send_email,
send_sms=should_send_sms,
)
except Exception as e:
logger.error(f"Failed to send Twilio notification: {e}")

View File

@@ -189,7 +189,6 @@ class PushNotificationService(NotificationServiceBase):
db: Session,
user_id: UUID,
instance_id: str,
question_id: str,
agent_name: str,
question_text: str,
) -> bool:
@@ -206,7 +205,6 @@ class PushNotificationService(NotificationServiceBase):
data = {
"type": "new_question",
"instanceId": instance_id,
"questionId": question_id,
}
return await self.send_notification(
@@ -222,14 +220,13 @@ class PushNotificationService(NotificationServiceBase):
db: Session,
user_id: UUID,
instance_id: str,
step_number: int,
agent_name: str,
step_description: str,
) -> bool:
"""Send notification for new agent step"""
# Format agent name for display
display_name = agent_name.replace("_", " ").title()
title = f"{display_name} - Step {step_number}"
title = f"{display_name} - New Step"
# Truncate step description for notification
body = step_description
@@ -239,7 +236,6 @@ class PushNotificationService(NotificationServiceBase):
data = {
"type": "new_step",
"instanceId": instance_id,
"stepNumber": step_number,
}
return await self.send_notification(

View File

@@ -172,12 +172,11 @@ class TwilioNotificationService(NotificationServiceBase):
logger.error(f"Error sending Twilio notifications: {e}")
return results
def send_question_notification(
async def send_question_notification(
self,
db: Session,
user_id: UUID,
instance_id: str,
question_id: str,
agent_name: str,
question_text: str,
send_email: bool | None = None,
@@ -218,7 +217,6 @@ The Omnara Team
db: Session,
user_id: UUID,
instance_id: str,
step_number: int,
agent_name: str,
step_description: str,
send_email: bool | None = None,
@@ -227,13 +225,13 @@ The Omnara Team
"""Send notification for new agent step"""
# Format agent name for display
display_name = agent_name.replace("_", " ").title()
title = f"{display_name} - Step {step_number}"
title = f"{display_name} - New Step"
# Email body with more detail
email_body = f"""
Your agent {display_name} has logged a new step:
Step {step_number}: {step_description}
{step_description}
You can view the full session at: https://omnara.com/dashboard/instances/{instance_id}
@@ -242,9 +240,9 @@ The Omnara Team
"""
# SMS body (shorter)
sms_body = f"{display_name} Step {step_number}: {step_description}"
sms_body = f"{display_name}: {step_description}"
if len(sms_body) > 160:
sms_body = f"{display_name} Step {step_number}: {step_description[:140]}..."
sms_body = f"{display_name}: {step_description[:140]}..."
return self.send_notification(
db=db,

File diff suppressed because it is too large Load Diff

View File

@@ -1,274 +0,0 @@
"""Integration tests using PostgreSQL."""
import pytest
from datetime import datetime, timezone
from uuid import uuid4
# Database fixtures come from conftest.py
# Import the real models
from shared.database.models import (
User,
UserAgent,
AgentInstance,
AgentStep,
AgentQuestion,
AgentUserFeedback,
)
from shared.database.enums import AgentStatus
# Import the core functions we want to test
from servers.shared.core import (
process_log_step,
create_agent_question,
process_end_session,
)
# Using test_db fixture from conftest.py which provides PostgreSQL via testcontainers
@pytest.fixture
def test_user(test_db):
"""Create a test user."""
user = User(
id=uuid4(),
email="integration@test.com",
display_name="Integration Test User",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
test_db.add(user)
test_db.commit()
return user
@pytest.fixture
def test_user_agent(test_db, test_user):
"""Create a test user agent."""
user_agent = UserAgent(
id=uuid4(),
user_id=test_user.id,
name="Claude Code Test",
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
test_db.add(user_agent)
test_db.commit()
return user_agent
class TestIntegrationFlow:
"""Test the complete integration flow with PostgreSQL."""
@pytest.mark.integration
@pytest.mark.asyncio
async def test_complete_agent_session_flow(
self, test_db, test_user, test_user_agent
):
"""Test a complete agent session from start to finish."""
# Step 1: Create new agent instance
instance_id, step_number, user_feedback = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=None,
agent_type="Claude Code Test",
step_description="Starting integration test task",
)
assert instance_id is not None
assert step_number == 1
assert user_feedback == []
# Verify instance was created in database
instance = test_db.query(AgentInstance).filter_by(id=instance_id).first()
assert instance is not None
assert instance.status == AgentStatus.ACTIVE
assert instance.user_id == test_user.id
# Step 2: Log another step
_, step_number2, _ = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=instance_id,
agent_type="Claude Code Test",
step_description="Processing files",
)
assert step_number2 == 2
# Step 3: Create a question
question = await create_agent_question(
db=test_db,
agent_instance_id=instance_id,
question_text="Should I refactor this module?",
user_id=str(test_user.id),
)
assert question is not None
question_id = question.id
# Verify question in database
question = test_db.query(AgentQuestion).filter_by(id=question_id).first()
assert question is not None
assert question.question_text == "Should I refactor this module?"
assert question.is_active is True
# Step 4: Add user feedback
feedback = AgentUserFeedback(
id=uuid4(),
agent_instance_id=instance_id,
created_by_user_id=test_user.id,
feedback_text="Please use async/await pattern",
created_at=datetime.now(timezone.utc),
)
test_db.add(feedback)
test_db.commit()
# Step 5: Next log_step should retrieve feedback
_, step_number3, feedback_list = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=instance_id,
agent_type="Claude Code Test",
step_description="Implementing async pattern",
)
assert step_number3 == 3
assert len(feedback_list) == 1
assert feedback_list[0] == "Please use async/await pattern"
# Verify feedback was marked as retrieved
test_db.refresh(feedback)
assert feedback.retrieved_at is not None
# Step 6: End the session
ended_instance_id, final_status = process_end_session(
db=test_db, agent_instance_id=instance_id, user_id=str(test_user.id)
)
assert ended_instance_id == instance_id
assert final_status == "completed"
# Verify final state
test_db.refresh(instance)
assert instance.status == AgentStatus.COMPLETED
assert instance.ended_at is not None
# Verify questions were deactivated
test_db.refresh(question)
assert question.is_active is False
# Verify all steps were logged
steps = (
test_db.query(AgentStep)
.filter_by(agent_instance_id=instance_id)
.order_by(AgentStep.step_number)
.all()
)
assert len(steps) == 3
assert steps[0].description == "Starting integration test task"
assert steps[1].description == "Processing files"
assert steps[2].description == "Implementing async pattern"
@pytest.mark.integration
def test_multiple_feedback_handling(self, test_db, test_user, test_user_agent):
"""Test handling multiple feedback items."""
# Create instance
instance_id, _, _ = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=None,
agent_type="Claude Code Test",
step_description="Starting task",
)
# Add multiple feedback items
feedback_items = []
for i in range(3):
feedback = AgentUserFeedback(
id=uuid4(),
agent_instance_id=instance_id,
created_by_user_id=test_user.id,
feedback_text=f"Feedback {i + 1}",
created_at=datetime.now(timezone.utc),
)
feedback_items.append(feedback)
test_db.add(feedback)
test_db.commit()
# Next log_step should retrieve all feedback
_, _, feedback_list = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=instance_id,
agent_type="Claude Code Test",
step_description="Processing feedback",
)
assert len(feedback_list) == 3
assert set(feedback_list) == {"Feedback 1", "Feedback 2", "Feedback 3"}
# All feedback should be marked as retrieved
for feedback in feedback_items:
test_db.refresh(feedback)
assert feedback.retrieved_at is not None
# Subsequent log_step should not retrieve same feedback
_, _, feedback_list2 = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=instance_id,
agent_type="Claude Code Test",
step_description="Continuing work",
)
assert len(feedback_list2) == 0
@pytest.mark.integration
def test_user_agent_creation_and_reuse(self, test_db, test_user):
"""Test that user agents are created and reused correctly."""
# First call should create a new user agent
instance1_id, _, _ = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=None,
agent_type="New Agent Type",
step_description="First task",
)
# Check user agent was created (name is stored in lowercase)
user_agents = (
test_db.query(UserAgent)
.filter_by(user_id=test_user.id, name="new agent type")
.all()
)
assert len(user_agents) == 1
# Second call with same agent type should reuse the user agent
instance2_id, _, _ = process_log_step(
db=test_db,
user_id=str(test_user.id),
agent_instance_id=None,
agent_type="New Agent Type",
step_description="Second task",
)
# Should still only have one user agent (name is stored in lowercase)
user_agents = (
test_db.query(UserAgent)
.filter_by(user_id=test_user.id, name="new agent type")
.all()
)
assert len(user_agents) == 1
# But two different instances
assert instance1_id != instance2_id
# Both instances should reference the same user agent
instance1 = test_db.query(AgentInstance).filter_by(id=instance1_id).first()
instance2 = test_db.query(AgentInstance).filter_by(id=instance2_id).first()
assert instance1.user_agent_id == instance2.user_agent_id

View File

@@ -3,12 +3,8 @@
from datetime import datetime, timezone
from uuid import uuid4
from shared.database.models import (
AgentStep,
AgentQuestion,
AgentUserFeedback,
)
from shared.database.enums import AgentStatus
from shared.database.models import Message
from shared.database.enums import AgentStatus, SenderType
class TestDatabaseModels:
@@ -21,117 +17,125 @@ class TestDatabaseModels:
assert test_agent_instance.started_at is not None
assert test_agent_instance.ended_at is None
def test_create_agent_step(self, test_db, test_agent_instance):
"""Test creating agent steps."""
step1 = AgentStep(
def test_create_agent_messages(self, test_db, test_agent_instance):
"""Test creating agent messages."""
message1 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
step_number=1,
description="First step",
sender_type=SenderType.AGENT,
content="First step",
created_at=datetime.now(timezone.utc),
requires_user_input=False,
)
step2 = AgentStep(
message2 = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
step_number=2,
description="Second step",
sender_type=SenderType.AGENT,
content="Second step",
created_at=datetime.now(timezone.utc),
requires_user_input=False,
)
test_db.add_all([step1, step2])
test_db.add_all([message1, message2])
test_db.commit()
# Query steps
steps = (
test_db.query(AgentStep)
# Query messages
messages = (
test_db.query(Message)
.filter_by(agent_instance_id=test_agent_instance.id)
.order_by(AgentStep.step_number)
.order_by(Message.created_at)
.all()
)
assert len(steps) == 2
assert steps[0].step_number == 1
assert steps[0].description == "First step"
assert steps[1].step_number == 2
assert steps[1].description == "Second step"
assert len(messages) == 2
assert messages[0].content == "First step"
assert messages[1].content == "Second step"
assert all(msg.sender_type == SenderType.AGENT for msg in messages)
def test_create_agent_question(self, test_db, test_agent_instance):
"""Test creating agent questions."""
question = AgentQuestion(
def test_create_agent_question_message(self, test_db, test_agent_instance):
"""Test creating agent question as a message."""
question = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Should I continue?",
asked_at=datetime.now(timezone.utc),
is_active=True,
sender_type=SenderType.AGENT,
content="Should I continue?",
created_at=datetime.now(timezone.utc),
requires_user_input=True,
)
test_db.add(question)
test_db.commit()
# Query question
saved_question = test_db.query(AgentQuestion).filter_by(id=question.id).first()
saved_question = test_db.query(Message).filter_by(id=question.id).first()
assert saved_question is not None
assert saved_question.question_text == "Should I continue?"
assert saved_question.is_active is True
assert saved_question.answer_text is None
assert saved_question.answered_at is None
assert saved_question.content == "Should I continue?"
assert saved_question.requires_user_input is True
assert saved_question.sender_type == SenderType.AGENT
def test_create_user_feedback(self, test_db, test_agent_instance):
"""Test creating user feedback."""
feedback = AgentUserFeedback(
def test_create_user_feedback_message(self, test_db, test_agent_instance):
"""Test creating user feedback as a message."""
feedback = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
created_by_user_id=test_agent_instance.user_id,
feedback_text="Please use TypeScript",
sender_type=SenderType.USER,
content="Please use TypeScript",
created_at=datetime.now(timezone.utc),
requires_user_input=False,
message_metadata={
"source": "user_feedback",
"created_by_user_id": str(test_agent_instance.user_id),
},
)
test_db.add(feedback)
test_db.commit()
# Query feedback
saved_feedback = (
test_db.query(AgentUserFeedback).filter_by(id=feedback.id).first()
)
saved_feedback = test_db.query(Message).filter_by(id=feedback.id).first()
assert saved_feedback is not None
assert saved_feedback.feedback_text == "Please use TypeScript"
assert saved_feedback.retrieved_at is None
assert saved_feedback.content == "Please use TypeScript"
assert saved_feedback.sender_type == SenderType.USER
assert saved_feedback.message_metadata["created_by_user_id"] == str(
test_agent_instance.user_id
)
def test_agent_instance_relationships(self, test_db, test_agent_instance):
"""Test agent instance relationships."""
# Add a step
step = AgentStep(
def test_agent_instance_message_relationships(self, test_db, test_agent_instance):
"""Test agent instance message relationships."""
# Add an agent message
agent_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
step_number=1,
description="Test step",
sender_type=SenderType.AGENT,
content="Test step",
created_at=datetime.now(timezone.utc),
requires_user_input=False,
)
# Add a question
question = AgentQuestion(
# Add a question message
question_msg = Message(
id=uuid4(),
agent_instance_id=test_agent_instance.id,
question_text="Test question?",
asked_at=datetime.now(timezone.utc),
is_active=True,
sender_type=SenderType.AGENT,
content="Test question?",
created_at=datetime.now(timezone.utc),
requires_user_input=True,
)
test_db.add_all([step, question])
test_db.add_all([agent_msg, question_msg])
test_db.commit()
# Refresh instance to load relationships
test_db.refresh(test_agent_instance)
# Test relationships
assert len(test_agent_instance.steps) == 1
assert test_agent_instance.steps[0].description == "Test step"
assert len(test_agent_instance.questions) == 1
assert test_agent_instance.questions[0].question_text == "Test question?"
assert len(test_agent_instance.messages) == 2
assert test_agent_instance.messages[0].content == "Test step"
assert test_agent_instance.messages[1].content == "Test question?"
assert test_agent_instance.messages[1].requires_user_input is True
class TestAgentStatusTransitions:

View File

@@ -0,0 +1,293 @@
"""Add messages table
Revision ID: 40d4252deb5b
Revises: dc285eabea90
Create Date: 2025-07-31 11:25:30.567076
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "40d4252deb5b"
down_revision: Union[str, None] = "dc285eabea90"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"messages",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
sa.Column(
"sender_type", sa.Enum("AGENT", "USER", name="sendertype"), nullable=False
),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("requires_user_input", sa.Boolean(), nullable=False),
sa.Column(
"message_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.ForeignKeyConstraint(
["agent_instance_id"],
["agent_instances.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"idx_messages_instance_created",
"messages",
["agent_instance_id", "created_at"],
unique=False,
)
# Migrate data from old tables to messages table
# This combines AgentStep, AgentQuestion, and AgentUserFeedback into Messages
# ordered by timestamp
op.execute("""
-- Insert agent steps as agent messages
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
SELECT
id,
agent_instance_id,
'AGENT',
description,
created_at,
false,
jsonb_build_object('source', 'agent_step', 'step_number', step_number)
FROM agent_steps;
-- Insert agent questions as agent messages
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
SELECT
id,
agent_instance_id,
'AGENT',
question_text,
asked_at,
CASE WHEN answer_text IS NULL THEN true ELSE false END,
jsonb_build_object('source', 'agent_question')
FROM agent_questions;
-- Insert question answers as user messages (only for answered questions)
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
SELECT
gen_random_uuid(),
agent_instance_id,
'USER',
answer_text,
answered_at,
false,
jsonb_build_object('source', 'question_answer', 'question_id', id::text, 'answered_by_user_id', answered_by_user_id::text)
FROM agent_questions
WHERE answer_text IS NOT NULL AND answered_at IS NOT NULL;
-- Insert user feedback as user messages
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
SELECT
id,
agent_instance_id,
'USER',
feedback_text,
created_at,
false,
jsonb_build_object('source', 'user_feedback', 'created_by_user_id', created_by_user_id::text)
FROM agent_user_feedback;
-- Mark all agent instances as completed
UPDATE agent_instances SET status = 'COMPLETED' WHERE status != 'COMPLETED';
""")
# Add last_read_message_id to agent_instances
op.add_column(
"agent_instances",
sa.Column("last_read_message_id", postgresql.UUID(as_uuid=True), nullable=True),
)
# Create the foreign key constraint after both tables exist to avoid circular dependency
op.create_foreign_key(
"fk_agent_instances_last_read_message",
"agent_instances",
"messages",
["last_read_message_id"],
["id"],
ondelete="SET NULL",
)
# Create combined function for message notifications
op.execute("""
CREATE OR REPLACE FUNCTION notify_message_change() RETURNS trigger AS $$
DECLARE
channel_name text;
payload text;
event_type text;
BEGIN
-- Create channel name based on instance ID
channel_name := 'message_channel_' || NEW.agent_instance_id::text;
-- Determine event type
IF TG_OP = 'INSERT' THEN
event_type := 'message_insert';
ELSIF TG_OP = 'UPDATE' THEN
event_type := 'message_update';
END IF;
-- Create JSON payload with message data
payload := json_build_object(
'event_type', event_type,
'id', NEW.id,
'agent_instance_id', NEW.agent_instance_id,
'sender_type', NEW.sender_type,
'content', NEW.content,
'created_at', to_char(NEW.created_at AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'),
'requires_user_input', NEW.requires_user_input,
'message_metadata', NEW.message_metadata,
'old_requires_user_input', CASE
WHEN TG_OP = 'UPDATE' THEN OLD.requires_user_input
ELSE NULL
END
)::text;
-- Send notification (quote channel name for UUIDs with hyphens)
EXECUTE format('NOTIFY %I, %L', channel_name, payload);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
# Create single trigger for both INSERT and UPDATE on messages table
op.execute("""
CREATE TRIGGER message_change_notify
AFTER INSERT OR UPDATE ON messages
FOR EACH ROW
EXECUTE FUNCTION notify_message_change();
""")
# Create function for status change notifications
op.execute("""
CREATE OR REPLACE FUNCTION notify_status_change() RETURNS trigger AS $$
DECLARE
channel_name text;
payload text;
BEGIN
-- Only notify if status actually changed
IF OLD.status IS DISTINCT FROM NEW.status THEN
-- Create channel name based on instance ID
channel_name := 'message_channel_' || NEW.id::text;
-- Create JSON payload with status update data
payload := json_build_object(
'event_type', 'status_update',
'instance_id', NEW.id,
'status', NEW.status,
'timestamp', to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"')
)::text;
-- Send notification (quote channel name for UUIDs with hyphens)
EXECUTE format('NOTIFY %I, %L', channel_name, payload);
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
# Create trigger on agent_instances table for status updates
op.execute("""
CREATE TRIGGER agent_instance_status_notify
AFTER UPDATE OF status ON agent_instances
FOR EACH ROW
EXECUTE FUNCTION notify_status_change();
""")
# Drop the old tables after migration
op.drop_table("agent_user_feedback")
op.drop_table("agent_questions")
op.drop_table("agent_steps")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Recreate the old tables first
op.create_table(
"agent_steps",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
sa.Column("step_number", sa.Integer(), nullable=False),
sa.Column("description", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["agent_instance_id"],
["agent_instances.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent_questions",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
sa.Column("question_text", sa.Text(), nullable=False),
sa.Column("answer_text", sa.Text(), nullable=True),
sa.Column("answered_by_user_id", sa.UUID(), nullable=True),
sa.Column("asked_at", sa.DateTime(), nullable=False),
sa.Column("answered_at", sa.DateTime(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["agent_instance_id"],
["agent_instances.id"],
),
sa.ForeignKeyConstraint(
["answered_by_user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent_user_feedback",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
sa.Column("created_by_user_id", sa.UUID(), nullable=False),
sa.Column("feedback_text", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("retrieved_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["agent_instance_id"],
["agent_instances.id"],
),
sa.ForeignKeyConstraint(
["created_by_user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Note: We cannot fully restore the original data as some information is lost
# (e.g., step_number ordering, which user created feedback vs answered questions)
# This is a best-effort restoration
# Drop status change trigger and function
op.execute(
"DROP TRIGGER IF EXISTS agent_instance_status_notify ON agent_instances;"
)
op.execute("DROP FUNCTION IF EXISTS notify_status_change();")
# Drop message trigger and function
op.execute("DROP TRIGGER IF EXISTS message_change_notify ON messages;")
op.execute("DROP FUNCTION IF EXISTS notify_message_change();")
op.drop_constraint(
"fk_agent_instances_last_read_message", "agent_instances", type_="foreignkey"
)
op.drop_column("agent_instances", "last_read_message_id")
op.drop_index("idx_messages_instance_created", table_name="messages")
op.drop_table("messages")
# ### end Alembic commands ###

View File

@@ -1,11 +1,9 @@
from .enums import AgentStatus
from .enums import AgentStatus, SenderType
from .models import (
AgentInstance,
AgentQuestion,
AgentStep,
AgentUserFeedback,
APIKey,
Base,
Message,
PushToken,
User,
UserAgent,
@@ -20,12 +18,11 @@ __all__ = [
"User",
"UserAgent",
"AgentInstance",
"AgentStep",
"AgentQuestion",
"AgentStatus",
"AgentUserFeedback",
"APIKey",
"Message",
"PushToken",
"SenderType",
"Subscription",
"BillingEvent",
]

View File

@@ -2,11 +2,16 @@ from enum import Enum
class AgentStatus(str, Enum):
ACTIVE = "active"
AWAITING_INPUT = "awaiting_input"
PAUSED = "paused"
STALE = "stale"
COMPLETED = "completed"
FAILED = "failed"
KILLED = "killed"
DISCONNECTED = "disconnected"
ACTIVE = "ACTIVE"
AWAITING_INPUT = "AWAITING_INPUT"
PAUSED = "PAUSED"
STALE = "STALE"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
KILLED = "KILLED"
DISCONNECTED = "DISCONNECTED"
class SenderType(str, Enum):
AGENT = "AGENT"
USER = "USER"

View File

@@ -3,7 +3,7 @@ from uuid import UUID, uuid4
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID, JSONB
from sqlalchemy.orm import (
DeclarativeBase, # type: ignore[attr-defined]
Mapped, # type: ignore[attr-defined]
@@ -12,7 +12,7 @@ from sqlalchemy.orm import (
validates,
)
from .enums import AgentStatus
from .enums import AgentStatus, SenderType
from .utils import is_valid_git_diff
if TYPE_CHECKING:
@@ -57,12 +57,6 @@ class User(Base):
agent_instances: Mapped[list["AgentInstance"]] = relationship(
"AgentInstance", back_populates="user"
)
answered_questions: Mapped[list["AgentQuestion"]] = relationship(
"AgentQuestion", back_populates="answered_by_user"
)
feedback: Mapped[list["AgentUserFeedback"]] = relationship(
"AgentUserFeedback", back_populates="created_by_user"
)
api_keys: Mapped[list["APIKey"]] = relationship("APIKey", back_populates="user")
user_agents: Mapped[list["UserAgent"]] = relationship(
"UserAgent", back_populates="user"
@@ -130,22 +124,33 @@ class AgentInstance(Base):
)
ended_at: Mapped[datetime | None] = mapped_column(default=None)
git_diff: Mapped[str | None] = mapped_column(Text, default=None)
last_read_message_id: Mapped[UUID | None] = mapped_column(
ForeignKey(
"messages.id",
use_alter=True,
name="fk_agent_instances_last_read_message",
ondelete="SET NULL",
),
type_=PostgresUUID(as_uuid=True),
default=None,
)
# Relationships
user_agent: Mapped["UserAgent"] = relationship(
"UserAgent", back_populates="instances"
)
user: Mapped["User"] = relationship("User", back_populates="agent_instances")
steps: Mapped[list["AgentStep"]] = relationship(
"AgentStep", back_populates="instance", order_by="AgentStep.created_at"
)
questions: Mapped[list["AgentQuestion"]] = relationship(
"AgentQuestion", back_populates="instance", order_by="AgentQuestion.asked_at"
)
user_feedback: Mapped[list["AgentUserFeedback"]] = relationship(
"AgentUserFeedback",
messages: Mapped[list["Message"]] = relationship(
"Message",
back_populates="instance",
order_by="AgentUserFeedback.created_at",
order_by="Message.created_at",
foreign_keys="Message.agent_instance_id",
)
last_read_message: Mapped["Message | None"] = relationship(
"Message",
foreign_keys=[last_read_message_id],
post_update=True,
passive_deletes=True,
)
@validates("git_diff")
@@ -154,7 +159,7 @@ class AgentInstance(Base):
Raises ValueError if the git diff is invalid.
"""
if value is None:
if value is None or value == "":
return value
if not is_valid_git_diff(value):
@@ -163,81 +168,6 @@ class AgentInstance(Base):
return value
class AgentStep(Base):
__tablename__ = "agent_steps"
id: Mapped[UUID] = mapped_column(
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
)
agent_instance_id: Mapped[UUID] = mapped_column(
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
)
step_number: Mapped[int] = mapped_column()
description: Mapped[str] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(
default=lambda: datetime.now(timezone.utc)
)
# Relationships
instance: Mapped["AgentInstance"] = relationship(
"AgentInstance", back_populates="steps"
)
class AgentQuestion(Base):
__tablename__ = "agent_questions"
id: Mapped[UUID] = mapped_column(
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
)
agent_instance_id: Mapped[UUID] = mapped_column(
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
)
question_text: Mapped[str] = mapped_column(Text)
answer_text: Mapped[str | None] = mapped_column(Text, default=None)
answered_by_user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("users.id"), type_=PostgresUUID(as_uuid=True), default=None
)
asked_at: Mapped[datetime] = mapped_column(
default=lambda: datetime.now(timezone.utc)
)
answered_at: Mapped[datetime | None] = mapped_column(default=None)
is_active: Mapped[bool] = mapped_column(default=True)
# Relationships
instance: Mapped["AgentInstance"] = relationship(
"AgentInstance", back_populates="questions"
)
answered_by_user: Mapped["User | None"] = relationship(
"User", back_populates="answered_questions"
)
class AgentUserFeedback(Base):
__tablename__ = "agent_user_feedback"
id: Mapped[UUID] = mapped_column(
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
)
agent_instance_id: Mapped[UUID] = mapped_column(
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
)
created_by_user_id: Mapped[UUID] = mapped_column(
ForeignKey("users.id"), type_=PostgresUUID(as_uuid=True)
)
feedback_text: Mapped[str] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(
default=lambda: datetime.now(timezone.utc)
)
retrieved_at: Mapped[datetime | None] = mapped_column(default=None)
# Relationships
instance: Mapped["AgentInstance"] = relationship(
"AgentInstance", back_populates="user_feedback"
)
created_by_user: Mapped["User"] = relationship("User", back_populates="feedback")
class APIKey(Base):
__tablename__ = "api_keys"
@@ -286,3 +216,31 @@ class PushToken(Base):
# Relationships
user: Mapped["User"] = relationship("User", back_populates="push_tokens")
class Message(Base):
__tablename__ = "messages"
__table_args__ = (
Index("idx_messages_instance_created", "agent_instance_id", "created_at"),
)
id: Mapped[UUID] = mapped_column(
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
)
agent_instance_id: Mapped[UUID] = mapped_column(
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
)
sender_type: Mapped[SenderType] = mapped_column()
content: Mapped[str] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(
default=lambda: datetime.now(timezone.utc)
)
requires_user_input: Mapped[bool] = mapped_column(default=False)
message_metadata: Mapped[dict | None] = mapped_column(JSONB, default=None)
# Relationships
instance: Mapped["AgentInstance"] = relationship(
"AgentInstance",
back_populates="messages",
foreign_keys=[agent_instance_id],
)

View File

@@ -339,6 +339,7 @@ class WebhookRequest(BaseModel):
prompt: str
name: str | None = None # Branch name
worktree_name: str | None = None
agent_type: str | None = None # Agent type name
@field_validator("agent_instance_id")
def validate_instance_id(cls, v):
@@ -589,6 +590,7 @@ async def start_claude(
prompt = webhook_data.prompt
worktree_name = webhook_data.worktree_name
branch_name = webhook_data.name
agent_type = webhook_data.agent_type
print("\n[INFO] Received webhook request:")
print(f" - Instance ID: {agent_instance_id}")
@@ -806,21 +808,26 @@ async def start_claude(
mcp_config = {
"mcpServers": {
"omnara": {
"command": "pipx",
"command": "omnara",
"args": [
"run",
"--no-cache",
"omnara",
"--api-key",
omnara_api_key,
"--claude-code-permission-tool",
"--git-diff",
"--agent-instance-id",
agent_instance_id,
"--base-url",
"http://localhost:8080",
],
}
}
}
# Add environment variable for agent type if provided
if agent_type:
mcp_config["mcpServers"]["omnara"]["env"] = {
"OMNARA_CLIENT_TYPE": agent_type
}
mcp_config_str = json.dumps(mcp_config)
# Build claude command with MCP config as string

1247
webhooks/claude_wrapper_v3.py Executable file

File diff suppressed because it is too large Load Diff