diff --git a/package.json b/package.json index 9ae22fa..ceec72b 100644 --- a/package.json +++ b/package.json @@ -73,6 +73,7 @@ "react-syntax-highlighter": "^15.5.0", "react-textarea-autosize": "^8.5.0", "recast": "^0.23.3", + "replicate": "^0.12.3", "socket.io": "^4.7.1", "socket.io-client": "^4.7.1", "superjson": "1.12.2", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a563e66..d2d34b7 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -161,6 +161,9 @@ dependencies: recast: specifier: ^0.23.3 version: 0.23.3 + replicate: + specifier: ^0.12.3 + version: 0.12.3 socket.io: specifier: ^4.7.1 version: 4.7.1 @@ -6988,6 +6991,11 @@ packages: functions-have-names: 1.2.3 dev: true + /replicate@0.12.3: + resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==} + engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'} + dev: false + /require-directory@2.1.1: resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==} engines: {node: '>=0.10.0'} diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index c9b2c00..d9b9d81 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -88,9 +88,11 @@ export default function OutputCell({ } const normalizedOutput = modelOutput - ? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema) + ? // @ts-expect-error TODO FIX ASAP + provider.normalizeOutput(modelOutput.output as unknown as OutputSchema) : streamedMessage - ? provider.normalizeOutput(streamedMessage) + ? // @ts-expect-error TODO FIX ASAP + provider.normalizeOutput(streamedMessage) : null; if (modelOutput && normalizedOutput?.type === "json") { diff --git a/src/env.mjs b/src/env.mjs index 2032c08..8458917 100644 --- a/src/env.mjs +++ b/src/env.mjs @@ -17,6 +17,7 @@ export const env = createEnv({ .transform((val) => val.toLowerCase() === "true"), GITHUB_CLIENT_ID: z.string().min(1), GITHUB_CLIENT_SECRET: z.string().min(1), + REPLICATE_API_TOKEN: z.string().min(1), }, /** @@ -42,6 +43,7 @@ export const env = createEnv({ NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL, GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET, + REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN, }, /** * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. diff --git a/src/modelProviders/modelProviders.ts b/src/modelProviders/modelProviders.ts index 326acb8..0013943 100644 --- a/src/modelProviders/modelProviders.ts +++ b/src/modelProviders/modelProviders.ts @@ -1,7 +1,9 @@ import openaiChatCompletion from "./openai-ChatCompletion"; +import replicateLlama2 from "./replicate-llama2"; const modelProviders = { "openai/ChatCompletion": openaiChatCompletion, + "replicate/llama2": replicateLlama2, } as const; export default modelProviders; diff --git a/src/modelProviders/modelProvidersFrontend.ts b/src/modelProviders/modelProvidersFrontend.ts index 42d6d7d..e1ef03c 100644 --- a/src/modelProviders/modelProvidersFrontend.ts +++ b/src/modelProviders/modelProvidersFrontend.ts @@ -1,10 +1,14 @@ -import modelProviderFrontend from "./openai-ChatCompletion/frontend"; +import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; +import replicateLlama2Frontend from "./replicate-llama2/frontend"; + +// TODO: make sure we get a typescript error if you forget to add a provider here // Keep attributes here that need to be accessible from the frontend. We can't // just include them in the default `modelProviders` object because it has some // transient dependencies that can only be imported on the server. const modelProvidersFrontend = { - "openai/ChatCompletion": modelProviderFrontend, + "openai/ChatCompletion": openaiChatCompletionFrontend, + "replicate/llama2": replicateLlama2Frontend, } as const; export default modelProvidersFrontend; diff --git a/src/modelProviders/replicate-llama2/frontend.ts b/src/modelProviders/replicate-llama2/frontend.ts new file mode 100644 index 0000000..e7f44eb --- /dev/null +++ b/src/modelProviders/replicate-llama2/frontend.ts @@ -0,0 +1,13 @@ +import { type ReplicateLlama2Provider } from "."; +import { type ModelProviderFrontend } from "../types"; + +const modelProviderFrontend: ModelProviderFrontend = { + normalizeOutput: (output) => { + return { + type: "text", + value: output.join(""), + }; + }, +}; + +export default modelProviderFrontend; diff --git a/src/modelProviders/replicate-llama2/getCompletion.ts b/src/modelProviders/replicate-llama2/getCompletion.ts new file mode 100644 index 0000000..4431e41 --- /dev/null +++ b/src/modelProviders/replicate-llama2/getCompletion.ts @@ -0,0 +1,62 @@ +import { env } from "~/env.mjs"; +import { type ReplicateLlama2Input, type ReplicateLlama2Output } from "."; +import { type CompletionResponse } from "../types"; +import Replicate from "replicate"; + +const replicate = new Replicate({ + auth: env.REPLICATE_API_TOKEN || "", +}); + +const modelIds: Record = { + "7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a", + "13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3", + "70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48", +}; + +export async function getCompletion( + input: ReplicateLlama2Input, + onStream: ((partialOutput: string[]) => void) | null, +): Promise> { + const start = Date.now(); + + const { model, stream, ...rest } = input; + + try { + const prediction = await replicate.predictions.create({ + version: modelIds[model], + input: rest, + }); + + console.log("stream?", onStream); + + const interval = onStream + ? // eslint-disable-next-line @typescript-eslint/no-misused-promises + setInterval(async () => { + const partialPrediction = await replicate.predictions.get(prediction.id); + + if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output); + }, 500) + : null; + + const resp = await replicate.wait(prediction, {}); + if (interval) clearInterval(interval); + + const timeToComplete = Date.now() - start; + + if (resp.error) throw new Error(resp.error as string); + + return { + type: "success", + statusCode: 200, + value: resp.output as ReplicateLlama2Output, + timeToComplete, + }; + } catch (error: unknown) { + console.error("ERROR IS", error); + return { + type: "error", + message: (error as Error).message, + autoRetry: true, + }; + } +} diff --git a/src/modelProviders/replicate-llama2/index.ts b/src/modelProviders/replicate-llama2/index.ts new file mode 100644 index 0000000..49e1d1e --- /dev/null +++ b/src/modelProviders/replicate-llama2/index.ts @@ -0,0 +1,74 @@ +import { type ModelProvider } from "../types"; +import { getCompletion } from "./getCompletion"; + +const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const; + +type SupportedModel = (typeof supportedModels)[number]; + +export type ReplicateLlama2Input = { + model: SupportedModel; + prompt: string; + stream?: boolean; + max_length?: number; + temperature?: number; + top_p?: number; + repetition_penalty?: number; + debug?: boolean; +}; + +export type ReplicateLlama2Output = string[]; + +export type ReplicateLlama2Provider = ModelProvider< + SupportedModel, + ReplicateLlama2Input, + ReplicateLlama2Output +>; + +const modelProvider: ReplicateLlama2Provider = { + name: "OpenAI ChatCompletion", + models: { + "7b-chat": {}, + "13b-chat": {}, + "70b-chat": {}, + }, + getModel: (input) => { + if (supportedModels.includes(input.model)) return input.model; + + return null; + }, + inputSchema: { + type: "object", + properties: { + model: { + type: "string", + enum: supportedModels as unknown as string[], + }, + prompt: { + type: "string", + }, + stream: { + type: "boolean", + }, + max_length: { + type: "number", + }, + temperature: { + type: "number", + }, + top_p: { + type: "number", + }, + repetition_penalty: { + type: "number", + }, + debug: { + type: "boolean", + }, + }, + required: ["model", "prompt"], + }, + shouldStream: (input) => input.stream ?? false, + getCompletion, +}; + +export default modelProvider; diff --git a/src/modelProviders/types.ts b/src/modelProviders/types.ts index edcbe85..03cd846 100644 --- a/src/modelProviders/types.ts +++ b/src/modelProviders/types.ts @@ -2,8 +2,8 @@ import { type JSONSchema4 } from "json-schema"; import { type JsonValue } from "type-fest"; type ModelProviderModel = { - name: string; - learnMore: string; + name?: string; + learnMore?: string; }; export type CompletionResponse = diff --git a/src/server/api/routers/experiments.router.ts b/src/server/api/routers/experiments.router.ts index f3f2a88..e21f34c 100644 --- a/src/server/api/routers/experiments.router.ts +++ b/src/server/api/routers/experiments.router.ts @@ -109,8 +109,7 @@ export const experimentsRouter = createTRPCRouter({ constructFn: dedent` /** * Use Javascript to define an OpenAI chat completion - * (https://platform.openai.com/docs/api-reference/chat/create) and - * assign it to the \`prompt\` variable. + * (https://platform.openai.com/docs/api-reference/chat/create). * * You have access to the current scenario in the \`scenario\` * variable. diff --git a/src/server/scripts/replicate-test.ts b/src/server/scripts/replicate-test.ts index 2126f7e..320ffb6 100644 --- a/src/server/scripts/replicate-test.ts +++ b/src/server/scripts/replicate-test.ts @@ -1,26 +1,26 @@ -// /* eslint-disable */ +/* eslint-disable */ -// import "dotenv/config"; -// import Replicate from "replicate"; +import "dotenv/config"; +import Replicate from "replicate"; -// const replicate = new Replicate({ -// auth: process.env.REPLICATE_API_TOKEN || "", -// }); +const replicate = new Replicate({ + auth: process.env.REPLICATE_API_TOKEN || "", +}); -// console.log("going to run"); -// const prediction = await replicate.predictions.create({ -// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48", -// input: { -// prompt: "...", -// }, -// }); +console.log("going to run"); +const prediction = await replicate.predictions.create({ + version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a", + input: { + prompt: "...", + }, +}); -// console.log("waiting"); -// setInterval(() => { -// replicate.predictions.get(prediction.id).then((prediction) => { -// console.log(prediction.output); -// }); -// }, 500); -// // const output = await replicate.wait(prediction, {}); +console.log("waiting"); +setInterval(() => { + replicate.predictions.get(prediction.id).then((prediction) => { + console.log(prediction); + }); +}, 500); +// const output = await replicate.wait(prediction, {}); -// // console.log(output); +// console.log(output); diff --git a/src/server/tasks/queryLLM.task.ts b/src/server/tasks/queryLLM.task.ts index b2b86cb..29affe7 100644 --- a/src/server/tasks/queryLLM.task.ts +++ b/src/server/tasks/queryLLM.task.ts @@ -99,6 +99,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { const provider = modelProviders[prompt.modelProvider]; + // @ts-expect-error TODO FIX ASAP const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null; if (streamingChannel) { @@ -115,6 +116,8 @@ export const queryLLM = defineTask("queryLLM", async (task) => { : null; for (let i = 0; true; i++) { + // @ts-expect-error TODO FIX ASAP + const response = await provider.getCompletion(prompt.modelInput, onStream); if (response.type === "success") { const inputHash = hashPrompt(prompt); diff --git a/src/server/utils/parseConstructFn.ts b/src/server/utils/parseConstructFn.ts index 8bfd667..1b0d8eb 100644 --- a/src/server/utils/parseConstructFn.ts +++ b/src/server/utils/parseConstructFn.ts @@ -70,6 +70,7 @@ export default async function parseConstructFn( // We've validated the JSON schema so this should be safe const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0]; + // @ts-expect-error TODO FIX ASAP const model = provider.getModel(input); if (!model) { return { @@ -79,6 +80,8 @@ export default async function parseConstructFn( return { modelProvider: prompt.modelProvider as keyof typeof modelProviders, + // @ts-expect-error TODO FIX ASAP + model, modelInput: input, };