do some completion caching
This commit is contained in:
@@ -74,7 +74,8 @@ model TemplateVariable {
|
|||||||
model ModelOutput {
|
model ModelOutput {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
output Json
|
inputHash String
|
||||||
|
output Json
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id])
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id])
|
||||||
@@ -86,6 +87,7 @@ model ModelOutput {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([promptVariantId, testScenarioId])
|
@@unique([promptVariantId, testScenarioId])
|
||||||
|
@@index([inputHash])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Necessary for Next auth
|
// Necessary for Next auth
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { prisma } from "~/server/db";
|
|||||||
import fillTemplate, { VariableMap } from "~/server/utils/fillTemplate";
|
import fillTemplate, { VariableMap } from "~/server/utils/fillTemplate";
|
||||||
import { JSONSerializable } from "~/server/types";
|
import { JSONSerializable } from "~/server/types";
|
||||||
import { getChatCompletion } from "~/server/utils/openai";
|
import { getChatCompletion } from "~/server/utils/openai";
|
||||||
|
import crypto from "crypto";
|
||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
export const modelOutputsRouter = createTRPCRouter({
|
||||||
get: publicProcedure
|
get: publicProcedure
|
||||||
@@ -39,13 +40,30 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
scenario.variableValues as VariableMap
|
scenario.variableValues as VariableMap
|
||||||
);
|
);
|
||||||
|
|
||||||
const modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!);
|
const inputHash = crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(JSON.stringify(filledTemplate))
|
||||||
|
.digest("hex");
|
||||||
|
|
||||||
|
// TODO: we should probably only use this if temperature=0
|
||||||
|
const existingResponse = await prisma.modelOutput.findFirst({
|
||||||
|
where: { inputHash },
|
||||||
|
});
|
||||||
|
|
||||||
|
let modelResponse: JSONSerializable;
|
||||||
|
|
||||||
|
if (existingResponse) {
|
||||||
|
modelResponse = existingResponse.output as JSONSerializable;
|
||||||
|
} else {
|
||||||
|
modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!);
|
||||||
|
}
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.create({
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
data: {
|
data: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenarioId: input.scenarioId,
|
testScenarioId: input.scenarioId,
|
||||||
output: modelResponse,
|
output: modelResponse,
|
||||||
|
inputHash,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user