From ab32995eb93ea0dd6aa22ab5c9ee043945619e8b Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Tue, 27 Jun 2023 10:48:09 -0700 Subject: [PATCH] slightly better error handling --- prisma/schema.prisma | 6 ++- src/components/OutputsTable/OutputCell.tsx | 15 +++++- src/server/api/routers/modelOutputs.router.ts | 14 +++-- src/server/utils/getChatCompletion.ts | 52 +++++++++++++++++++ src/server/utils/openai.ts | 19 ------- 5 files changed, 79 insertions(+), 27 deletions(-) create mode 100644 src/server/utils/getChatCompletion.ts delete mode 100644 src/server/utils/openai.ts diff --git a/prisma/schema.prisma b/prisma/schema.prisma index a466394..8be4861 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -76,8 +76,10 @@ model TemplateVariable { model ModelOutput { id String @id @default(uuid()) @db.Uuid - inputHash String - output Json + inputHash String + output Json + statusCode Int + errorMessage String? promptVariantId String @db.Uuid promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx index f4d4865..bee52a7 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell.tsx @@ -55,7 +55,20 @@ export default function OutputCell({ ); - if (!output.data) return No output; + if (!output.data) + return ( + + Error retrieving output + + ); + + if (output.data.errorMessage) { + return ( + + Error: {output.data.errorMessage} + + ); + } return ( // @ts-expect-error TODO proper typing and error checks diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index 588d785..8ce3c8a 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate"; import { type JSONSerializable } from "~/server/types"; -import { getChatCompletion } from "~/server/utils/openai"; +import { getChatCompletion } from "~/server/utils/getChatCompletion"; import crypto from "crypto"; import type { Prisma } from "@prisma/client"; import { env } from "~/env.mjs"; @@ -51,13 +51,17 @@ export const modelOutputsRouter = createTRPCRouter({ // TODO: we should probably only use this if temperature=0 const existingResponse = await prisma.modelOutput.findFirst({ - where: { inputHash }, + where: { inputHash, errorMessage: null }, }); - let modelResponse: JSONSerializable; + let modelResponse: Awaited>; if (existingResponse) { - modelResponse = existingResponse.output as JSONSerializable; + modelResponse = { + output: existingResponse.output as Prisma.InputJsonValue, + statusCode: existingResponse.statusCode, + errorMessage: existingResponse.errorMessage, + }; } else { modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY); } @@ -66,8 +70,8 @@ export const modelOutputsRouter = createTRPCRouter({ data: { promptVariantId: input.variantId, testScenarioId: input.scenarioId, - output: modelResponse as Prisma.InputJsonObject, inputHash, + ...modelResponse, }, }); diff --git a/src/server/utils/getChatCompletion.ts b/src/server/utils/getChatCompletion.ts new file mode 100644 index 0000000..bd97943 --- /dev/null +++ b/src/server/utils/getChatCompletion.ts @@ -0,0 +1,52 @@ +import { isObject } from "lodash"; +import { type JSONSerializable } from "../types"; +import { Prisma } from "@prisma/client"; + +type CompletionResponse = { + output: Prisma.InputJsonValue | typeof Prisma.JsonNull; + statusCode: number; + errorMessage: string | null; +}; + +export async function getChatCompletion( + payload: JSONSerializable, + apiKey: string +): Promise { + const response = await fetch("https://api.openai.com/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify(payload), + }); + + const resp: CompletionResponse = { + output: Prisma.JsonNull, + errorMessage: null, + statusCode: response.status, + }; + + try { + resp.output = await response.json(); + + if (!response.ok) { + // If it's an object, try to get the error message + if ( + isObject(resp.output) && + "error" in resp.output && + isObject(resp.output.error) && + "message" in resp.output.error + ) { + resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error"; + } + } + } catch (e) { + console.error(e); + if (response.ok) { + resp.errorMessage = "Failed to parse response"; + } + } + + return resp; +} diff --git a/src/server/utils/openai.ts b/src/server/utils/openai.ts deleted file mode 100644 index 7696905..0000000 --- a/src/server/utils/openai.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { type JSONSerializable } from "../types"; - -export async function getChatCompletion(payload: JSONSerializable, apiKey: string) { - const response = await fetch("https://api.openai.com/v1/chat/completions", { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${apiKey}`, - }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - throw new Error(`OpenAI API request failed with status ${response.status}`); - } - - const data = (await response.json()) as JSONSerializable; - return data; -}