import { env } from "~/env.mjs"; import { type ReplicateLlama2Input, type ReplicateLlama2Output } from "."; import { type CompletionResponse } from "../types"; import Replicate from "replicate"; const replicate = new Replicate({ auth: env.REPLICATE_API_TOKEN || "", }); const modelIds: Record = { "7b-chat": "4f0b260b6a13eb53a6b1891f089d57c08f41003ae79458be5011303d81a394dc", "13b-chat": "2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52", "70b-chat": "2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1", }; export async function getCompletion( input: ReplicateLlama2Input, onStream: ((partialOutput: string[]) => void) | null, ): Promise> { const start = Date.now(); const { model, ...rest } = input; try { const prediction = await replicate.predictions.create({ version: modelIds[model], input: rest, }); const interval = onStream ? // eslint-disable-next-line @typescript-eslint/no-misused-promises setInterval(async () => { const partialPrediction = await replicate.predictions.get(prediction.id); if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output); }, 500) : null; const resp = await replicate.wait(prediction, {}); if (interval) clearInterval(interval); const timeToComplete = Date.now() - start; if (resp.error) throw new Error(resp.error as string); return { type: "success", statusCode: 200, value: resp.output as ReplicateLlama2Output, timeToComplete, }; } catch (error: unknown) { console.error("ERROR IS", error); return { type: "error", message: (error as Error).message, autoRetry: true, }; } }