slightly better error handling

This commit is contained in:
Kyle Corbitt
2023-06-27 10:48:09 -07:00
parent f6f93a1161
commit ab32995eb9
5 changed files with 79 additions and 27 deletions

View File

@@ -76,8 +76,10 @@ model TemplateVariable {
model ModelOutput {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
inputHash String
output Json
statusCode Int
errorMessage String?
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)

View File

@@ -55,7 +55,20 @@ export default function OutputCell({
</CellShell>
);
if (!output.data) return <CellShell>No output</CellShell>;
if (!output.data)
return (
<CellShell>
<Text color="gray.500">Error retrieving output</Text>
</CellShell>
);
if (output.data.errorMessage) {
return (
<CellShell>
<Text color="red.600">Error: {output.data.errorMessage}</Text>
</CellShell>
);
}
return (
// @ts-expect-error TODO proper typing and error checks

View File

@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate";
import { type JSONSerializable } from "~/server/types";
import { getChatCompletion } from "~/server/utils/openai";
import { getChatCompletion } from "~/server/utils/getChatCompletion";
import crypto from "crypto";
import type { Prisma } from "@prisma/client";
import { env } from "~/env.mjs";
@@ -51,13 +51,17 @@ export const modelOutputsRouter = createTRPCRouter({
// TODO: we should probably only use this if temperature=0
const existingResponse = await prisma.modelOutput.findFirst({
where: { inputHash },
where: { inputHash, errorMessage: null },
});
let modelResponse: JSONSerializable;
let modelResponse: Awaited<ReturnType<typeof getChatCompletion>>;
if (existingResponse) {
modelResponse = existingResponse.output as JSONSerializable;
modelResponse = {
output: existingResponse.output as Prisma.InputJsonValue,
statusCode: existingResponse.statusCode,
errorMessage: existingResponse.errorMessage,
};
} else {
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY);
}
@@ -66,8 +70,8 @@ export const modelOutputsRouter = createTRPCRouter({
data: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
output: modelResponse as Prisma.InputJsonObject,
inputHash,
...modelResponse,
},
});

View File

@@ -0,0 +1,52 @@
import { isObject } from "lodash";
import { type JSONSerializable } from "../types";
import { Prisma } from "@prisma/client";
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number;
errorMessage: string | null;
};
export async function getChatCompletion(
payload: JSONSerializable,
apiKey: string
): Promise<CompletionResponse> {
const response = await fetch("https://api.openai.com/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify(payload),
});
const resp: CompletionResponse = {
output: Prisma.JsonNull,
errorMessage: null,
statusCode: response.status,
};
try {
resp.output = await response.json();
if (!response.ok) {
// If it's an object, try to get the error message
if (
isObject(resp.output) &&
"error" in resp.output &&
isObject(resp.output.error) &&
"message" in resp.output.error
) {
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
}
}
} catch (e) {
console.error(e);
if (response.ok) {
resp.errorMessage = "Failed to parse response";
}
}
return resp;
}

View File

@@ -1,19 +0,0 @@
import { type JSONSerializable } from "../types";
export async function getChatCompletion(payload: JSONSerializable, apiKey: string) {
const response = await fetch("https://api.openai.com/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify(payload),
});
if (!response.ok) {
throw new Error(`OpenAI API request failed with status ${response.status}`);
}
const data = (await response.json()) as JSONSerializable;
return data;
}