Prep for more model providers

Adds a `modelProvider` field to `promptVariants`, currently just set to "openai/ChatCompletion" for all variants for now.

Adds a `modelProviders/` directory where we can define and store pluggable model providers. Currently just OpenAI. Not everything is pluggable yet -- notably the code to actually generate completions hasn't been migrated to this setup yet.

Does a lot of work to get the types working. Prompts are now defined with a function `definePrompt(modelProvider, config)` instead of `prompt = config`. Added a script to migrate old prompt definitions.

This is still partial work, but the diff is large enough that I want to get it in. I don't think anything is broken but I haven't tested thoroughly.
This commit is contained in:
Kyle Corbitt
2023-07-20 14:47:39 -07:00
parent 2c8c8d07cf
commit ded6678e97
43 changed files with 1195 additions and 3023 deletions

View File

@@ -1,8 +1,9 @@
import crypto from "crypto";
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import { constructPrompt } from "./constructPrompt";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
const variant = await prisma.promptVariant.findUnique({
@@ -19,10 +20,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
@@ -37,10 +34,29 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
if (cell) return cell;
const parsedConstructFn = await parseConstructFn(
variant.constructFn,
scenario.variableValues as JsonObject,
);
if ("error" in parsedConstructFn) {
return await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
statusCode: 400,
errorMessage: parsedConstructFn.error,
},
});
}
const inputHash = hashPrompt(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
},
include: {
modelOutput: true,
@@ -48,9 +64,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
});
const matchingModelOutput = await prisma.modelOutput.findFirst({
where: {
inputHash,
},
where: { inputHash },
});
let newModelOutput;