Prep for more model providers
Adds a `modelProvider` field to `promptVariants`, currently just set to "openai/ChatCompletion" for all variants for now. Adds a `modelProviders/` directory where we can define and store pluggable model providers. Currently just OpenAI. Not everything is pluggable yet -- notably the code to actually generate completions hasn't been migrated to this setup yet. Does a lot of work to get the types working. Prompts are now defined with a function `definePrompt(modelProvider, config)` instead of `prompt = config`. Added a script to migrate old prompt definitions. This is still partial work, but the diff is large enough that I want to get it in. I don't think anything is broken but I haven't tested thoroughly.
This commit is contained in:
@@ -1,15 +0,0 @@
|
||||
import { test } from "vitest";
|
||||
import { constructPrompt } from "./constructPrompt";
|
||||
|
||||
test.skip("constructPrompt", async () => {
|
||||
const constructed = await constructPrompt(
|
||||
{
|
||||
constructFn: `prompt = { "fooz": "bar" }`,
|
||||
},
|
||||
{
|
||||
foo: "bar",
|
||||
},
|
||||
);
|
||||
|
||||
console.log(constructed);
|
||||
});
|
||||
@@ -1,35 +0,0 @@
|
||||
import { type PromptVariant, type TestScenario } from "@prisma/client";
|
||||
import ivm from "isolated-vm";
|
||||
import { type JSONSerializable } from "../types";
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function constructPrompt(
|
||||
variant: Pick<PromptVariant, "constructFn">,
|
||||
scenario: TestScenario["variableValues"],
|
||||
): Promise<JSONSerializable> {
|
||||
const code = `
|
||||
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||
let prompt = {};
|
||||
|
||||
${variant.constructFn}
|
||||
|
||||
global.prompt = prompt;
|
||||
`;
|
||||
|
||||
console.log("code is", code);
|
||||
|
||||
const context = await isolate.createContext();
|
||||
|
||||
const jail = context.global;
|
||||
await jail.set("global", jail.derefInto());
|
||||
|
||||
const script = await isolate.compileScript(code);
|
||||
|
||||
await script.run(context);
|
||||
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
|
||||
|
||||
const prompt = await promptReference.copy(); // Get the actual value from the isolate
|
||||
|
||||
return prompt as JSONSerializable;
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
import crypto from "crypto";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||
import { constructPrompt } from "./constructPrompt";
|
||||
import parseConstructFn from "./parseConstructFn";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import hashPrompt from "./hashPrompt";
|
||||
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
@@ -19,10 +20,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
@@ -37,10 +34,29 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
|
||||
if (cell) return cell;
|
||||
|
||||
const parsedConstructFn = await parseConstructFn(
|
||||
variant.constructFn,
|
||||
scenario.variableValues as JsonObject,
|
||||
);
|
||||
|
||||
if ("error" in parsedConstructFn) {
|
||||
return await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
statusCode: 400,
|
||||
errorMessage: parsedConstructFn.error,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const inputHash = hashPrompt(parsedConstructFn);
|
||||
|
||||
cell = await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
@@ -48,9 +64,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
});
|
||||
|
||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
},
|
||||
where: { inputHash },
|
||||
});
|
||||
|
||||
let newModelOutput;
|
||||
|
||||
@@ -7,7 +7,7 @@ import { type SupportedModel, type OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import { modelStats } from "../modelStats";
|
||||
import { modelStats } from "../../modelProviders/modelStats";
|
||||
|
||||
export type CompletionResponse = {
|
||||
output: ChatCompletion | null;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { OpenAIChatModel, type SupportedModel } from "../types";
|
||||
import openAIChatApiShape from "~/codegen/openai.types.ts.txt";
|
||||
import { type SupportedModel } from "../types";
|
||||
|
||||
export const getApiShapeForModel = (model: SupportedModel) => {
|
||||
if (model in OpenAIChatModel) return openAIChatApiShape;
|
||||
// if (model in OpenAIChatModel) return openAIChatApiShape;
|
||||
return "";
|
||||
};
|
||||
|
||||
37
src/server/utils/hashPrompt.ts
Normal file
37
src/server/utils/hashPrompt.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import crypto from "crypto";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { type ParsedConstructFn } from "./parseConstructFn";
|
||||
|
||||
function sortKeys(obj: JsonValue): JsonValue {
|
||||
if (typeof obj !== "object" || obj === null) {
|
||||
// Not an object or array, return as is
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map(sortKeys);
|
||||
}
|
||||
|
||||
// Get keys and sort them
|
||||
const keys = Object.keys(obj).sort();
|
||||
const sortedObj = {};
|
||||
|
||||
for (const key of keys) {
|
||||
// @ts-expect-error not worth fixing types
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
|
||||
sortedObj[key] = sortKeys(obj[key]);
|
||||
}
|
||||
|
||||
return sortedObj;
|
||||
}
|
||||
|
||||
export default function hashPrompt(prompt: ParsedConstructFn<any>): string {
|
||||
// Sort object keys recursively
|
||||
const sortedObj = sortKeys(prompt as unknown as JsonValue);
|
||||
|
||||
// Convert to JSON and hash it
|
||||
const str = JSON.stringify(sortedObj);
|
||||
const hash = crypto.createHash("sha256");
|
||||
hash.update(str);
|
||||
return hash.digest("hex");
|
||||
}
|
||||
45
src/server/utils/parseConstructFn.test.ts
Normal file
45
src/server/utils/parseConstructFn.test.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { expect, test } from "vitest";
|
||||
import parseConstructFn from "./parseConstructFn";
|
||||
import assert from "assert";
|
||||
|
||||
// Note: this has to be run with `vitest --no-threads` option or else
|
||||
// isolated-vm seems to throw errors
|
||||
test("parseConstructFn", async () => {
|
||||
const constructed = await parseConstructFn(
|
||||
`
|
||||
// These sometimes have a comment
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`What is the capital of \${scenario.country}?\`
|
||||
}
|
||||
]
|
||||
})
|
||||
`,
|
||||
{ country: "Bolivia" },
|
||||
);
|
||||
|
||||
expect(constructed).toEqual({
|
||||
modelProvider: "openai/ChatCompletion",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
modelInput: {
|
||||
messages: [
|
||||
{
|
||||
content: "What is the capital of Bolivia?",
|
||||
role: "user",
|
||||
},
|
||||
],
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test("bad syntax", async () => {
|
||||
const parsed = await parseConstructFn(`definePrompt("openai/ChatCompletion", {`);
|
||||
|
||||
assert("error" in parsed);
|
||||
expect(parsed.error).toContain("Unexpected end of input");
|
||||
});
|
||||
92
src/server/utils/parseConstructFn.ts
Normal file
92
src/server/utils/parseConstructFn.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import modelProviders from "~/modelProviders";
|
||||
import ivm from "isolated-vm";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import { validate } from "jsonschema";
|
||||
|
||||
export type ParsedConstructFn<T extends keyof typeof modelProviders> = {
|
||||
modelProvider: T;
|
||||
model: keyof (typeof modelProviders)[T]["models"];
|
||||
modelInput: Parameters<(typeof modelProviders)[T]["getModel"]>[0];
|
||||
};
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export default async function parseConstructFn(
|
||||
constructFn: string,
|
||||
scenario: JsonObject | undefined = {},
|
||||
): Promise<ParsedConstructFn<keyof typeof modelProviders> | { error: string }> {
|
||||
try {
|
||||
const modifiedConstructFn = constructFn.replace(
|
||||
"definePrompt(",
|
||||
"global.prompt = definePrompt(",
|
||||
);
|
||||
|
||||
const code = `
|
||||
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||
|
||||
const definePrompt = (modelProvider, input) => ({
|
||||
modelProvider,
|
||||
input
|
||||
})
|
||||
|
||||
${modifiedConstructFn}
|
||||
`;
|
||||
|
||||
const context = await isolate.createContext();
|
||||
const jail = context.global;
|
||||
await jail.set("global", jail.derefInto());
|
||||
|
||||
const script = await isolate.compileScript(code);
|
||||
await script.run(context);
|
||||
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
|
||||
const prompt = await promptReference.copy();
|
||||
|
||||
if (!isObject(prompt)) {
|
||||
return { error: "definePrompt did not return an object" };
|
||||
}
|
||||
if (!("modelProvider" in prompt) || !isString(prompt.modelProvider)) {
|
||||
return { error: "definePrompt did not return a valid modelProvider" };
|
||||
}
|
||||
|
||||
const provider =
|
||||
prompt.modelProvider in modelProviders &&
|
||||
modelProviders[prompt.modelProvider as keyof typeof modelProviders];
|
||||
if (!provider) {
|
||||
return { error: "definePrompt did not return a known modelProvider" };
|
||||
}
|
||||
if (!("input" in prompt) || !isObject(prompt.input)) {
|
||||
return { error: "definePrompt did not return an input" };
|
||||
}
|
||||
|
||||
const validationResult = validate(prompt.input, provider.inputSchema);
|
||||
if (!validationResult.valid)
|
||||
return {
|
||||
error: `definePrompt did not return a valid input: ${validationResult.errors
|
||||
.map((e) => e.stack)
|
||||
.join(", ")}`,
|
||||
};
|
||||
|
||||
// We've validated the JSON schema so this should be safe
|
||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||
|
||||
const model = provider.getModel(input);
|
||||
if (!model) {
|
||||
return {
|
||||
error: `definePrompt did not return a known model for the provider ${prompt.modelProvider}`,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||
model,
|
||||
modelInput: input,
|
||||
};
|
||||
} catch (e) {
|
||||
const msg =
|
||||
isObject(e) && "message" in e && isString(e.message)
|
||||
? e.message
|
||||
: "unknown error parsing definePrompt script";
|
||||
return { error: msg };
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
import { isObject } from "lodash-es";
|
||||
import { type JSONSerializable } from "../types";
|
||||
|
||||
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
|
||||
return shouldStream;
|
||||
};
|
||||
Reference in New Issue
Block a user