do some completion caching

This commit is contained in:
Kyle Corbitt
2023-06-26 14:44:47 -07:00
parent 3f850bbd7f
commit ce783f6279
2 changed files with 22 additions and 2 deletions

View File

@@ -74,7 +74,8 @@ model TemplateVariable {
model ModelOutput { model ModelOutput {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
output Json inputHash String
output Json
promptVariantId String @db.Uuid promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id]) promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id])
@@ -86,6 +87,7 @@ model ModelOutput {
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([promptVariantId, testScenarioId]) @@unique([promptVariantId, testScenarioId])
@@index([inputHash])
} }
// Necessary for Next auth // Necessary for Next auth

View File

@@ -4,6 +4,7 @@ import { prisma } from "~/server/db";
import fillTemplate, { VariableMap } from "~/server/utils/fillTemplate"; import fillTemplate, { VariableMap } from "~/server/utils/fillTemplate";
import { JSONSerializable } from "~/server/types"; import { JSONSerializable } from "~/server/types";
import { getChatCompletion } from "~/server/utils/openai"; import { getChatCompletion } from "~/server/utils/openai";
import crypto from "crypto";
export const modelOutputsRouter = createTRPCRouter({ export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure get: publicProcedure
@@ -39,13 +40,30 @@ export const modelOutputsRouter = createTRPCRouter({
scenario.variableValues as VariableMap scenario.variableValues as VariableMap
); );
const modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!); const inputHash = crypto
.createHash("sha256")
.update(JSON.stringify(filledTemplate))
.digest("hex");
// TODO: we should probably only use this if temperature=0
const existingResponse = await prisma.modelOutput.findFirst({
where: { inputHash },
});
let modelResponse: JSONSerializable;
if (existingResponse) {
modelResponse = existingResponse.output as JSONSerializable;
} else {
modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!);
}
const modelOutput = await prisma.modelOutput.create({ const modelOutput = await prisma.modelOutput.create({
data: { data: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenarioId: input.scenarioId, testScenarioId: input.scenarioId,
output: modelResponse, output: modelResponse,
inputHash,
}, },
}); });