Rename 'anthropic' to 'anthropic/completion' (#120)

More consistency in the way we name our model providers.
This commit is contained in:
Kyle Corbitt
2023-08-04 22:07:23 -07:00
committed by GitHub
parent 50e0b34d30
commit 01dcbfc896
14 changed files with 145 additions and 77 deletions

View File

@@ -18,7 +18,8 @@
"start": "next start",
"codegen": "tsx src/codegen/export-openai-types.ts",
"seed": "tsx prisma/seed.ts",
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'",
"test": "pnpm vitest --no-threads"
},
"dependencies": {
"@anthropic-ai/sdk": "^0.5.8",

View File

@@ -5,6 +5,9 @@ set -e
echo "Migrating the database"
pnpm prisma migrate deploy
echo "Migrating constructFns"
pnpm tsx src/server/migratePrompts/index.ts
echo "Starting the server"
pnpm concurrently --kill-others \

View File

@@ -13,7 +13,7 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, Completion> =
promptTokenPrice: 11.02 / 1000000,
completionTokenPrice: 32.68 / 1000000,
speed: "medium",
provider: "anthropic",
provider: "anthropic/completion",
learnMoreUrl: "https://www.anthropic.com/product",
apiDocsUrl: "https://docs.anthropic.com/claude/reference/complete_post",
},
@@ -23,7 +23,7 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, Completion> =
promptTokenPrice: 1.63 / 1000000,
completionTokenPrice: 5.51 / 1000000,
speed: "fast",
provider: "anthropic",
provider: "anthropic/completion",
learnMoreUrl: "https://www.anthropic.com/product",
apiDocsUrl: "https://docs.anthropic.com/claude/reference/complete_post",
},

View File

@@ -1,6 +1,6 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend";
import anthropicFrontend from "./anthropic/frontend";
import anthropicFrontend from "./anthropic-completion/frontend";
import { type SupportedProvider, type FrontendModelProvider } from "./types";
// Keep attributes here that need to be accessible from the frontend. We can't
@@ -9,7 +9,7 @@ import { type SupportedProvider, type FrontendModelProvider } from "./types";
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend,
anthropic: anthropicFrontend,
"anthropic/completion": anthropicFrontend,
};
export default frontendModelProviders;

View File

@@ -1,12 +1,12 @@
import openaiChatCompletion from "./openai-ChatCompletion";
import replicateLlama2 from "./replicate-llama2";
import anthropic from "./anthropic";
import anthropicCompletion from "./anthropic-completion";
import { type SupportedProvider, type ModelProvider } from "./types";
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
"openai/ChatCompletion": openaiChatCompletion,
"replicate/llama2": replicateLlama2,
anthropic,
"anthropic/completion": anthropicCompletion,
};
export default modelProviders;

View File

@@ -6,7 +6,7 @@ import { z } from "zod";
export const ZodSupportedProvider = z.union([
z.literal("openai/ChatCompletion"),
z.literal("replicate/llama2"),
z.literal("anthropic"),
z.literal("anthropic/completion"),
]);
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;

View File

@@ -1,7 +1,7 @@
import "dotenv/config";
import dedent from "dedent";
import { expect, test } from "vitest";
import { migrate1to2 } from "./migrateConstructFns";
import { migrate1to2, migrate2to3 } from ".";
test("migrate1to2", () => {
const constructFn = dedent`
@@ -32,14 +32,25 @@ test("migrate1to2", () => {
]
})
`);
// console.log(
// migrateConstructFn(dedent`definePrompt(
// "openai/ChatCompletion",
// {
// model: 'gpt-3.5-turbo-0613',
// messages: []
// }
// )`),
// );
});
test("migrate2to3", () => {
const constructFn = dedent`
// Test comment
definePrompt("anthropic", {
model: "claude-2.0",
prompt: "What is the capital of China?"
})
`;
const migrated = migrate2to3(constructFn);
expect(migrated).toBe(dedent`
// Test comment
definePrompt("anthropic/completion", {
model: "claude-2.0",
prompt: "What is the capital of China?"
})
`);
});

View File

@@ -0,0 +1,111 @@
import * as recast from "recast";
import { type ASTNode } from "ast-types";
import { prisma } from "../db";
import { fileURLToPath } from "url";
import parseConstructFn from "../utils/parseConstructFn";
const { builders: b } = recast.types;
export const migrate1to2 = (fnBody: string): string => {
const ast: ASTNode = recast.parse(fnBody);
recast.visit(ast, {
visitAssignmentExpression(path) {
const node = path.node;
if ("name" in node.left && node.left.name === "prompt") {
const functionCall = b.callExpression(b.identifier("definePrompt"), [
b.literal("openai/ChatCompletion"),
node.right,
]);
path.replace(functionCall);
}
return false;
},
});
return recast.print(ast).code;
};
export const migrate2to3 = (fnBody: string): string => {
const ast: ASTNode = recast.parse(fnBody);
recast.visit(ast, {
visitCallExpression(path) {
const node = path.node;
// Check if the function being called is 'definePrompt'
if (
recast.types.namedTypes.Identifier.check(node.callee) &&
node.callee.name === "definePrompt" &&
node.arguments.length > 0 &&
recast.types.namedTypes.Literal.check(node.arguments[0]) &&
node.arguments[0].value === "anthropic"
) {
console.log('Migrating "anthropic" to "anthropic/completion"');
node.arguments[0].value = "anthropic/completion";
}
return false;
},
});
return recast.print(ast).code;
};
const migrations: Record<number, (fnBody: string) => string> = {
2: migrate1to2,
3: migrate2to3,
};
const applyMigrations = (constructFn: string, currentVersion: number, targetVersion: number) => {
let migratedFn = constructFn;
for (let v = currentVersion + 1; v <= targetVersion; v++) {
const migrationFn = migrations[v];
if (migrationFn) {
migratedFn = migrationFn(migratedFn);
}
}
return migratedFn;
};
export default async function migrateConstructFns(targetVersion: number) {
const prompts = await prisma.promptVariant.findMany({
where: { constructFnVersion: { lt: targetVersion } },
});
await Promise.all(
prompts.map(async (variant) => {
const currentVersion = variant.constructFnVersion;
try {
const migratedFn = applyMigrations(variant.constructFn, currentVersion, targetVersion);
const parsedFn = await parseConstructFn(migratedFn);
if ("error" in parsedFn) {
throw new Error(parsedFn.error);
}
await prisma.promptVariant.update({
where: {
id: variant.id,
},
data: {
constructFn: migratedFn,
constructFnVersion: targetVersion,
modelProvider: parsedFn.modelProvider,
model: parsedFn.model,
},
});
} catch (e) {
console.error("Error migrating constructFn for variant", variant.id, e);
}
}),
);
}
// If we're running this file directly, run the migration to the latest version
if (process.argv.at(-1) === fileURLToPath(import.meta.url)) {
console.log("Running migration");
const latestVersion = Math.max(...Object.keys(migrations).map(Number));
await migrateConstructFns(latestVersion);
console.log("Done");
}

View File

@@ -1,58 +0,0 @@
import * as recast from "recast";
import { type ASTNode } from "ast-types";
import { prisma } from "../db";
import { fileURLToPath } from "url";
const { builders: b } = recast.types;
export const migrate1to2 = (fnBody: string): string => {
const ast: ASTNode = recast.parse(fnBody);
recast.visit(ast, {
visitAssignmentExpression(path) {
const node = path.node;
if ("name" in node.left && node.left.name === "prompt") {
const functionCall = b.callExpression(b.identifier("definePrompt"), [
b.literal("openai/ChatCompletion"),
node.right,
]);
path.replace(functionCall);
}
return false;
},
});
return recast.print(ast).code;
};
export default async function migrateConstructFns() {
const v1Prompts = await prisma.promptVariant.findMany({
where: {
constructFnVersion: 1,
},
});
console.log(`Migrating ${v1Prompts.length} prompts 1->2`);
await Promise.all(
v1Prompts.map(async (variant) => {
try {
await prisma.promptVariant.update({
where: {
id: variant.id,
},
data: {
constructFn: migrate1to2(variant.constructFn),
constructFnVersion: 2,
},
});
} catch (e) {
console.error("Error migrating constructFn for variant", variant.id, e);
}
}),
);
}
// If we're running this file directly, run the migration
if (process.argv.at(-1) === fileURLToPath(import.meta.url)) {
console.log("Running migration");
await migrateConstructFns();
console.log("Done");
}