Prep for more model providers

Adds a `modelProvider` field to `promptVariants`, currently just set to "openai/ChatCompletion" for all variants for now.

Adds a `modelProviders/` directory where we can define and store pluggable model providers. Currently just OpenAI. Not everything is pluggable yet -- notably the code to actually generate completions hasn't been migrated to this setup yet.

Does a lot of work to get the types working. Prompts are now defined with a function `definePrompt(modelProvider, config)` instead of `prompt = config`. Added a script to migrate old prompt definitions.

This is still partial work, but the diff is large enough that I want to get it in. I don't think anything is broken but I haven't tested thoroughly.
This commit is contained in:
Kyle Corbitt
2023-07-20 14:47:39 -07:00
parent 2c8c8d07cf
commit ded6678e97
43 changed files with 1195 additions and 3023 deletions

View File

@@ -1,15 +0,0 @@
import { test } from "vitest";
import { constructPrompt } from "./constructPrompt";
test.skip("constructPrompt", async () => {
const constructed = await constructPrompt(
{
constructFn: `prompt = { "fooz": "bar" }`,
},
{
foo: "bar",
},
);
console.log(constructed);
});

View File

@@ -1,35 +0,0 @@
import { type PromptVariant, type TestScenario } from "@prisma/client";
import ivm from "isolated-vm";
import { type JSONSerializable } from "../types";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function constructPrompt(
variant: Pick<PromptVariant, "constructFn">,
scenario: TestScenario["variableValues"],
): Promise<JSONSerializable> {
const code = `
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
let prompt = {};
${variant.constructFn}
global.prompt = prompt;
`;
console.log("code is", code);
const context = await isolate.createContext();
const jail = context.global;
await jail.set("global", jail.derefInto());
const script = await isolate.compileScript(code);
await script.run(context);
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
const prompt = await promptReference.copy(); // Get the actual value from the isolate
return prompt as JSONSerializable;
}

View File

@@ -1,8 +1,9 @@
import crypto from "crypto";
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import { constructPrompt } from "./constructPrompt";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
const variant = await prisma.promptVariant.findUnique({
@@ -19,10 +20,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
@@ -37,10 +34,29 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
if (cell) return cell;
const parsedConstructFn = await parseConstructFn(
variant.constructFn,
scenario.variableValues as JsonObject,
);
if ("error" in parsedConstructFn) {
return await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
statusCode: 400,
errorMessage: parsedConstructFn.error,
},
});
}
const inputHash = hashPrompt(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
},
include: {
modelOutput: true,
@@ -48,9 +64,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
});
const matchingModelOutput = await prisma.modelOutput.findFirst({
where: {
inputHash,
},
where: { inputHash },
});
let newModelOutput;

View File

@@ -7,7 +7,7 @@ import { type SupportedModel, type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings";
import { modelStats } from "../modelStats";
import { modelStats } from "../../modelProviders/modelStats";
export type CompletionResponse = {
output: ChatCompletion | null;

View File

@@ -1,7 +1,6 @@
import { OpenAIChatModel, type SupportedModel } from "../types";
import openAIChatApiShape from "~/codegen/openai.types.ts.txt";
import { type SupportedModel } from "../types";
export const getApiShapeForModel = (model: SupportedModel) => {
if (model in OpenAIChatModel) return openAIChatApiShape;
// if (model in OpenAIChatModel) return openAIChatApiShape;
return "";
};

View File

@@ -0,0 +1,37 @@
import crypto from "crypto";
import { type JsonValue } from "type-fest";
import { type ParsedConstructFn } from "./parseConstructFn";
function sortKeys(obj: JsonValue): JsonValue {
if (typeof obj !== "object" || obj === null) {
// Not an object or array, return as is
return obj;
}
if (Array.isArray(obj)) {
return obj.map(sortKeys);
}
// Get keys and sort them
const keys = Object.keys(obj).sort();
const sortedObj = {};
for (const key of keys) {
// @ts-expect-error not worth fixing types
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
sortedObj[key] = sortKeys(obj[key]);
}
return sortedObj;
}
export default function hashPrompt(prompt: ParsedConstructFn<any>): string {
// Sort object keys recursively
const sortedObj = sortKeys(prompt as unknown as JsonValue);
// Convert to JSON and hash it
const str = JSON.stringify(sortedObj);
const hash = crypto.createHash("sha256");
hash.update(str);
return hash.digest("hex");
}

View File

@@ -0,0 +1,45 @@
import { expect, test } from "vitest";
import parseConstructFn from "./parseConstructFn";
import assert from "assert";
// Note: this has to be run with `vitest --no-threads` option or else
// isolated-vm seems to throw errors
test("parseConstructFn", async () => {
const constructed = await parseConstructFn(
`
// These sometimes have a comment
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: \`What is the capital of \${scenario.country}?\`
}
]
})
`,
{ country: "Bolivia" },
);
expect(constructed).toEqual({
modelProvider: "openai/ChatCompletion",
model: "gpt-3.5-turbo-0613",
modelInput: {
messages: [
{
content: "What is the capital of Bolivia?",
role: "user",
},
],
model: "gpt-3.5-turbo-0613",
},
});
});
test("bad syntax", async () => {
const parsed = await parseConstructFn(`definePrompt("openai/ChatCompletion", {`);
assert("error" in parsed);
expect(parsed.error).toContain("Unexpected end of input");
});

View File

@@ -0,0 +1,92 @@
import modelProviders from "~/modelProviders";
import ivm from "isolated-vm";
import { isObject, isString } from "lodash-es";
import { type JsonObject } from "type-fest";
import { validate } from "jsonschema";
export type ParsedConstructFn<T extends keyof typeof modelProviders> = {
modelProvider: T;
model: keyof (typeof modelProviders)[T]["models"];
modelInput: Parameters<(typeof modelProviders)[T]["getModel"]>[0];
};
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export default async function parseConstructFn(
constructFn: string,
scenario: JsonObject | undefined = {},
): Promise<ParsedConstructFn<keyof typeof modelProviders> | { error: string }> {
try {
const modifiedConstructFn = constructFn.replace(
"definePrompt(",
"global.prompt = definePrompt(",
);
const code = `
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
const definePrompt = (modelProvider, input) => ({
modelProvider,
input
})
${modifiedConstructFn}
`;
const context = await isolate.createContext();
const jail = context.global;
await jail.set("global", jail.derefInto());
const script = await isolate.compileScript(code);
await script.run(context);
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
const prompt = await promptReference.copy();
if (!isObject(prompt)) {
return { error: "definePrompt did not return an object" };
}
if (!("modelProvider" in prompt) || !isString(prompt.modelProvider)) {
return { error: "definePrompt did not return a valid modelProvider" };
}
const provider =
prompt.modelProvider in modelProviders &&
modelProviders[prompt.modelProvider as keyof typeof modelProviders];
if (!provider) {
return { error: "definePrompt did not return a known modelProvider" };
}
if (!("input" in prompt) || !isObject(prompt.input)) {
return { error: "definePrompt did not return an input" };
}
const validationResult = validate(prompt.input, provider.inputSchema);
if (!validationResult.valid)
return {
error: `definePrompt did not return a valid input: ${validationResult.errors
.map((e) => e.stack)
.join(", ")}`,
};
// We've validated the JSON schema so this should be safe
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
const model = provider.getModel(input);
if (!model) {
return {
error: `definePrompt did not return a known model for the provider ${prompt.modelProvider}`,
};
}
return {
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
model,
modelInput: input,
};
} catch (e) {
const msg =
isObject(e) && "message" in e && isString(e.message)
? e.message
: "unknown error parsing definePrompt script";
return { error: msg };
}
}

View File

@@ -1,7 +0,0 @@
import { isObject } from "lodash-es";
import { type JSONSerializable } from "../types";
export const shouldStream = (config: JSONSerializable): boolean => {
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
return shouldStream;
};