From ce783f6279668df7a5949c9476f6583b6db59eb9 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Mon, 26 Jun 2023 14:44:47 -0700 Subject: [PATCH] do some completion caching --- prisma/schema.prisma | 4 +++- src/server/api/routers/modelOutputs.router.ts | 20 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 6cf0c5b..e1956f1 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -74,7 +74,8 @@ model TemplateVariable { model ModelOutput { id String @id @default(uuid()) @db.Uuid - output Json + inputHash String + output Json promptVariantId String @db.Uuid promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id]) @@ -86,6 +87,7 @@ model ModelOutput { updatedAt DateTime @updatedAt @@unique([promptVariantId, testScenarioId]) + @@index([inputHash]) } // Necessary for Next auth diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index 75249cd..cb695e5 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -4,6 +4,7 @@ import { prisma } from "~/server/db"; import fillTemplate, { VariableMap } from "~/server/utils/fillTemplate"; import { JSONSerializable } from "~/server/types"; import { getChatCompletion } from "~/server/utils/openai"; +import crypto from "crypto"; export const modelOutputsRouter = createTRPCRouter({ get: publicProcedure @@ -39,13 +40,30 @@ export const modelOutputsRouter = createTRPCRouter({ scenario.variableValues as VariableMap ); - const modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!); + const inputHash = crypto + .createHash("sha256") + .update(JSON.stringify(filledTemplate)) + .digest("hex"); + + // TODO: we should probably only use this if temperature=0 + const existingResponse = await prisma.modelOutput.findFirst({ + where: { inputHash }, + }); + + let modelResponse: JSONSerializable; + + if (existingResponse) { + modelResponse = existingResponse.output as JSONSerializable; + } else { + modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!); + } const modelOutput = await prisma.modelOutput.create({ data: { promptVariantId: input.variantId, testScenarioId: input.scenarioId, output: modelResponse, + inputHash, }, });