import { env } from "~/env.mjs"; import { type ReplicateLlama2Input, type ReplicateLlama2Output } from "."; import { type CompletionResponse } from "../types"; import Replicate from "replicate"; const replicate = new Replicate({ auth: env.REPLICATE_API_TOKEN || "", }); const modelIds: Record = { "7b-chat": "5ec5fdadd80ace49f5a2b2178cceeb9f2f77c493b85b1131002c26e6b2b13184", "13b-chat": "6b4da803a2382c08868c5af10a523892f38e2de1aafb2ee55b020d9efef2fdb8", "70b-chat": "2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48", }; export async function getCompletion( input: ReplicateLlama2Input, onStream: ((partialOutput: string[]) => void) | null, ): Promise> { const start = Date.now(); const { model, ...rest } = input; try { const prediction = await replicate.predictions.create({ version: modelIds[model], input: rest, }); const interval = onStream ? // eslint-disable-next-line @typescript-eslint/no-misused-promises setInterval(async () => { const partialPrediction = await replicate.predictions.get(prediction.id); if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output); }, 500) : null; const resp = await replicate.wait(prediction, {}); if (interval) clearInterval(interval); const timeToComplete = Date.now() - start; if (resp.error) throw new Error(resp.error as string); return { type: "success", statusCode: 200, value: resp.output as ReplicateLlama2Output, timeToComplete, }; } catch (error: unknown) { console.error("ERROR IS", error); return { type: "error", message: (error as Error).message, autoRetry: true, }; } }