diff --git a/app/src/server/api/routers/loggedCalls.router.ts b/app/src/server/api/routers/loggedCalls.router.ts index f9602a9..d23f3c1 100644 --- a/app/src/server/api/routers/loggedCalls.router.ts +++ b/app/src/server/api/routers/loggedCalls.router.ts @@ -7,13 +7,19 @@ import { kysely, prisma } from "~/server/db"; import { comparators } from "~/state/logFiltersSlice"; import { requireCanViewProject } from "~/utils/accessControl"; -const defaultFilterableFields = ["Request", "Response", "Model"]; +const defaultFilterableFields = ["Request", "Response", "Model", "Status Code"]; -const comparatorToSql = { - "=": "=", - "!=": "!=", - CONTAINS: "like", -} as const; +// create comparator type based off of comparators +const comparatorToSqlValue = (comparator: (typeof comparators)[number], value: string) => { + switch (comparator) { + case "=": + return `= '${value}'`; + case "!=": + return `!= '${value}'`; + case "CONTAINS": + return `like '%${value}%'`; + } +}; export const loggedCallsRouter = createTRPCRouter({ list: protectedProcedure @@ -45,7 +51,39 @@ export const loggedCallsRouter = createTRPCRouter({ for (const filter of input.filters) { if (!filter.value) continue; if (filter.field === "Request") { - wheres.push(sql.raw(`lcmr."reqPayload"::text like '%${filter.value}%'`)); + wheres.push( + sql.raw( + `lcmr."reqPayload"::text ${comparatorToSqlValue( + filter.comparator, + filter.value, + )}`, + ), + ); + } + if (filter.field === "Response") { + wheres.push( + sql.raw( + `lcmr."respPayload"::text ${comparatorToSqlValue( + filter.comparator, + filter.value, + )}`, + ), + ); + } + if (filter.field === "Model") { + wheres.push( + sql.raw(`lc."model" ${comparatorToSqlValue(filter.comparator, filter.value)}`), + ); + } + if (filter.field === "Status Code") { + wheres.push( + sql.raw( + `lcmr."statusCode"::text ${comparatorToSqlValue( + filter.comparator, + filter.value, + )}`, + ), + ); } } @@ -53,7 +91,7 @@ export const loggedCallsRouter = createTRPCRouter({ }); const rawCalls = await baseQuery - .select((eb) => [ + .select([ "lc.id as id", "lc.requestedAt as requestedAt", "model", @@ -96,43 +134,6 @@ export const loggedCallsRouter = createTRPCRouter({ const count = matchingLogIds.length; return { calls, count, matchingLogIds: matchingLogIds.map((log) => log.id) }; - - // const whereClauses: Prisma.LoggedCallWhereInput[] = [{ projectId }]; - - // for (const filter of input.filters) { - // if (!filter.value) continue; - // if (filter.field === "Request") { - // console.log("filter.value is", filter.value); - // whereClauses.push({ - // modelResponse: { - // is: { - // reqPayload: { - // string_contains: filter.value, - // }, - // }, - // }, - // }); - // } - // } - - // const calls = await prisma.loggedCall.findMany({ - // where: { AND: whereClauses }, - // orderBy: { requestedAt: "desc" }, - // include: { tags: true, modelResponse: true }, - // skip: (page - 1) * pageSize, - // take: pageSize, - // }); - - // const matchingLogs = await prisma.loggedCall.findMany({ - // where: { AND: whereClauses }, - // select: { id: true }, - // }); - - // const count = await prisma.loggedCall.count({ - // where: { AND: whereClauses }, - // }); - - // return { calls, count, matchingLogIds: matchingLogs.map((log) => log.id) }; }), getFilterableFields: protectedProcedure .input(z.object({ projectId: z.string() }))