diff --git a/prisma/schema.prisma b/prisma/schema.prisma
index a466394..8be4861 100644
--- a/prisma/schema.prisma
+++ b/prisma/schema.prisma
@@ -76,8 +76,10 @@ model TemplateVariable {
model ModelOutput {
id String @id @default(uuid()) @db.Uuid
- inputHash String
- output Json
+ inputHash String
+ output Json
+ statusCode Int
+ errorMessage String?
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx
index f4d4865..bee52a7 100644
--- a/src/components/OutputsTable/OutputCell.tsx
+++ b/src/components/OutputsTable/OutputCell.tsx
@@ -55,7 +55,20 @@ export default function OutputCell({
);
- if (!output.data) return No output;
+ if (!output.data)
+ return (
+
+ Error retrieving output
+
+ );
+
+ if (output.data.errorMessage) {
+ return (
+
+ Error: {output.data.errorMessage}
+
+ );
+ }
return (
// @ts-expect-error TODO proper typing and error checks
diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts
index 588d785..8ce3c8a 100644
--- a/src/server/api/routers/modelOutputs.router.ts
+++ b/src/server/api/routers/modelOutputs.router.ts
@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate";
import { type JSONSerializable } from "~/server/types";
-import { getChatCompletion } from "~/server/utils/openai";
+import { getChatCompletion } from "~/server/utils/getChatCompletion";
import crypto from "crypto";
import type { Prisma } from "@prisma/client";
import { env } from "~/env.mjs";
@@ -51,13 +51,17 @@ export const modelOutputsRouter = createTRPCRouter({
// TODO: we should probably only use this if temperature=0
const existingResponse = await prisma.modelOutput.findFirst({
- where: { inputHash },
+ where: { inputHash, errorMessage: null },
});
- let modelResponse: JSONSerializable;
+ let modelResponse: Awaited>;
if (existingResponse) {
- modelResponse = existingResponse.output as JSONSerializable;
+ modelResponse = {
+ output: existingResponse.output as Prisma.InputJsonValue,
+ statusCode: existingResponse.statusCode,
+ errorMessage: existingResponse.errorMessage,
+ };
} else {
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY);
}
@@ -66,8 +70,8 @@ export const modelOutputsRouter = createTRPCRouter({
data: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
- output: modelResponse as Prisma.InputJsonObject,
inputHash,
+ ...modelResponse,
},
});
diff --git a/src/server/utils/getChatCompletion.ts b/src/server/utils/getChatCompletion.ts
new file mode 100644
index 0000000..bd97943
--- /dev/null
+++ b/src/server/utils/getChatCompletion.ts
@@ -0,0 +1,52 @@
+import { isObject } from "lodash";
+import { type JSONSerializable } from "../types";
+import { Prisma } from "@prisma/client";
+
+type CompletionResponse = {
+ output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
+ statusCode: number;
+ errorMessage: string | null;
+};
+
+export async function getChatCompletion(
+ payload: JSONSerializable,
+ apiKey: string
+): Promise {
+ const response = await fetch("https://api.openai.com/v1/chat/completions", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ Authorization: `Bearer ${apiKey}`,
+ },
+ body: JSON.stringify(payload),
+ });
+
+ const resp: CompletionResponse = {
+ output: Prisma.JsonNull,
+ errorMessage: null,
+ statusCode: response.status,
+ };
+
+ try {
+ resp.output = await response.json();
+
+ if (!response.ok) {
+ // If it's an object, try to get the error message
+ if (
+ isObject(resp.output) &&
+ "error" in resp.output &&
+ isObject(resp.output.error) &&
+ "message" in resp.output.error
+ ) {
+ resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
+ }
+ }
+ } catch (e) {
+ console.error(e);
+ if (response.ok) {
+ resp.errorMessage = "Failed to parse response";
+ }
+ }
+
+ return resp;
+}
diff --git a/src/server/utils/openai.ts b/src/server/utils/openai.ts
deleted file mode 100644
index 7696905..0000000
--- a/src/server/utils/openai.ts
+++ /dev/null
@@ -1,19 +0,0 @@
-import { type JSONSerializable } from "../types";
-
-export async function getChatCompletion(payload: JSONSerializable, apiKey: string) {
- const response = await fetch("https://api.openai.com/v1/chat/completions", {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- Authorization: `Bearer ${apiKey}`,
- },
- body: JSON.stringify(payload),
- });
-
- if (!response.ok) {
- throw new Error(`OpenAI API request failed with status ${response.status}`);
- }
-
- const data = (await response.json()) as JSONSerializable;
- return data;
-}