Prep for more model providers
Adds a `modelProvider` field to `promptVariants`, currently just set to "openai/ChatCompletion" for all variants for now. Adds a `modelProviders/` directory where we can define and store pluggable model providers. Currently just OpenAI. Not everything is pluggable yet -- notably the code to actually generate completions hasn't been migrated to this setup yet. Does a lot of work to get the types working. Prompts are now defined with a function `definePrompt(modelProvider, config)` instead of `prompt = config`. Added a script to migrate old prompt definitions. This is still partial work, but the diff is large enough that I want to get it in. I don't think anything is broken but I haven't tested thoroughly.
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import crypto from "crypto";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||
import { constructPrompt } from "./constructPrompt";
|
||||
import parseConstructFn from "./parseConstructFn";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import hashPrompt from "./hashPrompt";
|
||||
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
@@ -19,10 +20,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
@@ -37,10 +34,29 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
|
||||
if (cell) return cell;
|
||||
|
||||
const parsedConstructFn = await parseConstructFn(
|
||||
variant.constructFn,
|
||||
scenario.variableValues as JsonObject,
|
||||
);
|
||||
|
||||
if ("error" in parsedConstructFn) {
|
||||
return await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
statusCode: 400,
|
||||
errorMessage: parsedConstructFn.error,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const inputHash = hashPrompt(parsedConstructFn);
|
||||
|
||||
cell = await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
@@ -48,9 +64,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
});
|
||||
|
||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
},
|
||||
where: { inputHash },
|
||||
});
|
||||
|
||||
let newModelOutput;
|
||||
|
||||
Reference in New Issue
Block a user