diff --git a/app/src/components/dashboard/LoggedCallTable.tsx b/app/src/components/dashboard/LoggedCallTable.tsx
index 70aa7a2..10ed790 100644
--- a/app/src/components/dashboard/LoggedCallTable.tsx
+++ b/app/src/components/dashboard/LoggedCallTable.tsx
@@ -90,9 +90,9 @@ function TableRow({
isExpanded: boolean;
onToggle: () => void;
}) {
- const isError = loggedCall.modelResponse?.respStatus !== 200;
- const timeAgo = dayjs(loggedCall.startTime).fromNow();
- const fullTime = dayjs(loggedCall.startTime).toString();
+ const isError = loggedCall.modelResponse?.statusCode !== 200;
+ const timeAgo = dayjs(loggedCall.requestedAt).fromNow();
+ const fullTime = dayjs(loggedCall.requestedAt).toString();
const model = useMemo(
() => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value,
@@ -124,7 +124,7 @@ function TableRow({
| {loggedCall.modelResponse?.inputTokens} |
{loggedCall.modelResponse?.outputTokens} |
- {loggedCall.modelResponse?.respStatus ?? "No response"}
+ {loggedCall.modelResponse?.statusCode ?? "No response"}
|
diff --git a/app/src/components/dashboard/UsageGraph.tsx b/app/src/components/dashboard/UsageGraph.tsx
new file mode 100644
index 0000000..dc3faa4
--- /dev/null
+++ b/app/src/components/dashboard/UsageGraph.tsx
@@ -0,0 +1,61 @@
+import {
+ ResponsiveContainer,
+ LineChart,
+ Line,
+ XAxis,
+ YAxis,
+ CartesianGrid,
+ Tooltip,
+ Legend,
+} from "recharts";
+import { useMemo } from "react";
+
+import { useSelectedProject } from "~/utils/hooks";
+import dayjs from "~/utils/dayjs";
+import { api } from "~/utils/api";
+
+export default function UsageGraph() {
+ const { data: selectedProject } = useSelectedProject();
+
+ const stats = api.dashboard.stats.useQuery(
+ { projectId: selectedProject?.id ?? "" },
+ { enabled: !!selectedProject },
+ );
+
+ const data = useMemo(() => {
+ return (
+ stats.data?.periods.map(({ period, numQueries, cost }) => ({
+ period,
+ Requests: numQueries,
+ "Total Spent (USD)": parseFloat(cost.toString()),
+ })) || []
+ );
+ }, [stats.data]);
+
+ return (
+
+
+ dayjs(str).format("MMM D")} />
+
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/app/src/components/tooltip/CostTooltip.tsx b/app/src/components/tooltip/CostTooltip.tsx
index 68cf3ea..0e2cd17 100644
--- a/app/src/components/tooltip/CostTooltip.tsx
+++ b/app/src/components/tooltip/CostTooltip.tsx
@@ -2,14 +2,14 @@ import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from
import { BsCurrencyDollar } from "react-icons/bs";
type CostTooltipProps = {
- promptTokens: number | null;
- completionTokens: number | null;
+ inputTokens: number | null;
+ outputTokens: number | null;
cost: number;
} & TooltipProps;
export const CostTooltip = ({
- promptTokens,
- completionTokens,
+ inputTokens,
+ outputTokens,
cost,
children,
...props
@@ -36,12 +36,12 @@ export const CostTooltip = ({
Prompt
- {promptTokens ?? 0}
+ {inputTokens ?? 0}
Completion
- {completionTokens ?? 0}
+ {outputTokens ?? 0}
diff --git a/app/src/modelProviders/anthropic-completion/index.ts b/app/src/modelProviders/anthropic-completion/index.ts
index 3b2d670..314d2ab 100644
--- a/app/src/modelProviders/anthropic-completion/index.ts
+++ b/app/src/modelProviders/anthropic-completion/index.ts
@@ -28,6 +28,10 @@ const modelProvider: AnthropicProvider = {
inputSchema: inputSchema as JSONSchema4,
canStream: true,
getCompletion,
+ getUsage: (input, output) => {
+ // TODO: add usage logic
+ return null;
+ },
...frontendModelProvider,
};
diff --git a/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts b/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts
index 462f2fa..e6f5123 100644
--- a/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts
+++ b/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts
@@ -4,14 +4,10 @@ import {
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat";
-import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types";
import { isArray, isString, omit } from "lodash-es";
import { openai } from "~/server/utils/openai";
-import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai";
-import frontendModelProvider from "./frontend";
-import modelProvider, { type SupportedModel } from ".";
const mergeStreamedChunks = (
base: ChatCompletion | null,
@@ -60,9 +56,6 @@ export async function getCompletion(
): Promise> {
const start = Date.now();
let finalCompletion: ChatCompletion | null = null;
- let promptTokens: number | undefined = undefined;
- let completionTokens: number | undefined = undefined;
- const modelName = modelProvider.getModel(input) as SupportedModel;
try {
if (onStream) {
@@ -86,16 +79,6 @@ export async function getCompletion(
autoRetry: false,
};
}
- try {
- promptTokens = countOpenAIChatTokens(modelName, input.messages);
- completionTokens = countOpenAIChatTokens(
- modelName,
- finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
- );
- } catch (err) {
- // TODO handle this, library seems like maybe it doesn't work with function calls?
- console.error(err);
- }
} else {
const resp = await openai.chat.completions.create(
{ ...input, stream: false },
@@ -104,25 +87,14 @@ export async function getCompletion(
},
);
finalCompletion = resp;
- promptTokens = resp.usage?.prompt_tokens ?? 0;
- completionTokens = resp.usage?.completion_tokens ?? 0;
}
const timeToComplete = Date.now() - start;
- const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
- let cost = undefined;
- if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
- cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
- }
-
return {
type: "success",
statusCode: 200,
value: finalCompletion,
timeToComplete,
- promptTokens,
- completionTokens,
- cost,
};
} catch (error: unknown) {
if (error instanceof APIError) {
diff --git a/app/src/modelProviders/openai-ChatCompletion/index.ts b/app/src/modelProviders/openai-ChatCompletion/index.ts
index 2b4e90c..b5dc00f 100644
--- a/app/src/modelProviders/openai-ChatCompletion/index.ts
+++ b/app/src/modelProviders/openai-ChatCompletion/index.ts
@@ -4,6 +4,8 @@ import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend";
+import { countOpenAIChatTokens } from "~/utils/countTokens";
+import { truthyFilter } from "~/utils/utils";
const supportedModels = [
"gpt-4-0613",
@@ -39,6 +41,41 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4,
canStream: true,
getCompletion,
+ getUsage: (input, output) => {
+ if (output.choices.length === 0) return null;
+
+ const model = modelProvider.getModel(input);
+ if (!model) return null;
+
+ let inputTokens: number;
+ let outputTokens: number;
+
+ if (output.usage) {
+ inputTokens = output.usage.prompt_tokens;
+ outputTokens = output.usage.completion_tokens;
+ } else {
+ try {
+ inputTokens = countOpenAIChatTokens(model, input.messages);
+ outputTokens = countOpenAIChatTokens(
+ model,
+ output.choices.map((c) => c.message).filter(truthyFilter),
+ );
+ } catch (err) {
+ inputTokens = 0;
+ outputTokens = 0;
+ // TODO handle this, library seems like maybe it doesn't work with function calls?
+ console.error(err);
+ }
+ }
+
+ const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[model];
+ let cost = undefined;
+ if (promptTokenPrice && completionTokenPrice && inputTokens && outputTokens) {
+ cost = inputTokens * promptTokenPrice + outputTokens * completionTokenPrice;
+ }
+
+ return { inputTokens: inputTokens, outputTokens: outputTokens, cost };
+ },
...frontendModelProvider,
};
diff --git a/app/src/modelProviders/replicate-llama2/index.ts b/app/src/modelProviders/replicate-llama2/index.ts
index b2397ce..9e68eca 100644
--- a/app/src/modelProviders/replicate-llama2/index.ts
+++ b/app/src/modelProviders/replicate-llama2/index.ts
@@ -75,6 +75,10 @@ const modelProvider: ReplicateLlama2Provider = {
},
canStream: true,
getCompletion,
+ getUsage: (input, output) => {
+ // TODO: add usage logic
+ return null;
+ },
...frontendModelProvider,
};
diff --git a/app/src/modelProviders/types.ts b/app/src/modelProviders/types.ts
index 5e5bf26..6b5e09e 100644
--- a/app/src/modelProviders/types.ts
+++ b/app/src/modelProviders/types.ts
@@ -43,9 +43,6 @@ export type CompletionResponse =
value: T;
timeToComplete: number;
statusCode: number;
- promptTokens?: number;
- completionTokens?: number;
- cost?: number;
};
export type ModelProvider = {
@@ -56,6 +53,10 @@ export type ModelProvider void) | null,
) => Promise>;
+ getUsage: (
+ input: InputSchema,
+ output: OutputSchema,
+ ) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null;
// This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null;
diff --git a/app/src/pages/logged-calls/index.tsx b/app/src/pages/logged-calls/index.tsx
index 8acb8ea..f19f90b 100644
--- a/app/src/pages/logged-calls/index.tsx
+++ b/app/src/pages/logged-calls/index.tsx
@@ -18,26 +18,15 @@ import {
Breadcrumb,
BreadcrumbItem,
} from "@chakra-ui/react";
-import {
- LineChart,
- Line,
- XAxis,
- YAxis,
- CartesianGrid,
- Tooltip,
- Legend,
- ResponsiveContainer,
-} from "recharts";
import { Ban, DollarSign, Hash } from "lucide-react";
-import { useMemo } from "react";
import AppShell from "~/components/nav/AppShell";
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
import { useSelectedProject } from "~/utils/hooks";
-import dayjs from "~/utils/dayjs";
import { api } from "~/utils/api";
import LoggedCallTable from "~/components/dashboard/LoggedCallTable";
+import UsageGraph from "~/components/dashboard/UsageGraph";
export default function LoggedCalls() {
const { data: selectedProject } = useSelectedProject();
@@ -47,16 +36,6 @@ export default function LoggedCalls() {
{ enabled: !!selectedProject },
);
- const data = useMemo(() => {
- return (
- stats.data?.periods.map(({ period, numQueries, totalCost }) => ({
- period,
- Requests: numQueries,
- "Total Spent (USD)": parseFloat(totalCost.toString()),
- })) || []
- );
- }, [stats.data]);
-
return (
@@ -83,39 +62,7 @@ export default function LoggedCalls() {
-
-
- dayjs(str).format("MMM D")}
- />
-
-
-
-
-
-
-
-
-
+
@@ -127,7 +74,7 @@ export default function LoggedCalls() {
- ${parseFloat(stats.data?.totals?.totalCost?.toString() ?? "0").toFixed(2)}
+ ${parseFloat(stats.data?.totals?.cost?.toString() ?? "0").toFixed(3)}
diff --git a/app/src/pages/project/settings/index.tsx b/app/src/pages/project/settings/index.tsx
index 0896f32..7fbaf4d 100644
--- a/app/src/pages/project/settings/index.tsx
+++ b/app/src/pages/project/settings/index.tsx
@@ -38,7 +38,10 @@ export default function Settings() {
id: selectedProject.id,
updates: { name },
});
- await Promise.all([utils.projects.get.invalidate({ id: selectedProject.id })]);
+ await Promise.all([
+ utils.projects.get.invalidate({ id: selectedProject.id }),
+ utils.projects.list.invalidate(),
+ ]);
}
}, [updateMutation, selectedProject]);
diff --git a/app/src/server/api/routers/dashboard.router.ts b/app/src/server/api/routers/dashboard.router.ts
index 6c9be03..8508e27 100644
--- a/app/src/server/api/routers/dashboard.router.ts
+++ b/app/src/server/api/routers/dashboard.router.ts
@@ -24,9 +24,9 @@ export const dashboardRouter = createTRPCRouter({
)
.where("projectId", "=", input.projectId)
.select(({ fn }) => [
- sql`date_trunc('day', "LoggedCallModelResponse"."startTime")`.as("period"),
+ sql`date_trunc('day', "LoggedCallModelResponse"."requestedAt")`.as("period"),
sql`count("LoggedCall"."id")::int`.as("numQueries"),
- fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql`0`)).as("totalCost"),
+ fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql`0`)).as("cost"),
])
.groupBy("period")
.orderBy("period")
@@ -57,7 +57,7 @@ export const dashboardRouter = createTRPCRouter({
backfilledPeriods.unshift({
period: dayjs(dayToMatch).toDate(),
numQueries: 0,
- totalCost: 0,
+ cost: 0,
});
}
dayToMatch = dayToMatch.subtract(1, "day");
@@ -72,7 +72,7 @@ export const dashboardRouter = createTRPCRouter({
)
.where("projectId", "=", input.projectId)
.select(({ fn }) => [
- fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql`0`)).as("totalCost"),
+ fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql`0`)).as("cost"),
fn.count("LoggedCall.id").as("numQueries"),
])
.executeTakeFirst();
@@ -85,8 +85,8 @@ export const dashboardRouter = createTRPCRouter({
"LoggedCall.id",
"LoggedCallModelResponse.originalLoggedCallId",
)
- .select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "respStatus as code"])
- .where("respStatus", ">", 200)
+ .select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "statusCode as code"])
+ .where("statusCode", ">", 200)
.groupBy("code")
.orderBy("count", "desc")
.execute();
@@ -108,7 +108,7 @@ export const dashboardRouter = createTRPCRouter({
// https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758
loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => {
const loggedCalls = await prisma.loggedCall.findMany({
- orderBy: { startTime: "desc" },
+ orderBy: { requestedAt: "desc" },
include: { tags: true, modelResponse: true },
take: 20,
});
diff --git a/app/src/server/api/routers/experiments.router.ts b/app/src/server/api/routers/experiments.router.ts
index 17dd991..f707a9b 100644
--- a/app/src/server/api/routers/experiments.router.ts
+++ b/app/src/server/api/routers/experiments.router.ts
@@ -227,7 +227,7 @@ export const experimentsRouter = createTRPCRouter({
...modelResponseData,
id: newModelResponseId,
scenarioVariantCellId: newCellId,
- output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
+ respPayload: (modelResponse.respPayload as Prisma.InputJsonValue) ?? undefined,
});
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({
diff --git a/app/src/server/api/routers/externalApi.router.ts b/app/src/server/api/routers/externalApi.router.ts
index 5eff9dd..aecbbbc 100644
--- a/app/src/server/api/routers/externalApi.router.ts
+++ b/app/src/server/api/routers/externalApi.router.ts
@@ -7,6 +7,11 @@ import { TRPCError } from "@trpc/server";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { hashRequest } from "~/server/utils/hashObject";
+import modelProvider from "~/modelProviders/openai-ChatCompletion";
+import {
+ type ChatCompletion,
+ type CompletionCreateParams,
+} from "openai/resources/chat/completions";
const reqValidator = z.object({
model: z.string(),
@@ -16,11 +21,6 @@ const reqValidator = z.object({
const respValidator = z.object({
id: z.string(),
model: z.string(),
- usage: z.object({
- total_tokens: z.number(),
- prompt_tokens: z.number(),
- completion_tokens: z.number(),
- }),
choices: z.array(
z.object({
finish_reason: z.string(),
@@ -76,7 +76,7 @@ export const externalApiRouter = createTRPCRouter({
originalLoggedCall: true,
},
orderBy: {
- startTime: "desc",
+ requestedAt: "desc",
},
});
@@ -85,7 +85,7 @@ export const externalApiRouter = createTRPCRouter({
await prisma.loggedCall.create({
data: {
projectId: key.projectId,
- startTime: new Date(input.startTime),
+ requestedAt: new Date(input.startTime),
cacheHit: true,
modelResponseId: existingResponse.id,
},
@@ -140,14 +140,20 @@ export const externalApiRouter = createTRPCRouter({
const newLoggedCallId = uuidv4();
const newModelResponseId = uuidv4();
- const usage = respPayload.success ? respPayload.data.usage : undefined;
+ let usage;
+ if (reqPayload.success && respPayload.success) {
+ usage = modelProvider.getUsage(
+ input.reqPayload as CompletionCreateParams,
+ input.respPayload as ChatCompletion,
+ );
+ }
await prisma.$transaction([
prisma.loggedCall.create({
data: {
id: newLoggedCallId,
projectId: key.projectId,
- startTime: new Date(input.startTime),
+ requestedAt: new Date(input.startTime),
cacheHit: false,
},
}),
@@ -155,20 +161,17 @@ export const externalApiRouter = createTRPCRouter({
data: {
id: newModelResponseId,
originalLoggedCallId: newLoggedCallId,
- startTime: new Date(input.startTime),
- endTime: new Date(input.endTime),
+ requestedAt: new Date(input.startTime),
+ receivedAt: new Date(input.endTime),
reqPayload: input.reqPayload as Prisma.InputJsonValue,
respPayload: input.respPayload as Prisma.InputJsonValue,
- respStatus: input.respStatus,
- error: input.error,
+ statusCode: input.respStatus,
+ errorMessage: input.error,
durationMs: input.endTime - input.startTime,
- ...(respPayload.success
- ? {
- cacheKey: requestHash,
- inputTokens: usage ? usage.prompt_tokens : undefined,
- outputTokens: usage ? usage.completion_tokens : undefined,
- }
- : null),
+ cacheKey: respPayload.success ? requestHash : null,
+ inputTokens: usage?.inputTokens,
+ outputTokens: usage?.outputTokens,
+ cost: usage?.cost,
},
}),
// Avoid foreign key constraint error by updating the logged call after the model response is created
@@ -182,24 +185,22 @@ export const externalApiRouter = createTRPCRouter({
}),
]);
- if (input.tags) {
- const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({
- loggedCallId: newLoggedCallId,
- // sanitize tags
- name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
- value,
- }));
+ const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({
+ loggedCallId: newLoggedCallId,
+ // sanitize tags
+ name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
+ value,
+ }));
- if (reqPayload.success) {
- tagsToCreate.push({
- loggedCallId: newLoggedCallId,
- name: "$model",
- value: reqPayload.data.model,
- });
- }
- await prisma.loggedCallTag.createMany({
- data: tagsToCreate,
+ if (reqPayload.success) {
+ tagsToCreate.push({
+ loggedCallId: newLoggedCallId,
+ name: "$model",
+ value: reqPayload.data.model,
});
}
+ await prisma.loggedCallTag.createMany({
+ data: tagsToCreate,
+ });
}),
});
diff --git a/app/src/server/api/routers/promptVariants.router.ts b/app/src/server/api/routers/promptVariants.router.ts
index f19bbd0..7a01da1 100644
--- a/app/src/server/api/routers/promptVariants.router.ts
+++ b/app/src/server/api/routers/promptVariants.router.ts
@@ -55,7 +55,7 @@ export const promptVariantsRouter = createTRPCRouter({
where: {
modelResponse: {
outdated: false,
- output: { not: Prisma.AnyNull },
+ respPayload: { not: Prisma.AnyNull },
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
@@ -100,7 +100,7 @@ export const promptVariantsRouter = createTRPCRouter({
modelResponses: {
some: {
outdated: false,
- output: {
+ respPayload: {
not: Prisma.AnyNull,
},
},
@@ -111,7 +111,7 @@ export const promptVariantsRouter = createTRPCRouter({
const overallTokens = await prisma.modelResponse.aggregate({
where: {
outdated: false,
- output: {
+ respPayload: {
not: Prisma.AnyNull,
},
scenarioVariantCell: {
@@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({
},
_sum: {
cost: true,
- promptTokens: true,
- completionTokens: true,
+ inputTokens: true,
+ outputTokens: true,
},
});
- const promptTokens = overallTokens._sum?.promptTokens ?? 0;
- const completionTokens = overallTokens._sum?.completionTokens ?? 0;
+ const inputTokens = overallTokens._sum?.inputTokens ?? 0;
+ const outputTokens = overallTokens._sum?.outputTokens ?? 0;
const awaitingEvals = !!evalResults.find(
(result) => result.totalCount < scenarioCount * evals.length,
@@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({
return {
evalResults,
- promptTokens,
- completionTokens,
+ inputTokens,
+ outputTokens,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
outputCount,
diff --git a/app/src/server/tasks/queryModel.task.ts b/app/src/server/tasks/queryModel.task.ts
index d7a5dc8..4580f2c 100644
--- a/app/src/server/tasks/queryModel.task.ts
+++ b/app/src/server/tasks/queryModel.task.ts
@@ -99,26 +99,27 @@ export const queryModel = defineTask("queryModel", async (task) =
}
: null;
- const inputHash = hashObject(prompt as JsonValue);
+ const cacheKey = hashObject(prompt as JsonValue);
let modelResponse = await prisma.modelResponse.create({
data: {
- inputHash,
+ cacheKey,
scenarioVariantCellId: cellId,
requestedAt: new Date(),
},
});
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
+ const usage = provider.getUsage(prompt.modelInput, response.value);
modelResponse = await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
- output: response.value as Prisma.InputJsonObject,
+ respPayload: response.value as Prisma.InputJsonObject,
statusCode: response.statusCode,
receivedAt: new Date(),
- promptTokens: response.promptTokens,
- completionTokens: response.completionTokens,
- cost: response.cost,
+ inputTokens: usage?.inputTokens,
+ outputTokens: usage?.outputTokens,
+ cost: usage?.cost,
},
});
diff --git a/app/src/server/utils/evaluations.ts b/app/src/server/utils/evaluations.ts
index 9259d91..f3039f0 100644
--- a/app/src/server/utils/evaluations.ts
+++ b/app/src/server/utils/evaluations.ts
@@ -51,7 +51,7 @@ export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelResponse.findMany({
where: {
outdated: false,
- output: {
+ respPayload: {
not: Prisma.AnyNull,
},
scenarioVariantCell: {
diff --git a/app/src/server/utils/generateNewCell.ts b/app/src/server/utils/generateNewCell.ts
index 858781e..678740d 100644
--- a/app/src/server/utils/generateNewCell.ts
+++ b/app/src/server/utils/generateNewCell.ts
@@ -57,7 +57,7 @@ export const generateNewCell = async (
return;
}
- const inputHash = hashObject(parsedConstructFn);
+ const cacheKey = hashObject(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({
data: {
@@ -73,8 +73,8 @@ export const generateNewCell = async (
const matchingModelResponse = await prisma.modelResponse.findFirst({
where: {
- inputHash,
- output: {
+ cacheKey,
+ respPayload: {
not: Prisma.AnyNull,
},
},
@@ -92,7 +92,7 @@ export const generateNewCell = async (
data: {
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
scenarioVariantCellId: cell.id,
- output: matchingModelResponse.output as Prisma.InputJsonValue,
+ respPayload: matchingModelResponse.respPayload as Prisma.InputJsonValue,
},
});
diff --git a/app/src/server/utils/runOneEval.ts b/app/src/server/utils/runOneEval.ts
index a65f417..87f4664 100644
--- a/app/src/server/utils/runOneEval.ts
+++ b/app/src/server/utils/runOneEval.ts
@@ -71,7 +71,7 @@ export const runOneEval = async (
provider: SupportedProvider,
): Promise<{ result: number; details?: string }> => {
const modelProvider = modelProviders[provider];
- const message = modelProvider.normalizeOutput(modelResponse.output);
+ const message = modelProvider.normalizeOutput(modelResponse.respPayload);
if (!message) return { result: 0 };