Rename constructFn to promptConstructor

It's a clearer name. Also reorganize the filesystem so all the promptConstructor related files are colocated.
This commit is contained in:
Kyle Corbitt
2023-08-04 23:09:39 -07:00
parent 01dcbfc896
commit e10589abff
24 changed files with 111 additions and 74 deletions

View File

@@ -0,0 +1,13 @@
/*
Warnings:
- You are about to drop the column `constructFn` on the `PromptVariant` table. All the data in the column will be lost.
- You are about to drop the column `constructFnVersion` on the `PromptVariant` table. All the data in the column will be lost.
- Added the required column `promptConstructor` to the `PromptVariant` table without a default value. This is not possible if the table is not empty.
- Added the required column `promptConstructorVersion` to the `PromptVariant` table without a default value. This is not possible if the table is not empty.
*/
-- AlterTable
ALTER TABLE "PromptVariant" RENAME COLUMN "constructFn" TO "promptConstructor";
ALTER TABLE "PromptVariant" RENAME COLUMN "constructFnVersion" TO "promptConstructorVersion";

View File

@@ -32,8 +32,8 @@ model PromptVariant {
id String @id @default(uuid()) @db.Uuid
label String
constructFn String
constructFnVersion Int
promptConstructor String
promptConstructorVersion Int
model String
modelProvider String

View File

@@ -1,6 +1,7 @@
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { promptConstructorVersion } from "~/promptConstructor/version";
const defaultId = "11111111-1111-1111-1111-111111111111";
@@ -51,8 +52,8 @@ await prisma.promptVariant.createMany({
sortIndex: 0,
model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 1,
constructFn: dedent`
promptConstructorVersion,
promptConstructor: dedent`
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [
@@ -70,8 +71,8 @@ await prisma.promptVariant.createMany({
sortIndex: 1,
model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 1,
constructFn: dedent`
promptConstructorVersion,
promptConstructor: dedent`
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [

View File

@@ -3,6 +3,7 @@ import { generateNewCell } from "~/server/utils/generateNewCell";
import dedent from "dedent";
import { execSync } from "child_process";
import fs from "fs";
import { promptConstructorVersion } from "~/promptConstructor/version";
const defaultId = "11111111-1111-1111-1111-111111111112";
@@ -98,8 +99,8 @@ for (const dataset of datasets) {
sortIndex: 0,
model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 1,
constructFn: dedent`
promptConstructorVersion,
promptConstructor: dedent`
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [

View File

@@ -2,6 +2,7 @@ import { prisma } from "~/server/db";
import dedent from "dedent";
import fs from "fs";
import { parse } from "csv-parse/sync";
import { promptConstructorVersion } from "~/promptConstructor/version";
const defaultId = "11111111-1111-1111-1111-111111111112";
@@ -85,8 +86,8 @@ await prisma.promptVariant.createMany({
sortIndex: 0,
model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 1,
constructFn: dedent`
promptConstructorVersion,
promptConstructor: dedent`
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [

View File

@@ -5,8 +5,8 @@ set -e
echo "Migrating the database"
pnpm prisma migrate deploy
echo "Migrating constructFns"
pnpm tsx src/server/migratePrompts/index.ts
echo "Migrating promptConstructors"
pnpm tsx src/promptConstructor/migrate.ts
echo "Starting the server"

View File

@@ -68,7 +68,7 @@ export const ChangeModelModal = ({
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: modifiedPromptFn,
promptConstructor: modifiedPromptFn,
streamScenarios: visibleScenarios,
});
await utils.promptVariants.list.invalidate();
@@ -107,7 +107,7 @@ export const ChangeModelModal = ({
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && (
<CompareFunctions
originalFunction={variant.constructFn}
originalFunction={variant.promptConstructor}
newFunction={modifiedPromptFn}
leftTitle={originalLabel}
rightTitle={convertedLabel}

View File

@@ -47,7 +47,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
return () => window.removeEventListener("keydown", handleEsc);
}, [isFullscreen, toggleFullscreen]);
const lastSavedFn = props.variant.constructFn;
const lastSavedFn = props.variant.promptConstructor;
const modifierKey = useModifierKeyLabel();
@@ -96,7 +96,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const resp = await replaceVariant.mutateAsync({
id: props.variant.id,
constructFn: currentFn,
promptConstructor: currentFn,
streamScenarios: visibleScenarios,
});
if (resp.status === "error") {

View File

@@ -73,7 +73,7 @@ export const RefinePromptModal = ({
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: refinedPromptFn,
promptConstructor: refinedPromptFn,
streamScenarios: visibleScenarios,
});
await utils.promptVariants.list.invalidate();
@@ -126,7 +126,7 @@ export const RefinePromptModal = ({
/>
</VStack>
<CompareFunctions
originalFunction={variant.constructFn}
originalFunction={variant.promptConstructor}
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
maxH="40vh"
/>

View File

@@ -1,5 +1,5 @@
import { expect, test } from "vitest";
import { stripTypes } from "./formatPromptConstructor";
import { stripTypes } from "./format";
test("stripTypes", () => {
expect(stripTypes(`const foo: string = "bar";`)).toBe(`const foo = "bar";`);

View File

@@ -1,10 +1,10 @@
import "dotenv/config";
import dedent from "dedent";
import { expect, test } from "vitest";
import { migrate1to2, migrate2to3 } from ".";
import { migrate1to2, migrate2to3 } from "./migrate";
test("migrate1to2", () => {
const constructFn = dedent`
const promptConstructor = dedent`
// Test comment
prompt = {
@@ -18,7 +18,7 @@ test("migrate1to2", () => {
}
`;
const migrated = migrate1to2(constructFn);
const migrated = migrate1to2(promptConstructor);
expect(migrated).toBe(dedent`
// Test comment
@@ -35,7 +35,7 @@ test("migrate1to2", () => {
});
test("migrate2to3", () => {
const constructFn = dedent`
const promptConstructor = dedent`
// Test comment
definePrompt("anthropic", {
@@ -44,7 +44,7 @@ test("migrate2to3", () => {
})
`;
const migrated = migrate2to3(constructFn);
const migrated = migrate2to3(promptConstructor);
expect(migrated).toBe(dedent`
// Test comment

View File

@@ -1,8 +1,10 @@
import "dotenv/config";
import * as recast from "recast";
import { type ASTNode } from "ast-types";
import { prisma } from "../db";
import { fileURLToPath } from "url";
import parseConstructFn from "../utils/parseConstructFn";
import parsePromptConstructor from "./parse";
import { prisma } from "~/server/db";
import { promptConstructorVersion } from "./version";
const { builders: b } = recast.types;
export const migrate1to2 = (fnBody: string): string => {
@@ -40,7 +42,6 @@ export const migrate2to3 = (fnBody: string): string => {
recast.types.namedTypes.Literal.check(node.arguments[0]) &&
node.arguments[0].value === "anthropic"
) {
console.log('Migrating "anthropic" to "anthropic/completion"');
node.arguments[0].value = "anthropic/completion";
}
@@ -56,8 +57,12 @@ const migrations: Record<number, (fnBody: string) => string> = {
3: migrate2to3,
};
const applyMigrations = (constructFn: string, currentVersion: number, targetVersion: number) => {
let migratedFn = constructFn;
const applyMigrations = (
promptConstructor: string,
currentVersion: number,
targetVersion: number,
) => {
let migratedFn = promptConstructor;
for (let v = currentVersion + 1; v <= targetVersion; v++) {
const migrationFn = migrations[v];
@@ -71,16 +76,21 @@ const applyMigrations = (constructFn: string, currentVersion: number, targetVers
export default async function migrateConstructFns(targetVersion: number) {
const prompts = await prisma.promptVariant.findMany({
where: { constructFnVersion: { lt: targetVersion } },
where: { promptConstructorVersion: { lt: targetVersion } },
});
console.log(`Migrating ${prompts.length} prompts to version ${targetVersion}`);
await Promise.all(
prompts.map(async (variant) => {
const currentVersion = variant.constructFnVersion;
const currentVersion = variant.promptConstructorVersion;
try {
const migratedFn = applyMigrations(variant.constructFn, currentVersion, targetVersion);
const migratedFn = applyMigrations(
variant.promptConstructor,
currentVersion,
targetVersion,
);
const parsedFn = await parseConstructFn(migratedFn);
const parsedFn = await parsePromptConstructor(migratedFn);
if ("error" in parsedFn) {
throw new Error(parsedFn.error);
}
@@ -89,14 +99,14 @@ export default async function migrateConstructFns(targetVersion: number) {
id: variant.id,
},
data: {
constructFn: migratedFn,
constructFnVersion: targetVersion,
promptConstructor: migratedFn,
promptConstructorVersion: targetVersion,
modelProvider: parsedFn.modelProvider,
model: parsedFn.model,
},
});
} catch (e) {
console.error("Error migrating constructFn for variant", variant.id, e);
console.error("Error migrating promptConstructor for variant", variant.id, e);
}
}),
);
@@ -104,8 +114,12 @@ export default async function migrateConstructFns(targetVersion: number) {
// If we're running this file directly, run the migration to the latest version
if (process.argv.at(-1) === fileURLToPath(import.meta.url)) {
console.log("Running migration");
const latestVersion = Math.max(...Object.keys(migrations).map(Number));
await migrateConstructFns(latestVersion);
if (latestVersion !== promptConstructorVersion) {
throw new Error(
`The latest migration is ${latestVersion}, but the promptConstructorVersion is ${promptConstructorVersion}`,
);
}
await migrateConstructFns(promptConstructorVersion);
console.log("Done");
}

View File

@@ -1,11 +1,11 @@
import { expect, test } from "vitest";
import parseConstructFn from "./parseConstructFn";
import parsePromptConstructor from "./parse";
import assert from "assert";
// Note: this has to be run with `vitest --no-threads` option or else
// isolated-vm seems to throw errors
test("parseConstructFn", async () => {
const constructed = await parseConstructFn(
test("parsePromptConstructor", async () => {
const constructed = await parsePromptConstructor(
`
// These sometimes have a comment
@@ -38,7 +38,7 @@ test("parseConstructFn", async () => {
});
test("bad syntax", async () => {
const parsed = await parseConstructFn(`definePrompt("openai/ChatCompletion", {`);
const parsed = await parsePromptConstructor(`definePrompt("openai/ChatCompletion", {`);
assert("error" in parsed);
expect(parsed.error).toContain("Unexpected end of input");

View File

@@ -4,7 +4,7 @@ import { isObject, isString } from "lodash-es";
import { type JsonObject } from "type-fest";
import { validate } from "jsonschema";
export type ParsedConstructFn<T extends keyof typeof modelProviders> = {
export type ParsedPromptConstructor<T extends keyof typeof modelProviders> = {
modelProvider: T;
model: keyof (typeof modelProviders)[T]["models"];
modelInput: Parameters<(typeof modelProviders)[T]["getModel"]>[0];
@@ -12,12 +12,12 @@ export type ParsedConstructFn<T extends keyof typeof modelProviders> = {
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export default async function parseConstructFn(
constructFn: string,
export default async function parsePromptConstructor(
promptConstructor: string,
scenario: JsonObject | undefined = {},
): Promise<ParsedConstructFn<keyof typeof modelProviders> | { error: string }> {
): Promise<ParsedPromptConstructor<keyof typeof modelProviders> | { error: string }> {
try {
const modifiedConstructFn = constructFn.replace(
const modifiedConstructFn = promptConstructor.replace(
"definePrompt(",
"global.prompt = definePrompt(",
);

View File

@@ -0,0 +1 @@
export const promptConstructorVersion = 3;

View File

@@ -51,7 +51,7 @@ export const autogenerateScenarioValues = async (
messages.push({
role: "user",
content: `Prompt constructor function:\n---\n${prompt.constructFn}`,
content: `Prompt constructor function:\n---\n${prompt.promptConstructor}`,
});
existingScenarios

View File

@@ -13,6 +13,7 @@ import {
} from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
import generateTypes from "~/modelProviders/generateTypes";
import { promptConstructorVersion } from "~/promptConstructor/version";
export const experimentsRouter = createTRPCRouter({
stats: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
@@ -318,7 +319,7 @@ export const experimentsRouter = createTRPCRouter({
sortIndex: 0,
// The interpolated $ is necessary until dedent incorporates
// https://github.com/dmnd/dedent/pull/46
constructFn: dedent`
promptConstructor: dedent`
/**
* Use Javascript to define an OpenAI chat completion
* (https://platform.openai.com/docs/api-reference/chat/create).
@@ -339,7 +340,7 @@ export const experimentsRouter = createTRPCRouter({
});`,
model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 2,
promptConstructorVersion,
},
}),
prisma.templateVariable.create({

View File

@@ -9,9 +9,10 @@ import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
import parseConstructFn from "~/server/utils/parseConstructFn";
import modelProviders from "~/modelProviders/modelProviders";
import { ZodSupportedProvider } from "~/modelProviders/types";
import parsePromptConstructor from "~/promptConstructor/parse";
import { promptConstructorVersion } from "~/promptConstructor/version";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure
@@ -199,8 +200,9 @@ export const promptVariantsRouter = createTRPCRouter({
experimentId: input.experimentId,
label: newVariantLabel,
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn: newConstructFn,
constructFnVersion: 2,
promptConstructor: newConstructFn,
promptConstructorVersion:
originalVariant?.promptConstructorVersion ?? promptConstructorVersion,
model: originalVariant?.model ?? "gpt-3.5-turbo",
modelProvider: originalVariant?.modelProvider ?? "openai/ChatCompletion",
},
@@ -310,7 +312,7 @@ export const promptVariantsRouter = createTRPCRouter({
});
await requireCanModifyExperiment(existing.experimentId, ctx);
const constructedPrompt = await parseConstructFn(existing.constructFn);
const constructedPrompt = await parsePromptConstructor(existing.promptConstructor);
if ("error" in constructedPrompt) {
return userError(constructedPrompt.error);
@@ -332,7 +334,7 @@ export const promptVariantsRouter = createTRPCRouter({
.input(
z.object({
id: z.string(),
constructFn: z.string(),
promptConstructor: z.string(),
streamScenarios: z.array(z.string()),
}),
)
@@ -348,7 +350,7 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
}
const parsedPrompt = await parseConstructFn(input.constructFn);
const parsedPrompt = await parsePromptConstructor(input.promptConstructor);
if ("error" in parsedPrompt) {
return userError(parsedPrompt.error);
@@ -361,8 +363,8 @@ export const promptVariantsRouter = createTRPCRouter({
label: existing.label,
sortIndex: existing.sortIndex,
uiId: existing.uiId,
constructFn: input.constructFn,
constructFnVersion: 2,
promptConstructor: input.promptConstructor,
promptConstructorVersion: existing.promptConstructorVersion,
modelProvider: parsedPrompt.modelProvider,
model: parsedPrompt.model,
},

View File

@@ -5,8 +5,8 @@ import { prisma } from "~/server/db";
import { wsConnection } from "~/utils/wsConnection";
import { runEvalsForOutput } from "../utils/evaluations";
import hashPrompt from "../utils/hashPrompt";
import parseConstructFn from "../utils/parseConstructFn";
import defineTask from "./defineTask";
import parsePromptConstructor from "~/promptConstructor/parse";
export type QueryModelJob = {
cellId: string;
@@ -75,7 +75,10 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
return;
}
const prompt = await parseConstructFn(variant.constructFn, scenario.variableValues as JsonObject);
const prompt = await parsePromptConstructor(
variant.promptConstructor,
scenario.variableValues as JsonObject,
);
if ("error" in prompt) {
await prisma.scenarioVariantCell.update({

View File

@@ -4,7 +4,7 @@ import dedent from "dedent";
import { openai } from "./openai";
import { isObject } from "lodash-es";
import type { CreateChatCompletionRequestMessage } from "openai/resources/chat/completions";
import formatPromptConstructor from "~/utils/formatPromptConstructor";
import formatPromptConstructor from "~/promptConstructor/format";
import { type SupportedProvider, type Model } from "~/modelProviders/types";
import modelProviders from "~/modelProviders/modelProviders";
@@ -16,7 +16,7 @@ export async function deriveNewConstructFn(
instructions?: string,
) {
if (originalVariant && !newModel && !instructions) {
return originalVariant.constructFn;
return originalVariant.promptConstructor;
}
if (originalVariant && (newModel || instructions)) {
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
@@ -55,7 +55,7 @@ const requestUpdatedPromptFunction = async (
},
{
role: "user",
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
content: `This is the current prompt constructor function:\n---\n${originalVariant.promptConstructor}`,
},
];
if (newModel) {

View File

@@ -1,10 +1,10 @@
import { Prisma } from "@prisma/client";
import { prisma } from "../db";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
import { omit } from "lodash-es";
import { queueQueryModel } from "../tasks/queryModel.task";
import parsePromptConstructor from "~/promptConstructor/parse";
export const generateNewCell = async (
variantId: string,
@@ -41,8 +41,8 @@ export const generateNewCell = async (
if (cell) return;
const parsedConstructFn = await parseConstructFn(
variant.constructFn,
const parsedConstructFn = await parsePromptConstructor(
variant.promptConstructor,
scenario.variableValues as JsonObject,
);

View File

@@ -1,6 +1,6 @@
import crypto from "crypto";
import { type JsonValue } from "type-fest";
import { type ParsedConstructFn } from "./parseConstructFn";
import { ParsedPromptConstructor } from "~/promptConstructor/parse";
function sortKeys(obj: JsonValue): JsonValue {
if (typeof obj !== "object" || obj === null) {
@@ -25,7 +25,7 @@ function sortKeys(obj: JsonValue): JsonValue {
return sortedObj;
}
export default function hashPrompt(prompt: ParsedConstructFn<any>): string {
export default function hashPrompt(prompt: ParsedPromptConstructor<any>): string {
// Sort object keys recursively
const sortedObj = sortKeys(prompt as unknown as JsonValue);

View File

@@ -1,7 +1,7 @@
import { type RouterOutputs } from "~/utils/api";
import { type SliceCreator } from "./store";
import loader from "@monaco-editor/loader";
import formatPromptConstructor from "~/utils/formatPromptConstructor";
import formatPromptConstructor from "~/promptConstructor/format";
export const editorBackground = "#fafafa";