diff --git a/package.json b/package.json index a4ce6de..1855f60 100644 --- a/package.json +++ b/package.json @@ -18,7 +18,8 @@ "start": "next start", "codegen": "tsx src/codegen/export-openai-types.ts", "seed": "tsx prisma/seed.ts", - "check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'" + "check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'", + "test": "pnpm vitest --no-threads" }, "dependencies": { "@anthropic-ai/sdk": "^0.5.8", diff --git a/run-prod.sh b/run-prod.sh index 30ec09c..c10759f 100755 --- a/run-prod.sh +++ b/run-prod.sh @@ -5,6 +5,9 @@ set -e echo "Migrating the database" pnpm prisma migrate deploy +echo "Migrating constructFns" +pnpm tsx src/server/migratePrompts/index.ts + echo "Starting the server" pnpm concurrently --kill-others \ diff --git a/src/modelProviders/anthropic/codegen/codegen.ts b/src/modelProviders/anthropic-completion/codegen/codegen.ts similarity index 100% rename from src/modelProviders/anthropic/codegen/codegen.ts rename to src/modelProviders/anthropic-completion/codegen/codegen.ts diff --git a/src/modelProviders/anthropic/codegen/input.schema.json b/src/modelProviders/anthropic-completion/codegen/input.schema.json similarity index 100% rename from src/modelProviders/anthropic/codegen/input.schema.json rename to src/modelProviders/anthropic-completion/codegen/input.schema.json diff --git a/src/modelProviders/anthropic/frontend.ts b/src/modelProviders/anthropic-completion/frontend.ts similarity index 93% rename from src/modelProviders/anthropic/frontend.ts rename to src/modelProviders/anthropic-completion/frontend.ts index a28aa83..ffd6f95 100644 --- a/src/modelProviders/anthropic/frontend.ts +++ b/src/modelProviders/anthropic-completion/frontend.ts @@ -13,7 +13,7 @@ const frontendModelProvider: FrontendModelProvider = promptTokenPrice: 11.02 / 1000000, completionTokenPrice: 32.68 / 1000000, speed: "medium", - provider: "anthropic", + provider: "anthropic/completion", learnMoreUrl: "https://www.anthropic.com/product", apiDocsUrl: "https://docs.anthropic.com/claude/reference/complete_post", }, @@ -23,7 +23,7 @@ const frontendModelProvider: FrontendModelProvider = promptTokenPrice: 1.63 / 1000000, completionTokenPrice: 5.51 / 1000000, speed: "fast", - provider: "anthropic", + provider: "anthropic/completion", learnMoreUrl: "https://www.anthropic.com/product", apiDocsUrl: "https://docs.anthropic.com/claude/reference/complete_post", }, diff --git a/src/modelProviders/anthropic/getCompletion.ts b/src/modelProviders/anthropic-completion/getCompletion.ts similarity index 100% rename from src/modelProviders/anthropic/getCompletion.ts rename to src/modelProviders/anthropic-completion/getCompletion.ts diff --git a/src/modelProviders/anthropic/index.ts b/src/modelProviders/anthropic-completion/index.ts similarity index 100% rename from src/modelProviders/anthropic/index.ts rename to src/modelProviders/anthropic-completion/index.ts diff --git a/src/modelProviders/anthropic/refinementActions.ts b/src/modelProviders/anthropic-completion/refinementActions.ts similarity index 100% rename from src/modelProviders/anthropic/refinementActions.ts rename to src/modelProviders/anthropic-completion/refinementActions.ts diff --git a/src/modelProviders/frontendModelProviders.ts b/src/modelProviders/frontendModelProviders.ts index fd8ba57..9950e36 100644 --- a/src/modelProviders/frontendModelProviders.ts +++ b/src/modelProviders/frontendModelProviders.ts @@ -1,6 +1,6 @@ import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; import replicateLlama2Frontend from "./replicate-llama2/frontend"; -import anthropicFrontend from "./anthropic/frontend"; +import anthropicFrontend from "./anthropic-completion/frontend"; import { type SupportedProvider, type FrontendModelProvider } from "./types"; // Keep attributes here that need to be accessible from the frontend. We can't @@ -9,7 +9,7 @@ import { type SupportedProvider, type FrontendModelProvider } from "./types"; const frontendModelProviders: Record> = { "openai/ChatCompletion": openaiChatCompletionFrontend, "replicate/llama2": replicateLlama2Frontend, - anthropic: anthropicFrontend, + "anthropic/completion": anthropicFrontend, }; export default frontendModelProviders; diff --git a/src/modelProviders/modelProviders.ts b/src/modelProviders/modelProviders.ts index 044bc98..ae5bac4 100644 --- a/src/modelProviders/modelProviders.ts +++ b/src/modelProviders/modelProviders.ts @@ -1,12 +1,12 @@ import openaiChatCompletion from "./openai-ChatCompletion"; import replicateLlama2 from "./replicate-llama2"; -import anthropic from "./anthropic"; +import anthropicCompletion from "./anthropic-completion"; import { type SupportedProvider, type ModelProvider } from "./types"; const modelProviders: Record> = { "openai/ChatCompletion": openaiChatCompletion, "replicate/llama2": replicateLlama2, - anthropic, + "anthropic/completion": anthropicCompletion, }; export default modelProviders; diff --git a/src/modelProviders/types.ts b/src/modelProviders/types.ts index 842d278..5e5bf26 100644 --- a/src/modelProviders/types.ts +++ b/src/modelProviders/types.ts @@ -6,7 +6,7 @@ import { z } from "zod"; export const ZodSupportedProvider = z.union([ z.literal("openai/ChatCompletion"), z.literal("replicate/llama2"), - z.literal("anthropic"), + z.literal("anthropic/completion"), ]); export type SupportedProvider = z.infer; diff --git a/src/server/scripts/migrateConstructFns.test.ts b/src/server/migratePrompts/index.test.ts similarity index 58% rename from src/server/scripts/migrateConstructFns.test.ts rename to src/server/migratePrompts/index.test.ts index 01e9f2e..1a2baeb 100644 --- a/src/server/scripts/migrateConstructFns.test.ts +++ b/src/server/migratePrompts/index.test.ts @@ -1,7 +1,7 @@ import "dotenv/config"; import dedent from "dedent"; import { expect, test } from "vitest"; -import { migrate1to2 } from "./migrateConstructFns"; +import { migrate1to2, migrate2to3 } from "."; test("migrate1to2", () => { const constructFn = dedent` @@ -32,14 +32,25 @@ test("migrate1to2", () => { ] }) `); - - // console.log( - // migrateConstructFn(dedent`definePrompt( - // "openai/ChatCompletion", - // { - // model: 'gpt-3.5-turbo-0613', - // messages: [] - // } - // )`), - // ); +}); + +test("migrate2to3", () => { + const constructFn = dedent` + // Test comment + + definePrompt("anthropic", { + model: "claude-2.0", + prompt: "What is the capital of China?" + }) + `; + + const migrated = migrate2to3(constructFn); + expect(migrated).toBe(dedent` + // Test comment + + definePrompt("anthropic/completion", { + model: "claude-2.0", + prompt: "What is the capital of China?" + }) + `); }); diff --git a/src/server/migratePrompts/index.ts b/src/server/migratePrompts/index.ts new file mode 100644 index 0000000..e48d01d --- /dev/null +++ b/src/server/migratePrompts/index.ts @@ -0,0 +1,111 @@ +import * as recast from "recast"; +import { type ASTNode } from "ast-types"; +import { prisma } from "../db"; +import { fileURLToPath } from "url"; +import parseConstructFn from "../utils/parseConstructFn"; +const { builders: b } = recast.types; + +export const migrate1to2 = (fnBody: string): string => { + const ast: ASTNode = recast.parse(fnBody); + + recast.visit(ast, { + visitAssignmentExpression(path) { + const node = path.node; + if ("name" in node.left && node.left.name === "prompt") { + const functionCall = b.callExpression(b.identifier("definePrompt"), [ + b.literal("openai/ChatCompletion"), + node.right, + ]); + path.replace(functionCall); + } + return false; + }, + }); + + return recast.print(ast).code; +}; + +export const migrate2to3 = (fnBody: string): string => { + const ast: ASTNode = recast.parse(fnBody); + + recast.visit(ast, { + visitCallExpression(path) { + const node = path.node; + + // Check if the function being called is 'definePrompt' + if ( + recast.types.namedTypes.Identifier.check(node.callee) && + node.callee.name === "definePrompt" && + node.arguments.length > 0 && + recast.types.namedTypes.Literal.check(node.arguments[0]) && + node.arguments[0].value === "anthropic" + ) { + console.log('Migrating "anthropic" to "anthropic/completion"'); + node.arguments[0].value = "anthropic/completion"; + } + + return false; + }, + }); + + return recast.print(ast).code; +}; + +const migrations: Record string> = { + 2: migrate1to2, + 3: migrate2to3, +}; + +const applyMigrations = (constructFn: string, currentVersion: number, targetVersion: number) => { + let migratedFn = constructFn; + + for (let v = currentVersion + 1; v <= targetVersion; v++) { + const migrationFn = migrations[v]; + if (migrationFn) { + migratedFn = migrationFn(migratedFn); + } + } + + return migratedFn; +}; + +export default async function migrateConstructFns(targetVersion: number) { + const prompts = await prisma.promptVariant.findMany({ + where: { constructFnVersion: { lt: targetVersion } }, + }); + await Promise.all( + prompts.map(async (variant) => { + const currentVersion = variant.constructFnVersion; + + try { + const migratedFn = applyMigrations(variant.constructFn, currentVersion, targetVersion); + + const parsedFn = await parseConstructFn(migratedFn); + if ("error" in parsedFn) { + throw new Error(parsedFn.error); + } + await prisma.promptVariant.update({ + where: { + id: variant.id, + }, + data: { + constructFn: migratedFn, + constructFnVersion: targetVersion, + modelProvider: parsedFn.modelProvider, + model: parsedFn.model, + }, + }); + } catch (e) { + console.error("Error migrating constructFn for variant", variant.id, e); + } + }), + ); +} + +// If we're running this file directly, run the migration to the latest version +if (process.argv.at(-1) === fileURLToPath(import.meta.url)) { + console.log("Running migration"); + const latestVersion = Math.max(...Object.keys(migrations).map(Number)); + await migrateConstructFns(latestVersion); + console.log("Done"); +} diff --git a/src/server/scripts/migrateConstructFns.ts b/src/server/scripts/migrateConstructFns.ts deleted file mode 100644 index 8c25051..0000000 --- a/src/server/scripts/migrateConstructFns.ts +++ /dev/null @@ -1,58 +0,0 @@ -import * as recast from "recast"; -import { type ASTNode } from "ast-types"; -import { prisma } from "../db"; -import { fileURLToPath } from "url"; -const { builders: b } = recast.types; - -export const migrate1to2 = (fnBody: string): string => { - const ast: ASTNode = recast.parse(fnBody); - - recast.visit(ast, { - visitAssignmentExpression(path) { - const node = path.node; - if ("name" in node.left && node.left.name === "prompt") { - const functionCall = b.callExpression(b.identifier("definePrompt"), [ - b.literal("openai/ChatCompletion"), - node.right, - ]); - path.replace(functionCall); - } - return false; - }, - }); - - return recast.print(ast).code; -}; - -export default async function migrateConstructFns() { - const v1Prompts = await prisma.promptVariant.findMany({ - where: { - constructFnVersion: 1, - }, - }); - console.log(`Migrating ${v1Prompts.length} prompts 1->2`); - await Promise.all( - v1Prompts.map(async (variant) => { - try { - await prisma.promptVariant.update({ - where: { - id: variant.id, - }, - data: { - constructFn: migrate1to2(variant.constructFn), - constructFnVersion: 2, - }, - }); - } catch (e) { - console.error("Error migrating constructFn for variant", variant.id, e); - } - }), - ); -} - -// If we're running this file directly, run the migration -if (process.argv.at(-1) === fileURLToPath(import.meta.url)) { - console.log("Running migration"); - await migrateConstructFns(); - console.log("Done"); -}