77 lines
2.2 KiB
TypeScript
77 lines
2.2 KiB
TypeScript
import { z } from "zod";
|
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
|
import { prisma } from "~/server/db";
|
|
import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate";
|
|
import { type JSONSerializable } from "~/server/types";
|
|
import { getChatCompletion } from "~/server/utils/openai";
|
|
import crypto from "crypto";
|
|
import type { Prisma } from "@prisma/client";
|
|
import { env } from "~/env.mjs";
|
|
|
|
env;
|
|
|
|
export const modelOutputsRouter = createTRPCRouter({
|
|
get: publicProcedure
|
|
.input(z.object({ scenarioId: z.string(), variantId: z.string() }))
|
|
.query(async ({ input }) => {
|
|
const existing = await prisma.modelOutput.findUnique({
|
|
where: {
|
|
promptVariantId_testScenarioId: {
|
|
promptVariantId: input.variantId,
|
|
testScenarioId: input.scenarioId,
|
|
},
|
|
},
|
|
});
|
|
|
|
if (existing) return existing;
|
|
|
|
const variant = await prisma.promptVariant.findUnique({
|
|
where: {
|
|
id: input.variantId,
|
|
},
|
|
});
|
|
|
|
const scenario = await prisma.testScenario.findUnique({
|
|
where: {
|
|
id: input.scenarioId,
|
|
},
|
|
});
|
|
|
|
if (!variant || !scenario) return null;
|
|
|
|
const filledTemplate = fillTemplate(
|
|
variant.config as JSONSerializable,
|
|
scenario.variableValues as VariableMap
|
|
);
|
|
|
|
const inputHash = crypto
|
|
.createHash("sha256")
|
|
.update(JSON.stringify(filledTemplate))
|
|
.digest("hex");
|
|
|
|
// TODO: we should probably only use this if temperature=0
|
|
const existingResponse = await prisma.modelOutput.findFirst({
|
|
where: { inputHash },
|
|
});
|
|
|
|
let modelResponse: JSONSerializable;
|
|
|
|
if (existingResponse) {
|
|
modelResponse = existingResponse.output as JSONSerializable;
|
|
} else {
|
|
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY);
|
|
}
|
|
|
|
const modelOutput = await prisma.modelOutput.create({
|
|
data: {
|
|
promptVariantId: input.variantId,
|
|
testScenarioId: input.scenarioId,
|
|
output: modelResponse as Prisma.InputJsonObject,
|
|
inputHash,
|
|
},
|
|
});
|
|
|
|
return modelOutput;
|
|
}),
|
|
});
|