Co-authored-by: Kartik Sarangmath <kartiksarangmath@Kartiks-MacBook-Air.local>
This commit is contained in:
ksarangmath
2025-08-07 23:45:31 -07:00
committed by GitHub
parent fdf74a129f
commit 7b4c55b053
3 changed files with 93 additions and 42 deletions

View File

@@ -15,7 +15,7 @@ from shared.database import (
)
from shared.database.billing_operations import get_or_create_subscription
from shared.database.subscription_models import BillingEvent, Subscription
from sqlalchemy import and_, desc, func
from sqlalchemy import case, desc, func
from sqlalchemy.orm import Session, joinedload, subqueryload
# Import Pydantic models for type-safe returns
@@ -76,44 +76,48 @@ def get_all_agent_types_with_instances(
for instance in user_agent.instances:
all_instance_ids.append(instance.id)
# Get message stats for ALL instances in a single query
# Get message stats for ALL instances in a simpler, more efficient query
message_stats = {}
if all_instance_ids:
# Subquery to get latest message per instance
latest_messages_subq = (
# Use window functions to get both count and latest message in one query
# This leverages our new index on (agent_instance_id, created_at)
# Subquery with row_number to identify the latest message per instance
latest_msg_cte = (
db.query(
Message.agent_instance_id,
func.count(Message.id).label("msg_count"),
func.max(Message.created_at).label("latest_at"),
Message.content,
Message.created_at,
func.row_number()
.over(
partition_by=Message.agent_instance_id,
order_by=desc(Message.created_at),
)
.label("rn"),
func.count(Message.id)
.over(partition_by=Message.agent_instance_id)
.label("msg_count"),
)
.filter(Message.agent_instance_id.in_(all_instance_ids))
.group_by(Message.agent_instance_id)
.subquery()
)
# Join to get the actual latest message content
# Get only the latest message (rn=1) with counts
stats_results = (
db.query(
latest_messages_subq.c.agent_instance_id,
latest_messages_subq.c.msg_count,
latest_messages_subq.c.latest_at,
Message.content,
)
.outerjoin(
Message,
and_(
Message.agent_instance_id
== latest_messages_subq.c.agent_instance_id,
Message.created_at == latest_messages_subq.c.latest_at,
),
latest_msg_cte.c.agent_instance_id,
latest_msg_cte.c.content,
latest_msg_cte.c.created_at,
latest_msg_cte.c.msg_count,
)
.filter(latest_msg_cte.c.rn == 1)
.all()
)
for row in stats_results:
message_stats[row.agent_instance_id] = {
"count": row.msg_count or 0,
"latest_at": row.latest_at,
"latest_at": row.created_at,
"latest_content": row.content,
}
@@ -219,29 +223,30 @@ def get_all_agent_instances(
def get_agent_summary(db: Session, user_id: UUID) -> dict:
"""Get lightweight summary of agent counts without fetching detailed instance data"""
# Count total instances
total_instances = (
db.query(AgentInstance).filter(AgentInstance.user_id == user_id).count()
# Single query to get all counts using conditional aggregation
stats = (
db.query(
func.count(AgentInstance.id).label("total"),
func.count(case((AgentInstance.status == AgentStatus.ACTIVE, 1))).label(
"active"
),
func.count(case((AgentInstance.status == AgentStatus.COMPLETED, 1))).label(
"completed"
),
)
.filter(AgentInstance.user_id == user_id)
.first()
)
# Count active instances (only 'active' for now until DB enum is updated)
active_instances = (
db.query(AgentInstance)
.filter(
AgentInstance.user_id == user_id, AgentInstance.status == AgentStatus.ACTIVE
)
.count()
)
# Count completed instances
completed_instances = (
db.query(AgentInstance)
.filter(
AgentInstance.user_id == user_id,
AgentInstance.status == AgentStatus.COMPLETED,
)
.count()
)
# Handle the case where stats might be None (though COUNT queries always return a row)
if stats:
total_instances = stats.total or 0
active_instances = stats.active or 0
completed_instances = stats.completed or 0
else:
total_instances = 0
active_instances = 0
completed_instances = 0
# Count by user agent and status (for fleet overview)
# Get instances with their user agents

View File

@@ -0,0 +1,42 @@
"""Add performance indexes to agent_instances table
Revision ID: 20de0aa419ca
Revises: 9fe045ea7ad9
Create Date: 2025-08-07 23:20:32.457833
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "20de0aa419ca"
down_revision: Union[str, None] = "9fe045ea7ad9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(
"idx_agent_instances_user_agent_id",
"agent_instances",
["user_agent_id"],
unique=False,
)
op.create_index(
"idx_agent_instances_user_status",
"agent_instances",
["user_id", "status"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("idx_agent_instances_user_status", table_name="agent_instances")
op.drop_index("idx_agent_instances_user_agent_id", table_name="agent_instances")
# ### end Alembic commands ###

View File

@@ -108,6 +108,10 @@ class UserAgent(Base):
class AgentInstance(Base):
__tablename__ = "agent_instances"
__table_args__ = (
Index("idx_agent_instances_user_agent_id", "user_agent_id"),
Index("idx_agent_instances_user_status", "user_id", "status"),
)
id: Mapped[UUID] = mapped_column(
PostgresUUID(as_uuid=True), primary_key=True, default=uuid4