mirror of
https://github.com/omnara-ai/omnara.git
synced 2025-08-12 20:39:09 +03:00
Refactor (#45)
* progress * refactor * backend refactor * test * minor stdio changes * streeeeeaming * stream statuses * change required user input * something * progress * new beginnings * progress * edge case * uuhh the claude wrapper def worked before this commit * progress * ai slop that works * tests * tests * plan mode handling * progress * clean up * bump * migrations * readmes * no text truncation * consistent poll interval * fix time --------- Co-authored-by: Kartik Sarangmath <kartiksarangmath@Kartiks-MacBook-Air.local>
This commit is contained in:
10
CLAUDE.md
10
CLAUDE.md
@@ -40,7 +40,8 @@ omnara/
|
||||
- **PostgreSQL** with **SQLAlchemy 2.0+**
|
||||
- **Alembic** for migrations - ALWAYS create migrations for schema changes
|
||||
- Multi-tenant design - all data is scoped by user_id
|
||||
- Key tables: users, agent_types, agent_instances, agent_steps, agent_questions, agent_user_feedback, api_keys
|
||||
- Key tables: users, user_agents, agent_instances, messages, api_keys
|
||||
- **Unified messaging system**: All agent interactions (steps, questions, feedback) are now stored in the `messages` table with `sender_type` and `requires_user_input` fields
|
||||
|
||||
### Server Architecture
|
||||
- **Unified server** (`servers/app.py`) supports both MCP and REST
|
||||
@@ -107,6 +108,13 @@ make test-integration # Integration tests (needs Docker)
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Working with Messages
|
||||
The unified messaging system uses a single `messages` table:
|
||||
- **Agent messages**: Set `sender_type=AGENT`, use `requires_user_input=True` for questions
|
||||
- **User messages**: Set `sender_type=USER` for feedback/responses
|
||||
- **Reading messages**: Use `last_read_message_id` to track reading progress
|
||||
- **Queued messages**: Agent receives unread user messages when sending new messages
|
||||
|
||||
### Adding a New API Endpoint
|
||||
1. Add route in `backend/api/` or `servers/fastapi_server/routers.py`
|
||||
2. Create Pydantic models for request/response in `models.py`
|
||||
|
||||
18
README.md
18
README.md
@@ -186,25 +186,27 @@ client = OmnaraClient(api_key="your-api-key")
|
||||
instance_id = str(uuid.uuid4())
|
||||
|
||||
# Log progress and check for user feedback
|
||||
response = client.log_step(
|
||||
response = client.send_message(
|
||||
agent_type="claude-code",
|
||||
step_description="Analyzing codebase structure",
|
||||
agent_instance_id=instance_id
|
||||
content="Analyzing codebase structure",
|
||||
agent_instance_id=instance_id,
|
||||
requires_user_input=False
|
||||
)
|
||||
|
||||
# Ask for user input when needed
|
||||
answer = client.ask_question(
|
||||
question="Should I refactor this legacy module?",
|
||||
agent_instance_id=instance_id
|
||||
answer = client.send_message(
|
||||
content="Should I refactor this legacy module?",
|
||||
agent_instance_id=instance_id,
|
||||
requires_user_input=True
|
||||
)
|
||||
```
|
||||
|
||||
### Method 3: REST API
|
||||
```bash
|
||||
curl -X POST https://api.omnara.ai/api/v1/steps \
|
||||
curl -X POST https://api.omnara.ai/api/v1/messages/agent \
|
||||
-H "Authorization: Bearer YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"step_description": "Starting deployment process"}'
|
||||
-d '{"content": "Starting deployment process", "agent_type": "claude-code", "requires_user_input": false}'
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
@@ -27,7 +27,7 @@ The backend provides a REST API for accessing and managing agent-related data. I
|
||||
## Key Features
|
||||
|
||||
- **Agent Monitoring** - View agent types, instances, and execution history
|
||||
- **User Interactions** - Handle questions from agents and user feedback
|
||||
- **Unified Messaging** - All agent interactions (steps, questions, feedback) through a single messaging system
|
||||
- **Multi-tenancy** - User-scoped data isolation and access control
|
||||
- **Authentication** - Support for both web dashboard users and programmatic agent access
|
||||
- **User Agent Management** - Custom agent configurations and webhook integrations
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from uuid import UUID
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from shared.database.models import User
|
||||
from shared.database.session import get_db
|
||||
from shared.database.enums import AgentStatus
|
||||
from sqlalchemy.orm import Session
|
||||
import asyncpg
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..db import (
|
||||
@@ -14,14 +19,14 @@ from ..db import (
|
||||
get_all_agent_instances,
|
||||
get_all_agent_types_with_instances,
|
||||
mark_instance_completed,
|
||||
submit_user_feedback,
|
||||
submit_user_message,
|
||||
)
|
||||
from ..models import (
|
||||
AgentInstanceDetail,
|
||||
AgentInstanceResponse,
|
||||
AgentTypeOverview,
|
||||
UserFeedbackRequest,
|
||||
UserFeedbackResponse,
|
||||
UserMessageRequest,
|
||||
UserMessageResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["agents"])
|
||||
@@ -87,21 +92,129 @@ async def get_instance_detail(
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agent-instances/{instance_id}/feedback", response_model=UserFeedbackResponse
|
||||
"/agent-instances/{instance_id}/messages", response_model=UserMessageResponse
|
||||
)
|
||||
async def add_user_feedback(
|
||||
async def create_user_message(
|
||||
instance_id: UUID,
|
||||
request: UserFeedbackRequest,
|
||||
request: UserMessageRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Submit user feedback for an agent instance for the current user"""
|
||||
result = submit_user_feedback(db, instance_id, request.feedback, current_user.id)
|
||||
"""Send a message to an agent instance (answers questions or provides feedback)"""
|
||||
result = submit_user_message(db, instance_id, request.content, current_user.id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Agent instance not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/agent-instances/{instance_id}/messages/stream")
|
||||
async def stream_messages(
|
||||
request: Request,
|
||||
instance_id: UUID,
|
||||
token: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Stream new messages for an agent instance using Server-Sent Events"""
|
||||
# Handle SSE authentication - token comes from query param
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Token required for SSE")
|
||||
|
||||
try:
|
||||
# Verify token and get user
|
||||
from ..auth.supabase_client import get_supabase_anon_client
|
||||
|
||||
supabase = get_supabase_anon_client()
|
||||
user_response = supabase.auth.get_user(token)
|
||||
|
||||
if not user_response or not user_response.user:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
user_id = UUID(user_response.user.id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
|
||||
|
||||
# Verify the user has access to this instance
|
||||
instance = get_agent_instance_detail(db, instance_id, user_id)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Agent instance not found")
|
||||
|
||||
async def message_generator() -> AsyncGenerator[str, None]:
|
||||
# Import settings here to avoid circular imports
|
||||
from shared.config.settings import settings
|
||||
|
||||
# Create connection to PostgreSQL for LISTEN/NOTIFY
|
||||
conn = await asyncpg.connect(settings.database_url)
|
||||
try:
|
||||
# Listen to the channel for this instance
|
||||
channel_name = f"message_channel_{instance_id}"
|
||||
|
||||
# Execute LISTEN command (quote channel name for UUIDs with hyphens)
|
||||
await conn.execute(f'LISTEN "{channel_name}"')
|
||||
|
||||
# Create a queue to receive notifications
|
||||
notification_queue = asyncio.Queue()
|
||||
|
||||
# Define callback to put notifications in queue
|
||||
def notification_callback(connection, pid, channel, payload):
|
||||
asyncio.create_task(notification_queue.put(payload))
|
||||
|
||||
# Add listener with callback
|
||||
await conn.add_listener(channel_name, notification_callback)
|
||||
|
||||
# Send initial connection event
|
||||
yield f"event: connected\ndata: {json.dumps({'instance_id': str(instance_id)})}\n\n"
|
||||
|
||||
while True:
|
||||
# Check if client disconnected
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
try:
|
||||
# Wait for notification with timeout for heartbeat
|
||||
payload = await asyncio.wait_for(
|
||||
notification_queue.get(), timeout=30.0
|
||||
)
|
||||
|
||||
# Parse the JSON payload
|
||||
data = json.loads(payload)
|
||||
|
||||
# Check event type and send appropriate SSE event
|
||||
event_type = data.get("event_type")
|
||||
if event_type == "status_update":
|
||||
# Send status_update event
|
||||
yield f"event: status_update\ndata: {json.dumps(data)}\n\n"
|
||||
elif event_type == "message_update":
|
||||
# Send message_update event for frontend to handle
|
||||
yield f"event: message_update\ndata: {json.dumps(data)}\n\n"
|
||||
else:
|
||||
# Regular message event (either message_insert or legacy without event_type)
|
||||
yield f"event: message\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield f"event: heartbeat\ndata: {json.dumps({'timestamp': asyncio.get_event_loop().time()})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
# Send error event
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
break
|
||||
|
||||
finally:
|
||||
# Clean up listener and connection
|
||||
await conn.remove_listener(channel_name, notification_callback)
|
||||
await conn.close()
|
||||
|
||||
return StreamingResponse(
|
||||
message_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable Nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/agent-instances/{instance_id}/status",
|
||||
response_model=AgentInstanceResponse,
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from shared.database.models import User
|
||||
from shared.database.session import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..db import submit_answer
|
||||
from ..models import AnswerRequest
|
||||
|
||||
router = APIRouter(prefix="/questions", tags=["questions"])
|
||||
|
||||
|
||||
@router.post("/{question_id}/answer")
|
||||
async def answer_question(
|
||||
question_id: UUID,
|
||||
request: AnswerRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Submit an answer to a pending question for the current user"""
|
||||
result = submit_answer(db, question_id, request.answer, current_user.id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Question not found or already answered"
|
||||
)
|
||||
return {"success": True, "message": "Answer submitted successfully"}
|
||||
@@ -5,8 +5,7 @@ from .queries import (
|
||||
get_all_agent_types_with_instances,
|
||||
get_agent_summary,
|
||||
mark_instance_completed,
|
||||
submit_answer,
|
||||
submit_user_feedback,
|
||||
submit_user_message,
|
||||
)
|
||||
from .user_agent_queries import (
|
||||
create_user_agent,
|
||||
@@ -24,8 +23,7 @@ __all__ = [
|
||||
"get_agent_summary",
|
||||
"get_agent_instance_detail",
|
||||
"mark_instance_completed",
|
||||
"submit_answer",
|
||||
"submit_user_feedback",
|
||||
"submit_user_message",
|
||||
"create_user_agent",
|
||||
"get_user_agents",
|
||||
"update_user_agent",
|
||||
|
||||
@@ -5,12 +5,11 @@ from uuid import UUID
|
||||
from shared.config import settings
|
||||
from shared.database import (
|
||||
AgentInstance,
|
||||
AgentQuestion,
|
||||
AgentStatus,
|
||||
AgentStep,
|
||||
AgentUserFeedback,
|
||||
APIKey,
|
||||
Message,
|
||||
PushToken,
|
||||
SenderType,
|
||||
User,
|
||||
UserAgent,
|
||||
)
|
||||
@@ -19,52 +18,48 @@ from shared.database.subscription_models import BillingEvent, Subscription
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
# Import Pydantic models for type-safe returns
|
||||
from backend.models import (
|
||||
AgentInstanceResponse,
|
||||
AgentInstanceDetail,
|
||||
MessageResponse,
|
||||
UserMessageResponse,
|
||||
AgentTypeOverview,
|
||||
)
|
||||
|
||||
def _format_instance(instance: AgentInstance) -> dict:
|
||||
|
||||
def _format_instance(instance: AgentInstance) -> AgentInstanceResponse:
|
||||
"""Helper function to format an agent instance consistently"""
|
||||
# Get latest step
|
||||
latest_step = None
|
||||
if instance.steps:
|
||||
latest_step = max(instance.steps, key=lambda s: s.created_at).description
|
||||
# Get all messages for this instance
|
||||
messages = instance.messages if hasattr(instance, "messages") else []
|
||||
|
||||
# Get step count
|
||||
step_count = len(instance.steps) if instance.steps else 0
|
||||
# Get latest message and its timestamp
|
||||
latest_message = None
|
||||
latest_message_at = None
|
||||
if messages:
|
||||
last_msg = max(messages, key=lambda m: m.created_at)
|
||||
latest_message = last_msg.content
|
||||
latest_message_at = last_msg.created_at
|
||||
|
||||
# Check for pending questions
|
||||
pending_questions = [q for q in instance.questions if q.is_active]
|
||||
pending_questions_count = len(pending_questions)
|
||||
has_pending = pending_questions_count > 0
|
||||
pending_age = None
|
||||
if has_pending:
|
||||
oldest_pending = min(pending_questions, key=lambda q: q.asked_at)
|
||||
# All database times are stored as UTC but may be naive
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
asked_at = oldest_pending.asked_at
|
||||
if asked_at.tzinfo is None:
|
||||
asked_at = asked_at.replace(tzinfo=timezone.utc)
|
||||
pending_age = int((now_utc - asked_at).total_seconds())
|
||||
# Get total message count (chat length)
|
||||
chat_length = len(messages)
|
||||
|
||||
return {
|
||||
"id": str(instance.id),
|
||||
"agent_type_id": str(instance.user_agent_id) if instance.user_agent_id else "",
|
||||
"agent_type_name": instance.user_agent.name
|
||||
if instance.user_agent
|
||||
else "Unknown",
|
||||
"status": instance.status,
|
||||
"started_at": instance.started_at,
|
||||
"ended_at": instance.ended_at,
|
||||
"latest_step": latest_step,
|
||||
"has_pending_question": has_pending,
|
||||
"pending_question_age": pending_age,
|
||||
"pending_questions_count": pending_questions_count,
|
||||
"step_count": step_count,
|
||||
"last_signal_at": instance.steps[-1].created_at
|
||||
if instance.steps
|
||||
else instance.started_at,
|
||||
}
|
||||
return AgentInstanceResponse(
|
||||
id=str(instance.id),
|
||||
agent_type_id=str(instance.user_agent_id) if instance.user_agent_id else "",
|
||||
agent_type_name=instance.user_agent.name if instance.user_agent else "Unknown",
|
||||
status=instance.status,
|
||||
started_at=instance.started_at,
|
||||
ended_at=instance.ended_at,
|
||||
latest_message=latest_message,
|
||||
latest_message_at=latest_message_at,
|
||||
chat_length=chat_length,
|
||||
)
|
||||
|
||||
|
||||
def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]:
|
||||
def get_all_agent_types_with_instances(
|
||||
db: Session, user_id: UUID
|
||||
) -> list[AgentTypeOverview]:
|
||||
"""Get all user agents with their instances for a specific user"""
|
||||
|
||||
# Get all user agents for this user
|
||||
@@ -79,25 +74,29 @@ def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]
|
||||
AgentInstance.user_agent_id == user_agent.id,
|
||||
)
|
||||
.options(
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.messages),
|
||||
joinedload(AgentInstance.user_agent),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Sort instances: pending questions first, then by most recent activity
|
||||
# Sort instances: AWAITING_INPUT instances first, then by most recent activity
|
||||
def sort_key(instance):
|
||||
pending_questions = [q for q in instance.questions if q.is_active]
|
||||
if pending_questions:
|
||||
oldest_question = min(pending_questions, key=lambda q: q.asked_at)
|
||||
return (0, oldest_question.asked_at)
|
||||
messages = instance.messages if hasattr(instance, "messages") else []
|
||||
|
||||
# If instance is awaiting input, prioritize it
|
||||
if instance.status == AgentStatus.AWAITING_INPUT:
|
||||
# Sort by when the question was asked (last message time)
|
||||
if messages:
|
||||
last_msg_time = max(messages, key=lambda m: m.created_at).created_at
|
||||
return (0, last_msg_time)
|
||||
else:
|
||||
return (0, instance.started_at)
|
||||
|
||||
# Otherwise sort by last activity
|
||||
last_activity = instance.started_at
|
||||
if instance.steps:
|
||||
last_activity = max(
|
||||
instance.steps, key=lambda s: s.created_at
|
||||
).created_at
|
||||
if messages:
|
||||
last_activity = max(messages, key=lambda m: m.created_at).created_at
|
||||
return (1, -last_activity.timestamp())
|
||||
|
||||
sorted_instances = sorted(instances, key=sort_key)
|
||||
@@ -108,16 +107,16 @@ def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]
|
||||
]
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": str(user_agent.id),
|
||||
"name": user_agent.name,
|
||||
"created_at": user_agent.created_at,
|
||||
"recent_instances": formatted_instances,
|
||||
"total_instances": len(instances),
|
||||
"active_instances": sum(
|
||||
AgentTypeOverview(
|
||||
id=str(user_agent.id),
|
||||
name=user_agent.name,
|
||||
created_at=user_agent.created_at,
|
||||
recent_instances=formatted_instances,
|
||||
total_instances=len(instances),
|
||||
active_instances=sum(
|
||||
1 for i in instances if i.status == AgentStatus.ACTIVE
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -125,15 +124,14 @@ def get_all_agent_types_with_instances(db: Session, user_id: UUID) -> list[dict]
|
||||
|
||||
def get_all_agent_instances(
|
||||
db: Session, user_id: UUID, limit: int | None = None
|
||||
) -> list[dict]:
|
||||
) -> list[AgentInstanceResponse]:
|
||||
"""Get all agent instances for a specific user, sorted by most recent activity"""
|
||||
|
||||
query = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.user_id == user_id)
|
||||
.options(
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.messages),
|
||||
joinedload(AgentInstance.user_agent),
|
||||
)
|
||||
.order_by(desc(AgentInstance.started_at))
|
||||
@@ -216,7 +214,7 @@ def get_agent_summary(db: Session, user_id: UUID) -> dict:
|
||||
|
||||
def get_agent_type_instances(
|
||||
db: Session, agent_type_id: UUID, user_id: UUID
|
||||
) -> list[dict] | None:
|
||||
) -> list[AgentInstanceResponse] | None:
|
||||
"""Get all instances for a specific user agent"""
|
||||
|
||||
user_agent = (
|
||||
@@ -233,8 +231,7 @@ def get_agent_type_instances(
|
||||
AgentInstance.user_agent_id == agent_type_id,
|
||||
)
|
||||
.options(
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.messages),
|
||||
joinedload(AgentInstance.user_agent),
|
||||
)
|
||||
.order_by(desc(AgentInstance.started_at))
|
||||
@@ -247,7 +244,7 @@ def get_agent_type_instances(
|
||||
|
||||
def get_agent_instance_detail(
|
||||
db: Session, instance_id: UUID, user_id: UUID
|
||||
) -> dict | None:
|
||||
) -> AgentInstanceDetail | None:
|
||||
"""Get detailed information about a specific agent instance for a specific user"""
|
||||
|
||||
instance = (
|
||||
@@ -255,9 +252,7 @@ def get_agent_instance_detail(
|
||||
.filter(AgentInstance.id == instance_id, AgentInstance.user_id == user_id)
|
||||
.options(
|
||||
joinedload(AgentInstance.user_agent),
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.user_feedback),
|
||||
joinedload(AgentInstance.messages),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -265,122 +260,45 @@ def get_agent_instance_detail(
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Sort steps by step number
|
||||
sorted_steps = sorted(instance.steps, key=lambda s: s.step_number)
|
||||
|
||||
# Sort questions by asked_at
|
||||
sorted_questions = sorted(instance.questions, key=lambda q: q.asked_at)
|
||||
|
||||
# Sort user feedback by created_at
|
||||
sorted_feedback = sorted(instance.user_feedback, key=lambda f: f.created_at)
|
||||
|
||||
return {
|
||||
"id": str(instance.id),
|
||||
"agent_type_id": str(instance.user_agent_id) if instance.user_agent_id else "",
|
||||
"agent_type": {
|
||||
"id": str(instance.user_agent.id) if instance.user_agent else "",
|
||||
"name": instance.user_agent.name if instance.user_agent else "Unknown",
|
||||
"created_at": instance.user_agent.created_at
|
||||
if instance.user_agent
|
||||
else datetime.now(timezone.utc),
|
||||
"recent_instances": [],
|
||||
"total_instances": 0,
|
||||
"active_instances": 0,
|
||||
},
|
||||
"status": instance.status,
|
||||
"started_at": instance.started_at,
|
||||
"ended_at": instance.ended_at,
|
||||
"git_diff": instance.git_diff,
|
||||
"steps": [
|
||||
{
|
||||
"id": str(step.id),
|
||||
"step_number": step.step_number,
|
||||
"description": step.description,
|
||||
"created_at": step.created_at,
|
||||
}
|
||||
for step in sorted_steps
|
||||
],
|
||||
"questions": [
|
||||
{
|
||||
"id": str(question.id),
|
||||
"question_text": question.question_text,
|
||||
"answer_text": question.answer_text,
|
||||
"asked_at": question.asked_at,
|
||||
"answered_at": question.answered_at,
|
||||
"is_active": question.is_active,
|
||||
}
|
||||
for question in sorted_questions
|
||||
],
|
||||
"user_feedback": [
|
||||
{
|
||||
"id": str(feedback.id),
|
||||
"feedback_text": feedback.feedback_text,
|
||||
"created_at": feedback.created_at,
|
||||
"retrieved_at": feedback.retrieved_at,
|
||||
}
|
||||
for feedback in sorted_feedback
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def submit_answer(
|
||||
db: Session, question_id: UUID, answer: str, user_id: UUID
|
||||
) -> dict | None:
|
||||
"""Submit an answer to a question for a specific user"""
|
||||
|
||||
question = (
|
||||
db.query(AgentQuestion)
|
||||
.filter(AgentQuestion.id == question_id, AgentQuestion.is_active)
|
||||
.join(AgentInstance)
|
||||
.filter(AgentInstance.user_id == user_id)
|
||||
.first()
|
||||
# Get all messages and sort by created_at
|
||||
messages = (
|
||||
sorted(instance.messages, key=lambda m: m.created_at)
|
||||
if hasattr(instance, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
if not question:
|
||||
return None
|
||||
|
||||
question.answer_text = answer
|
||||
question.answered_at = datetime.now(timezone.utc)
|
||||
question.is_active = False
|
||||
question.answered_by_user_id = user_id
|
||||
|
||||
# Update agent instance status back to ACTIVE if it was AWAITING_INPUT
|
||||
instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == question.agent_instance_id)
|
||||
.first()
|
||||
)
|
||||
if instance and instance.status == AgentStatus.AWAITING_INPUT:
|
||||
# Check if there are other active questions for this instance
|
||||
other_active_questions = (
|
||||
db.query(AgentQuestion)
|
||||
.filter(
|
||||
AgentQuestion.agent_instance_id == instance.id,
|
||||
AgentQuestion.id != question_id,
|
||||
AgentQuestion.is_active,
|
||||
# Format messages for chat display
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
formatted_messages.append(
|
||||
MessageResponse(
|
||||
id=str(msg.id),
|
||||
content=msg.content,
|
||||
sender_type=msg.sender_type.value, # "agent" or "user"
|
||||
created_at=msg.created_at,
|
||||
requires_user_input=msg.requires_user_input,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Only change status back to ACTIVE if no other questions are pending
|
||||
if other_active_questions == 0:
|
||||
instance.status = AgentStatus.ACTIVE
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"id": str(question.id),
|
||||
"question_text": question.question_text,
|
||||
"answer_text": question.answer_text,
|
||||
"asked_at": question.asked_at,
|
||||
"answered_at": question.answered_at,
|
||||
"is_active": question.is_active,
|
||||
}
|
||||
return AgentInstanceDetail(
|
||||
id=str(instance.id),
|
||||
agent_type_id=str(instance.user_agent_id) if instance.user_agent_id else "",
|
||||
agent_type_name=instance.user_agent.name if instance.user_agent else "Unknown",
|
||||
status=instance.status,
|
||||
started_at=instance.started_at,
|
||||
ended_at=instance.ended_at,
|
||||
git_diff=instance.git_diff,
|
||||
messages=formatted_messages,
|
||||
last_read_message_id=str(instance.last_read_message_id)
|
||||
if instance.last_read_message_id
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def submit_user_feedback(
|
||||
db: Session, instance_id: UUID, feedback_text: str, user_id: UUID
|
||||
) -> dict | None:
|
||||
"""Submit user feedback for an agent instance for a specific user"""
|
||||
def submit_user_message(
|
||||
db: Session, instance_id: UUID, content: str, user_id: UUID
|
||||
) -> UserMessageResponse | None:
|
||||
"""Submit a user message to an agent instance"""
|
||||
|
||||
# Check if instance exists and belongs to user
|
||||
instance = (
|
||||
@@ -391,28 +309,33 @@ def submit_user_feedback(
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Create new feedback
|
||||
feedback = AgentUserFeedback(
|
||||
# Create new user message
|
||||
user_message = Message(
|
||||
agent_instance_id=instance_id,
|
||||
feedback_text=feedback_text,
|
||||
created_by_user_id=user_id,
|
||||
sender_type=SenderType.USER,
|
||||
content=content,
|
||||
requires_user_input=False,
|
||||
)
|
||||
db.add(user_message)
|
||||
|
||||
if instance.status != AgentStatus.COMPLETED:
|
||||
instance.status = AgentStatus.ACTIVE
|
||||
|
||||
db.add(feedback)
|
||||
db.commit()
|
||||
db.refresh(feedback)
|
||||
db.refresh(user_message)
|
||||
|
||||
return {
|
||||
"id": str(feedback.id),
|
||||
"feedback_text": feedback.feedback_text,
|
||||
"created_at": feedback.created_at,
|
||||
"retrieved_at": feedback.retrieved_at,
|
||||
}
|
||||
return UserMessageResponse(
|
||||
id=str(user_message.id),
|
||||
content=user_message.content,
|
||||
sender_type=user_message.sender_type.value,
|
||||
created_at=user_message.created_at,
|
||||
requires_user_input=user_message.requires_user_input,
|
||||
)
|
||||
|
||||
|
||||
def mark_instance_completed(
|
||||
db: Session, instance_id: UUID, user_id: UUID
|
||||
) -> dict | None:
|
||||
) -> AgentInstanceResponse | None:
|
||||
"""Mark an agent instance as completed for a specific user"""
|
||||
|
||||
# Check if instance exists and belongs to user
|
||||
@@ -428,10 +351,7 @@ def mark_instance_completed(
|
||||
instance.status = AgentStatus.COMPLETED
|
||||
instance.ended_at = datetime.now(timezone.utc)
|
||||
|
||||
# Deactivate any pending questions
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
|
||||
).update({"is_active": False})
|
||||
# No need to deactivate questions - they're handled by checking for user responses
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -441,8 +361,7 @@ def mark_instance_completed(
|
||||
.filter(AgentInstance.id == instance_id)
|
||||
.options(
|
||||
joinedload(AgentInstance.user_agent),
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.messages),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -483,16 +402,6 @@ def delete_user_account(db: Session, user_id: UUID) -> None:
|
||||
)
|
||||
|
||||
# Delete in order of foreign key dependencies
|
||||
# 1. Delete AgentUserFeedback (depends on AgentInstance and User)
|
||||
db.query(AgentUserFeedback).filter(
|
||||
AgentUserFeedback.created_by_user_id == user_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# 2. Delete AgentQuestions (depends on AgentInstance and User)
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.answered_by_user_id == user_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Get all agent instances for this user to delete their related data
|
||||
instance_ids = [
|
||||
instance.id
|
||||
@@ -502,19 +411,9 @@ def delete_user_account(db: Session, user_id: UUID) -> None:
|
||||
]
|
||||
|
||||
if instance_ids:
|
||||
# Delete questions for user's instances
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id.in_(instance_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Delete steps for user's instances
|
||||
db.query(AgentStep).filter(
|
||||
AgentStep.agent_instance_id.in_(instance_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Delete feedback for user's instances
|
||||
db.query(AgentUserFeedback).filter(
|
||||
AgentUserFeedback.agent_instance_id.in_(instance_ids)
|
||||
# Delete messages for user's instances
|
||||
db.query(Message).filter(
|
||||
Message.agent_instance_id.in_(instance_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# 3. Delete AgentInstances (depends on UserAgent and User)
|
||||
|
||||
@@ -12,9 +12,7 @@ from shared.database import (
|
||||
AgentInstance,
|
||||
AgentStatus,
|
||||
APIKey,
|
||||
AgentStep,
|
||||
AgentQuestion,
|
||||
AgentUserFeedback,
|
||||
Message,
|
||||
)
|
||||
from shared.database.billing_operations import check_agent_limit
|
||||
from sqlalchemy import and_, func
|
||||
@@ -157,6 +155,7 @@ async def trigger_webhook_agent(
|
||||
payload = {
|
||||
"agent_instance_id": str(agent_instance_id),
|
||||
"prompt": prompt,
|
||||
"agent_type": user_agent.name,
|
||||
}
|
||||
|
||||
if name is not None:
|
||||
@@ -281,9 +280,7 @@ def get_user_agent_instances(db: Session, agent_id: UUID, user_id: UUID) -> list
|
||||
instances = (
|
||||
db.query(AgentInstance)
|
||||
.options(
|
||||
joinedload(AgentInstance.questions),
|
||||
joinedload(AgentInstance.steps),
|
||||
joinedload(AgentInstance.user_feedback),
|
||||
joinedload(AgentInstance.messages),
|
||||
joinedload(AgentInstance.user_agent),
|
||||
)
|
||||
.filter(AgentInstance.user_agent_id == agent_id)
|
||||
@@ -315,18 +312,8 @@ def delete_user_agent(db: Session, agent_id: UUID, user_id: UUID) -> bool:
|
||||
|
||||
# For each agent instance, delete all related data
|
||||
for instance in agent_instances:
|
||||
# Delete agent steps
|
||||
db.query(AgentStep).filter(AgentStep.agent_instance_id == instance.id).delete()
|
||||
|
||||
# Delete agent questions
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id == instance.id
|
||||
).delete()
|
||||
|
||||
# Delete user feedback
|
||||
db.query(AgentUserFeedback).filter(
|
||||
AgentUserFeedback.agent_instance_id == instance.id
|
||||
).delete()
|
||||
# Delete messages
|
||||
db.query(Message).filter(Message.agent_instance_id == instance.id).delete()
|
||||
|
||||
# Delete all agent instances
|
||||
db.query(AgentInstance).filter(AgentInstance.user_agent_id == agent_id).delete()
|
||||
|
||||
@@ -9,7 +9,6 @@ import sentry_sdk
|
||||
from shared.config import settings
|
||||
from .api import (
|
||||
agents,
|
||||
questions,
|
||||
user_agents,
|
||||
push_notifications,
|
||||
billing,
|
||||
@@ -73,7 +72,6 @@ app.add_middleware(
|
||||
# Include routers with versioned API prefix
|
||||
app.include_router(auth_routes.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(agents.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(questions.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(user_agents.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(push_notifications.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(user_settings.router, prefix=settings.api_v1_prefix)
|
||||
|
||||
@@ -13,31 +13,24 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
# ============================================================================
|
||||
# Question Models
|
||||
# Message Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# This is when the Agent prompts the user for an answer and this format
|
||||
# is what the user responds with.
|
||||
class AnswerRequest(BaseModel):
|
||||
answer: str = Field(..., description="User's answer to the question")
|
||||
# Unified message models
|
||||
class UserMessageRequest(BaseModel):
|
||||
content: str = Field(..., description="Message content from the user")
|
||||
|
||||
|
||||
# User feedback that agents can retrieve during their operations
|
||||
class UserFeedbackRequest(BaseModel):
|
||||
feedback: str = Field(..., description="User's feedback or additional information")
|
||||
|
||||
|
||||
class UserFeedbackResponse(BaseModel):
|
||||
class UserMessageResponse(BaseModel):
|
||||
id: str
|
||||
feedback_text: str
|
||||
content: str
|
||||
sender_type: str
|
||||
created_at: datetime
|
||||
retrieved_at: datetime | None
|
||||
requires_user_input: bool
|
||||
|
||||
@field_serializer("created_at", "retrieved_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
if dt is None:
|
||||
return None
|
||||
@field_serializer("created_at")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -75,20 +68,6 @@ class UserNotificationSettingsResponse(BaseModel):
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Represents individual steps/actions taken by an agent
|
||||
class AgentStepResponse(BaseModel):
|
||||
id: str
|
||||
step_number: int
|
||||
description: str
|
||||
created_at: datetime
|
||||
|
||||
@field_serializer("created_at")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Summary view of an agent instance (a single agent session/run)
|
||||
class AgentInstanceResponse(BaseModel):
|
||||
id: str
|
||||
@@ -97,13 +76,11 @@ class AgentInstanceResponse(BaseModel):
|
||||
status: AgentStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime | None
|
||||
latest_step: str | None = None
|
||||
has_pending_question: bool = False
|
||||
pending_question_age: int | None = None # Age in seconds
|
||||
pending_questions_count: int = 0
|
||||
step_count: int = 0
|
||||
latest_message: str | None = None
|
||||
latest_message_at: datetime | None = None # Timestamp of the latest message
|
||||
chat_length: int = 0 # Total message count
|
||||
|
||||
@field_serializer("started_at", "ended_at")
|
||||
@field_serializer("started_at", "ended_at", "latest_message_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
if dt is None:
|
||||
return None
|
||||
@@ -134,37 +111,33 @@ class AgentTypeOverview(BaseModel):
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Detailed information about a question asked by an agent, including answer status
|
||||
class QuestionDetail(BaseModel):
|
||||
# Message model for the chat interface
|
||||
class MessageResponse(BaseModel):
|
||||
id: str
|
||||
question_text: str
|
||||
answer_text: str | None
|
||||
asked_at: datetime
|
||||
answered_at: datetime | None
|
||||
is_active: bool
|
||||
content: str
|
||||
sender_type: str
|
||||
created_at: datetime
|
||||
requires_user_input: bool
|
||||
|
||||
@field_serializer("asked_at", "answered_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
if dt is None:
|
||||
return None
|
||||
@field_serializer("created_at")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Complete detailed view of a specific agent instance
|
||||
# with full step and question history
|
||||
# with full message history
|
||||
class AgentInstanceDetail(BaseModel):
|
||||
id: str
|
||||
agent_type_id: str
|
||||
agent_type: AgentTypeOverview
|
||||
agent_type_name: str
|
||||
status: AgentStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime | None
|
||||
git_diff: str | None = None
|
||||
steps: list[AgentStepResponse] = []
|
||||
questions: list[QuestionDetail] = []
|
||||
user_feedback: list[UserFeedbackResponse] = []
|
||||
messages: list[MessageResponse] = []
|
||||
last_read_message_id: str | None = None
|
||||
|
||||
@field_serializer("started_at", "ended_at")
|
||||
def serialize_datetime(self, dt: datetime | None, _info):
|
||||
|
||||
@@ -7,4 +7,5 @@ cryptography==42.0.5
|
||||
email-validator==2.1.0
|
||||
exponent-server-sdk>=2.1.0
|
||||
stripe==10.8.0
|
||||
asyncpg==0.29.0
|
||||
-r ../shared/requirements.txt
|
||||
@@ -7,9 +7,6 @@ from shared.database.models import (
|
||||
User,
|
||||
UserAgent,
|
||||
AgentInstance,
|
||||
AgentStep,
|
||||
AgentQuestion,
|
||||
AgentUserFeedback,
|
||||
)
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
@@ -74,20 +71,23 @@ class TestAgentEndpoints:
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.ACTIVE,
|
||||
status=AgentStatus.AWAITING_INPUT,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(instance)
|
||||
|
||||
# Create a pending question with timezone-aware datetime
|
||||
question = AgentQuestion(
|
||||
# Create a message with requires_user_input=True to simulate a question
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=instance.id,
|
||||
question_text="Test question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Test question?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.add(question_msg)
|
||||
test_db.commit()
|
||||
|
||||
# This should not raise a timezone error
|
||||
@@ -99,12 +99,10 @@ class TestAgentEndpoints:
|
||||
agent_type = data[0]
|
||||
assert len(agent_type["recent_instances"]) == 1
|
||||
|
||||
# Check that pending question info is populated
|
||||
# Check that the instance has AWAITING_INPUT status
|
||||
instance_data = agent_type["recent_instances"][0]
|
||||
assert instance_data["has_pending_question"] is True
|
||||
assert instance_data["pending_questions_count"] == 1
|
||||
assert instance_data["pending_question_age"] is not None
|
||||
assert instance_data["pending_question_age"] >= 0
|
||||
assert instance_data["status"] == "AWAITING_INPUT"
|
||||
assert instance_data["latest_message"] == "Test question?"
|
||||
|
||||
def test_list_all_agent_instances(self, authenticated_client, test_agent_instance):
|
||||
"""Test listing all agent instances."""
|
||||
@@ -115,7 +113,7 @@ class TestAgentEndpoints:
|
||||
assert len(data) == 1
|
||||
instance = data[0]
|
||||
assert instance["id"] == str(test_agent_instance.id)
|
||||
assert instance["status"] == "active"
|
||||
assert instance["status"] == "ACTIVE"
|
||||
|
||||
def test_list_agent_instances_with_limit(
|
||||
self, authenticated_client, test_db, test_user, test_user_agent
|
||||
@@ -158,15 +156,19 @@ class TestAgentEndpoints:
|
||||
)
|
||||
test_db.add(completed_instance)
|
||||
|
||||
# Add a question to the active instance
|
||||
question = AgentQuestion(
|
||||
# Add a message with requires_user_input to the active instance
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Test question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Test question?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.add(question_msg)
|
||||
test_agent_instance.status = AgentStatus.AWAITING_INPUT
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get("/api/v1/agent-summary")
|
||||
@@ -174,7 +176,7 @@ class TestAgentEndpoints:
|
||||
data = response.json()
|
||||
|
||||
assert data["total_instances"] == 2
|
||||
assert data["active_instances"] == 1
|
||||
assert data["active_instances"] == 0 # AWAITING_INPUT doesn't count as active
|
||||
assert data["completed_instances"] == 1
|
||||
assert "agent_types" in data
|
||||
assert len(data["agent_types"]) == 1
|
||||
@@ -203,39 +205,35 @@ class TestAgentEndpoints:
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test getting detailed agent instance information."""
|
||||
# Add steps and questions
|
||||
step1 = AgentStep(
|
||||
# Add messages to simulate conversation
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
msg1 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=1,
|
||||
description="First step",
|
||||
sender_type=SenderType.AGENT,
|
||||
content="First step completed",
|
||||
requires_user_input=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
step2 = AgentStep(
|
||||
msg2 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=2,
|
||||
description="Second step",
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Need input?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
msg3 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.USER,
|
||||
content="Great work!",
|
||||
requires_user_input=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Need input?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
feedback = AgentUserFeedback(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
created_by_user_id=test_agent_instance.user_id,
|
||||
feedback_text="Great work!",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
test_db.add_all([step1, step2, question, feedback])
|
||||
test_db.add_all([msg1, msg2, msg3])
|
||||
test_db.commit()
|
||||
|
||||
response = authenticated_client.get(
|
||||
@@ -245,12 +243,14 @@ class TestAgentEndpoints:
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == str(test_agent_instance.id)
|
||||
assert len(data["steps"]) == 2
|
||||
assert data["steps"][0]["description"] == "First step"
|
||||
assert len(data["questions"]) == 1
|
||||
assert data["questions"][0]["question_text"] == "Need input?"
|
||||
assert len(data["user_feedback"]) == 1
|
||||
assert data["user_feedback"][0]["feedback_text"] == "Great work!"
|
||||
assert "messages" in data
|
||||
assert len(data["messages"]) == 3
|
||||
assert data["messages"][0]["content"] == "First step completed"
|
||||
assert data["messages"][0]["sender_type"] == "AGENT"
|
||||
assert data["messages"][1]["content"] == "Need input?"
|
||||
assert data["messages"][1]["requires_user_input"] is True
|
||||
assert data["messages"][2]["content"] == "Great work!"
|
||||
assert data["messages"][2]["sender_type"] == "USER"
|
||||
|
||||
def test_get_instance_detail_not_found(self, authenticated_client):
|
||||
"""Test getting non-existent instance detail."""
|
||||
@@ -265,71 +265,140 @@ class TestAgentEndpoints:
|
||||
"""Test adding user feedback to an agent instance."""
|
||||
feedback_text = "Please use TypeScript for this component"
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/feedback",
|
||||
json={"feedback": feedback_text},
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": feedback_text},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["feedback_text"] == feedback_text
|
||||
assert data["content"] == feedback_text
|
||||
assert data["sender_type"] == "USER"
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert data["requires_user_input"] is False
|
||||
|
||||
# Verify in database
|
||||
feedback = (
|
||||
test_db.query(AgentUserFeedback)
|
||||
.filter_by(agent_instance_id=test_agent_instance.id)
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
message = (
|
||||
test_db.query(Message)
|
||||
.filter_by(
|
||||
agent_instance_id=test_agent_instance.id, sender_type=SenderType.USER
|
||||
)
|
||||
.first()
|
||||
)
|
||||
assert feedback is not None
|
||||
assert feedback.feedback_text == feedback_text
|
||||
assert feedback.retrieved_at is None
|
||||
assert message is not None
|
||||
assert message.content == feedback_text
|
||||
assert message.sender_type == SenderType.USER
|
||||
|
||||
def test_add_feedback_to_nonexistent_instance(self, authenticated_client):
|
||||
"""Test adding feedback to non-existent instance."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{fake_id}/feedback",
|
||||
json={"feedback": "Test feedback"},
|
||||
f"/api/v1/agent-instances/{fake_id}/messages",
|
||||
json={"content": "Test feedback"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Agent instance not found"
|
||||
|
||||
def test_update_agent_status_completed(
|
||||
def test_instance_status_changes_with_messages(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test marking an agent instance as completed."""
|
||||
response = authenticated_client.put(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/status",
|
||||
json={"status": "completed"},
|
||||
)
|
||||
"""Test that instance status changes based on message flow."""
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
# Initially instance should be ACTIVE
|
||||
assert test_agent_instance.status == AgentStatus.ACTIVE
|
||||
|
||||
# Agent sends a message requiring user input
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Should I use TypeScript or JavaScript?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question_msg)
|
||||
test_db.commit()
|
||||
|
||||
# Status should change to AWAITING_INPUT (this would be done by the agent)
|
||||
test_agent_instance.status = AgentStatus.AWAITING_INPUT
|
||||
test_db.commit()
|
||||
|
||||
# User responds
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "Use TypeScript for better type safety"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Check that status changed back to ACTIVE
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.ACTIVE
|
||||
|
||||
def test_message_creates_status_update_notification(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test that sending messages triggers appropriate notifications."""
|
||||
# This would test the notification system if implemented
|
||||
# For now, just verify message creation works
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "Test message for notifications"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify message was created
|
||||
from shared.database import Message
|
||||
|
||||
messages = (
|
||||
test_db.query(Message)
|
||||
.filter_by(agent_instance_id=test_agent_instance.id)
|
||||
.all()
|
||||
)
|
||||
assert len(messages) >= 1
|
||||
assert any(msg.content == "Test message for notifications" for msg in messages)
|
||||
|
||||
def test_agent_instance_latest_message_tracking(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test that latest_message is properly tracked in instance listing."""
|
||||
from shared.database import Message, SenderType
|
||||
import time
|
||||
|
||||
# Create messages with different timestamps
|
||||
msg1 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="First message",
|
||||
requires_user_input=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(msg1)
|
||||
test_db.commit()
|
||||
|
||||
# Small delay to ensure different timestamps
|
||||
time.sleep(0.1)
|
||||
|
||||
msg2 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.USER,
|
||||
content="Latest message",
|
||||
requires_user_input=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(msg2)
|
||||
test_db.commit()
|
||||
|
||||
# Get instance list
|
||||
response = authenticated_client.get("/api/v1/agent-instances")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "completed"
|
||||
assert data["ended_at"] is not None
|
||||
|
||||
# Verify in database
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.COMPLETED
|
||||
assert test_agent_instance.ended_at is not None
|
||||
|
||||
def test_update_agent_status_unsupported(
|
||||
self, authenticated_client, test_agent_instance
|
||||
):
|
||||
"""Test unsupported status update."""
|
||||
response = authenticated_client.put(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/status",
|
||||
json={"status": "paused"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Status update not supported"
|
||||
|
||||
def test_update_status_nonexistent_instance(self, authenticated_client):
|
||||
"""Test updating status of non-existent instance."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.put(
|
||||
f"/api/v1/agent-instances/{fake_id}/status", json={"status": "completed"}
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Agent instance not found"
|
||||
assert len(data) == 1
|
||||
instance = data[0]
|
||||
assert instance["latest_message"] == "Latest message"
|
||||
assert instance["chat_length"] == 2
|
||||
|
||||
298
backend/tests/test_message_system.py
Normal file
298
backend/tests/test_message_system.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Comprehensive tests for the unified message system."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
import time
|
||||
|
||||
from shared.database.models import Message, AgentInstance
|
||||
from shared.database.enums import AgentStatus, SenderType
|
||||
|
||||
|
||||
class TestMessageSystem:
|
||||
"""Test the core message system functionality."""
|
||||
|
||||
def test_message_flow_creates_conversation(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test that messages create a proper conversation flow."""
|
||||
# Agent sends initial message
|
||||
agent_msg1 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="I'm starting to work on your request",
|
||||
requires_user_input=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(agent_msg1)
|
||||
test_db.commit()
|
||||
|
||||
# Agent asks a question
|
||||
agent_question = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Which framework would you prefer: React or Vue?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(agent_question)
|
||||
test_agent_instance.status = AgentStatus.AWAITING_INPUT
|
||||
test_db.commit()
|
||||
|
||||
# User responds
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "Let's use React with TypeScript"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
user_response = response.json()
|
||||
assert user_response["sender_type"] == "USER"
|
||||
assert user_response["requires_user_input"] is False
|
||||
|
||||
# Verify conversation in detail endpoint
|
||||
detail_response = authenticated_client.get(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}"
|
||||
)
|
||||
assert detail_response.status_code == 200
|
||||
detail = detail_response.json()
|
||||
|
||||
assert len(detail["messages"]) == 3
|
||||
assert (
|
||||
detail["messages"][0]["content"] == "I'm starting to work on your request"
|
||||
)
|
||||
assert (
|
||||
detail["messages"][1]["content"]
|
||||
== "Which framework would you prefer: React or Vue?"
|
||||
)
|
||||
assert detail["messages"][1]["requires_user_input"] is True
|
||||
assert detail["messages"][2]["content"] == "Let's use React with TypeScript"
|
||||
assert detail["messages"][2]["sender_type"] == "USER"
|
||||
|
||||
def test_multiple_user_messages_allowed(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test that users can send multiple messages in a row."""
|
||||
# Send first user message
|
||||
response1 = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "First instruction"},
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
|
||||
# Send second user message immediately
|
||||
response2 = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "Additional instruction"},
|
||||
)
|
||||
assert response2.status_code == 200
|
||||
|
||||
# Send third user message
|
||||
response3 = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "One more thing..."},
|
||||
)
|
||||
assert response3.status_code == 200
|
||||
|
||||
# Verify all messages were created
|
||||
messages = (
|
||||
test_db.query(Message)
|
||||
.filter_by(
|
||||
agent_instance_id=test_agent_instance.id, sender_type=SenderType.USER
|
||||
)
|
||||
.all()
|
||||
)
|
||||
assert len(messages) == 3
|
||||
assert [msg.content for msg in messages] == [
|
||||
"First instruction",
|
||||
"Additional instruction",
|
||||
"One more thing...",
|
||||
]
|
||||
|
||||
def test_message_ordering_preserved(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test that message ordering is preserved correctly."""
|
||||
# Create messages with specific order
|
||||
messages_data = [
|
||||
(SenderType.AGENT, "Starting task", False),
|
||||
(SenderType.AGENT, "Found an issue", False),
|
||||
(SenderType.AGENT, "How should I proceed?", True),
|
||||
(SenderType.USER, "Fix it with option A", False),
|
||||
(SenderType.AGENT, "Implementing option A", False),
|
||||
]
|
||||
|
||||
created_messages = []
|
||||
for i, (sender_type, content, requires_input) in enumerate(messages_data):
|
||||
time.sleep(0.01) # Ensure different timestamps
|
||||
|
||||
if sender_type == SenderType.AGENT:
|
||||
msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=sender_type,
|
||||
content=content,
|
||||
requires_user_input=requires_input,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(msg)
|
||||
test_db.commit()
|
||||
created_messages.append(msg)
|
||||
else:
|
||||
# User message via API
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": content},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get instance detail
|
||||
response = authenticated_client.get(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify order is preserved
|
||||
assert len(data["messages"]) == 5
|
||||
for i, (sender_type, content, requires_input) in enumerate(messages_data):
|
||||
msg = data["messages"][i]
|
||||
assert msg["content"] == content
|
||||
assert msg["sender_type"] == sender_type.value
|
||||
if sender_type == SenderType.AGENT:
|
||||
assert msg["requires_user_input"] == requires_input
|
||||
|
||||
def test_agent_instance_summary_includes_message_stats(
|
||||
self, authenticated_client, test_db, test_user, test_user_agent
|
||||
):
|
||||
"""Test that agent instance summaries include message statistics."""
|
||||
# Create multiple instances with different message counts
|
||||
instances = []
|
||||
for i in range(3):
|
||||
instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.ACTIVE,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(instance)
|
||||
instances.append(instance)
|
||||
|
||||
# Add different number of messages to each instance
|
||||
for j in range(i + 1):
|
||||
msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=instance.id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content=f"Message {j} for instance {i}",
|
||||
requires_user_input=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(msg)
|
||||
|
||||
test_db.commit()
|
||||
|
||||
# Get agent types overview
|
||||
response = authenticated_client.get("/api/v1/agent-types")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert len(data) == 1
|
||||
agent_type = data[0]
|
||||
assert len(agent_type["recent_instances"]) >= 3
|
||||
|
||||
# Verify chat_length is included
|
||||
for instance in agent_type["recent_instances"]:
|
||||
assert "chat_length" in instance
|
||||
assert instance["chat_length"] >= 0
|
||||
|
||||
def test_message_with_empty_content_allowed(
|
||||
self, authenticated_client, test_agent_instance
|
||||
):
|
||||
"""Test that empty messages are allowed."""
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": ""},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["content"] == ""
|
||||
assert data["sender_type"] == "USER"
|
||||
|
||||
def test_message_status_transitions(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test that message flow properly triggers status transitions."""
|
||||
# Start with ACTIVE
|
||||
assert test_agent_instance.status == AgentStatus.ACTIVE
|
||||
|
||||
# Agent asks question -> AWAITING_INPUT
|
||||
question = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Should I continue?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_agent_instance.status = AgentStatus.AWAITING_INPUT
|
||||
test_db.commit()
|
||||
|
||||
# User responds -> back to ACTIVE
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "Yes, continue"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.ACTIVE
|
||||
|
||||
def test_awaiting_input_instances_prioritized(
|
||||
self, authenticated_client, test_db, test_user, test_user_agent
|
||||
):
|
||||
"""Test that instances awaiting input are prioritized in listings."""
|
||||
# Create active instance
|
||||
active_instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.ACTIVE,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(active_instance)
|
||||
|
||||
# Create awaiting input instance (older)
|
||||
awaiting_instance = AgentInstance(
|
||||
id=uuid4(),
|
||||
user_agent_id=test_user_agent.id,
|
||||
user_id=test_user.id,
|
||||
status=AgentStatus.AWAITING_INPUT,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(awaiting_instance)
|
||||
|
||||
# Add question message to awaiting instance
|
||||
question = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=awaiting_instance.id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Need your input",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Get agent types
|
||||
response = authenticated_client.get("/api/v1/agent-types")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Awaiting input instance should be first
|
||||
instances = data[0]["recent_instances"]
|
||||
assert len(instances) >= 2
|
||||
assert instances[0]["status"] == "AWAITING_INPUT"
|
||||
assert instances[0]["latest_message"] == "Need your input"
|
||||
@@ -3,8 +3,8 @@
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.database.models import AgentQuestion, AgentInstance, User
|
||||
from shared.database.enums import AgentStatus
|
||||
from shared.database.models import Message, AgentInstance, User
|
||||
from shared.database.enums import AgentStatus, SenderType
|
||||
|
||||
|
||||
class TestQuestionEndpoints:
|
||||
@@ -14,78 +14,90 @@ class TestQuestionEndpoints:
|
||||
self, authenticated_client, test_db, test_agent_instance, test_user
|
||||
):
|
||||
"""Test answering a pending question."""
|
||||
# Create a pending question
|
||||
question = AgentQuestion(
|
||||
# Create a message with requires_user_input=True (question)
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Should I use async/await?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Should I use async/await?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.add(question_msg)
|
||||
|
||||
# Set instance status to AWAITING_INPUT
|
||||
test_agent_instance.status = AgentStatus.AWAITING_INPUT
|
||||
test_db.commit()
|
||||
|
||||
# Submit answer
|
||||
# Submit answer as a new message
|
||||
answer_text = "Yes, use async/await for better performance"
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer", json={"answer": answer_text}
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": answer_text},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["message"] == "Answer submitted successfully"
|
||||
|
||||
# Verify in database
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text == answer_text
|
||||
assert question.answered_at is not None
|
||||
assert question.answered_by_user_id == test_user.id
|
||||
assert question.is_active is False
|
||||
assert data["content"] == answer_text
|
||||
assert data["sender_type"] == "USER"
|
||||
assert data["requires_user_input"] is False
|
||||
|
||||
# Verify agent instance status changed back to active
|
||||
test_db.refresh(test_agent_instance)
|
||||
assert test_agent_instance.status == AgentStatus.ACTIVE
|
||||
|
||||
def test_answer_question_not_found(self, authenticated_client):
|
||||
"""Test answering a non-existent question."""
|
||||
"""Test sending message to non-existent agent instance."""
|
||||
fake_id = uuid4()
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{fake_id}/answer", json={"answer": "Some answer"}
|
||||
f"/api/v1/agent-instances/{fake_id}/messages",
|
||||
json={"content": "Some answer"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
assert response.json()["detail"] == "Agent instance not found"
|
||||
|
||||
def test_answer_already_answered_question(
|
||||
self, authenticated_client, test_db, test_agent_instance, test_user
|
||||
):
|
||||
"""Test answering an already answered question."""
|
||||
# Create an already answered question
|
||||
question = AgentQuestion(
|
||||
"""Test that you can continue sending messages after answering a question."""
|
||||
# In the new system, you can always send more messages
|
||||
# This test verifies that behavior
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
# Create a question message
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Already answered?",
|
||||
answer_text="Previous answer",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
answered_at=datetime.now(timezone.utc),
|
||||
answered_by_user_id=test_user.id,
|
||||
is_active=False,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Already answered?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.add(question_msg)
|
||||
|
||||
# Create answer message
|
||||
answer_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
sender_type=SenderType.USER,
|
||||
content="Previous answer",
|
||||
requires_user_input=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(answer_msg)
|
||||
test_db.commit()
|
||||
|
||||
# Try to answer again
|
||||
# Send another message - should work fine
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer", json={"answer": "New answer"}
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "New message"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
|
||||
# Verify answer didn't change
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text == "Previous answer"
|
||||
assert response.status_code == 200
|
||||
assert response.json()["content"] == "New message"
|
||||
|
||||
def test_answer_question_wrong_user(self, authenticated_client, test_db):
|
||||
"""Test answering a question from another user's agent."""
|
||||
@@ -121,78 +133,67 @@ class TestQuestionEndpoints:
|
||||
)
|
||||
test_db.add(other_instance)
|
||||
|
||||
# Create question for other user's agent
|
||||
question = AgentQuestion(
|
||||
# Create question message for other user's agent
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=other_instance.id,
|
||||
question_text="Other user's question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Other user's question?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.add(question_msg)
|
||||
test_db.commit()
|
||||
|
||||
# Try to answer as current user
|
||||
# Try to send message as current user to other user's instance
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer",
|
||||
json={"answer": "Trying to answer"},
|
||||
f"/api/v1/agent-instances/{other_instance.id}/messages",
|
||||
json={"content": "Trying to answer"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
|
||||
# Verify question remains unanswered
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text is None
|
||||
assert question.is_active is True
|
||||
assert response.json()["detail"] == "Agent instance not found"
|
||||
|
||||
def test_answer_inactive_question(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test answering an inactive question."""
|
||||
# Create an inactive question (but not answered)
|
||||
question = AgentQuestion(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Inactive question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=False, # Inactive but not answered
|
||||
)
|
||||
test_db.add(question)
|
||||
"""Test sending message to completed agent instance."""
|
||||
# Set instance to COMPLETED status
|
||||
test_agent_instance.status = AgentStatus.COMPLETED
|
||||
test_agent_instance.ended_at = datetime.now(timezone.utc)
|
||||
test_db.commit()
|
||||
|
||||
# Try to answer
|
||||
# Try to send message - should still work since status updates happen server-side
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer",
|
||||
json={"answer": "Trying to answer inactive"},
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": "Message to completed instance"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Question not found or already answered"
|
||||
# Based on the backend code, messages can still be sent to completed instances
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_answer_question_empty_answer(
|
||||
self, authenticated_client, test_db, test_agent_instance
|
||||
):
|
||||
"""Test submitting an empty answer."""
|
||||
# Create a pending question
|
||||
question = AgentQuestion(
|
||||
"""Test submitting an empty message."""
|
||||
# Create a question message
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Can I submit empty?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Can I submit empty?",
|
||||
requires_user_input=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(question)
|
||||
test_db.add(question_msg)
|
||||
test_agent_instance.status = AgentStatus.AWAITING_INPUT
|
||||
test_db.commit()
|
||||
|
||||
# Submit empty answer - should still work
|
||||
# Submit empty message - should still work
|
||||
response = authenticated_client.post(
|
||||
f"/api/v1/questions/{question.id}/answer", json={"answer": ""}
|
||||
f"/api/v1/agent-instances/{test_agent_instance.id}/messages",
|
||||
json={"content": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify empty answer was saved
|
||||
test_db.refresh(question)
|
||||
assert question.answer_text == ""
|
||||
assert question.is_active is False
|
||||
assert response.json()["content"] == ""
|
||||
|
||||
@@ -186,8 +186,8 @@ class TestUserAgentEndpoints:
|
||||
|
||||
assert len(data) == 2
|
||||
statuses = [inst["status"] for inst in data]
|
||||
assert "active" in statuses
|
||||
assert "completed" in statuses
|
||||
assert "ACTIVE" in statuses
|
||||
assert "COMPLETED" in statuses
|
||||
|
||||
def test_get_user_agent_instances_not_found(self, authenticated_client):
|
||||
"""Test getting instances for non-existent user agent."""
|
||||
|
||||
@@ -57,6 +57,29 @@ def run_webhook_server(
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
def run_claude_wrapper(api_key, base_url=None, claude_args=None):
|
||||
"""Run the Claude wrapper V3 for Omnara integration"""
|
||||
# Import and run directly instead of subprocess
|
||||
from webhooks.claude_wrapper_v3 import main as claude_wrapper_main
|
||||
|
||||
# Prepare sys.argv for the claude wrapper
|
||||
original_argv = sys.argv
|
||||
new_argv = ["claude_wrapper_v3", "--api-key", api_key]
|
||||
|
||||
if base_url:
|
||||
new_argv.extend(["--base-url", base_url])
|
||||
|
||||
# Add any additional Claude arguments
|
||||
if claude_args:
|
||||
new_argv.extend(claude_args)
|
||||
|
||||
try:
|
||||
sys.argv = new_argv
|
||||
claude_wrapper_main()
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point that dispatches based on command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -79,6 +102,12 @@ Examples:
|
||||
# Run webhook server on custom port
|
||||
omnara --claude-code-webhook --port 8080
|
||||
|
||||
# Run Claude wrapper V3
|
||||
omnara --claude --api-key YOUR_API_KEY
|
||||
|
||||
# Run Claude wrapper with custom base URL
|
||||
omnara --claude --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
# Run with custom API base URL
|
||||
omnara --stdio --api-key YOUR_API_KEY --base-url http://localhost:8000
|
||||
|
||||
@@ -99,6 +128,11 @@ Examples:
|
||||
action="store_true",
|
||||
help="Run the Claude Code webhook server",
|
||||
)
|
||||
mode_group.add_argument(
|
||||
"--claude",
|
||||
action="store_true",
|
||||
help="Run the Claude wrapper V3 for Omnara integration",
|
||||
)
|
||||
|
||||
# Arguments for webhook mode
|
||||
parser.add_argument(
|
||||
@@ -142,7 +176,8 @@ Examples:
|
||||
help="Pre-existing agent instance ID to use for this session (stdio mode only)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
# Use parse_known_args to capture remaining args for Claude
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
if args.cloudflare_tunnel and not args.claude_code_webhook:
|
||||
parser.error("--cloudflare-tunnel can only be used with --claude-code-webhook")
|
||||
@@ -156,6 +191,10 @@ Examples:
|
||||
dangerously_skip_permissions=args.dangerously_skip_permissions,
|
||||
port=args.port,
|
||||
)
|
||||
elif args.claude:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for --claude mode")
|
||||
run_claude_wrapper(args.api_key, args.base_url, unknown_args)
|
||||
else:
|
||||
if not args.api_key:
|
||||
parser.error("--api-key is required for stdio mode")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Optional, Dict, Any
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, Union, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiohttp
|
||||
@@ -11,10 +12,14 @@ from aiohttp import ClientTimeout
|
||||
|
||||
from .exceptions import AuthenticationError, TimeoutError, APIError
|
||||
from .models import (
|
||||
LogStepResponse,
|
||||
QuestionResponse,
|
||||
QuestionStatus,
|
||||
EndSessionResponse,
|
||||
CreateMessageResponse,
|
||||
PendingMessagesResponse,
|
||||
Message,
|
||||
)
|
||||
from .utils import (
|
||||
validate_agent_instance_id,
|
||||
build_message_request_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -76,6 +81,7 @@ class AsyncOmnaraClient:
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an async HTTP request to the API.
|
||||
@@ -84,6 +90,7 @@ class AsyncOmnaraClient:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint path
|
||||
json: JSON body for the request
|
||||
params: Query parameters for the request
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
@@ -104,7 +111,11 @@ class AsyncOmnaraClient:
|
||||
|
||||
try:
|
||||
async with self.session.request(
|
||||
method=method, url=url, json=json, timeout=request_timeout
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
params=params,
|
||||
timeout=request_timeout,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise AuthenticationError(
|
||||
@@ -128,138 +139,246 @@ class AsyncOmnaraClient:
|
||||
except aiohttp.ClientError as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
|
||||
async def log_step(
|
||||
async def send_message(
|
||||
self,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
agent_instance_id: Optional[str] = None,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
|
||||
step_description: Clear description of what the agent is doing
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
send_push: Send push notification for this step (default: False)
|
||||
send_email: Send email notification for this step (default: False)
|
||||
send_sms: Send SMS notification for this step (default: False)
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance ID, and user feedback
|
||||
"""
|
||||
data: Dict[str, Any] = {
|
||||
"agent_type": agent_type,
|
||||
"step_description": step_description,
|
||||
}
|
||||
if agent_instance_id:
|
||||
data["agent_instance_id"] = agent_instance_id
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
|
||||
response = await self._make_request("POST", "/api/v1/steps", json=data)
|
||||
|
||||
return LogStepResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
step_number=response["step_number"],
|
||||
user_feedback=response.get("user_feedback", []),
|
||||
)
|
||||
|
||||
async def ask_question(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
content: str,
|
||||
agent_type: Optional[str] = None,
|
||||
agent_instance_id: Optional[Union[str, uuid.UUID]] = None,
|
||||
requires_user_input: bool = False,
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> QuestionResponse:
|
||||
"""Ask the user a question and wait for their response.
|
||||
|
||||
This method submits the question and then polls for the answer.
|
||||
) -> CreateMessageResponse:
|
||||
"""Send a message to the dashboard.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask the user
|
||||
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification for this question (default: user preference)
|
||||
send_email: Send email notification for this question (default: user preference)
|
||||
send_sms: Send SMS notification for this question (default: user preference)
|
||||
content: The message content (step description or question text)
|
||||
agent_type: Type of agent (required if agent_instance_id not provided)
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
requires_user_input: Whether this message requires user input (default: False)
|
||||
timeout_minutes: If requires_user_input, max time to wait in minutes (default: 1440)
|
||||
poll_interval: If requires_user_input, time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification (default: False for steps, user pref for questions)
|
||||
send_email: Send email notification (default: False for steps, user pref for questions)
|
||||
send_sms: Send SMS notification (default: False for steps, user pref for questions)
|
||||
git_diff: Git diff content to include (optional)
|
||||
|
||||
Returns:
|
||||
QuestionResponse with the user's answer
|
||||
CreateMessageResponse with any queued user messages
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no answer is received within timeout
|
||||
ValueError: If neither agent_type nor agent_instance_id is provided
|
||||
TimeoutError: If requires_user_input and no answer is received within timeout
|
||||
"""
|
||||
# Submit the question
|
||||
data: Dict[str, Any] = {
|
||||
"agent_instance_id": agent_instance_id,
|
||||
"question_text": question_text,
|
||||
}
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
# If no agent_instance_id provided, generate one client-side
|
||||
if not agent_instance_id:
|
||||
if not agent_type:
|
||||
raise ValueError("agent_type is required when creating a new instance")
|
||||
agent_instance_id = uuid.uuid4()
|
||||
|
||||
# First, try the non-blocking endpoint to create the question
|
||||
response = await self._make_request(
|
||||
"POST", "/api/v1/questions", json=data, timeout=5
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
# Build request data using shared utility
|
||||
data = build_message_request_data(
|
||||
content=content,
|
||||
agent_instance_id=agent_instance_id_str,
|
||||
requires_user_input=requires_user_input,
|
||||
agent_type=agent_type,
|
||||
send_push=send_push,
|
||||
send_email=send_email,
|
||||
send_sms=send_sms,
|
||||
git_diff=git_diff,
|
||||
)
|
||||
question_id = response["question_id"]
|
||||
|
||||
# Convert timeout from minutes to seconds
|
||||
# Send the message
|
||||
response = await self._make_request("POST", "/api/v1/messages/agent", json=data)
|
||||
response_agent_instance_id = response["agent_instance_id"]
|
||||
message_id = response["message_id"]
|
||||
|
||||
queued_contents = [
|
||||
msg["content"] if isinstance(msg, dict) else msg
|
||||
for msg in response.get("queued_user_messages", [])
|
||||
]
|
||||
|
||||
create_response = CreateMessageResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response_agent_instance_id,
|
||||
message_id=message_id,
|
||||
queued_user_messages=queued_contents,
|
||||
)
|
||||
|
||||
# If it doesn't require user input, return immediately
|
||||
if not requires_user_input:
|
||||
return create_response
|
||||
|
||||
# Otherwise, poll for the answer
|
||||
# Use the message ID we just created as our starting point
|
||||
last_read_message_id = message_id
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
|
||||
# Poll for the answer
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
|
||||
status = await self.get_question_status(question_id)
|
||||
all_messages = []
|
||||
|
||||
if status.status == "answered" and status.answer:
|
||||
return QuestionResponse(answer=status.answer, question_id=question_id)
|
||||
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages
|
||||
pending_response = await self.get_pending_messages(
|
||||
agent_instance_id_str, last_read_message_id
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all messages
|
||||
all_messages.extend(pending_response.messages)
|
||||
|
||||
# Return the response with all collected messages
|
||||
create_response.queued_user_messages = [
|
||||
msg.content for msg in all_messages
|
||||
]
|
||||
return create_response
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
|
||||
|
||||
async def get_question_status(self, question_id: str) -> QuestionStatus:
|
||||
"""Get the current status of a question.
|
||||
async def get_pending_messages(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
last_read_message_id: Optional[str] = None,
|
||||
) -> PendingMessagesResponse:
|
||||
"""Get pending user messages for an agent instance.
|
||||
|
||||
Args:
|
||||
question_id: ID of the question to check
|
||||
agent_instance_id: Agent instance ID
|
||||
last_read_message_id: The last message ID that was read (optional)
|
||||
|
||||
Returns:
|
||||
QuestionStatus with current status and answer (if available)
|
||||
PendingMessagesResponse with messages and status
|
||||
"""
|
||||
response = await self._make_request("GET", f"/api/v1/questions/{question_id}")
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
return QuestionStatus(
|
||||
question_id=response["question_id"],
|
||||
status=response["status"],
|
||||
answer=response.get("answer"),
|
||||
asked_at=response["asked_at"],
|
||||
answered_at=response.get("answered_at"),
|
||||
params = {"agent_instance_id": agent_instance_id_str}
|
||||
if last_read_message_id:
|
||||
params["last_read_message_id"] = last_read_message_id
|
||||
|
||||
response = await self._make_request(
|
||||
"GET", "/api/v1/messages/pending", params=params
|
||||
)
|
||||
|
||||
async def end_session(self, agent_instance_id: str) -> EndSessionResponse:
|
||||
return PendingMessagesResponse(
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
messages=[Message(**msg) for msg in response["messages"]],
|
||||
status=response["status"],
|
||||
)
|
||||
|
||||
async def send_user_message(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
content: str,
|
||||
mark_as_read: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a user message to an agent instance.
|
||||
|
||||
Args:
|
||||
agent_instance_id: The agent instance ID to send the message to
|
||||
content: Message content
|
||||
mark_as_read: Whether to mark as read (update last_read_message_id) (default: True)
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- success: Whether the message was created
|
||||
- message_id: ID of the created message
|
||||
- marked_as_read: Whether the message was marked as read
|
||||
|
||||
Raises:
|
||||
ValueError: If agent instance not found or access denied
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Validate and convert agent_instance_id
|
||||
agent_instance_id = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data = {
|
||||
"agent_instance_id": str(agent_instance_id),
|
||||
"content": content,
|
||||
"mark_as_read": mark_as_read,
|
||||
}
|
||||
|
||||
return await self._make_request("POST", "/api/v1/messages/user", json=data)
|
||||
|
||||
async def request_user_input(
|
||||
self,
|
||||
message_id: Union[str, uuid.UUID],
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
) -> List[str]:
|
||||
"""Request user input for a previously sent agent message.
|
||||
|
||||
This method updates an agent message to require user input and polls for responses.
|
||||
It's useful when you initially send a message without requiring input, but later
|
||||
decide you need user feedback.
|
||||
|
||||
Args:
|
||||
message_id: The message ID to update (must be an agent message)
|
||||
timeout_minutes: Max time to wait for user response in minutes (default: 1440)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
|
||||
Returns:
|
||||
List of user message contents received as responses
|
||||
|
||||
Raises:
|
||||
ValueError: If message not found, already requires input, or not an agent message
|
||||
TimeoutError: If no user response is received within timeout
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Convert message_id to string if it's a UUID
|
||||
message_id_str = str(message_id)
|
||||
|
||||
# Call the endpoint to update the message
|
||||
response = await self._make_request(
|
||||
"PATCH", f"/api/v1/messages/{message_id_str}/request-input"
|
||||
)
|
||||
|
||||
agent_instance_id = response["agent_instance_id"]
|
||||
messages = response.get("messages", [])
|
||||
|
||||
if messages:
|
||||
return [msg["content"] for msg in messages]
|
||||
|
||||
# Otherwise, poll for user response
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
all_messages = []
|
||||
|
||||
while asyncio.get_event_loop().time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages using the message_id as last_read
|
||||
pending_response = await self.get_pending_messages(
|
||||
agent_instance_id, message_id_str
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all message contents
|
||||
all_messages.extend([msg.content for msg in pending_response.messages])
|
||||
return all_messages
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"No user response received after {timeout_minutes} minutes")
|
||||
|
||||
async def end_session(
|
||||
self, agent_instance_id: Union[str, uuid.UUID]
|
||||
) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
|
||||
Args:
|
||||
@@ -268,7 +387,10 @@ class AsyncOmnaraClient:
|
||||
Returns:
|
||||
EndSessionResponse with success status and final details
|
||||
"""
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id}
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id_str}
|
||||
response = await self._make_request("POST", "/api/v1/sessions/end", json=data)
|
||||
|
||||
return EndSessionResponse(
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Main client for interacting with the Omnara Agent Dashboard API."""
|
||||
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, Union, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
@@ -10,10 +11,14 @@ from urllib3.util.retry import Retry
|
||||
|
||||
from .exceptions import AuthenticationError, TimeoutError, APIError
|
||||
from .models import (
|
||||
LogStepResponse,
|
||||
QuestionResponse,
|
||||
QuestionStatus,
|
||||
EndSessionResponse,
|
||||
CreateMessageResponse,
|
||||
PendingMessagesResponse,
|
||||
Message,
|
||||
)
|
||||
from .utils import (
|
||||
validate_agent_instance_id,
|
||||
build_message_request_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -63,6 +68,7 @@ class OmnaraClient:
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to the API.
|
||||
@@ -71,6 +77,7 @@ class OmnaraClient:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint path
|
||||
json: JSON body for the request
|
||||
params: Query parameters for the request
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
@@ -86,7 +93,7 @@ class OmnaraClient:
|
||||
|
||||
try:
|
||||
response = self.session.request(
|
||||
method=method, url=url, json=json, timeout=timeout
|
||||
method=method, url=url, json=json, params=params, timeout=timeout
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
@@ -106,138 +113,245 @@ class OmnaraClient:
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise APIError(0, f"Request failed: {str(e)}")
|
||||
|
||||
def log_step(
|
||||
def send_message(
|
||||
self,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
agent_instance_id: Optional[str] = None,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent (e.g., 'Claude Code', 'Cursor')
|
||||
step_description: Clear description of what the agent is doing
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
send_push: Send push notification for this step (default: False)
|
||||
send_email: Send email notification for this step (default: False)
|
||||
send_sms: Send SMS notification for this step (default: False)
|
||||
git_diff: Git diff content to include with this step (optional)
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance ID, and user feedback
|
||||
"""
|
||||
data: Dict[str, Any] = {
|
||||
"agent_type": agent_type,
|
||||
"step_description": step_description,
|
||||
}
|
||||
if agent_instance_id:
|
||||
data["agent_instance_id"] = agent_instance_id
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
|
||||
response = self._make_request("POST", "/api/v1/steps", json=data)
|
||||
|
||||
return LogStepResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
step_number=response["step_number"],
|
||||
user_feedback=response.get("user_feedback", []),
|
||||
)
|
||||
|
||||
def ask_question(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
content: str,
|
||||
agent_type: Optional[str] = None,
|
||||
agent_instance_id: Optional[Union[str, uuid.UUID]] = None,
|
||||
requires_user_input: bool = False,
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> QuestionResponse:
|
||||
"""Ask the user a question and wait for their response.
|
||||
|
||||
This method submits the question and then polls for the answer.
|
||||
) -> CreateMessageResponse:
|
||||
"""Send a message to the dashboard.
|
||||
|
||||
Args:
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask the user
|
||||
timeout_minutes: Maximum time to wait for answer in minutes (default: 1440 = 24 hours)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification for this question (default: user preference)
|
||||
send_email: Send email notification for this question (default: user preference)
|
||||
send_sms: Send SMS notification for this question (default: user preference)
|
||||
git_diff: Git diff content to include with this question (optional)
|
||||
content: The message content (step description or question text)
|
||||
agent_type: Type of agent (required if agent_instance_id not provided)
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
requires_user_input: Whether this message requires user input (default: False)
|
||||
timeout_minutes: If requires_user_input, max time to wait in minutes (default: 1440)
|
||||
poll_interval: If requires_user_input, time between polls in seconds (default: 10.0)
|
||||
send_push: Send push notification (default: False for steps, user pref for questions)
|
||||
send_email: Send email notification (default: False for steps, user pref for questions)
|
||||
send_sms: Send SMS notification (default: False for steps, user pref for questions)
|
||||
git_diff: Git diff content to include (optional)
|
||||
|
||||
Returns:
|
||||
QuestionResponse with the user's answer
|
||||
CreateMessageResponse with any queued user messages
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no answer is received within timeout
|
||||
ValueError: If neither agent_type nor agent_instance_id is provided
|
||||
TimeoutError: If requires_user_input and no answer is received within timeout
|
||||
"""
|
||||
# Submit the question
|
||||
data: Dict[str, Any] = {
|
||||
"agent_instance_id": agent_instance_id,
|
||||
"question_text": question_text,
|
||||
}
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
# If no agent_instance_id provided, generate one client-side
|
||||
if not agent_instance_id:
|
||||
if not agent_type:
|
||||
raise ValueError("agent_type is required when creating a new instance")
|
||||
agent_instance_id = uuid.uuid4()
|
||||
|
||||
# First, try the non-blocking endpoint to create the question
|
||||
response = self._make_request("POST", "/api/v1/questions", json=data, timeout=5)
|
||||
question_id = response["question_id"]
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
# Build request data using shared utility
|
||||
data = build_message_request_data(
|
||||
content=content,
|
||||
agent_instance_id=agent_instance_id_str,
|
||||
requires_user_input=requires_user_input,
|
||||
agent_type=agent_type,
|
||||
send_push=send_push,
|
||||
send_email=send_email,
|
||||
send_sms=send_sms,
|
||||
git_diff=git_diff,
|
||||
)
|
||||
|
||||
# Send the message
|
||||
response = self._make_request("POST", "/api/v1/messages/agent", json=data)
|
||||
response_agent_instance_id = response["agent_instance_id"]
|
||||
message_id = response["message_id"]
|
||||
|
||||
queued_contents = [
|
||||
msg["content"] if isinstance(msg, dict) else msg
|
||||
for msg in response.get("queued_user_messages", [])
|
||||
]
|
||||
|
||||
create_response = CreateMessageResponse(
|
||||
success=response["success"],
|
||||
agent_instance_id=response_agent_instance_id,
|
||||
message_id=message_id,
|
||||
queued_user_messages=queued_contents,
|
||||
)
|
||||
|
||||
# If it doesn't require user input, return immediately with any queued messages
|
||||
if not requires_user_input:
|
||||
return create_response
|
||||
|
||||
# Otherwise, we need to poll for user response
|
||||
# Use the message ID we just created as our starting point
|
||||
last_read_message_id = message_id
|
||||
|
||||
# Convert timeout from minutes to seconds
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
|
||||
# Poll for the answer
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
status = self.get_question_status(question_id)
|
||||
all_messages = []
|
||||
|
||||
if status.status == "answered" and status.answer:
|
||||
return QuestionResponse(answer=status.answer, question_id=question_id)
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages
|
||||
pending_response = self.get_pending_messages(
|
||||
agent_instance_id_str, last_read_message_id
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all messages
|
||||
all_messages.extend(pending_response.messages)
|
||||
|
||||
# Return the response with all collected messages
|
||||
create_response.queued_user_messages = [
|
||||
msg.content for msg in all_messages
|
||||
]
|
||||
return create_response
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"Question timed out after {timeout_minutes} minutes")
|
||||
|
||||
def get_question_status(self, question_id: str) -> QuestionStatus:
|
||||
"""Get the current status of a question.
|
||||
def get_pending_messages(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
last_read_message_id: Optional[str] = None,
|
||||
) -> PendingMessagesResponse:
|
||||
"""Get pending user messages for an agent instance.
|
||||
|
||||
Args:
|
||||
question_id: ID of the question to check
|
||||
agent_instance_id: Agent instance ID
|
||||
last_read_message_id: The last message ID that was read (optional)
|
||||
|
||||
Returns:
|
||||
QuestionStatus with current status and answer (if available)
|
||||
PendingMessagesResponse with messages and status
|
||||
"""
|
||||
response = self._make_request("GET", f"/api/v1/questions/{question_id}")
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
return QuestionStatus(
|
||||
question_id=response["question_id"],
|
||||
params = {"agent_instance_id": agent_instance_id_str}
|
||||
if last_read_message_id:
|
||||
params["last_read_message_id"] = last_read_message_id
|
||||
|
||||
response = self._make_request("GET", "/api/v1/messages/pending", params=params)
|
||||
|
||||
return PendingMessagesResponse(
|
||||
agent_instance_id=response["agent_instance_id"],
|
||||
messages=[Message(**msg) for msg in response["messages"]],
|
||||
status=response["status"],
|
||||
answer=response.get("answer"),
|
||||
asked_at=response["asked_at"],
|
||||
answered_at=response.get("answered_at"),
|
||||
)
|
||||
|
||||
def end_session(self, agent_instance_id: str) -> EndSessionResponse:
|
||||
def send_user_message(
|
||||
self,
|
||||
agent_instance_id: Union[str, uuid.UUID],
|
||||
content: str,
|
||||
mark_as_read: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a user message to an agent instance.
|
||||
|
||||
Args:
|
||||
agent_instance_id: The agent instance ID to send the message to
|
||||
content: Message content
|
||||
mark_as_read: Whether to mark as read (update last_read_message_id) (default: True)
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- success: Whether the message was created
|
||||
- message_id: ID of the created message
|
||||
- marked_as_read: Whether the message was marked as read
|
||||
|
||||
Raises:
|
||||
ValueError: If agent instance not found or access denied
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Validate and convert agent_instance_id
|
||||
agent_instance_id = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data = {
|
||||
"agent_instance_id": str(agent_instance_id),
|
||||
"content": content,
|
||||
"mark_as_read": mark_as_read,
|
||||
}
|
||||
|
||||
return self._make_request("POST", "/api/v1/messages/user", json=data)
|
||||
|
||||
def request_user_input(
|
||||
self,
|
||||
message_id: Union[str, uuid.UUID],
|
||||
timeout_minutes: int = 1440,
|
||||
poll_interval: float = 10.0,
|
||||
) -> List[str]:
|
||||
"""Request user input for a previously sent agent message.
|
||||
|
||||
This method updates an agent message to require user input and polls for responses.
|
||||
It's useful when you initially send a message without requiring input, but later
|
||||
decide you need user feedback.
|
||||
|
||||
Args:
|
||||
message_id: The message ID to update (must be an agent message)
|
||||
timeout_minutes: Max time to wait for user response in minutes (default: 1440)
|
||||
poll_interval: Time between polls in seconds (default: 10.0)
|
||||
|
||||
Returns:
|
||||
List of user message contents received as responses
|
||||
|
||||
Raises:
|
||||
ValueError: If message not found, already requires input, or not an agent message
|
||||
TimeoutError: If no user response is received within timeout
|
||||
APIError: If the API request fails
|
||||
"""
|
||||
# Convert message_id to string if it's a UUID
|
||||
message_id_str = str(message_id)
|
||||
|
||||
# Call the endpoint to update the message
|
||||
response = self._make_request(
|
||||
"PATCH", f"/api/v1/messages/{message_id_str}/request-input"
|
||||
)
|
||||
|
||||
agent_instance_id = response["agent_instance_id"]
|
||||
messages = response.get("messages", [])
|
||||
|
||||
if messages:
|
||||
return [msg["content"] for msg in messages]
|
||||
|
||||
# Otherwise, poll for user response
|
||||
timeout_seconds = timeout_minutes * 60
|
||||
start_time = time.time()
|
||||
all_messages = []
|
||||
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
# Poll for pending messages using the message_id as last_read
|
||||
pending_response = self.get_pending_messages(
|
||||
agent_instance_id, message_id_str
|
||||
)
|
||||
|
||||
# If status is "stale", another process has read the messages
|
||||
if pending_response.status == "stale":
|
||||
raise TimeoutError("Another process has read the messages")
|
||||
|
||||
# Check if we got any messages
|
||||
if pending_response.messages:
|
||||
# Collect all message contents
|
||||
all_messages.extend([msg.content for msg in pending_response.messages])
|
||||
return all_messages
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(f"No user response received after {timeout_minutes} minutes")
|
||||
|
||||
def end_session(
|
||||
self, agent_instance_id: Union[str, uuid.UUID]
|
||||
) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
|
||||
Args:
|
||||
@@ -246,7 +360,10 @@ class OmnaraClient:
|
||||
Returns:
|
||||
EndSessionResponse with success status and final details
|
||||
"""
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id}
|
||||
# Validate and convert agent_instance_id to string
|
||||
agent_instance_id_str = validate_agent_instance_id(agent_instance_id)
|
||||
|
||||
data: Dict[str, Any] = {"agent_instance_id": agent_instance_id_str}
|
||||
response = self._make_request("POST", "/api/v1/sessions/end", json=data)
|
||||
|
||||
return EndSessionResponse(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Data models for the Omnara SDK."""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -14,25 +14,6 @@ class LogStepResponse:
|
||||
user_feedback: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionResponse:
|
||||
"""Response from asking a question."""
|
||||
|
||||
answer: str
|
||||
question_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionStatus:
|
||||
"""Status of a question."""
|
||||
|
||||
question_id: str
|
||||
status: str # 'pending' or 'answered'
|
||||
answer: Optional[str]
|
||||
asked_at: str
|
||||
answered_at: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndSessionResponse:
|
||||
"""Response from ending a session."""
|
||||
@@ -40,3 +21,33 @@ class EndSessionResponse:
|
||||
success: bool
|
||||
agent_instance_id: str
|
||||
final_status: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateMessageResponse:
|
||||
"""Response from creating a message."""
|
||||
|
||||
success: bool
|
||||
agent_instance_id: str
|
||||
message_id: str
|
||||
queued_user_messages: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A message in the conversation."""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
sender_type: str # 'agent' or 'user'
|
||||
created_at: str
|
||||
requires_user_input: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingMessagesResponse:
|
||||
"""Response from getting pending messages."""
|
||||
|
||||
agent_instance_id: str
|
||||
messages: List[Message]
|
||||
status: str # 'ok' or 'stale'
|
||||
|
||||
78
omnara/sdk/utils.py
Normal file
78
omnara/sdk/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Utility functions for the Omnara SDK."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional, Union, Dict, Any
|
||||
|
||||
|
||||
def validate_agent_instance_id(
|
||||
agent_instance_id: Optional[Union[str, uuid.UUID]],
|
||||
) -> str:
|
||||
"""Validate and convert agent_instance_id to string.
|
||||
|
||||
Args:
|
||||
agent_instance_id: UUID string, UUID object, or None
|
||||
|
||||
Returns:
|
||||
Validated UUID string
|
||||
|
||||
Raises:
|
||||
ValueError: If agent_instance_id is not a valid UUID
|
||||
"""
|
||||
if agent_instance_id is None:
|
||||
raise ValueError("agent_instance_id cannot be None")
|
||||
|
||||
if isinstance(agent_instance_id, str):
|
||||
try:
|
||||
# Validate it's a valid UUID
|
||||
uuid.UUID(agent_instance_id)
|
||||
return agent_instance_id
|
||||
except ValueError:
|
||||
raise ValueError("agent_instance_id must be a valid UUID string")
|
||||
elif isinstance(agent_instance_id, uuid.UUID):
|
||||
return str(agent_instance_id)
|
||||
else:
|
||||
raise ValueError("agent_instance_id must be a string or UUID object")
|
||||
|
||||
|
||||
def build_message_request_data(
|
||||
content: str,
|
||||
agent_instance_id: str,
|
||||
requires_user_input: bool,
|
||||
agent_type: Optional[str] = None,
|
||||
send_push: Optional[bool] = None,
|
||||
send_email: Optional[bool] = None,
|
||||
send_sms: Optional[bool] = None,
|
||||
git_diff: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build request data for creating a message.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
agent_instance_id: Agent instance ID (already validated)
|
||||
requires_user_input: Whether message requires user input
|
||||
agent_type: Optional agent type
|
||||
send_push: Optional push notification flag
|
||||
send_email: Optional email notification flag
|
||||
send_sms: Optional SMS notification flag
|
||||
git_diff: Optional git diff content
|
||||
|
||||
Returns:
|
||||
Dictionary of request data
|
||||
"""
|
||||
data: Dict[str, Any] = {
|
||||
"content": content,
|
||||
"requires_user_input": requires_user_input,
|
||||
"agent_instance_id": agent_instance_id,
|
||||
}
|
||||
if agent_type:
|
||||
data["agent_type"] = agent_type
|
||||
if send_push is not None:
|
||||
data["send_push"] = send_push
|
||||
if send_email is not None:
|
||||
data["send_email"] = send_email
|
||||
if send_sms is not None:
|
||||
data["send_sms"] = send_sms
|
||||
if git_diff is not None:
|
||||
data["git_diff"] = git_diff
|
||||
|
||||
return data
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "omnara"
|
||||
version = "1.3.15"
|
||||
version = "1.3.16"
|
||||
description = "Omnara Agent Dashboard - MCP Server and Python SDK"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -4,9 +4,9 @@ This directory contains the write operations server for the Agent Dashboard syst
|
||||
|
||||
## Overview
|
||||
|
||||
The servers directory implements all write operations that agents need:
|
||||
- Logging their progress and receiving user feedback
|
||||
- Asking questions to users
|
||||
The servers directory implements all write operations that agents need through a unified messaging system:
|
||||
- Sending messages (both informational steps and questions requiring user input)
|
||||
- Receiving user responses and feedback
|
||||
- Managing session lifecycle
|
||||
|
||||
All operations are authenticated and multi-tenant, ensuring data isolation between users.
|
||||
@@ -36,10 +36,10 @@ The servers use a separate authentication system from the main backend:
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Write-only operations**: Designed for agent interactions, not data retrieval
|
||||
- **Unified messaging**: All agent interactions use the same message-based API
|
||||
- **Automatic session management**: Creates sessions on first interaction
|
||||
- **User feedback delivery**: Agents receive feedback when logging steps
|
||||
- **Non-blocking questions**: Async implementation for user interactions
|
||||
- **Message queuing**: Agents receive unread user messages when sending new messages
|
||||
- **Flexible interactions**: Messages can be informational or require user input
|
||||
- **Multi-protocol support**: Same functionality via MCP or REST API
|
||||
|
||||
## Running the Server
|
||||
|
||||
@@ -1,29 +1,58 @@
|
||||
"""Pydantic models for FastAPI request/response schemas."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from servers.shared.models import (
|
||||
BaseLogStepRequest,
|
||||
BaseLogStepResponse,
|
||||
BaseAskQuestionRequest,
|
||||
BaseEndSessionRequest,
|
||||
BaseEndSessionResponse,
|
||||
)
|
||||
|
||||
|
||||
# Request models
|
||||
class LogStepRequest(BaseLogStepRequest):
|
||||
"""FastAPI-specific request model for logging a step."""
|
||||
class CreateMessageRequest(BaseModel):
|
||||
"""Request model for creating a new message."""
|
||||
|
||||
pass
|
||||
agent_instance_id: str = Field(
|
||||
...,
|
||||
description="Existing agent instance ID. Creates a new agent instance if ID doesn't exist.",
|
||||
)
|
||||
agent_type: str | None = Field(
|
||||
None, description="Type of agent (e.g., 'claude_code', 'cursor')"
|
||||
)
|
||||
content: str = Field(
|
||||
..., description="Message content (step description or question text)"
|
||||
)
|
||||
requires_user_input: bool = Field(
|
||||
False, description="Whether this message requires user input (is a question)"
|
||||
)
|
||||
send_email: bool | None = Field(
|
||||
None,
|
||||
description="Whether to send email notification (overrides user preference)",
|
||||
)
|
||||
send_sms: bool | None = Field(
|
||||
None, description="Whether to send SMS notification (overrides user preference)"
|
||||
)
|
||||
send_push: bool | None = Field(
|
||||
None,
|
||||
description="Whether to send push notification (overrides user preference)",
|
||||
)
|
||||
git_diff: str | None = Field(
|
||||
None,
|
||||
description="Git diff content to store with the instance",
|
||||
)
|
||||
|
||||
|
||||
class AskQuestionRequest(BaseAskQuestionRequest):
|
||||
"""FastAPI-specific request model for asking a question."""
|
||||
class CreateUserMessageRequest(BaseModel):
|
||||
"""Request model for creating a user message."""
|
||||
|
||||
pass
|
||||
agent_instance_id: str = Field(
|
||||
..., description="Agent instance ID to send the message to"
|
||||
)
|
||||
content: str = Field(..., description="Message content")
|
||||
mark_as_read: bool = Field(
|
||||
True,
|
||||
description="Whether to mark this message as read (update last_read_message_id)",
|
||||
)
|
||||
|
||||
|
||||
class EndSessionRequest(BaseEndSessionRequest):
|
||||
@@ -33,35 +62,60 @@ class EndSessionRequest(BaseEndSessionRequest):
|
||||
|
||||
|
||||
# Response models
|
||||
class LogStepResponse(BaseLogStepResponse):
|
||||
"""FastAPI-specific response model for log step endpoint."""
|
||||
class MessageResponse(BaseModel):
|
||||
"""Response model for individual messages."""
|
||||
|
||||
pass
|
||||
id: str = Field(..., description="Message ID")
|
||||
content: str = Field(..., description="Message content")
|
||||
sender_type: str = Field(..., description="Sender type: 'agent' or 'user'")
|
||||
created_at: str = Field(..., description="ISO timestamp when message was created")
|
||||
requires_user_input: bool = Field(
|
||||
..., description="Whether this message requires user input"
|
||||
)
|
||||
|
||||
|
||||
# FastAPI-specific: Response only contains question ID (non-blocking)
|
||||
class AskQuestionResponse(BaseModel):
|
||||
"""FastAPI-specific response model for ask question endpoint."""
|
||||
class CreateMessageResponse(BaseModel):
|
||||
"""Response model for create message endpoint."""
|
||||
|
||||
question_id: str = Field(..., description="ID of the created question")
|
||||
success: bool = Field(
|
||||
..., description="Whether the message was created successfully"
|
||||
)
|
||||
agent_instance_id: str = Field(
|
||||
..., description="Agent instance ID (new or existing)"
|
||||
)
|
||||
message_id: str = Field(..., description="ID of the message that was created")
|
||||
queued_user_messages: list[MessageResponse] = Field(
|
||||
default_factory=list,
|
||||
description="List of queued user messages with full metadata",
|
||||
)
|
||||
|
||||
|
||||
class CreateUserMessageResponse(BaseModel):
|
||||
"""Response model for create user message endpoint."""
|
||||
|
||||
success: bool = Field(
|
||||
..., description="Whether the message was created successfully"
|
||||
)
|
||||
message_id: str = Field(..., description="ID of the created message")
|
||||
marked_as_read: bool = Field(
|
||||
..., description="Whether the message was marked as read"
|
||||
)
|
||||
|
||||
|
||||
class GetMessagesResponse(BaseModel):
|
||||
"""Response model for get messages endpoint."""
|
||||
|
||||
agent_instance_id: str = Field(..., description="Agent instance ID")
|
||||
messages: list[MessageResponse] = Field(
|
||||
default_factory=list, description="List of messages"
|
||||
)
|
||||
status: str = Field(
|
||||
"ok",
|
||||
description="Status: 'ok' if messages retrieved, 'stale' if last_read_message_id is outdated",
|
||||
)
|
||||
|
||||
|
||||
class EndSessionResponse(BaseEndSessionResponse):
|
||||
"""FastAPI-specific response model for end session endpoint."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# FastAPI-specific: Additional model for polling question status
|
||||
class QuestionStatusResponse(BaseModel):
|
||||
"""Response model for question status endpoint."""
|
||||
|
||||
question_id: str
|
||||
status: str = Field(
|
||||
..., description="Status of the question: 'pending' or 'answered'"
|
||||
)
|
||||
answer: Optional[str] = Field(
|
||||
None, description="Answer text if status is 'answered'"
|
||||
)
|
||||
asked_at: str
|
||||
answered_at: Optional[str] = None
|
||||
|
||||
@@ -1,66 +1,93 @@
|
||||
"""API routes for agent operations."""
|
||||
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from shared.database.session import get_db
|
||||
from servers.shared.db import get_question, get_agent_instance
|
||||
from servers.shared.core import (
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
from shared.database import Message, AgentInstance, SenderType, AgentStatus
|
||||
from servers.shared.db import (
|
||||
send_agent_message,
|
||||
end_session,
|
||||
get_or_create_agent_instance,
|
||||
get_queued_user_messages,
|
||||
create_user_message,
|
||||
)
|
||||
from servers.shared.notification_utils import send_message_notifications
|
||||
from .auth import get_current_user_id
|
||||
from .models import (
|
||||
AskQuestionRequest,
|
||||
AskQuestionResponse,
|
||||
CreateMessageRequest,
|
||||
CreateMessageResponse,
|
||||
CreateUserMessageRequest,
|
||||
CreateUserMessageResponse,
|
||||
EndSessionRequest,
|
||||
EndSessionResponse,
|
||||
LogStepRequest,
|
||||
LogStepResponse,
|
||||
QuestionStatusResponse,
|
||||
GetMessagesResponse,
|
||||
MessageResponse,
|
||||
)
|
||||
|
||||
agent_router = APIRouter(tags=["agents"])
|
||||
|
||||
|
||||
@agent_router.post("/steps", response_model=LogStepResponse)
|
||||
def log_step(
|
||||
request: LogStepRequest, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> LogStepResponse:
|
||||
"""Log a high-level step the agent is performing.
|
||||
@agent_router.post("/messages/agent", response_model=CreateMessageResponse)
|
||||
async def create_agent_message_endpoint(
|
||||
request: CreateMessageRequest, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> CreateMessageResponse:
|
||||
"""Create a new agent message.
|
||||
|
||||
This endpoint:
|
||||
- Creates or retrieves an agent instance
|
||||
- Logs the step with a sequential number
|
||||
- Returns any unretrieved user feedback
|
||||
|
||||
User feedback is automatically marked as retrieved.
|
||||
- Creates a new message
|
||||
- Returns the message ID and any queued user messages
|
||||
- Sends notifications if requested
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic
|
||||
instance_id, step_number, user_feedback = process_log_step(
|
||||
# Use the unified send_agent_message function
|
||||
instance_id, message_id, queued_messages = await send_agent_message(
|
||||
db=db,
|
||||
agent_type=request.agent_type,
|
||||
step_description=request.step_description,
|
||||
user_id=user_id,
|
||||
agent_instance_id=request.agent_instance_id,
|
||||
send_email=request.send_email,
|
||||
send_sms=request.send_sms,
|
||||
send_push=request.send_push,
|
||||
content=request.content,
|
||||
user_id=user_id,
|
||||
agent_type=request.agent_type,
|
||||
requires_user_input=request.requires_user_input,
|
||||
git_diff=request.git_diff,
|
||||
)
|
||||
|
||||
return LogStepResponse(
|
||||
# Send notifications if requested
|
||||
await send_message_notifications(
|
||||
db=db,
|
||||
instance_id=UUID(instance_id),
|
||||
content=request.content,
|
||||
requires_user_input=request.requires_user_input,
|
||||
send_email=request.send_email,
|
||||
send_sms=request.send_sms,
|
||||
send_push=request.send_push,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
message_responses = [
|
||||
MessageResponse(
|
||||
id=str(msg.id),
|
||||
content=msg.content,
|
||||
sender_type=msg.sender_type.value,
|
||||
created_at=msg.created_at.isoformat(),
|
||||
requires_user_input=msg.requires_user_input,
|
||||
)
|
||||
for msg in queued_messages
|
||||
]
|
||||
|
||||
return CreateMessageResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
step_number=step_number,
|
||||
user_feedback=user_feedback,
|
||||
message_id=message_id,
|
||||
queued_user_messages=message_responses,
|
||||
)
|
||||
except ValueError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
@@ -72,40 +99,202 @@ def log_step(
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.post("/questions", response_model=AskQuestionResponse)
|
||||
async def ask_question(
|
||||
request: AskQuestionRequest, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> AskQuestionResponse:
|
||||
"""Create a question for the user to answer.
|
||||
@agent_router.post("/messages/user", response_model=CreateUserMessageResponse)
|
||||
async def create_user_message_endpoint(
|
||||
request: CreateUserMessageRequest,
|
||||
user_id: Annotated[str, Depends(get_current_user_id)],
|
||||
) -> CreateUserMessageResponse:
|
||||
"""Create a user message.
|
||||
|
||||
This endpoint:
|
||||
- Creates a question record in the database
|
||||
- Returns immediately with the question ID
|
||||
- Client should poll GET /questions/{question_id} for the answer
|
||||
|
||||
Note: This endpoint is non-blocking.
|
||||
- Creates a user message for an existing agent instance
|
||||
- Optionally marks it as read (updates last_read_message_id)
|
||||
- Returns the message ID
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic to create question
|
||||
question = await create_agent_question(
|
||||
# Create the user message
|
||||
message_id, marked_as_read = create_user_message(
|
||||
db=db,
|
||||
agent_instance_id=request.agent_instance_id,
|
||||
question_text=request.question_text,
|
||||
content=request.content,
|
||||
user_id=user_id,
|
||||
send_email=request.send_email,
|
||||
send_sms=request.send_sms,
|
||||
send_push=request.send_push,
|
||||
git_diff=request.git_diff,
|
||||
mark_as_read=request.mark_as_read,
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db.commit()
|
||||
|
||||
# FastAPI-specific: Return immediately with question ID (non-blocking)
|
||||
return AskQuestionResponse(
|
||||
question_id=str(question.id),
|
||||
return CreateUserMessageResponse(
|
||||
success=True,
|
||||
message_id=message_id,
|
||||
marked_as_read=marked_as_read,
|
||||
)
|
||||
except ValueError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Internal server error: {str(e)}",
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.get("/messages/pending", response_model=GetMessagesResponse)
|
||||
async def get_pending_messages(
|
||||
agent_instance_id: str,
|
||||
last_read_message_id: str | None,
|
||||
user_id: Annotated[str, Depends(get_current_user_id)],
|
||||
) -> GetMessagesResponse:
|
||||
"""Get pending user messages for an agent instance.
|
||||
|
||||
This endpoint:
|
||||
- Returns all user messages since the provided last_read_message_id
|
||||
- Updates the last_read_message_id to the latest message
|
||||
- Returns None status if another process has already read the messages
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Validate access (agent_instance_id is required here)
|
||||
instance = get_or_create_agent_instance(db, agent_instance_id, user_id)
|
||||
|
||||
# Parse last_read_message_id if provided
|
||||
last_read_uuid = UUID(last_read_message_id) if last_read_message_id else None
|
||||
|
||||
# Get queued messages
|
||||
messages = get_queued_user_messages(db, instance.id, last_read_uuid)
|
||||
|
||||
# If messages is None, another process has read the messages
|
||||
if messages is None:
|
||||
return GetMessagesResponse(
|
||||
agent_instance_id=agent_instance_id,
|
||||
messages=[],
|
||||
status="stale", # Indicate that the last_read_message_id is stale
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Convert to response format
|
||||
message_responses = [
|
||||
MessageResponse(
|
||||
id=str(msg.id),
|
||||
content=msg.content,
|
||||
sender_type=msg.sender_type.value,
|
||||
created_at=msg.created_at.isoformat(),
|
||||
requires_user_input=msg.requires_user_input,
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
return GetMessagesResponse(
|
||||
agent_instance_id=agent_instance_id,
|
||||
messages=message_responses,
|
||||
status="ok",
|
||||
)
|
||||
except ValueError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Internal server error: {str(e)}",
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.patch("/messages/{message_id}/request-input")
|
||||
async def request_user_input_endpoint(
|
||||
message_id: UUID,
|
||||
user_id: Annotated[str, Depends(get_current_user_id)],
|
||||
) -> dict:
|
||||
"""Update an agent message to request user input.
|
||||
|
||||
This endpoint:
|
||||
- Updates the requires_user_input field from false to true
|
||||
- Only works on agent messages that don't already require input
|
||||
- Returns any queued user messages since this message
|
||||
- Triggers a notification via the database trigger
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Find the message and verify it's an agent message belonging to the user
|
||||
message = (
|
||||
db.query(Message)
|
||||
.join(AgentInstance, Message.agent_instance_id == AgentInstance.id)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.sender_type == SenderType.AGENT,
|
||||
AgentInstance.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Agent message not found or access denied",
|
||||
)
|
||||
|
||||
# Check if it already requires user input
|
||||
if message.requires_user_input:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Message already requires user input",
|
||||
)
|
||||
|
||||
# Update the field
|
||||
message.requires_user_input = True
|
||||
|
||||
queued_messages = get_queued_user_messages(
|
||||
db, message.agent_instance_id, message_id
|
||||
)
|
||||
|
||||
if not queued_messages:
|
||||
agent_instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == message.agent_instance_id)
|
||||
.first()
|
||||
)
|
||||
if agent_instance:
|
||||
agent_instance.status = AgentStatus.AWAITING_INPUT
|
||||
|
||||
await send_message_notifications(
|
||||
db=db,
|
||||
instance_id=message.agent_instance_id,
|
||||
content=message.content,
|
||||
requires_user_input=True,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
message_responses = [
|
||||
MessageResponse(
|
||||
id=str(msg.id),
|
||||
content=msg.content,
|
||||
sender_type=msg.sender_type.value,
|
||||
created_at=msg.created_at.isoformat(),
|
||||
requires_user_input=msg.requires_user_input,
|
||||
)
|
||||
for msg in (queued_messages or [])
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": str(message_id),
|
||||
"agent_instance_id": str(message.agent_instance_id),
|
||||
"messages": message_responses,
|
||||
"status": "ok" if queued_messages is not None else "stale",
|
||||
}
|
||||
except HTTPException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
@@ -117,48 +306,8 @@ async def ask_question(
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.get("/questions/{question_id}", response_model=QuestionStatusResponse)
|
||||
async def get_question_status(
|
||||
question_id: str, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> QuestionStatusResponse:
|
||||
"""Get the status of a question.
|
||||
|
||||
This endpoint allows polling for question answers without blocking.
|
||||
Returns the current status and answer (if available).
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Get the question
|
||||
question = get_question(db, question_id)
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Question not found"
|
||||
)
|
||||
|
||||
# Verify the question belongs to the authenticated user
|
||||
instance = get_agent_instance(db, str(question.agent_instance_id))
|
||||
if not instance or str(instance.user_id) != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Access denied"
|
||||
)
|
||||
|
||||
# Return question status
|
||||
return QuestionStatusResponse(
|
||||
question_id=str(question.id),
|
||||
status="answered" if question.answer_text else "pending",
|
||||
answer=question.answer_text,
|
||||
asked_at=question.asked_at.isoformat(),
|
||||
answered_at=question.answered_at.isoformat()
|
||||
if question.answered_at
|
||||
else None,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@agent_router.post("/sessions/end", response_model=EndSessionResponse)
|
||||
async def end_session(
|
||||
async def end_session_endpoint(
|
||||
request: EndSessionRequest, user_id: Annotated[str, Depends(get_current_user_id)]
|
||||
) -> EndSessionResponse:
|
||||
"""End an agent session and mark it as completed.
|
||||
@@ -166,24 +315,26 @@ async def end_session(
|
||||
This endpoint:
|
||||
- Marks the agent instance as COMPLETED
|
||||
- Sets the session end time
|
||||
- Deactivates any pending questions
|
||||
"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Use shared business logic
|
||||
instance_id, final_status = process_end_session(
|
||||
# Use the end_session function from queries
|
||||
instance_id, final_status = end_session(
|
||||
db=db,
|
||||
agent_instance_id=request.agent_instance_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
return EndSessionResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
final_status=final_status,
|
||||
)
|
||||
except ValueError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
|
||||
@@ -73,13 +73,13 @@ mcp = FastMCP("Agent Dashboard MCP Server", auth=auth)
|
||||
|
||||
@mcp.tool(name="log_step", description=LOG_STEP_DESCRIPTION)
|
||||
@require_auth
|
||||
def log_step_tool(
|
||||
async def log_step_tool(
|
||||
agent_instance_id: str | None = None,
|
||||
step_description: str = "",
|
||||
_user_id: str = "", # Injected by decorator
|
||||
) -> LogStepResponse:
|
||||
agent_type = detect_agent_type_from_headers()
|
||||
return log_step_impl(
|
||||
return await log_step_impl(
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type=agent_type,
|
||||
step_description=step_description,
|
||||
|
||||
@@ -211,21 +211,23 @@ async def log_step_tool(
|
||||
# Get git diff if enabled
|
||||
git_diff = get_git_diff()
|
||||
|
||||
response = await client.log_step(
|
||||
response = await client.send_message(
|
||||
agent_type=agent_type,
|
||||
step_description=step_description,
|
||||
content=step_description,
|
||||
agent_instance_id=agent_instance_id,
|
||||
requires_user_input=False, # Log steps don't require user input
|
||||
git_diff=git_diff,
|
||||
)
|
||||
|
||||
if current_agent_instance_id is None:
|
||||
current_agent_instance_id = response.agent_instance_id
|
||||
|
||||
# Return LogStepResponse with queued messages as user_feedback
|
||||
return LogStepResponse(
|
||||
success=response.success,
|
||||
agent_instance_id=response.agent_instance_id,
|
||||
step_number=response.step_number,
|
||||
user_feedback=response.user_feedback,
|
||||
step_number=1, # We don't track step numbers anymore
|
||||
user_feedback=response.queued_user_messages,
|
||||
)
|
||||
|
||||
|
||||
@@ -244,6 +246,7 @@ async def ask_question_tool(
|
||||
if not question_text:
|
||||
raise ValueError("question_text is required")
|
||||
|
||||
agent_type = detect_agent_type_from_environment()
|
||||
client = get_client()
|
||||
|
||||
# Get git diff if enabled
|
||||
@@ -253,17 +256,24 @@ async def ask_question_tool(
|
||||
current_agent_instance_id = agent_instance_id
|
||||
|
||||
try:
|
||||
response = await client.ask_question(
|
||||
response = await client.send_message(
|
||||
agent_type=agent_type,
|
||||
agent_instance_id=agent_instance_id,
|
||||
question_text=question_text,
|
||||
content=question_text,
|
||||
requires_user_input=True, # Questions require user input
|
||||
timeout_minutes=1440, # 24 hours default
|
||||
poll_interval=10.0,
|
||||
git_diff=git_diff,
|
||||
)
|
||||
|
||||
# Get the answer from queued_user_messages
|
||||
answer = (
|
||||
response.queued_user_messages[0] if response.queued_user_messages else ""
|
||||
)
|
||||
|
||||
return AskQuestionResponse(
|
||||
answer=response.answer,
|
||||
question_id=response.question_id,
|
||||
answer=answer,
|
||||
question_id=response.message_id,
|
||||
)
|
||||
except OmnaraTimeoutError:
|
||||
raise TimeoutError("Question timed out waiting for user response")
|
||||
@@ -316,10 +326,11 @@ async def approve_tool(
|
||||
instance_id = current_agent_instance_id
|
||||
else:
|
||||
# Only create a new instance if we don't have one
|
||||
response = await client.log_step(
|
||||
response = await client.send_message(
|
||||
agent_type="Claude Code",
|
||||
step_description="Permission request",
|
||||
content="Permission request",
|
||||
agent_instance_id=None,
|
||||
requires_user_input=False,
|
||||
)
|
||||
instance_id = response.agent_instance_id
|
||||
current_agent_instance_id = instance_id
|
||||
@@ -384,15 +395,21 @@ async def approve_tool(
|
||||
|
||||
try:
|
||||
# Ask the permission question
|
||||
answer_response = await client.ask_question(
|
||||
response = await client.send_message(
|
||||
agent_instance_id=instance_id,
|
||||
question_text=question_text,
|
||||
content=question_text,
|
||||
requires_user_input=True,
|
||||
timeout_minutes=1440,
|
||||
poll_interval=10.0,
|
||||
)
|
||||
|
||||
# Parse the answer to determine approval
|
||||
answer = answer_response.answer.strip()
|
||||
# Get the answer from queued_user_messages
|
||||
answer = (
|
||||
response.queued_user_messages[0].strip()
|
||||
if response.queued_user_messages
|
||||
else ""
|
||||
)
|
||||
|
||||
# Handle option selections by comparing with actual option text
|
||||
if answer == option_yes:
|
||||
@@ -432,7 +449,7 @@ async def approve_tool(
|
||||
# Custom text response - treat as denial with message
|
||||
return {
|
||||
"behavior": "deny",
|
||||
"message": f"Permission denied by user: {answer_response.answer}",
|
||||
"message": f"Permission denied by user: {answer}",
|
||||
}
|
||||
|
||||
except OmnaraTimeoutError:
|
||||
|
||||
@@ -5,21 +5,23 @@ the hosted server and stdio server. The authentication logic is handled
|
||||
by the individual servers.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from fastmcp import Context
|
||||
from shared.database.session import get_db
|
||||
|
||||
from servers.shared.db import wait_for_answer
|
||||
from servers.shared.core import (
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
from servers.shared.db import (
|
||||
send_agent_message,
|
||||
end_session,
|
||||
wait_for_answer,
|
||||
create_agent_message,
|
||||
get_or_create_agent_instance,
|
||||
)
|
||||
from .models import AskQuestionResponse, EndSessionResponse, LogStepResponse
|
||||
|
||||
|
||||
def log_step_impl(
|
||||
async def log_step_impl(
|
||||
agent_instance_id: str | None = None,
|
||||
agent_type: str = "",
|
||||
step_description: str = "",
|
||||
@@ -29,20 +31,25 @@ def log_step_impl(
|
||||
|
||||
Args:
|
||||
agent_instance_id: Existing agent instance ID (optional)
|
||||
agent_type: Name of the agent (e.g., 'Claude Code', 'Cursor')
|
||||
agent_type: Name of the agent (e.g., 'claude_code', 'cursor')
|
||||
step_description: High-level description of the current step
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
LogStepResponse with success status, instance details, and user feedback
|
||||
"""
|
||||
if agent_instance_id:
|
||||
# Generate a new UUID if agent_instance_id is not provided
|
||||
if not agent_instance_id:
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
else:
|
||||
# Validate the provided UUID
|
||||
try:
|
||||
UUID(agent_instance_id)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid agent_instance_id format: must be a valid UUID, got '{agent_instance_id}'"
|
||||
)
|
||||
|
||||
if not agent_type:
|
||||
raise ValueError("agent_type is required")
|
||||
if not step_description:
|
||||
@@ -53,19 +60,37 @@ def log_step_impl(
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
instance_id, step_number, user_feedback = process_log_step(
|
||||
# Use send_agent_message for steps (requires_user_input=False)
|
||||
instance_id, message_id, queued_messages = await send_agent_message(
|
||||
db=db,
|
||||
agent_type=agent_type,
|
||||
step_description=step_description,
|
||||
user_id=user_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
content=step_description,
|
||||
user_id=user_id,
|
||||
agent_type=agent_type,
|
||||
requires_user_input=False,
|
||||
)
|
||||
|
||||
# For backward compatibility, we need to return a step number
|
||||
# Count the number of agent messages (steps) for this instance
|
||||
from shared.database import Message, SenderType
|
||||
|
||||
step_count = (
|
||||
db.query(Message)
|
||||
.filter(
|
||||
Message.agent_instance_id == UUID(instance_id),
|
||||
Message.sender_type == SenderType.AGENT,
|
||||
Message.requires_user_input.is_(False),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
return LogStepResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
step_number=step_number,
|
||||
user_feedback=user_feedback,
|
||||
step_number=step_count,
|
||||
user_feedback=[msg.content for msg in queued_messages],
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@@ -108,13 +133,21 @@ async def ask_question_impl(
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
question = await create_agent_question(
|
||||
# Validate access first (agent_instance_id is required here)
|
||||
instance = get_or_create_agent_instance(db, agent_instance_id, user_id)
|
||||
|
||||
# Create question message (requires_user_input=True)
|
||||
question = create_agent_message(
|
||||
db=db,
|
||||
agent_instance_id=agent_instance_id,
|
||||
question_text=question_text,
|
||||
user_id=user_id,
|
||||
instance_id=instance.id,
|
||||
content=question_text,
|
||||
requires_user_input=True,
|
||||
)
|
||||
|
||||
# Commit to make the question visible
|
||||
db.commit()
|
||||
|
||||
# Wait for answer
|
||||
answer = await wait_for_answer(db, question.id, tool_context=tool_context)
|
||||
|
||||
if answer is None:
|
||||
@@ -156,12 +189,15 @@ def end_session_impl(
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
instance_id, final_status = process_end_session(
|
||||
instance_id, final_status = end_session(
|
||||
db=db,
|
||||
agent_instance_id=agent_instance_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db.commit()
|
||||
|
||||
return EndSessionResponse(
|
||||
success=True,
|
||||
agent_instance_id=instance_id,
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Core business logic shared between servers."""
|
||||
|
||||
from .agents import (
|
||||
validate_agent_access,
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"validate_agent_access",
|
||||
"process_log_step",
|
||||
"create_agent_question",
|
||||
"process_end_session",
|
||||
]
|
||||
@@ -1,167 +0,0 @@
|
||||
"""Shared business logic for agent operations.
|
||||
|
||||
This module contains the common logic used by both MCP and FastAPI servers,
|
||||
avoiding code duplication while allowing protocol-specific implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from servers.shared.db import (
|
||||
get_agent_instance,
|
||||
create_agent_instance,
|
||||
log_step,
|
||||
create_question,
|
||||
get_and_mark_unretrieved_feedback,
|
||||
create_or_get_user_agent,
|
||||
end_session,
|
||||
)
|
||||
from shared.database.utils import sanitize_git_diff
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_agent_access(db: Session, agent_instance_id: str, user_id: str):
|
||||
"""Validate that a user has access to an agent instance.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID to validate
|
||||
user_id: User ID requesting access
|
||||
|
||||
Returns:
|
||||
The agent instance if validation passes
|
||||
|
||||
Raises:
|
||||
ValueError: If instance not found or user doesn't have access
|
||||
"""
|
||||
instance = get_agent_instance(db, agent_instance_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Agent instance {agent_instance_id} not found")
|
||||
if str(instance.user_id) != user_id:
|
||||
raise ValueError(
|
||||
"Access denied. Agent instance does not belong to authenticated user."
|
||||
)
|
||||
return instance
|
||||
|
||||
|
||||
def process_log_step(
|
||||
db: Session,
|
||||
agent_type: str,
|
||||
step_description: str,
|
||||
user_id: str,
|
||||
agent_instance_id: str | None = None,
|
||||
send_email: bool | None = None,
|
||||
send_sms: bool | None = None,
|
||||
send_push: bool | None = None,
|
||||
git_diff: str | None = None,
|
||||
) -> tuple[str, int, list[str]]:
|
||||
"""Process a log step operation with all common logic.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_type: Type of agent
|
||||
step_description: Description of the step
|
||||
user_id: Authenticated user ID
|
||||
agent_instance_id: Optional existing instance ID
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_instance_id, step_number, user_feedback)
|
||||
"""
|
||||
# Get or create user agent type
|
||||
agent_type_obj = create_or_get_user_agent(db, agent_type, user_id)
|
||||
|
||||
# Get or create instance
|
||||
if agent_instance_id:
|
||||
instance = validate_agent_access(db, agent_instance_id, user_id)
|
||||
else:
|
||||
instance = create_agent_instance(db, agent_type_obj.id, user_id)
|
||||
|
||||
# Create step with notification preferences
|
||||
step = log_step(db, instance.id, step_description, send_email, send_sms, send_push)
|
||||
|
||||
# Update git diff if provided
|
||||
if git_diff is not None:
|
||||
# Validate and sanitize git diff
|
||||
sanitized_diff = sanitize_git_diff(git_diff)
|
||||
if sanitized_diff is not None: # Allow empty string (cleared diff)
|
||||
instance.git_diff = sanitized_diff
|
||||
db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid git diff format for instance {instance.id}, skipping git diff update"
|
||||
)
|
||||
|
||||
# Get unretrieved feedback
|
||||
feedback = get_and_mark_unretrieved_feedback(db, instance.id)
|
||||
|
||||
return str(instance.id), step.step_number, feedback
|
||||
|
||||
|
||||
async def create_agent_question(
|
||||
db: Session,
|
||||
agent_instance_id: str,
|
||||
question_text: str,
|
||||
user_id: str,
|
||||
send_email: bool | None = None,
|
||||
send_sms: bool | None = None,
|
||||
send_push: bool | None = None,
|
||||
git_diff: str | None = None,
|
||||
):
|
||||
"""Create a question with validation and send push notification.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID
|
||||
question_text: Question to ask
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
The created question object
|
||||
"""
|
||||
# Validate access
|
||||
instance = validate_agent_access(db, agent_instance_id, user_id)
|
||||
|
||||
# Update git diff if provided
|
||||
if git_diff is not None:
|
||||
# Validate and sanitize git diff
|
||||
sanitized_diff = sanitize_git_diff(git_diff)
|
||||
if sanitized_diff is not None: # Allow empty string (cleared diff)
|
||||
instance.git_diff = sanitized_diff
|
||||
db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid git diff format for instance {instance.id}, skipping git diff update"
|
||||
)
|
||||
|
||||
# Create question
|
||||
# Note: Notifications sent by create_question() function based on parameters
|
||||
question = await create_question(
|
||||
db, instance.id, question_text, send_email, send_sms, send_push
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
|
||||
def process_end_session(
|
||||
db: Session,
|
||||
agent_instance_id: str,
|
||||
user_id: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Process ending a session with validation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID to end
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_instance_id, final_status)
|
||||
"""
|
||||
# Validate access
|
||||
instance = validate_agent_access(db, agent_instance_id, user_id)
|
||||
|
||||
# End the session
|
||||
updated_instance = end_session(db, instance.id)
|
||||
|
||||
return str(updated_instance.id), updated_instance.status.value
|
||||
@@ -1,25 +1,31 @@
|
||||
"""Database queries and operations for servers."""
|
||||
|
||||
from .queries import (
|
||||
# Low-level functions
|
||||
create_agent_instance,
|
||||
create_or_get_user_agent,
|
||||
create_question,
|
||||
create_agent_message,
|
||||
create_user_message,
|
||||
end_session,
|
||||
get_agent_instance,
|
||||
get_and_mark_unretrieved_feedback,
|
||||
get_question,
|
||||
log_step,
|
||||
get_queued_user_messages,
|
||||
get_or_create_agent_instance,
|
||||
wait_for_answer,
|
||||
# High-level functions
|
||||
send_agent_message,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Low-level functions
|
||||
"create_agent_instance",
|
||||
"create_or_get_user_agent",
|
||||
"create_question",
|
||||
"create_agent_message",
|
||||
"create_user_message",
|
||||
"end_session",
|
||||
"get_agent_instance",
|
||||
"get_and_mark_unretrieved_feedback",
|
||||
"get_question",
|
||||
"log_step",
|
||||
"get_queued_user_messages",
|
||||
"get_or_create_agent_instance",
|
||||
"wait_for_answer",
|
||||
# High-level functions
|
||||
"send_agent_message",
|
||||
]
|
||||
|
||||
@@ -6,19 +6,15 @@ from uuid import UUID
|
||||
|
||||
from shared.database import (
|
||||
AgentInstance,
|
||||
AgentQuestion,
|
||||
AgentStatus,
|
||||
AgentStep,
|
||||
AgentUserFeedback,
|
||||
Message,
|
||||
SenderType,
|
||||
UserAgent,
|
||||
User,
|
||||
)
|
||||
from shared.database.billing_operations import check_agent_limit
|
||||
from sqlalchemy import func
|
||||
from shared.database.utils import sanitize_git_diff
|
||||
from sqlalchemy.orm import Session
|
||||
from fastmcp import Context
|
||||
from servers.shared.notifications import push_service
|
||||
from servers.shared.twilio_service import twilio_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,8 +36,7 @@ def create_or_get_user_agent(db: Session, name: str, user_id: str) -> UserAgent:
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user_agent)
|
||||
db.commit()
|
||||
db.refresh(user_agent)
|
||||
db.flush() # Flush to get the user_agent ID
|
||||
return user_agent
|
||||
|
||||
|
||||
@@ -56,8 +51,6 @@ def create_agent_instance(
|
||||
user_agent_id=user_agent_id, user_id=UUID(user_id), status=AgentStatus.ACTIVE
|
||||
)
|
||||
db.add(instance)
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
@@ -66,215 +59,153 @@ def get_agent_instance(db: Session, instance_id: str) -> AgentInstance | None:
|
||||
return db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
|
||||
|
||||
def log_step(
|
||||
db: Session,
|
||||
instance_id: UUID,
|
||||
description: str,
|
||||
send_email: bool | None = None,
|
||||
send_sms: bool | None = None,
|
||||
send_push: bool | None = None,
|
||||
) -> AgentStep:
|
||||
"""Log a new step for an agent instance"""
|
||||
# Get the next step number
|
||||
max_step = (
|
||||
db.query(func.max(AgentStep.step_number))
|
||||
.filter(AgentStep.agent_instance_id == instance_id)
|
||||
.scalar()
|
||||
)
|
||||
next_step_number = (max_step or 0) + 1
|
||||
def get_or_create_agent_instance(
|
||||
db: Session, agent_instance_id: str, user_id: str, agent_type: str | None = None
|
||||
) -> AgentInstance:
|
||||
"""Get an existing agent instance or create a new one.
|
||||
|
||||
# Create the step
|
||||
step = AgentStep(
|
||||
agent_instance_id=instance_id,
|
||||
step_number=next_step_number,
|
||||
description=description,
|
||||
)
|
||||
db.add(step)
|
||||
db.commit()
|
||||
db.refresh(step)
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID (always required)
|
||||
user_id: User ID requesting access
|
||||
agent_type: Agent type name (required only when creating new instance)
|
||||
|
||||
# Send notifications if requested (all default to False for log steps)
|
||||
if send_email or send_sms or send_push:
|
||||
# Get instance details for notifications
|
||||
instance = (
|
||||
db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
)
|
||||
if instance:
|
||||
user = db.query(User).filter(User.id == instance.user_id).first()
|
||||
Returns:
|
||||
The agent instance (existing or newly created)
|
||||
|
||||
if user:
|
||||
agent_name = (
|
||||
instance.user_agent.name if instance.user_agent else "Agent"
|
||||
)
|
||||
Raises:
|
||||
ValueError: If instance not found, user doesn't have access, or agent_type missing when creating
|
||||
"""
|
||||
# Try to get existing instance
|
||||
instance = get_agent_instance(db, agent_instance_id)
|
||||
|
||||
# Override defaults - for log steps, all notifications default to False
|
||||
should_send_push = send_push if send_push is not None else False
|
||||
should_send_email = send_email if send_email is not None else False
|
||||
should_send_sms = send_sms if send_sms is not None else False
|
||||
|
||||
# Send push notification if explicitly enabled
|
||||
if should_send_push:
|
||||
try:
|
||||
asyncio.create_task(
|
||||
push_service.send_step_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
step_number=step.step_number,
|
||||
agent_name=agent_name,
|
||||
step_description=description,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send push notification for step {step.id}: {e}"
|
||||
)
|
||||
|
||||
# Send Twilio notifications if explicitly enabled
|
||||
if should_send_email or should_send_sms:
|
||||
try:
|
||||
asyncio.create_task(
|
||||
twilio_service.send_step_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
step_number=step.step_number,
|
||||
agent_name=agent_name,
|
||||
step_description=description,
|
||||
send_email=should_send_email,
|
||||
send_sms=should_send_sms,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send Twilio notification for step {step.id}: {e}"
|
||||
)
|
||||
|
||||
return step
|
||||
|
||||
|
||||
async def create_question(
|
||||
db: Session,
|
||||
instance_id: UUID,
|
||||
question_text: str,
|
||||
send_email: bool | None = None,
|
||||
send_sms: bool | None = None,
|
||||
send_push: bool | None = None,
|
||||
) -> AgentQuestion:
|
||||
"""Create a new question for an agent instance"""
|
||||
# Mark any existing active questions as inactive
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
|
||||
).update({"is_active": False})
|
||||
|
||||
# Update agent instance status to awaiting_input
|
||||
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
if instance and instance.status == AgentStatus.ACTIVE:
|
||||
instance.status = AgentStatus.AWAITING_INPUT
|
||||
|
||||
# Create new question
|
||||
question = AgentQuestion(
|
||||
agent_instance_id=instance_id, question_text=question_text, is_active=True
|
||||
)
|
||||
db.add(question)
|
||||
db.commit()
|
||||
db.refresh(question)
|
||||
|
||||
# Send notifications based on user preferences
|
||||
if instance:
|
||||
# Get user for checking preferences
|
||||
user = db.query(User).filter(User.id == instance.user_id).first()
|
||||
|
||||
if user:
|
||||
agent_name = instance.user_agent.name if instance.user_agent else "Agent"
|
||||
|
||||
# Determine notification preferences
|
||||
# For questions: push defaults to True (or user preference), email/SMS default to False
|
||||
should_send_push = (
|
||||
send_push if send_push is not None else user.push_notifications_enabled
|
||||
)
|
||||
should_send_email = (
|
||||
send_email
|
||||
if send_email is not None
|
||||
else user.email_notifications_enabled
|
||||
)
|
||||
should_send_sms = (
|
||||
send_sms if send_sms is not None else user.sms_notifications_enabled
|
||||
# Validate access to existing instance
|
||||
if str(instance.user_id) != user_id:
|
||||
raise ValueError(
|
||||
"Access denied. Agent instance does not belong to authenticated user."
|
||||
)
|
||||
return instance
|
||||
else:
|
||||
# Create new instance with the provided ID
|
||||
if not agent_type:
|
||||
raise ValueError("agent_type is required when creating new instance")
|
||||
|
||||
# Send push notification if enabled
|
||||
if should_send_push:
|
||||
try:
|
||||
await push_service.send_question_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
question_id=str(question.id),
|
||||
agent_name=agent_name,
|
||||
question_text=question_text,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send push notification for question {question.id}: {e}"
|
||||
)
|
||||
agent_type_obj = create_or_get_user_agent(db, agent_type, user_id)
|
||||
|
||||
# Send Twilio notification if enabled (email and/or SMS)
|
||||
if should_send_email or should_send_sms:
|
||||
try:
|
||||
twilio_service.send_question_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
question_id=str(question.id),
|
||||
agent_name=agent_name,
|
||||
question_text=question_text,
|
||||
send_email=should_send_email,
|
||||
send_sms=should_send_sms,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send Twilio notification for question {question.id}: {e}"
|
||||
)
|
||||
# Check usage limits if billing is enabled
|
||||
check_agent_limit(UUID(user_id), db)
|
||||
|
||||
return question
|
||||
# Create instance with the specific ID
|
||||
instance = AgentInstance(
|
||||
id=UUID(agent_instance_id),
|
||||
user_agent_id=agent_type_obj.id,
|
||||
user_id=UUID(user_id),
|
||||
status=AgentStatus.ACTIVE,
|
||||
)
|
||||
db.add(instance)
|
||||
db.flush() # Flush to ensure the instance is in the session with its ID
|
||||
return instance
|
||||
|
||||
|
||||
def end_session(db: Session, agent_instance_id: str, user_id: str) -> tuple[str, str]:
|
||||
"""End an agent session by marking it as completed.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID to end
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_instance_id, final_status)
|
||||
"""
|
||||
instance = get_or_create_agent_instance(db, agent_instance_id, user_id)
|
||||
|
||||
instance.status = AgentStatus.COMPLETED
|
||||
instance.ended_at = datetime.now(timezone.utc)
|
||||
|
||||
return str(instance.id), instance.status.value
|
||||
|
||||
|
||||
def create_agent_message(
|
||||
db: Session,
|
||||
instance_id: UUID,
|
||||
content: str,
|
||||
requires_user_input: bool = False,
|
||||
) -> Message:
|
||||
"""Create a new agent message without committing"""
|
||||
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
if instance and instance.status != AgentStatus.COMPLETED:
|
||||
if requires_user_input:
|
||||
instance.status = AgentStatus.AWAITING_INPUT
|
||||
else:
|
||||
instance.status = AgentStatus.ACTIVE
|
||||
|
||||
message = Message(
|
||||
agent_instance_id=instance_id,
|
||||
sender_type=SenderType.AGENT,
|
||||
content=content,
|
||||
requires_user_input=requires_user_input,
|
||||
)
|
||||
db.add(message)
|
||||
db.flush() # Flush to get the message ID
|
||||
|
||||
# Update last read message
|
||||
if instance:
|
||||
instance.last_read_message_id = message.id
|
||||
|
||||
return message
|
||||
|
||||
|
||||
async def wait_for_answer(
|
||||
db: Session,
|
||||
question_id: UUID,
|
||||
timeout: int = 86400,
|
||||
timeout_seconds: int = 86400, # 24 hours default
|
||||
tool_context: Context | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Wait for an answer to a question (async non-blocking)
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
question_id: Question ID to wait for
|
||||
timeout: Maximum time to wait in seconds (default 24 hours)
|
||||
|
||||
Returns:
|
||||
Answer text if received, None if timeout
|
||||
"""
|
||||
"""Wait for an answer to a question using polling"""
|
||||
start_time = time.time()
|
||||
last_progress_report = start_time
|
||||
total_minutes = int(timeout / 60)
|
||||
total_minutes = timeout_seconds // 60
|
||||
|
||||
# Report initial progress (0 minutes elapsed)
|
||||
if tool_context:
|
||||
await tool_context.report_progress(0, total_minutes)
|
||||
# Get the question message
|
||||
question = db.query(Message).filter(Message.id == question_id).first()
|
||||
if not question or not question.requires_user_input:
|
||||
return None
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check for answer
|
||||
db.commit() # Ensure we see latest data
|
||||
question = (
|
||||
db.query(AgentQuestion).filter(AgentQuestion.id == question_id).first()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
# Check if agent has moved on (last read message changed)
|
||||
instance = (
|
||||
db.query(AgentInstance)
|
||||
.filter(AgentInstance.id == question.agent_instance_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if question and question.answer_text is not None:
|
||||
# If last_read_message_id has changed from our question, agent has moved on
|
||||
if instance and instance.last_read_message_id != question_id:
|
||||
return None
|
||||
|
||||
# Check for a user message after this question
|
||||
answer = (
|
||||
db.query(Message)
|
||||
.filter(
|
||||
Message.agent_instance_id == question.agent_instance_id,
|
||||
Message.sender_type == SenderType.USER,
|
||||
Message.created_at > question.created_at,
|
||||
)
|
||||
.order_by(Message.created_at)
|
||||
.first()
|
||||
)
|
||||
|
||||
if answer:
|
||||
# Update last read message to this answer
|
||||
if instance:
|
||||
instance.last_read_message_id = answer.id
|
||||
|
||||
if tool_context:
|
||||
await tool_context.report_progress(total_minutes, total_minutes)
|
||||
return question.answer_text
|
||||
|
||||
return answer.content
|
||||
|
||||
# Report progress every minute if tool_context is provided
|
||||
current_time = time.time()
|
||||
@@ -285,59 +216,178 @@ async def wait_for_answer(
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Timeout - mark question as inactive
|
||||
db.query(AgentQuestion).filter(AgentQuestion.id == question_id).update(
|
||||
{"is_active": False}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_question(db: Session, question_id: str) -> AgentQuestion | None:
|
||||
"""Get a question by ID"""
|
||||
return db.query(AgentQuestion).filter(AgentQuestion.id == question_id).first()
|
||||
def get_queued_user_messages(
|
||||
db: Session, instance_id: UUID, last_read_message_id: UUID | None = None
|
||||
) -> list[Message] | None:
|
||||
"""Get all user messages since the agent last read them.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
instance_id: Agent instance ID
|
||||
last_read_message_id: The message ID the agent last read (optional)
|
||||
|
||||
Returns:
|
||||
- None if last_read_message_id doesn't match the instance's current last_read_message_id
|
||||
- Empty list if no new messages
|
||||
- List of messages if there are new user messages
|
||||
"""
|
||||
# Get the agent instance to check last read message
|
||||
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
if not instance:
|
||||
return []
|
||||
|
||||
if (
|
||||
last_read_message_id is not None
|
||||
and instance.last_read_message_id != last_read_message_id
|
||||
):
|
||||
return None
|
||||
|
||||
# If no last read message, get all user messages
|
||||
if not instance.last_read_message_id:
|
||||
messages = (
|
||||
db.query(Message)
|
||||
.filter(
|
||||
Message.agent_instance_id == instance_id,
|
||||
Message.sender_type == SenderType.USER,
|
||||
)
|
||||
.order_by(Message.created_at)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
last_read_message = (
|
||||
db.query(Message)
|
||||
.filter(Message.id == instance.last_read_message_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not last_read_message:
|
||||
return []
|
||||
|
||||
# Get all user messages after the last read message
|
||||
messages = (
|
||||
db.query(Message)
|
||||
.filter(
|
||||
Message.agent_instance_id == instance_id,
|
||||
Message.sender_type == SenderType.USER,
|
||||
Message.created_at > last_read_message.created_at,
|
||||
)
|
||||
.order_by(Message.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
if messages and last_read_message_id is not None:
|
||||
instance.last_read_message_id = messages[-1].id
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def get_and_mark_unretrieved_feedback(
|
||||
db: Session, instance_id: UUID, since_time: datetime | None = None
|
||||
) -> list[str]:
|
||||
"""Get unretrieved user feedback for an agent instance and mark as retrieved"""
|
||||
async def send_agent_message(
|
||||
db: Session,
|
||||
agent_instance_id: str,
|
||||
content: str,
|
||||
user_id: str,
|
||||
agent_type: str | None = None,
|
||||
requires_user_input: bool = False,
|
||||
git_diff: str | None = None,
|
||||
) -> tuple[str, str, list[Message]]:
|
||||
"""High-level function to send an agent message and get queued user messages.
|
||||
|
||||
query = db.query(AgentUserFeedback).filter(
|
||||
AgentUserFeedback.agent_instance_id == instance_id,
|
||||
AgentUserFeedback.retrieved_at.is_(None),
|
||||
This combines the common pattern of:
|
||||
1. Getting or creating an agent instance
|
||||
2. Validating access (if existing instance)
|
||||
3. Creating a message
|
||||
4. Updating git diff if provided
|
||||
5. Getting any queued user messages
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID (pass None to create new)
|
||||
content: Message content
|
||||
user_id: Authenticated user ID
|
||||
agent_type: Type of agent (required if creating new instance)
|
||||
requires_user_input: Whether this is a question requiring response
|
||||
git_diff: Optional git diff to update on the instance
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_instance_id, message_id, list of queued user message contents)
|
||||
"""
|
||||
# Get or create instance using the unified function
|
||||
instance = get_or_create_agent_instance(db, agent_instance_id, user_id, agent_type)
|
||||
|
||||
# Update git diff if provided (but don't commit yet)
|
||||
if git_diff is not None:
|
||||
sanitized_diff = sanitize_git_diff(git_diff)
|
||||
if sanitized_diff is not None: # Allow empty string (cleared diff)
|
||||
instance.git_diff = sanitized_diff
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid git diff format for instance {instance.id}, skipping git diff update"
|
||||
)
|
||||
|
||||
queued_messages = get_queued_user_messages(db, instance.id, None)
|
||||
|
||||
# Create the message (this will update last_read_message_id)
|
||||
message = create_agent_message(
|
||||
db=db,
|
||||
instance_id=instance.id,
|
||||
content=content,
|
||||
requires_user_input=requires_user_input,
|
||||
)
|
||||
|
||||
if since_time:
|
||||
query = query.filter(AgentUserFeedback.created_at > since_time)
|
||||
# Handle the None case (shouldn't happen here since we just created the message)
|
||||
if queued_messages is None:
|
||||
queued_messages = []
|
||||
|
||||
feedback_list = query.order_by(AgentUserFeedback.created_at).all()
|
||||
|
||||
# Mark all feedback as retrieved
|
||||
for feedback in feedback_list:
|
||||
feedback.retrieved_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return [feedback.feedback_text for feedback in feedback_list]
|
||||
return str(instance.id), str(message.id), queued_messages
|
||||
|
||||
|
||||
def end_session(db: Session, instance_id: UUID) -> AgentInstance:
|
||||
"""End an agent session by marking it as completed"""
|
||||
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
def create_user_message(
|
||||
db: Session,
|
||||
agent_instance_id: str,
|
||||
content: str,
|
||||
user_id: str,
|
||||
mark_as_read: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Create a user message for an agent instance.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
agent_instance_id: Agent instance ID to send the message to
|
||||
content: Message content
|
||||
user_id: Authenticated user ID
|
||||
mark_as_read: Whether to update last_read_message_id (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple of (message_id, marked_as_read)
|
||||
|
||||
Raises:
|
||||
ValueError: If instance not found or user doesn't have access
|
||||
"""
|
||||
# Get the instance and validate access
|
||||
instance = get_agent_instance(db, agent_instance_id)
|
||||
if not instance:
|
||||
raise ValueError(f"Agent instance {instance_id} not found")
|
||||
raise ValueError("Agent instance not found")
|
||||
|
||||
# Update status to completed
|
||||
instance.status = AgentStatus.COMPLETED
|
||||
instance.ended_at = datetime.now(timezone.utc)
|
||||
if str(instance.user_id) != user_id:
|
||||
raise ValueError(
|
||||
"Access denied. Agent instance does not belong to authenticated user."
|
||||
)
|
||||
|
||||
# Mark any active questions as inactive
|
||||
db.query(AgentQuestion).filter(
|
||||
AgentQuestion.agent_instance_id == instance_id, AgentQuestion.is_active
|
||||
).update({"is_active": False})
|
||||
# Create the user message
|
||||
message = Message(
|
||||
agent_instance_id=UUID(agent_instance_id),
|
||||
sender_type=SenderType.USER,
|
||||
content=content,
|
||||
requires_user_input=False,
|
||||
)
|
||||
db.add(message)
|
||||
db.flush() # Get the message ID
|
||||
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
return instance
|
||||
# Update last_read_message_id if requested
|
||||
if mark_as_read:
|
||||
instance.last_read_message_id = message.id
|
||||
|
||||
return str(message.id), mark_as_read
|
||||
|
||||
@@ -27,7 +27,6 @@ class NotificationServiceBase(ABC):
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
instance_id: str,
|
||||
question_id: str,
|
||||
agent_name: str,
|
||||
question_text: str,
|
||||
**kwargs,
|
||||
@@ -41,7 +40,6 @@ class NotificationServiceBase(ABC):
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
instance_id: str,
|
||||
step_number: int,
|
||||
agent_name: str,
|
||||
step_description: str,
|
||||
**kwargs,
|
||||
|
||||
111
servers/shared/notification_utils.py
Normal file
111
servers/shared/notification_utils.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Notification utilities for sending push, email, and SMS notifications."""
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from shared.database import User, AgentInstance
|
||||
from .notifications import push_service
|
||||
from .twilio_service import twilio_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_message_notifications(
|
||||
db: Session,
|
||||
instance_id: UUID,
|
||||
content: str,
|
||||
requires_user_input: bool,
|
||||
send_email: bool | None = None,
|
||||
send_sms: bool | None = None,
|
||||
send_push: bool | None = None,
|
||||
) -> None:
|
||||
"""Send notifications for a message (either step or question).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
instance_id: Agent instance ID
|
||||
content: Message content
|
||||
requires_user_input: Whether this message requires user input
|
||||
send_email: Override email notification preference
|
||||
send_sms: Override SMS notification preference
|
||||
send_push: Override push notification preference
|
||||
"""
|
||||
# Get instance and user
|
||||
instance = db.query(AgentInstance).filter(AgentInstance.id == instance_id).first()
|
||||
if not instance:
|
||||
logger.warning(f"Instance {instance_id} not found for notifications")
|
||||
return
|
||||
|
||||
user = db.query(User).filter(User.id == instance.user_id).first()
|
||||
if not user:
|
||||
logger.warning(f"User {instance.user_id} not found for notifications")
|
||||
return
|
||||
|
||||
agent_name = instance.user_agent.name if instance.user_agent else "Agent"
|
||||
|
||||
# Determine notification preferences based on message type
|
||||
if requires_user_input:
|
||||
# For questions: respect user preferences
|
||||
should_send_push = (
|
||||
send_push if send_push is not None else user.push_notifications_enabled
|
||||
)
|
||||
should_send_email = (
|
||||
send_email if send_email is not None else user.email_notifications_enabled
|
||||
)
|
||||
should_send_sms = (
|
||||
send_sms if send_sms is not None else user.sms_notifications_enabled
|
||||
)
|
||||
else:
|
||||
# For steps: notifications default to False unless explicitly enabled
|
||||
should_send_push = send_push if send_push is not None else False
|
||||
should_send_email = send_email if send_email is not None else False
|
||||
should_send_sms = send_sms if send_sms is not None else False
|
||||
|
||||
# Send push notification if enabled
|
||||
if should_send_push:
|
||||
try:
|
||||
if requires_user_input:
|
||||
await push_service.send_question_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
agent_name=agent_name,
|
||||
question_text=content,
|
||||
)
|
||||
else:
|
||||
await push_service.send_step_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
agent_name=agent_name,
|
||||
step_description=content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send push notification: {e}")
|
||||
|
||||
# Send Twilio notifications if enabled
|
||||
if should_send_email or should_send_sms:
|
||||
try:
|
||||
if requires_user_input:
|
||||
await twilio_service.send_question_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
agent_name=agent_name,
|
||||
question_text=content,
|
||||
send_email=should_send_email,
|
||||
send_sms=should_send_sms,
|
||||
)
|
||||
else:
|
||||
await twilio_service.send_step_notification(
|
||||
db=db,
|
||||
user_id=instance.user_id,
|
||||
instance_id=str(instance.id),
|
||||
agent_name=agent_name,
|
||||
step_description=content,
|
||||
send_email=should_send_email,
|
||||
send_sms=should_send_sms,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Twilio notification: {e}")
|
||||
@@ -189,7 +189,6 @@ class PushNotificationService(NotificationServiceBase):
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
instance_id: str,
|
||||
question_id: str,
|
||||
agent_name: str,
|
||||
question_text: str,
|
||||
) -> bool:
|
||||
@@ -206,7 +205,6 @@ class PushNotificationService(NotificationServiceBase):
|
||||
data = {
|
||||
"type": "new_question",
|
||||
"instanceId": instance_id,
|
||||
"questionId": question_id,
|
||||
}
|
||||
|
||||
return await self.send_notification(
|
||||
@@ -222,14 +220,13 @@ class PushNotificationService(NotificationServiceBase):
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
instance_id: str,
|
||||
step_number: int,
|
||||
agent_name: str,
|
||||
step_description: str,
|
||||
) -> bool:
|
||||
"""Send notification for new agent step"""
|
||||
# Format agent name for display
|
||||
display_name = agent_name.replace("_", " ").title()
|
||||
title = f"{display_name} - Step {step_number}"
|
||||
title = f"{display_name} - New Step"
|
||||
|
||||
# Truncate step description for notification
|
||||
body = step_description
|
||||
@@ -239,7 +236,6 @@ class PushNotificationService(NotificationServiceBase):
|
||||
data = {
|
||||
"type": "new_step",
|
||||
"instanceId": instance_id,
|
||||
"stepNumber": step_number,
|
||||
}
|
||||
|
||||
return await self.send_notification(
|
||||
|
||||
@@ -172,12 +172,11 @@ class TwilioNotificationService(NotificationServiceBase):
|
||||
logger.error(f"Error sending Twilio notifications: {e}")
|
||||
return results
|
||||
|
||||
def send_question_notification(
|
||||
async def send_question_notification(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
instance_id: str,
|
||||
question_id: str,
|
||||
agent_name: str,
|
||||
question_text: str,
|
||||
send_email: bool | None = None,
|
||||
@@ -218,7 +217,6 @@ The Omnara Team
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
instance_id: str,
|
||||
step_number: int,
|
||||
agent_name: str,
|
||||
step_description: str,
|
||||
send_email: bool | None = None,
|
||||
@@ -227,13 +225,13 @@ The Omnara Team
|
||||
"""Send notification for new agent step"""
|
||||
# Format agent name for display
|
||||
display_name = agent_name.replace("_", " ").title()
|
||||
title = f"{display_name} - Step {step_number}"
|
||||
title = f"{display_name} - New Step"
|
||||
|
||||
# Email body with more detail
|
||||
email_body = f"""
|
||||
Your agent {display_name} has logged a new step:
|
||||
|
||||
Step {step_number}: {step_description}
|
||||
{step_description}
|
||||
|
||||
You can view the full session at: https://omnara.com/dashboard/instances/{instance_id}
|
||||
|
||||
@@ -242,9 +240,9 @@ The Omnara Team
|
||||
"""
|
||||
|
||||
# SMS body (shorter)
|
||||
sms_body = f"{display_name} Step {step_number}: {step_description}"
|
||||
sms_body = f"{display_name}: {step_description}"
|
||||
if len(sms_body) > 160:
|
||||
sms_body = f"{display_name} Step {step_number}: {step_description[:140]}..."
|
||||
sms_body = f"{display_name}: {step_description[:140]}..."
|
||||
|
||||
return self.send_notification(
|
||||
db=db,
|
||||
|
||||
1204
servers/tests/test_db_queries.py
Normal file
1204
servers/tests/test_db_queries.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,274 +0,0 @@
|
||||
"""Integration tests using PostgreSQL."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
# Database fixtures come from conftest.py
|
||||
|
||||
# Import the real models
|
||||
from shared.database.models import (
|
||||
User,
|
||||
UserAgent,
|
||||
AgentInstance,
|
||||
AgentStep,
|
||||
AgentQuestion,
|
||||
AgentUserFeedback,
|
||||
)
|
||||
from shared.database.enums import AgentStatus
|
||||
|
||||
# Import the core functions we want to test
|
||||
from servers.shared.core import (
|
||||
process_log_step,
|
||||
create_agent_question,
|
||||
process_end_session,
|
||||
)
|
||||
|
||||
|
||||
# Using test_db fixture from conftest.py which provides PostgreSQL via testcontainers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(test_db):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
email="integration@test.com",
|
||||
display_name="Integration Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(user)
|
||||
test_db.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_agent(test_db, test_user):
|
||||
"""Create a test user agent."""
|
||||
user_agent = UserAgent(
|
||||
id=uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="Claude Code Test",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(user_agent)
|
||||
test_db.commit()
|
||||
return user_agent
|
||||
|
||||
|
||||
class TestIntegrationFlow:
|
||||
"""Test the complete integration flow with PostgreSQL."""
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_agent_session_flow(
|
||||
self, test_db, test_user, test_user_agent
|
||||
):
|
||||
"""Test a complete agent session from start to finish."""
|
||||
# Step 1: Create new agent instance
|
||||
instance_id, step_number, user_feedback = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Starting integration test task",
|
||||
)
|
||||
|
||||
assert instance_id is not None
|
||||
assert step_number == 1
|
||||
assert user_feedback == []
|
||||
|
||||
# Verify instance was created in database
|
||||
instance = test_db.query(AgentInstance).filter_by(id=instance_id).first()
|
||||
assert instance is not None
|
||||
assert instance.status == AgentStatus.ACTIVE
|
||||
assert instance.user_id == test_user.id
|
||||
|
||||
# Step 2: Log another step
|
||||
_, step_number2, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Processing files",
|
||||
)
|
||||
|
||||
assert step_number2 == 2
|
||||
|
||||
# Step 3: Create a question
|
||||
question = await create_agent_question(
|
||||
db=test_db,
|
||||
agent_instance_id=instance_id,
|
||||
question_text="Should I refactor this module?",
|
||||
user_id=str(test_user.id),
|
||||
)
|
||||
|
||||
assert question is not None
|
||||
question_id = question.id
|
||||
|
||||
# Verify question in database
|
||||
question = test_db.query(AgentQuestion).filter_by(id=question_id).first()
|
||||
assert question is not None
|
||||
assert question.question_text == "Should I refactor this module?"
|
||||
assert question.is_active is True
|
||||
|
||||
# Step 4: Add user feedback
|
||||
feedback = AgentUserFeedback(
|
||||
id=uuid4(),
|
||||
agent_instance_id=instance_id,
|
||||
created_by_user_id=test_user.id,
|
||||
feedback_text="Please use async/await pattern",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
test_db.add(feedback)
|
||||
test_db.commit()
|
||||
|
||||
# Step 5: Next log_step should retrieve feedback
|
||||
_, step_number3, feedback_list = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Implementing async pattern",
|
||||
)
|
||||
|
||||
assert step_number3 == 3
|
||||
assert len(feedback_list) == 1
|
||||
assert feedback_list[0] == "Please use async/await pattern"
|
||||
|
||||
# Verify feedback was marked as retrieved
|
||||
test_db.refresh(feedback)
|
||||
assert feedback.retrieved_at is not None
|
||||
|
||||
# Step 6: End the session
|
||||
ended_instance_id, final_status = process_end_session(
|
||||
db=test_db, agent_instance_id=instance_id, user_id=str(test_user.id)
|
||||
)
|
||||
|
||||
assert ended_instance_id == instance_id
|
||||
assert final_status == "completed"
|
||||
|
||||
# Verify final state
|
||||
test_db.refresh(instance)
|
||||
assert instance.status == AgentStatus.COMPLETED
|
||||
assert instance.ended_at is not None
|
||||
|
||||
# Verify questions were deactivated
|
||||
test_db.refresh(question)
|
||||
assert question.is_active is False
|
||||
|
||||
# Verify all steps were logged
|
||||
steps = (
|
||||
test_db.query(AgentStep)
|
||||
.filter_by(agent_instance_id=instance_id)
|
||||
.order_by(AgentStep.step_number)
|
||||
.all()
|
||||
)
|
||||
|
||||
assert len(steps) == 3
|
||||
assert steps[0].description == "Starting integration test task"
|
||||
assert steps[1].description == "Processing files"
|
||||
assert steps[2].description == "Implementing async pattern"
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_multiple_feedback_handling(self, test_db, test_user, test_user_agent):
|
||||
"""Test handling multiple feedback items."""
|
||||
# Create instance
|
||||
instance_id, _, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Starting task",
|
||||
)
|
||||
|
||||
# Add multiple feedback items
|
||||
feedback_items = []
|
||||
for i in range(3):
|
||||
feedback = AgentUserFeedback(
|
||||
id=uuid4(),
|
||||
agent_instance_id=instance_id,
|
||||
created_by_user_id=test_user.id,
|
||||
feedback_text=f"Feedback {i + 1}",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
feedback_items.append(feedback)
|
||||
test_db.add(feedback)
|
||||
|
||||
test_db.commit()
|
||||
|
||||
# Next log_step should retrieve all feedback
|
||||
_, _, feedback_list = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Processing feedback",
|
||||
)
|
||||
|
||||
assert len(feedback_list) == 3
|
||||
assert set(feedback_list) == {"Feedback 1", "Feedback 2", "Feedback 3"}
|
||||
|
||||
# All feedback should be marked as retrieved
|
||||
for feedback in feedback_items:
|
||||
test_db.refresh(feedback)
|
||||
assert feedback.retrieved_at is not None
|
||||
|
||||
# Subsequent log_step should not retrieve same feedback
|
||||
_, _, feedback_list2 = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=instance_id,
|
||||
agent_type="Claude Code Test",
|
||||
step_description="Continuing work",
|
||||
)
|
||||
|
||||
assert len(feedback_list2) == 0
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_user_agent_creation_and_reuse(self, test_db, test_user):
|
||||
"""Test that user agents are created and reused correctly."""
|
||||
# First call should create a new user agent
|
||||
instance1_id, _, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="New Agent Type",
|
||||
step_description="First task",
|
||||
)
|
||||
|
||||
# Check user agent was created (name is stored in lowercase)
|
||||
user_agents = (
|
||||
test_db.query(UserAgent)
|
||||
.filter_by(user_id=test_user.id, name="new agent type")
|
||||
.all()
|
||||
)
|
||||
assert len(user_agents) == 1
|
||||
|
||||
# Second call with same agent type should reuse the user agent
|
||||
instance2_id, _, _ = process_log_step(
|
||||
db=test_db,
|
||||
user_id=str(test_user.id),
|
||||
agent_instance_id=None,
|
||||
agent_type="New Agent Type",
|
||||
step_description="Second task",
|
||||
)
|
||||
|
||||
# Should still only have one user agent (name is stored in lowercase)
|
||||
user_agents = (
|
||||
test_db.query(UserAgent)
|
||||
.filter_by(user_id=test_user.id, name="new agent type")
|
||||
.all()
|
||||
)
|
||||
assert len(user_agents) == 1
|
||||
|
||||
# But two different instances
|
||||
assert instance1_id != instance2_id
|
||||
|
||||
# Both instances should reference the same user agent
|
||||
instance1 = test_db.query(AgentInstance).filter_by(id=instance1_id).first()
|
||||
instance2 = test_db.query(AgentInstance).filter_by(id=instance2_id).first()
|
||||
assert instance1.user_agent_id == instance2.user_agent_id
|
||||
@@ -3,12 +3,8 @@
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.database.models import (
|
||||
AgentStep,
|
||||
AgentQuestion,
|
||||
AgentUserFeedback,
|
||||
)
|
||||
from shared.database.enums import AgentStatus
|
||||
from shared.database.models import Message
|
||||
from shared.database.enums import AgentStatus, SenderType
|
||||
|
||||
|
||||
class TestDatabaseModels:
|
||||
@@ -21,117 +17,125 @@ class TestDatabaseModels:
|
||||
assert test_agent_instance.started_at is not None
|
||||
assert test_agent_instance.ended_at is None
|
||||
|
||||
def test_create_agent_step(self, test_db, test_agent_instance):
|
||||
"""Test creating agent steps."""
|
||||
step1 = AgentStep(
|
||||
def test_create_agent_messages(self, test_db, test_agent_instance):
|
||||
"""Test creating agent messages."""
|
||||
message1 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=1,
|
||||
description="First step",
|
||||
sender_type=SenderType.AGENT,
|
||||
content="First step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
requires_user_input=False,
|
||||
)
|
||||
|
||||
step2 = AgentStep(
|
||||
message2 = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=2,
|
||||
description="Second step",
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Second step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
requires_user_input=False,
|
||||
)
|
||||
|
||||
test_db.add_all([step1, step2])
|
||||
test_db.add_all([message1, message2])
|
||||
test_db.commit()
|
||||
|
||||
# Query steps
|
||||
steps = (
|
||||
test_db.query(AgentStep)
|
||||
# Query messages
|
||||
messages = (
|
||||
test_db.query(Message)
|
||||
.filter_by(agent_instance_id=test_agent_instance.id)
|
||||
.order_by(AgentStep.step_number)
|
||||
.order_by(Message.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
assert len(steps) == 2
|
||||
assert steps[0].step_number == 1
|
||||
assert steps[0].description == "First step"
|
||||
assert steps[1].step_number == 2
|
||||
assert steps[1].description == "Second step"
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "First step"
|
||||
assert messages[1].content == "Second step"
|
||||
assert all(msg.sender_type == SenderType.AGENT for msg in messages)
|
||||
|
||||
def test_create_agent_question(self, test_db, test_agent_instance):
|
||||
"""Test creating agent questions."""
|
||||
question = AgentQuestion(
|
||||
def test_create_agent_question_message(self, test_db, test_agent_instance):
|
||||
"""Test creating agent question as a message."""
|
||||
question = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Should I continue?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Should I continue?",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
requires_user_input=True,
|
||||
)
|
||||
|
||||
test_db.add(question)
|
||||
test_db.commit()
|
||||
|
||||
# Query question
|
||||
saved_question = test_db.query(AgentQuestion).filter_by(id=question.id).first()
|
||||
saved_question = test_db.query(Message).filter_by(id=question.id).first()
|
||||
|
||||
assert saved_question is not None
|
||||
assert saved_question.question_text == "Should I continue?"
|
||||
assert saved_question.is_active is True
|
||||
assert saved_question.answer_text is None
|
||||
assert saved_question.answered_at is None
|
||||
assert saved_question.content == "Should I continue?"
|
||||
assert saved_question.requires_user_input is True
|
||||
assert saved_question.sender_type == SenderType.AGENT
|
||||
|
||||
def test_create_user_feedback(self, test_db, test_agent_instance):
|
||||
"""Test creating user feedback."""
|
||||
feedback = AgentUserFeedback(
|
||||
def test_create_user_feedback_message(self, test_db, test_agent_instance):
|
||||
"""Test creating user feedback as a message."""
|
||||
feedback = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
created_by_user_id=test_agent_instance.user_id,
|
||||
feedback_text="Please use TypeScript",
|
||||
sender_type=SenderType.USER,
|
||||
content="Please use TypeScript",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
requires_user_input=False,
|
||||
message_metadata={
|
||||
"source": "user_feedback",
|
||||
"created_by_user_id": str(test_agent_instance.user_id),
|
||||
},
|
||||
)
|
||||
|
||||
test_db.add(feedback)
|
||||
test_db.commit()
|
||||
|
||||
# Query feedback
|
||||
saved_feedback = (
|
||||
test_db.query(AgentUserFeedback).filter_by(id=feedback.id).first()
|
||||
)
|
||||
saved_feedback = test_db.query(Message).filter_by(id=feedback.id).first()
|
||||
|
||||
assert saved_feedback is not None
|
||||
assert saved_feedback.feedback_text == "Please use TypeScript"
|
||||
assert saved_feedback.retrieved_at is None
|
||||
assert saved_feedback.content == "Please use TypeScript"
|
||||
assert saved_feedback.sender_type == SenderType.USER
|
||||
assert saved_feedback.message_metadata["created_by_user_id"] == str(
|
||||
test_agent_instance.user_id
|
||||
)
|
||||
|
||||
def test_agent_instance_relationships(self, test_db, test_agent_instance):
|
||||
"""Test agent instance relationships."""
|
||||
# Add a step
|
||||
step = AgentStep(
|
||||
def test_agent_instance_message_relationships(self, test_db, test_agent_instance):
|
||||
"""Test agent instance message relationships."""
|
||||
# Add an agent message
|
||||
agent_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
step_number=1,
|
||||
description="Test step",
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Test step",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
requires_user_input=False,
|
||||
)
|
||||
|
||||
# Add a question
|
||||
question = AgentQuestion(
|
||||
# Add a question message
|
||||
question_msg = Message(
|
||||
id=uuid4(),
|
||||
agent_instance_id=test_agent_instance.id,
|
||||
question_text="Test question?",
|
||||
asked_at=datetime.now(timezone.utc),
|
||||
is_active=True,
|
||||
sender_type=SenderType.AGENT,
|
||||
content="Test question?",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
requires_user_input=True,
|
||||
)
|
||||
|
||||
test_db.add_all([step, question])
|
||||
test_db.add_all([agent_msg, question_msg])
|
||||
test_db.commit()
|
||||
|
||||
# Refresh instance to load relationships
|
||||
test_db.refresh(test_agent_instance)
|
||||
|
||||
# Test relationships
|
||||
assert len(test_agent_instance.steps) == 1
|
||||
assert test_agent_instance.steps[0].description == "Test step"
|
||||
|
||||
assert len(test_agent_instance.questions) == 1
|
||||
assert test_agent_instance.questions[0].question_text == "Test question?"
|
||||
assert len(test_agent_instance.messages) == 2
|
||||
assert test_agent_instance.messages[0].content == "Test step"
|
||||
assert test_agent_instance.messages[1].content == "Test question?"
|
||||
assert test_agent_instance.messages[1].requires_user_input is True
|
||||
|
||||
|
||||
class TestAgentStatusTransitions:
|
||||
|
||||
293
shared/alembic/versions/40d4252deb5b_add_messages_table.py
Normal file
293
shared/alembic/versions/40d4252deb5b_add_messages_table.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Add messages table
|
||||
|
||||
Revision ID: 40d4252deb5b
|
||||
Revises: dc285eabea90
|
||||
Create Date: 2025-07-31 11:25:30.567076
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "40d4252deb5b"
|
||||
down_revision: Union[str, None] = "dc285eabea90"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"messages",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"sender_type", sa.Enum("AGENT", "USER", name="sendertype"), nullable=False
|
||||
),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("requires_user_input", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"message_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_instance_id"],
|
||||
["agent_instances.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_messages_instance_created",
|
||||
"messages",
|
||||
["agent_instance_id", "created_at"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Migrate data from old tables to messages table
|
||||
# This combines AgentStep, AgentQuestion, and AgentUserFeedback into Messages
|
||||
# ordered by timestamp
|
||||
op.execute("""
|
||||
-- Insert agent steps as agent messages
|
||||
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
|
||||
SELECT
|
||||
id,
|
||||
agent_instance_id,
|
||||
'AGENT',
|
||||
description,
|
||||
created_at,
|
||||
false,
|
||||
jsonb_build_object('source', 'agent_step', 'step_number', step_number)
|
||||
FROM agent_steps;
|
||||
|
||||
-- Insert agent questions as agent messages
|
||||
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
|
||||
SELECT
|
||||
id,
|
||||
agent_instance_id,
|
||||
'AGENT',
|
||||
question_text,
|
||||
asked_at,
|
||||
CASE WHEN answer_text IS NULL THEN true ELSE false END,
|
||||
jsonb_build_object('source', 'agent_question')
|
||||
FROM agent_questions;
|
||||
|
||||
-- Insert question answers as user messages (only for answered questions)
|
||||
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
agent_instance_id,
|
||||
'USER',
|
||||
answer_text,
|
||||
answered_at,
|
||||
false,
|
||||
jsonb_build_object('source', 'question_answer', 'question_id', id::text, 'answered_by_user_id', answered_by_user_id::text)
|
||||
FROM agent_questions
|
||||
WHERE answer_text IS NOT NULL AND answered_at IS NOT NULL;
|
||||
|
||||
-- Insert user feedback as user messages
|
||||
INSERT INTO messages (id, agent_instance_id, sender_type, content, created_at, requires_user_input, message_metadata)
|
||||
SELECT
|
||||
id,
|
||||
agent_instance_id,
|
||||
'USER',
|
||||
feedback_text,
|
||||
created_at,
|
||||
false,
|
||||
jsonb_build_object('source', 'user_feedback', 'created_by_user_id', created_by_user_id::text)
|
||||
FROM agent_user_feedback;
|
||||
|
||||
-- Mark all agent instances as completed
|
||||
UPDATE agent_instances SET status = 'COMPLETED' WHERE status != 'COMPLETED';
|
||||
""")
|
||||
|
||||
# Add last_read_message_id to agent_instances
|
||||
op.add_column(
|
||||
"agent_instances",
|
||||
sa.Column("last_read_message_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
# Create the foreign key constraint after both tables exist to avoid circular dependency
|
||||
op.create_foreign_key(
|
||||
"fk_agent_instances_last_read_message",
|
||||
"agent_instances",
|
||||
"messages",
|
||||
["last_read_message_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# Create combined function for message notifications
|
||||
op.execute("""
|
||||
CREATE OR REPLACE FUNCTION notify_message_change() RETURNS trigger AS $$
|
||||
DECLARE
|
||||
channel_name text;
|
||||
payload text;
|
||||
event_type text;
|
||||
BEGIN
|
||||
-- Create channel name based on instance ID
|
||||
channel_name := 'message_channel_' || NEW.agent_instance_id::text;
|
||||
|
||||
-- Determine event type
|
||||
IF TG_OP = 'INSERT' THEN
|
||||
event_type := 'message_insert';
|
||||
ELSIF TG_OP = 'UPDATE' THEN
|
||||
event_type := 'message_update';
|
||||
END IF;
|
||||
|
||||
-- Create JSON payload with message data
|
||||
payload := json_build_object(
|
||||
'event_type', event_type,
|
||||
'id', NEW.id,
|
||||
'agent_instance_id', NEW.agent_instance_id,
|
||||
'sender_type', NEW.sender_type,
|
||||
'content', NEW.content,
|
||||
'created_at', to_char(NEW.created_at AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'),
|
||||
'requires_user_input', NEW.requires_user_input,
|
||||
'message_metadata', NEW.message_metadata,
|
||||
'old_requires_user_input', CASE
|
||||
WHEN TG_OP = 'UPDATE' THEN OLD.requires_user_input
|
||||
ELSE NULL
|
||||
END
|
||||
)::text;
|
||||
|
||||
-- Send notification (quote channel name for UUIDs with hyphens)
|
||||
EXECUTE format('NOTIFY %I, %L', channel_name, payload);
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
""")
|
||||
|
||||
# Create single trigger for both INSERT and UPDATE on messages table
|
||||
op.execute("""
|
||||
CREATE TRIGGER message_change_notify
|
||||
AFTER INSERT OR UPDATE ON messages
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION notify_message_change();
|
||||
""")
|
||||
|
||||
# Create function for status change notifications
|
||||
op.execute("""
|
||||
CREATE OR REPLACE FUNCTION notify_status_change() RETURNS trigger AS $$
|
||||
DECLARE
|
||||
channel_name text;
|
||||
payload text;
|
||||
BEGIN
|
||||
-- Only notify if status actually changed
|
||||
IF OLD.status IS DISTINCT FROM NEW.status THEN
|
||||
-- Create channel name based on instance ID
|
||||
channel_name := 'message_channel_' || NEW.id::text;
|
||||
|
||||
-- Create JSON payload with status update data
|
||||
payload := json_build_object(
|
||||
'event_type', 'status_update',
|
||||
'instance_id', NEW.id,
|
||||
'status', NEW.status,
|
||||
'timestamp', to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"')
|
||||
)::text;
|
||||
|
||||
-- Send notification (quote channel name for UUIDs with hyphens)
|
||||
EXECUTE format('NOTIFY %I, %L', channel_name, payload);
|
||||
END IF;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
""")
|
||||
|
||||
# Create trigger on agent_instances table for status updates
|
||||
op.execute("""
|
||||
CREATE TRIGGER agent_instance_status_notify
|
||||
AFTER UPDATE OF status ON agent_instances
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION notify_status_change();
|
||||
""")
|
||||
|
||||
# Drop the old tables after migration
|
||||
op.drop_table("agent_user_feedback")
|
||||
op.drop_table("agent_questions")
|
||||
op.drop_table("agent_steps")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Recreate the old tables first
|
||||
op.create_table(
|
||||
"agent_steps",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
|
||||
sa.Column("step_number", sa.Integer(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_instance_id"],
|
||||
["agent_instances.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent_questions",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
|
||||
sa.Column("question_text", sa.Text(), nullable=False),
|
||||
sa.Column("answer_text", sa.Text(), nullable=True),
|
||||
sa.Column("answered_by_user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("asked_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("answered_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_instance_id"],
|
||||
["agent_instances.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["answered_by_user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent_user_feedback",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("agent_instance_id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_by_user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("feedback_text", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("retrieved_at", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_instance_id"],
|
||||
["agent_instances.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by_user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Note: We cannot fully restore the original data as some information is lost
|
||||
# (e.g., step_number ordering, which user created feedback vs answered questions)
|
||||
# This is a best-effort restoration
|
||||
|
||||
# Drop status change trigger and function
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS agent_instance_status_notify ON agent_instances;"
|
||||
)
|
||||
op.execute("DROP FUNCTION IF EXISTS notify_status_change();")
|
||||
|
||||
# Drop message trigger and function
|
||||
op.execute("DROP TRIGGER IF EXISTS message_change_notify ON messages;")
|
||||
op.execute("DROP FUNCTION IF EXISTS notify_message_change();")
|
||||
|
||||
op.drop_constraint(
|
||||
"fk_agent_instances_last_read_message", "agent_instances", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("agent_instances", "last_read_message_id")
|
||||
op.drop_index("idx_messages_instance_created", table_name="messages")
|
||||
op.drop_table("messages")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,11 +1,9 @@
|
||||
from .enums import AgentStatus
|
||||
from .enums import AgentStatus, SenderType
|
||||
from .models import (
|
||||
AgentInstance,
|
||||
AgentQuestion,
|
||||
AgentStep,
|
||||
AgentUserFeedback,
|
||||
APIKey,
|
||||
Base,
|
||||
Message,
|
||||
PushToken,
|
||||
User,
|
||||
UserAgent,
|
||||
@@ -20,12 +18,11 @@ __all__ = [
|
||||
"User",
|
||||
"UserAgent",
|
||||
"AgentInstance",
|
||||
"AgentStep",
|
||||
"AgentQuestion",
|
||||
"AgentStatus",
|
||||
"AgentUserFeedback",
|
||||
"APIKey",
|
||||
"Message",
|
||||
"PushToken",
|
||||
"SenderType",
|
||||
"Subscription",
|
||||
"BillingEvent",
|
||||
]
|
||||
|
||||
@@ -2,11 +2,16 @@ from enum import Enum
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
AWAITING_INPUT = "awaiting_input"
|
||||
PAUSED = "paused"
|
||||
STALE = "stale"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
KILLED = "killed"
|
||||
DISCONNECTED = "disconnected"
|
||||
ACTIVE = "ACTIVE"
|
||||
AWAITING_INPUT = "AWAITING_INPUT"
|
||||
PAUSED = "PAUSED"
|
||||
STALE = "STALE"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
KILLED = "KILLED"
|
||||
DISCONNECTED = "DISCONNECTED"
|
||||
|
||||
|
||||
class SenderType(str, Enum):
|
||||
AGENT = "AGENT"
|
||||
USER = "USER"
|
||||
|
||||
@@ -3,7 +3,7 @@ from uuid import UUID, uuid4
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
|
||||
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID, JSONB
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase, # type: ignore[attr-defined]
|
||||
Mapped, # type: ignore[attr-defined]
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy.orm import (
|
||||
validates,
|
||||
)
|
||||
|
||||
from .enums import AgentStatus
|
||||
from .enums import AgentStatus, SenderType
|
||||
from .utils import is_valid_git_diff
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -57,12 +57,6 @@ class User(Base):
|
||||
agent_instances: Mapped[list["AgentInstance"]] = relationship(
|
||||
"AgentInstance", back_populates="user"
|
||||
)
|
||||
answered_questions: Mapped[list["AgentQuestion"]] = relationship(
|
||||
"AgentQuestion", back_populates="answered_by_user"
|
||||
)
|
||||
feedback: Mapped[list["AgentUserFeedback"]] = relationship(
|
||||
"AgentUserFeedback", back_populates="created_by_user"
|
||||
)
|
||||
api_keys: Mapped[list["APIKey"]] = relationship("APIKey", back_populates="user")
|
||||
user_agents: Mapped[list["UserAgent"]] = relationship(
|
||||
"UserAgent", back_populates="user"
|
||||
@@ -130,22 +124,33 @@ class AgentInstance(Base):
|
||||
)
|
||||
ended_at: Mapped[datetime | None] = mapped_column(default=None)
|
||||
git_diff: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
last_read_message_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey(
|
||||
"messages.id",
|
||||
use_alter=True,
|
||||
name="fk_agent_instances_last_read_message",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
type_=PostgresUUID(as_uuid=True),
|
||||
default=None,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user_agent: Mapped["UserAgent"] = relationship(
|
||||
"UserAgent", back_populates="instances"
|
||||
)
|
||||
user: Mapped["User"] = relationship("User", back_populates="agent_instances")
|
||||
steps: Mapped[list["AgentStep"]] = relationship(
|
||||
"AgentStep", back_populates="instance", order_by="AgentStep.created_at"
|
||||
)
|
||||
questions: Mapped[list["AgentQuestion"]] = relationship(
|
||||
"AgentQuestion", back_populates="instance", order_by="AgentQuestion.asked_at"
|
||||
)
|
||||
user_feedback: Mapped[list["AgentUserFeedback"]] = relationship(
|
||||
"AgentUserFeedback",
|
||||
messages: Mapped[list["Message"]] = relationship(
|
||||
"Message",
|
||||
back_populates="instance",
|
||||
order_by="AgentUserFeedback.created_at",
|
||||
order_by="Message.created_at",
|
||||
foreign_keys="Message.agent_instance_id",
|
||||
)
|
||||
last_read_message: Mapped["Message | None"] = relationship(
|
||||
"Message",
|
||||
foreign_keys=[last_read_message_id],
|
||||
post_update=True,
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
@validates("git_diff")
|
||||
@@ -154,7 +159,7 @@ class AgentInstance(Base):
|
||||
|
||||
Raises ValueError if the git diff is invalid.
|
||||
"""
|
||||
if value is None:
|
||||
if value is None or value == "":
|
||||
return value
|
||||
|
||||
if not is_valid_git_diff(value):
|
||||
@@ -163,81 +168,6 @@ class AgentInstance(Base):
|
||||
return value
|
||||
|
||||
|
||||
class AgentStep(Base):
|
||||
__tablename__ = "agent_steps"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
agent_instance_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
|
||||
)
|
||||
step_number: Mapped[int] = mapped_column()
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# Relationships
|
||||
instance: Mapped["AgentInstance"] = relationship(
|
||||
"AgentInstance", back_populates="steps"
|
||||
)
|
||||
|
||||
|
||||
class AgentQuestion(Base):
|
||||
__tablename__ = "agent_questions"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
agent_instance_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
|
||||
)
|
||||
question_text: Mapped[str] = mapped_column(Text)
|
||||
answer_text: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
answered_by_user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("users.id"), type_=PostgresUUID(as_uuid=True), default=None
|
||||
)
|
||||
asked_at: Mapped[datetime] = mapped_column(
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
answered_at: Mapped[datetime | None] = mapped_column(default=None)
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
|
||||
# Relationships
|
||||
instance: Mapped["AgentInstance"] = relationship(
|
||||
"AgentInstance", back_populates="questions"
|
||||
)
|
||||
answered_by_user: Mapped["User | None"] = relationship(
|
||||
"User", back_populates="answered_questions"
|
||||
)
|
||||
|
||||
|
||||
class AgentUserFeedback(Base):
|
||||
__tablename__ = "agent_user_feedback"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
agent_instance_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
|
||||
)
|
||||
created_by_user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("users.id"), type_=PostgresUUID(as_uuid=True)
|
||||
)
|
||||
feedback_text: Mapped[str] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
retrieved_at: Mapped[datetime | None] = mapped_column(default=None)
|
||||
|
||||
# Relationships
|
||||
instance: Mapped["AgentInstance"] = relationship(
|
||||
"AgentInstance", back_populates="user_feedback"
|
||||
)
|
||||
created_by_user: Mapped["User"] = relationship("User", back_populates="feedback")
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
@@ -286,3 +216,31 @@ class PushToken(Base):
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship("User", back_populates="push_tokens")
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
__table_args__ = (
|
||||
Index("idx_messages_instance_created", "agent_instance_id", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
agent_instance_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("agent_instances.id"), type_=PostgresUUID(as_uuid=True)
|
||||
)
|
||||
sender_type: Mapped[SenderType] = mapped_column()
|
||||
content: Mapped[str] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
requires_user_input: Mapped[bool] = mapped_column(default=False)
|
||||
message_metadata: Mapped[dict | None] = mapped_column(JSONB, default=None)
|
||||
|
||||
# Relationships
|
||||
instance: Mapped["AgentInstance"] = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="messages",
|
||||
foreign_keys=[agent_instance_id],
|
||||
)
|
||||
|
||||
@@ -339,6 +339,7 @@ class WebhookRequest(BaseModel):
|
||||
prompt: str
|
||||
name: str | None = None # Branch name
|
||||
worktree_name: str | None = None
|
||||
agent_type: str | None = None # Agent type name
|
||||
|
||||
@field_validator("agent_instance_id")
|
||||
def validate_instance_id(cls, v):
|
||||
@@ -589,6 +590,7 @@ async def start_claude(
|
||||
prompt = webhook_data.prompt
|
||||
worktree_name = webhook_data.worktree_name
|
||||
branch_name = webhook_data.name
|
||||
agent_type = webhook_data.agent_type
|
||||
|
||||
print("\n[INFO] Received webhook request:")
|
||||
print(f" - Instance ID: {agent_instance_id}")
|
||||
@@ -806,21 +808,26 @@ async def start_claude(
|
||||
mcp_config = {
|
||||
"mcpServers": {
|
||||
"omnara": {
|
||||
"command": "pipx",
|
||||
"command": "omnara",
|
||||
"args": [
|
||||
"run",
|
||||
"--no-cache",
|
||||
"omnara",
|
||||
"--api-key",
|
||||
omnara_api_key,
|
||||
"--claude-code-permission-tool",
|
||||
"--git-diff",
|
||||
"--agent-instance-id",
|
||||
agent_instance_id,
|
||||
"--base-url",
|
||||
"http://localhost:8080",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Add environment variable for agent type if provided
|
||||
if agent_type:
|
||||
mcp_config["mcpServers"]["omnara"]["env"] = {
|
||||
"OMNARA_CLIENT_TYPE": agent_type
|
||||
}
|
||||
mcp_config_str = json.dumps(mcp_config)
|
||||
|
||||
# Build claude command with MCP config as string
|
||||
|
||||
1247
webhooks/claude_wrapper_v3.py
Executable file
1247
webhooks/claude_wrapper_v3.py
Executable file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user