61 lines
1.8 KiB
TypeScript
61 lines
1.8 KiB
TypeScript
import { env } from "~/env.mjs";
|
|
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
|
|
import { type CompletionResponse } from "../types";
|
|
import Replicate from "replicate";
|
|
|
|
const replicate = new Replicate({
|
|
auth: env.REPLICATE_API_TOKEN || "",
|
|
});
|
|
|
|
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
|
"7b-chat": "4f0b260b6a13eb53a6b1891f089d57c08f41003ae79458be5011303d81a394dc",
|
|
"13b-chat": "2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
|
|
"70b-chat": "2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1",
|
|
};
|
|
|
|
export async function getCompletion(
|
|
input: ReplicateLlama2Input,
|
|
onStream: ((partialOutput: string[]) => void) | null,
|
|
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
|
const start = Date.now();
|
|
|
|
const { model, ...rest } = input;
|
|
|
|
try {
|
|
const prediction = await replicate.predictions.create({
|
|
version: modelIds[model],
|
|
input: rest,
|
|
});
|
|
|
|
const interval = onStream
|
|
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
|
setInterval(async () => {
|
|
const partialPrediction = await replicate.predictions.get(prediction.id);
|
|
|
|
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
|
|
}, 500)
|
|
: null;
|
|
|
|
const resp = await replicate.wait(prediction, {});
|
|
if (interval) clearInterval(interval);
|
|
|
|
const timeToComplete = Date.now() - start;
|
|
|
|
if (resp.error) throw new Error(resp.error as string);
|
|
|
|
return {
|
|
type: "success",
|
|
statusCode: 200,
|
|
value: resp.output as ReplicateLlama2Output,
|
|
timeToComplete,
|
|
};
|
|
} catch (error: unknown) {
|
|
console.error("ERROR IS", error);
|
|
return {
|
|
type: "error",
|
|
message: (error as Error).message,
|
|
autoRetry: true,
|
|
};
|
|
}
|
|
}
|