Merge remote-tracking branch 'refs/remotes/origin/main'

This commit is contained in:
ishaansehgal99
2025-07-14 10:58:57 -07:00
12 changed files with 119 additions and 43 deletions

View File

@@ -69,7 +69,7 @@ def log_step(
@agent_router.post("/questions", response_model=AskQuestionResponse)
def ask_question(
async def ask_question(
request: AskQuestionRequest, user_id: Annotated[str, Depends(get_current_user_id)]
) -> AskQuestionResponse:
"""Create a question for the user to answer.
@@ -85,7 +85,7 @@ def ask_question(
try:
# Use shared business logic to create question
question = create_agent_question(
question = await create_agent_question(
db=db,
agent_instance_id=request.agent_instance_id,
question_text=request.question_text,

View File

@@ -99,7 +99,7 @@ async def ask_question_impl(
try:
# Use shared business logic to create question
question = create_agent_question(
question = await create_agent_question(
db=db,
agent_instance_id=agent_instance_id,
question_text=question_text,

View File

@@ -81,7 +81,7 @@ def process_log_step(
return str(instance.id), step.step_number, feedback
def create_agent_question(
async def create_agent_question(
db: Session,
agent_instance_id: str,
question_text: str,
@@ -103,7 +103,7 @@ def create_agent_question(
# Create question
# Note: Push notification sent by create_question() function
question = create_question(db, instance.id, question_text)
question = await create_question(db, instance.id, question_text)
return question

View File

@@ -85,7 +85,7 @@ def log_step(db: Session, instance_id: UUID, description: str) -> AgentStep:
return step
def create_question(
async def create_question(
db: Session, instance_id: UUID, question_text: str
) -> AgentQuestion:
"""Create a new question for an agent instance"""
@@ -115,7 +115,7 @@ def create_question(
if instance:
agent_name = instance.user_agent.name if instance.user_agent else "Agent"
push_service.send_question_notification(
await push_service.send_question_notification(
db=db,
user_id=instance.user_id,
instance_id=str(instance.id),

View File

@@ -1,7 +1,7 @@
"""Push notification service using Expo Push API"""
import asyncio
import logging
import time
from datetime import datetime, timezone
from typing import Dict, Any, Optional
from uuid import UUID
@@ -13,6 +13,7 @@ from exponent_server_sdk import (
PushTicketError,
DeviceNotRegisteredError,
)
import requests.exceptions
from shared.database import PushToken
@@ -25,7 +26,7 @@ class PushNotificationService:
def __init__(self):
self.client = PushClient()
def send_notification(
async def send_notification(
self,
db: Session,
user_id: UUID,
@@ -89,6 +90,11 @@ class PushNotificationService:
max_retries = 3
for attempt in range(max_retries):
try:
if attempt > 0:
logger.info(
f"Push notification retry attempt {attempt + 1} of {max_retries}"
)
# Send messages in batches (Expo recommends max 100 per batch)
for chunk in self._chunks(messages, 100):
response = self.client.publish_multiple(chunk)
@@ -108,19 +114,6 @@ class PushNotificationService:
)
return True
except (PushServerError, ConnectionError) as e:
if attempt < max_retries - 1:
wait_time = 2**attempt # Exponential backoff: 1s, 2s, 4s
logger.warning(
f"Push notification attempt {attempt + 1} failed, retrying in {wait_time}s: {str(e)}"
)
time.sleep(wait_time)
continue
else:
logger.error(
f"Push server error after {max_retries} attempts: {str(e)}"
)
return False
except DeviceNotRegisteredError as e:
logger.error(f"Device not registered, deactivating token: {str(e)}")
# Mark token as inactive
@@ -133,15 +126,58 @@ class PushNotificationService:
except PushTicketError as e:
logger.error(f"Push ticket error: {str(e)}")
return False
except Exception as e:
# Check if this is a connection-related error that should be retried
# This includes ConnectionError, OSError, requests.exceptions.RequestException, etc.
# We check the exception type and its base classes
error_type = type(e)
is_connection_error = (
isinstance(
e,
(
ConnectionError,
OSError,
requests.exceptions.RequestException,
PushServerError,
),
)
or any(
issubclass(error_type, exc_type)
for exc_type in [ConnectionError, OSError]
)
or "connection" in str(e).lower()
or "reset" in str(e).lower()
)
if is_connection_error and attempt < max_retries - 1:
wait_time = 2**attempt # Exponential backoff: 1s, 2s, 4s
logger.warning(
f"Push notification attempt {attempt + 1} failed, retrying in {wait_time}s: {type(e).__name__}: {e}"
)
await asyncio.sleep(wait_time)
continue
elif is_connection_error:
logger.error(
f"Push server error after {max_retries} attempts: {type(e).__name__}: {e}"
)
return False
else:
# Non-connection error, don't retry
logger.error(
f"Unexpected error sending push notification: {type(e).__name__}: {e}"
)
return False
# If we get here, all retry attempts were exhausted
return False
except Exception as e:
logger.error(f"Error sending push notification: {str(e)}")
logger.error(
f"Unexpected error in send_notification: {type(e).__name__}: {e}"
)
return False
def send_question_notification(
async def send_question_notification(
self,
db: Session,
user_id: UUID,
@@ -166,7 +202,7 @@ class PushNotificationService:
"questionId": question_id,
}
return self.send_notification(
return await self.send_notification(
db=db,
user_id=user_id,
title=title,

View File

@@ -63,7 +63,10 @@ class TestIntegrationFlow:
"""Test the complete integration flow with PostgreSQL."""
@pytest.mark.integration
def test_complete_agent_session_flow(self, test_db, test_user, test_user_agent):
@pytest.mark.asyncio
async def test_complete_agent_session_flow(
self, test_db, test_user, test_user_agent
):
"""Test a complete agent session from start to finish."""
# Step 1: Create new agent instance
instance_id, step_number, user_feedback = process_log_step(
@@ -96,7 +99,7 @@ class TestIntegrationFlow:
assert step_number2 == 2
# Step 3: Create a question
question = create_agent_question(
question = await create_agent_question(
db=test_db,
agent_instance_id=instance_id,
question_text="Should I refactor this module?",