replicate/llama2 provider
Still need to fix the types but it runs
This commit is contained in:
@@ -73,6 +73,7 @@
|
|||||||
"react-syntax-highlighter": "^15.5.0",
|
"react-syntax-highlighter": "^15.5.0",
|
||||||
"react-textarea-autosize": "^8.5.0",
|
"react-textarea-autosize": "^8.5.0",
|
||||||
"recast": "^0.23.3",
|
"recast": "^0.23.3",
|
||||||
|
"replicate": "^0.12.3",
|
||||||
"socket.io": "^4.7.1",
|
"socket.io": "^4.7.1",
|
||||||
"socket.io-client": "^4.7.1",
|
"socket.io-client": "^4.7.1",
|
||||||
"superjson": "1.12.2",
|
"superjson": "1.12.2",
|
||||||
|
|||||||
8
pnpm-lock.yaml
generated
8
pnpm-lock.yaml
generated
@@ -161,6 +161,9 @@ dependencies:
|
|||||||
recast:
|
recast:
|
||||||
specifier: ^0.23.3
|
specifier: ^0.23.3
|
||||||
version: 0.23.3
|
version: 0.23.3
|
||||||
|
replicate:
|
||||||
|
specifier: ^0.12.3
|
||||||
|
version: 0.12.3
|
||||||
socket.io:
|
socket.io:
|
||||||
specifier: ^4.7.1
|
specifier: ^4.7.1
|
||||||
version: 4.7.1
|
version: 4.7.1
|
||||||
@@ -6988,6 +6991,11 @@ packages:
|
|||||||
functions-have-names: 1.2.3
|
functions-have-names: 1.2.3
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/replicate@0.12.3:
|
||||||
|
resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==}
|
||||||
|
engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/require-directory@2.1.1:
|
/require-directory@2.1.1:
|
||||||
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
|
|||||||
@@ -88,9 +88,11 @@ export default function OutputCell({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const normalizedOutput = modelOutput
|
const normalizedOutput = modelOutput
|
||||||
? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
? // @ts-expect-error TODO FIX ASAP
|
||||||
|
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
||||||
: streamedMessage
|
: streamedMessage
|
||||||
? provider.normalizeOutput(streamedMessage)
|
? // @ts-expect-error TODO FIX ASAP
|
||||||
|
provider.normalizeOutput(streamedMessage)
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
if (modelOutput && normalizedOutput?.type === "json") {
|
if (modelOutput && normalizedOutput?.type === "json") {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ export const env = createEnv({
|
|||||||
.transform((val) => val.toLowerCase() === "true"),
|
.transform((val) => val.toLowerCase() === "true"),
|
||||||
GITHUB_CLIENT_ID: z.string().min(1),
|
GITHUB_CLIENT_ID: z.string().min(1),
|
||||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||||
|
REPLICATE_API_TOKEN: z.string().min(1),
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -42,6 +43,7 @@ export const env = createEnv({
|
|||||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||||
|
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||||
|
import replicateLlama2 from "./replicate-llama2";
|
||||||
|
|
||||||
const modelProviders = {
|
const modelProviders = {
|
||||||
"openai/ChatCompletion": openaiChatCompletion,
|
"openai/ChatCompletion": openaiChatCompletion,
|
||||||
|
"replicate/llama2": replicateLlama2,
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
export default modelProviders;
|
export default modelProviders;
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
import modelProviderFrontend from "./openai-ChatCompletion/frontend";
|
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||||
|
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||||
|
|
||||||
|
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||||
|
|
||||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||||
// just include them in the default `modelProviders` object because it has some
|
// just include them in the default `modelProviders` object because it has some
|
||||||
// transient dependencies that can only be imported on the server.
|
// transient dependencies that can only be imported on the server.
|
||||||
const modelProvidersFrontend = {
|
const modelProvidersFrontend = {
|
||||||
"openai/ChatCompletion": modelProviderFrontend,
|
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||||
|
"replicate/llama2": replicateLlama2Frontend,
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
export default modelProvidersFrontend;
|
export default modelProvidersFrontend;
|
||||||
|
|||||||
13
src/modelProviders/replicate-llama2/frontend.ts
Normal file
13
src/modelProviders/replicate-llama2/frontend.ts
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import { type ReplicateLlama2Provider } from ".";
|
||||||
|
import { type ModelProviderFrontend } from "../types";
|
||||||
|
|
||||||
|
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
||||||
|
normalizeOutput: (output) => {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
value: output.join(""),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProviderFrontend;
|
||||||
62
src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
62
src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import { env } from "~/env.mjs";
|
||||||
|
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
|
||||||
|
import { type CompletionResponse } from "../types";
|
||||||
|
import Replicate from "replicate";
|
||||||
|
|
||||||
|
const replicate = new Replicate({
|
||||||
|
auth: env.REPLICATE_API_TOKEN || "",
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||||
|
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||||
|
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
||||||
|
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function getCompletion(
|
||||||
|
input: ReplicateLlama2Input,
|
||||||
|
onStream: ((partialOutput: string[]) => void) | null,
|
||||||
|
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||||
|
const start = Date.now();
|
||||||
|
|
||||||
|
const { model, stream, ...rest } = input;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const prediction = await replicate.predictions.create({
|
||||||
|
version: modelIds[model],
|
||||||
|
input: rest,
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log("stream?", onStream);
|
||||||
|
|
||||||
|
const interval = onStream
|
||||||
|
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||||
|
setInterval(async () => {
|
||||||
|
const partialPrediction = await replicate.predictions.get(prediction.id);
|
||||||
|
|
||||||
|
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
|
||||||
|
}, 500)
|
||||||
|
: null;
|
||||||
|
|
||||||
|
const resp = await replicate.wait(prediction, {});
|
||||||
|
if (interval) clearInterval(interval);
|
||||||
|
|
||||||
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
|
if (resp.error) throw new Error(resp.error as string);
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "success",
|
||||||
|
statusCode: 200,
|
||||||
|
value: resp.output as ReplicateLlama2Output,
|
||||||
|
timeToComplete,
|
||||||
|
};
|
||||||
|
} catch (error: unknown) {
|
||||||
|
console.error("ERROR IS", error);
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: (error as Error).message,
|
||||||
|
autoRetry: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
74
src/modelProviders/replicate-llama2/index.ts
Normal file
74
src/modelProviders/replicate-llama2/index.ts
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { type ModelProvider } from "../types";
|
||||||
|
import { getCompletion } from "./getCompletion";
|
||||||
|
|
||||||
|
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||||
|
|
||||||
|
type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
|
export type ReplicateLlama2Input = {
|
||||||
|
model: SupportedModel;
|
||||||
|
prompt: string;
|
||||||
|
stream?: boolean;
|
||||||
|
max_length?: number;
|
||||||
|
temperature?: number;
|
||||||
|
top_p?: number;
|
||||||
|
repetition_penalty?: number;
|
||||||
|
debug?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ReplicateLlama2Output = string[];
|
||||||
|
|
||||||
|
export type ReplicateLlama2Provider = ModelProvider<
|
||||||
|
SupportedModel,
|
||||||
|
ReplicateLlama2Input,
|
||||||
|
ReplicateLlama2Output
|
||||||
|
>;
|
||||||
|
|
||||||
|
const modelProvider: ReplicateLlama2Provider = {
|
||||||
|
name: "OpenAI ChatCompletion",
|
||||||
|
models: {
|
||||||
|
"7b-chat": {},
|
||||||
|
"13b-chat": {},
|
||||||
|
"70b-chat": {},
|
||||||
|
},
|
||||||
|
getModel: (input) => {
|
||||||
|
if (supportedModels.includes(input.model)) return input.model;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
inputSchema: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
model: {
|
||||||
|
type: "string",
|
||||||
|
enum: supportedModels as unknown as string[],
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
stream: {
|
||||||
|
type: "boolean",
|
||||||
|
},
|
||||||
|
max_length: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
temperature: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
top_p: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
repetition_penalty: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
debug: {
|
||||||
|
type: "boolean",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ["model", "prompt"],
|
||||||
|
},
|
||||||
|
shouldStream: (input) => input.stream ?? false,
|
||||||
|
getCompletion,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProvider;
|
||||||
@@ -2,8 +2,8 @@ import { type JSONSchema4 } from "json-schema";
|
|||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
|
|
||||||
type ModelProviderModel = {
|
type ModelProviderModel = {
|
||||||
name: string;
|
name?: string;
|
||||||
learnMore: string;
|
learnMore?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CompletionResponse<T> =
|
export type CompletionResponse<T> =
|
||||||
|
|||||||
@@ -109,8 +109,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
constructFn: dedent`
|
constructFn: dedent`
|
||||||
/**
|
/**
|
||||||
* Use Javascript to define an OpenAI chat completion
|
* Use Javascript to define an OpenAI chat completion
|
||||||
* (https://platform.openai.com/docs/api-reference/chat/create) and
|
* (https://platform.openai.com/docs/api-reference/chat/create).
|
||||||
* assign it to the \`prompt\` variable.
|
|
||||||
*
|
*
|
||||||
* You have access to the current scenario in the \`scenario\`
|
* You have access to the current scenario in the \`scenario\`
|
||||||
* variable.
|
* variable.
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
// /* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
// import "dotenv/config";
|
import "dotenv/config";
|
||||||
// import Replicate from "replicate";
|
import Replicate from "replicate";
|
||||||
|
|
||||||
// const replicate = new Replicate({
|
const replicate = new Replicate({
|
||||||
// auth: process.env.REPLICATE_API_TOKEN || "",
|
auth: process.env.REPLICATE_API_TOKEN || "",
|
||||||
// });
|
});
|
||||||
|
|
||||||
// console.log("going to run");
|
console.log("going to run");
|
||||||
// const prediction = await replicate.predictions.create({
|
const prediction = await replicate.predictions.create({
|
||||||
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||||
// input: {
|
input: {
|
||||||
// prompt: "...",
|
prompt: "...",
|
||||||
// },
|
},
|
||||||
// });
|
});
|
||||||
|
|
||||||
// console.log("waiting");
|
console.log("waiting");
|
||||||
// setInterval(() => {
|
setInterval(() => {
|
||||||
// replicate.predictions.get(prediction.id).then((prediction) => {
|
replicate.predictions.get(prediction.id).then((prediction) => {
|
||||||
// console.log(prediction.output);
|
console.log(prediction);
|
||||||
// });
|
});
|
||||||
// }, 500);
|
}, 500);
|
||||||
// // const output = await replicate.wait(prediction, {});
|
// const output = await replicate.wait(prediction, {});
|
||||||
|
|
||||||
// // console.log(output);
|
// console.log(output);
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const provider = modelProviders[prompt.modelProvider];
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
|
// @ts-expect-error TODO FIX ASAP
|
||||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||||
|
|
||||||
if (streamingChannel) {
|
if (streamingChannel) {
|
||||||
@@ -115,6 +116,8 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
: null;
|
: null;
|
||||||
|
|
||||||
for (let i = 0; true; i++) {
|
for (let i = 0; true; i++) {
|
||||||
|
// @ts-expect-error TODO FIX ASAP
|
||||||
|
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
if (response.type === "success") {
|
if (response.type === "success") {
|
||||||
const inputHash = hashPrompt(prompt);
|
const inputHash = hashPrompt(prompt);
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ export default async function parseConstructFn(
|
|||||||
// We've validated the JSON schema so this should be safe
|
// We've validated the JSON schema so this should be safe
|
||||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||||
|
|
||||||
|
// @ts-expect-error TODO FIX ASAP
|
||||||
const model = provider.getModel(input);
|
const model = provider.getModel(input);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return {
|
return {
|
||||||
@@ -79,6 +80,8 @@ export default async function parseConstructFn(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||||
|
// @ts-expect-error TODO FIX ASAP
|
||||||
|
|
||||||
model,
|
model,
|
||||||
modelInput: input,
|
modelInput: input,
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user