we're actually calling openai

This commit is contained in:
Kyle Corbitt
2023-06-22 17:57:21 -07:00
parent 0fa3af4e9f
commit a31c112745
14 changed files with 271 additions and 30 deletions

View File

@@ -0,0 +1,19 @@
import { api } from "~/utils/api";
import { PromptVariant, Scenario } from "./types";
export default function OutputCell({
scenario,
variant,
}: {
scenario: Scenario;
variant: PromptVariant;
}) {
const output = api.outputs.get.useQuery({
scenarioId: scenario.id,
variantId: variant.id,
});
if (!output.data) return null;
return <div>{JSON.stringify(output.data.output.choices[0].message.content, null, 2)}</div>;
}

View File

@@ -0,0 +1,59 @@
import { Header, Stack, Title } from "@mantine/core";
import { PromptVariant } from "@prisma/client";
import { useMonaco } from "@monaco-editor/react";
import { useEffect, useRef, useState } from "react";
let isThemeDefined = false;
export default function VariantHeader({ variant }: { variant: PromptVariant }) {
const monaco = useMonaco();
const editorRef = useRef(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
useEffect(() => {
if (monaco && !isThemeDefined) {
monaco.editor.defineTheme("customTheme", {
base: "vs",
inherit: true,
rules: [],
colors: {
"editor.background": "#fafafa",
},
});
isThemeDefined = true;
}
}, [monaco]);
useEffect(() => {
if (monaco) {
editorRef.current = monaco.editor.create(document.getElementById(editorId), {
value: JSON.stringify(variant.config, null, 2),
language: "json",
theme: "customTheme",
lineNumbers: "off",
minimap: { enabled: false },
wrappingIndent: "indent",
wrappingStrategy: "advanced",
wordWrap: "on",
folding: false,
scrollbar: { vertical: "hidden" },
wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "",
});
}
// Clean up the editor instance on unmount
return () => {
if (editorRef.current) {
editorRef.current.dispose();
}
};
}, [monaco, variant, editorId]);
return (
<Stack w="100%">
<Title order={4}>{variant.label}</Title>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
</Stack>
);
}

View File

@@ -1,9 +1,12 @@
import { MRT_ColumnDef, MantineReactTable } from "mantine-react-table";
import { useMemo } from "react";
import { RouterOutputs, api } from "~/utils/api";
import { PromptVariant } from "./types";
import VariantHeader from "./VariantHeader";
import OutputCell from "./OutputCell";
type CellData = {
variant: NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
variant: PromptVariant;
scenario: NonNullable<RouterOutputs["scenarios"]["list"]>[0];
};
@@ -42,15 +45,9 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
(variant): MRT_ColumnDef<TableRow> => ({
id: variant.id,
header: variant.label,
// size: 300,
Cell: ({ row }) => {
const cellData = row.original[variant.id];
return (
<div>
{row.original.scenario.id} | {variant.id}
</div>
);
},
Header: <VariantHeader variant={variant} />,
size: 400,
Cell: ({ row }) => <OutputCell scenario={row.original.scenario} variant={variant} />,
})
) ?? []),
],
@@ -64,9 +61,11 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
scenario,
} as TableRow;
}) ?? [],
[variants.data, scenarios.data]
[scenarios.data]
);
if (!variants.data || !scenarios.data) return null;
return (
<MantineReactTable
mantinePaperProps={{
@@ -83,8 +82,34 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
enableDensityToggle={false}
enableFullScreenToggle={false}
enableHiding={false}
enableRowDragging
enableColumnActions={false}
enableColumnResizing
mantineTableProps={{
sx: {
th: {
verticalAlign: "bottom",
},
"& .mantine-TableHeadCell-Content": {
width: "100%",
height: "100%",
// display: "flex",
"& .mantine-TableHeadCell-Content-Actions": {
alignSelf: "flex-start",
},
"& > .mantine-TableHeadCell-Content-Labels": {
width: "100%",
height: "100%",
"& > .mantine-TableHeadCell-Content-Wrapper": {
width: "100%",
height: "100%",
},
},
},
},
}}
/>
);
// return <div>OutputsTable</div>;
}

View File

@@ -0,0 +1,5 @@
import { RouterOutputs } from "~/utils/api";
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];

View File

@@ -19,7 +19,7 @@ const useStyles = createStyles((theme) => ({
paddingBottom: theme.spacing.md,
marginBottom: `calc(${theme.spacing.md} * 1.5)`,
borderBottom: `${rem(1)} solid ${
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[2]
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[4]
}`,
},
@@ -27,12 +27,12 @@ const useStyles = createStyles((theme) => ({
paddingTop: theme.spacing.md,
marginTop: theme.spacing.md,
borderTop: `${rem(1)} solid ${
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[2]
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[4]
}`,
},
link: {
...theme.fn.focusStyles(),
...(theme.fn.focusStyles() as Record<string, any>),
display: "flex",
alignItems: "center",
textDecoration: "none",
@@ -101,7 +101,7 @@ export default function AppNav(props: { children: React.ReactNode; title?: strin
return (
<Box mih="100vh" sx={{ display: "flex" }}>
<Head>
<title>{props.title && `${props.title} | `}Prompt Bench</title>
<title>{props.title ? `${props.title} | Prompt Bench` : "Prompt Bench"}</title>
</Head>
<Navbar height="100vh" width={{ sm: 250 }} p="md" bg="gray.1">
<Navbar.Section grow>

View File

@@ -7,16 +7,15 @@ import { api } from "~/utils/api";
export default function Experiment() {
const router = useRouter();
console.log(router.query.id);
const experiment = api.experiments.get.useQuery(
{ id: router.query.id as string },
{ enabled: !!router.query.id }
);
if (!experiment.data) {
if (!experiment.isLoading && !experiment.data) {
return (
<AppNav title="Experiment not found">
<Center>
<Center h="100vh">
<div>Experiment not found 😕</div>
</Center>
</AppNav>
@@ -24,7 +23,7 @@ export default function Experiment() {
}
return (
<AppNav title={experiment.data.label}>
<AppNav title={experiment.data?.label}>
<Box sx={{ minHeight: "100vh" }}>
<OutputsTable experimentId={router.query.id as string | undefined} />
</Box>

View File

@@ -2,6 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
import { createTRPCRouter } from "~/server/api/trpc";
import { experimentsRouter } from "./routers/experiments.router";
import { scenariosRouter } from "./routers/scenarios.router";
import { modelOutputsRouter } from "./routers/modelOutputs.router";
/**
* This is the primary router for your server.
@@ -12,6 +13,7 @@ export const appRouter = createTRPCRouter({
promptVariants: promptVariantsRouter,
experiments: experimentsRouter,
scenarios: scenariosRouter,
outputs: modelOutputsRouter,
});
// export type definition of API

View File

@@ -0,0 +1,53 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import fillTemplate, { JSONSerializable, VariableMap } from "~/server/utils/fillTemplate";
import { getChatCompletion } from "~/server/utils/openai";
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
.input(z.object({ scenarioId: z.string(), variantId: z.string() }))
.query(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
});
if (existing) return existing;
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: input.scenarioId,
},
});
if (!variant || !scenario) return null;
const filledTemplate = fillTemplate(
variant.config as JSONSerializable,
scenario.variableValues as VariableMap
);
const modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!);
const modelOutput = await prisma.modelOutput.create({
data: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
output: modelResponse,
},
});
return modelOutput;
}),
});

View File

@@ -0,0 +1,27 @@
export type JSONSerializable =
| string
| number
| boolean
| null
| JSONSerializable[]
| { [key: string]: JSONSerializable };
export type VariableMap = Record<string, string>;
export default function fillTemplate<T extends JSONSerializable>(
template: T,
variables: VariableMap
): T {
if (typeof template === "string") {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "") as T;
} else if (Array.isArray(template)) {
return template.map((item) => fillTemplate(item, variables)) as T;
} else if (typeof template === "object" && template !== null) {
return Object.keys(template).reduce((acc, key) => {
acc[key] = fillTemplate(template[key] as JSONSerializable, variables);
return acc;
}, {} as { [key: string]: JSONSerializable } & T);
} else {
return template;
}
}

View File

@@ -0,0 +1,19 @@
import { JSONSerializable } from "./fillTemplate";
export async function getChatCompletion(payload: JSONSerializable, apiKey: string) {
const response = await fetch("https://api.openai.com/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify(payload),
});
if (!response.ok) {
throw new Error(`OpenAI API request failed with status ${response.status}`);
}
const data = (await response.json()) as JSONSerializable;
return data;
}