Display tag values

This commit is contained in:
David Corbitt
2023-08-15 02:32:05 -07:00
parent 9636fa033e
commit 3547c85c86
7 changed files with 70 additions and 58 deletions

View File

@@ -1,14 +1,12 @@
import { z } from "zod";
import { type Expression, type SqlBool } from "kysely";
import { sql } from "kysely";
import { type Expression, type SqlBool, sql } from "kysely";
import { jsonArrayFrom } from "kysely/helpers/postgres";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { kysely, prisma } from "~/server/db";
import { comparators } from "~/state/logFiltersSlice";
import { comparators, defaultFilterableFields } from "~/state/logFiltersSlice";
import { requireCanViewProject } from "~/utils/accessControl";
const defaultFilterableFields = ["Request", "Response", "Model", "Status Code"];
// create comparator type based off of comparators
const comparatorToSqlValue = (comparator: (typeof comparators)[number], value: string) => {
switch (comparator) {
@@ -91,7 +89,10 @@ export const loggedCallsRouter = createTRPCRouter({
});
const tagFilters = input.filters.filter(
(filter) => !defaultFilterableFields.includes(filter.field),
(filter) =>
!defaultFilterableFields.includes(
filter.field as (typeof defaultFilterableFields)[number],
),
);
let updatedBaseQuery = baseQuery;
@@ -112,7 +113,7 @@ export const loggedCallsRouter = createTRPCRouter({
}
const rawCalls = await updatedBaseQuery
.select([
.select((eb) => [
"lc.id as id",
"lc.requestedAt as requestedAt",
"model",
@@ -127,28 +128,45 @@ export const loggedCallsRouter = createTRPCRouter({
"cost",
"statusCode",
"durationMs",
jsonArrayFrom(
eb
.selectFrom("LoggedCallTag")
.select(["name", "value"])
.whereRef("loggedCallId", "=", "lc.id"),
).as("tags"),
])
.orderBy("lc.requestedAt", "desc")
.limit(pageSize)
.offset((page - 1) * pageSize)
.execute();
const calls = rawCalls.map((rawCall) => ({
id: rawCall.id,
requestedAt: rawCall.requestedAt,
model: rawCall.model,
cacheHit: rawCall.cacheHit,
modelResponse: {
receivedAt: rawCall.receivedAt,
reqPayload: rawCall.reqPayload,
respPayload: rawCall.respPayload,
inputTokens: rawCall.inputTokens,
outputTokens: rawCall.outputTokens,
cost: rawCall.cost,
statusCode: rawCall.statusCode,
durationMs: rawCall.durationMs,
},
}));
const calls = rawCalls.map((rawCall) => {
const tagsObject = rawCall.tags.reduce(
(acc, tag) => {
acc[tag.name] = tag.value;
return acc;
},
{} as Record<string, string | null>,
);
return {
id: rawCall.id,
requestedAt: rawCall.requestedAt,
model: rawCall.model,
cacheHit: rawCall.cacheHit,
modelResponse: {
receivedAt: rawCall.receivedAt,
reqPayload: rawCall.reqPayload,
respPayload: rawCall.respPayload,
inputTokens: rawCall.inputTokens,
outputTokens: rawCall.outputTokens,
cost: rawCall.cost,
statusCode: rawCall.statusCode,
durationMs: rawCall.durationMs,
},
tags: tagsObject,
};
});
const matchingLogIds = await updatedBaseQuery.select(["lc.id"]).execute();
@@ -156,7 +174,7 @@ export const loggedCallsRouter = createTRPCRouter({
return { calls, count, matchingLogIds: matchingLogIds.map((log) => log.id) };
}),
getFilterableFields: protectedProcedure
getTagNames: protectedProcedure
.input(z.object({ projectId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewProject(input.projectId, ctx);
@@ -169,8 +187,11 @@ export const loggedCallsRouter = createTRPCRouter({
select: {
name: true,
},
orderBy: {
name: "asc",
},
});
return [...defaultFilterableFields, ...tags.map((tag) => tag.name)];
return tags.map((tag) => tag.name);
}),
});