Files
OpenPipe-llm/app/src/modelProviders/replicate-llama2/getCompletion.ts
2023-07-29 15:53:01 -07:00

61 lines
1.8 KiB
TypeScript

import { env } from "~/env.mjs";
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
import { type CompletionResponse } from "../types";
import Replicate from "replicate";
const replicate = new Replicate({
auth: env.REPLICATE_API_TOKEN || "",
});
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
"7b-chat": "5ec5fdadd80ace49f5a2b2178cceeb9f2f77c493b85b1131002c26e6b2b13184",
"13b-chat": "6b4da803a2382c08868c5af10a523892f38e2de1aafb2ee55b020d9efef2fdb8",
"70b-chat": "2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48",
};
export async function getCompletion(
input: ReplicateLlama2Input,
onStream: ((partialOutput: string[]) => void) | null,
): Promise<CompletionResponse<ReplicateLlama2Output>> {
const start = Date.now();
const { model, ...rest } = input;
try {
const prediction = await replicate.predictions.create({
version: modelIds[model],
input: rest,
});
const interval = onStream
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
setInterval(async () => {
const partialPrediction = await replicate.predictions.get(prediction.id);
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
}, 500)
: null;
const resp = await replicate.wait(prediction, {});
if (interval) clearInterval(interval);
const timeToComplete = Date.now() - start;
if (resp.error) throw new Error(resp.error as string);
return {
type: "success",
statusCode: 200,
value: resp.output as ReplicateLlama2Output,
timeToComplete,
};
} catch (error: unknown) {
console.error("ERROR IS", error);
return {
type: "error",
message: (error as Error).message,
autoRetry: true,
};
}
}