we're actually calling openai
This commit is contained in:
@@ -18,6 +18,7 @@
|
|||||||
"@mantine/form": "^6.0.14",
|
"@mantine/form": "^6.0.14",
|
||||||
"@mantine/hooks": "^6.0.14",
|
"@mantine/hooks": "^6.0.14",
|
||||||
"@mantine/next": "^6.0.14",
|
"@mantine/next": "^6.0.14",
|
||||||
|
"@monaco-editor/react": "^4.5.1",
|
||||||
"@next-auth/prisma-adapter": "^1.0.5",
|
"@next-auth/prisma-adapter": "^1.0.5",
|
||||||
"@prisma/client": "^4.14.0",
|
"@prisma/client": "^4.14.0",
|
||||||
"@t3-oss/env-nextjs": "^0.3.1",
|
"@t3-oss/env-nextjs": "^0.3.1",
|
||||||
|
|||||||
33
pnpm-lock.yaml
generated
33
pnpm-lock.yaml
generated
@@ -26,6 +26,9 @@ dependencies:
|
|||||||
'@mantine/next':
|
'@mantine/next':
|
||||||
specifier: ^6.0.14
|
specifier: ^6.0.14
|
||||||
version: 6.0.14(@emotion/react@11.11.1)(@emotion/server@11.11.0)(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
version: 6.0.14(@emotion/react@11.11.1)(@emotion/server@11.11.0)(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
'@monaco-editor/react':
|
||||||
|
specifier: ^4.5.1
|
||||||
|
version: 4.5.1(monaco-editor@0.39.0)(react-dom@18.2.0)(react@18.2.0)
|
||||||
'@next-auth/prisma-adapter':
|
'@next-auth/prisma-adapter':
|
||||||
specifier: ^1.0.5
|
specifier: ^1.0.5
|
||||||
version: 1.0.5(@prisma/client@4.14.0)(next-auth@4.22.1)
|
version: 1.0.5(@prisma/client@4.14.0)(next-auth@4.22.1)
|
||||||
@@ -688,6 +691,28 @@ packages:
|
|||||||
react: 18.2.0
|
react: 18.2.0
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/@monaco-editor/loader@1.3.3(monaco-editor@0.39.0):
|
||||||
|
resolution: {integrity: sha512-6KKF4CTzcJiS8BJwtxtfyYt9shBiEv32ateQ9T4UVogwn4HM/uPo9iJd2Dmbkpz8CM6Y0PDUpjnZzCwC+eYo2Q==}
|
||||||
|
peerDependencies:
|
||||||
|
monaco-editor: '>= 0.21.0 < 1'
|
||||||
|
dependencies:
|
||||||
|
monaco-editor: 0.39.0
|
||||||
|
state-local: 1.0.7
|
||||||
|
dev: false
|
||||||
|
|
||||||
|
/@monaco-editor/react@4.5.1(monaco-editor@0.39.0)(react-dom@18.2.0)(react@18.2.0):
|
||||||
|
resolution: {integrity: sha512-NNDFdP+2HojtNhCkRfE6/D6ro6pBNihaOzMbGK84lNWzRu+CfBjwzGt4jmnqimLuqp5yE5viHS2vi+QOAnD5FQ==}
|
||||||
|
peerDependencies:
|
||||||
|
monaco-editor: '>= 0.25.0 < 1'
|
||||||
|
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||||
|
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||||
|
dependencies:
|
||||||
|
'@monaco-editor/loader': 1.3.3(monaco-editor@0.39.0)
|
||||||
|
monaco-editor: 0.39.0
|
||||||
|
react: 18.2.0
|
||||||
|
react-dom: 18.2.0(react@18.2.0)
|
||||||
|
dev: false
|
||||||
|
|
||||||
/@next-auth/prisma-adapter@1.0.5(@prisma/client@4.14.0)(next-auth@4.22.1):
|
/@next-auth/prisma-adapter@1.0.5(@prisma/client@4.14.0)(next-auth@4.22.1):
|
||||||
resolution: {integrity: sha512-VqMS11IxPXrPGXw6Oul6jcyS/n8GLOWzRMrPr3EMdtD6eOalM6zz05j08PcNiis8QzkfuYnCv49OvufTuaEwYQ==}
|
resolution: {integrity: sha512-VqMS11IxPXrPGXw6Oul6jcyS/n8GLOWzRMrPr3EMdtD6eOalM6zz05j08PcNiis8QzkfuYnCv49OvufTuaEwYQ==}
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
@@ -2882,6 +2907,10 @@ packages:
|
|||||||
/minimist@1.2.8:
|
/minimist@1.2.8:
|
||||||
resolution: {integrity: sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==}
|
resolution: {integrity: sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==}
|
||||||
|
|
||||||
|
/monaco-editor@0.39.0:
|
||||||
|
resolution: {integrity: sha512-zhbZ2Nx93tLR8aJmL2zI1mhJpsl87HMebNBM6R8z4pLfs8pj604pIVIVwyF1TivcfNtIPpMXL+nb3DsBmE/x6Q==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/ms@2.1.2:
|
/ms@2.1.2:
|
||||||
resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==}
|
resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==}
|
||||||
dev: true
|
dev: true
|
||||||
@@ -3516,6 +3545,10 @@ packages:
|
|||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/state-local@1.0.7:
|
||||||
|
resolution: {integrity: sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/streamsearch@1.1.0:
|
/streamsearch@1.1.0:
|
||||||
resolution: {integrity: sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==}
|
resolution: {integrity: sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==}
|
||||||
engines: {node: '>=10.0.0'}
|
engines: {node: '>=10.0.0'}
|
||||||
|
|||||||
@@ -70,15 +70,14 @@ model TemplateVariable {
|
|||||||
model ModelOutput {
|
model ModelOutput {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
output Json
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id])
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id])
|
||||||
|
|
||||||
testScenarioId String @db.Uuid
|
testScenarioId String @db.Uuid
|
||||||
testScenario TestScenario @relation(fields: [testScenarioId], references: [id])
|
testScenario TestScenario @relation(fields: [testScenarioId], references: [id])
|
||||||
|
|
||||||
variableValues Json
|
|
||||||
inputsHash String @unique
|
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
|||||||
@@ -22,21 +22,21 @@ const resp = await prisma.promptVariant.createMany({
|
|||||||
data: [
|
data: [
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId,
|
||||||
label: "Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
config: {
|
config: {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [{ role: "user", content: "What is the capitol of {{input}}?" }],
|
messages: [{ role: "user", content: "What is the capitol of {{state}}?" }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId,
|
||||||
label: "Variant 2",
|
label: "Prompt Variant 2",
|
||||||
sortIndex: 1,
|
sortIndex: 1,
|
||||||
config: {
|
config: {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [{ role: "user", content: "What is the capitol of the US state {{input}}?" }],
|
messages: [{ role: "user", content: "What is the capitol of the US state {{state}}?" }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -69,13 +69,13 @@ await prisma.testScenario.createMany({
|
|||||||
{
|
{
|
||||||
experimentId,
|
experimentId,
|
||||||
variableValues: {
|
variableValues: {
|
||||||
input: "Washington",
|
state: "Washington",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId,
|
||||||
variableValues: {
|
variableValues: {
|
||||||
input: "Georgia",
|
state: "Georgia",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
|||||||
19
src/components/OutputsTable/OutputCell.tsx
Normal file
19
src/components/OutputsTable/OutputCell.tsx
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { PromptVariant, Scenario } from "./types";
|
||||||
|
|
||||||
|
export default function OutputCell({
|
||||||
|
scenario,
|
||||||
|
variant,
|
||||||
|
}: {
|
||||||
|
scenario: Scenario;
|
||||||
|
variant: PromptVariant;
|
||||||
|
}) {
|
||||||
|
const output = api.outputs.get.useQuery({
|
||||||
|
scenarioId: scenario.id,
|
||||||
|
variantId: variant.id,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!output.data) return null;
|
||||||
|
|
||||||
|
return <div>{JSON.stringify(output.data.output.choices[0].message.content, null, 2)}</div>;
|
||||||
|
}
|
||||||
59
src/components/OutputsTable/VariantHeader.tsx
Normal file
59
src/components/OutputsTable/VariantHeader.tsx
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import { Header, Stack, Title } from "@mantine/core";
|
||||||
|
import { PromptVariant } from "@prisma/client";
|
||||||
|
import { useMonaco } from "@monaco-editor/react";
|
||||||
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
|
let isThemeDefined = false;
|
||||||
|
|
||||||
|
export default function VariantHeader({ variant }: { variant: PromptVariant }) {
|
||||||
|
const monaco = useMonaco();
|
||||||
|
const editorRef = useRef(null);
|
||||||
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (monaco && !isThemeDefined) {
|
||||||
|
monaco.editor.defineTheme("customTheme", {
|
||||||
|
base: "vs",
|
||||||
|
inherit: true,
|
||||||
|
rules: [],
|
||||||
|
colors: {
|
||||||
|
"editor.background": "#fafafa",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
isThemeDefined = true;
|
||||||
|
}
|
||||||
|
}, [monaco]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (monaco) {
|
||||||
|
editorRef.current = monaco.editor.create(document.getElementById(editorId), {
|
||||||
|
value: JSON.stringify(variant.config, null, 2),
|
||||||
|
language: "json",
|
||||||
|
theme: "customTheme",
|
||||||
|
lineNumbers: "off",
|
||||||
|
minimap: { enabled: false },
|
||||||
|
wrappingIndent: "indent",
|
||||||
|
wrappingStrategy: "advanced",
|
||||||
|
wordWrap: "on",
|
||||||
|
folding: false,
|
||||||
|
scrollbar: { vertical: "hidden" },
|
||||||
|
wordWrapBreakAfterCharacters: "",
|
||||||
|
wordWrapBreakBeforeCharacters: "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up the editor instance on unmount
|
||||||
|
return () => {
|
||||||
|
if (editorRef.current) {
|
||||||
|
editorRef.current.dispose();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, [monaco, variant, editorId]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Stack w="100%">
|
||||||
|
<Title order={4}>{variant.label}</Title>
|
||||||
|
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
||||||
|
</Stack>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
import { MRT_ColumnDef, MantineReactTable } from "mantine-react-table";
|
import { MRT_ColumnDef, MantineReactTable } from "mantine-react-table";
|
||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { RouterOutputs, api } from "~/utils/api";
|
import { RouterOutputs, api } from "~/utils/api";
|
||||||
|
import { PromptVariant } from "./types";
|
||||||
|
import VariantHeader from "./VariantHeader";
|
||||||
|
import OutputCell from "./OutputCell";
|
||||||
|
|
||||||
type CellData = {
|
type CellData = {
|
||||||
variant: NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
variant: PromptVariant;
|
||||||
scenario: NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
scenario: NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -42,15 +45,9 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
(variant): MRT_ColumnDef<TableRow> => ({
|
(variant): MRT_ColumnDef<TableRow> => ({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
header: variant.label,
|
header: variant.label,
|
||||||
// size: 300,
|
Header: <VariantHeader variant={variant} />,
|
||||||
Cell: ({ row }) => {
|
size: 400,
|
||||||
const cellData = row.original[variant.id];
|
Cell: ({ row }) => <OutputCell scenario={row.original.scenario} variant={variant} />,
|
||||||
return (
|
|
||||||
<div>
|
|
||||||
{row.original.scenario.id} | {variant.id}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
) ?? []),
|
) ?? []),
|
||||||
],
|
],
|
||||||
@@ -64,9 +61,11 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
scenario,
|
scenario,
|
||||||
} as TableRow;
|
} as TableRow;
|
||||||
}) ?? [],
|
}) ?? [],
|
||||||
[variants.data, scenarios.data]
|
[scenarios.data]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (!variants.data || !scenarios.data) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<MantineReactTable
|
<MantineReactTable
|
||||||
mantinePaperProps={{
|
mantinePaperProps={{
|
||||||
@@ -83,8 +82,34 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
enableDensityToggle={false}
|
enableDensityToggle={false}
|
||||||
enableFullScreenToggle={false}
|
enableFullScreenToggle={false}
|
||||||
enableHiding={false}
|
enableHiding={false}
|
||||||
enableRowDragging
|
enableColumnActions={false}
|
||||||
|
enableColumnResizing
|
||||||
|
mantineTableProps={{
|
||||||
|
sx: {
|
||||||
|
th: {
|
||||||
|
verticalAlign: "bottom",
|
||||||
|
},
|
||||||
|
"& .mantine-TableHeadCell-Content": {
|
||||||
|
width: "100%",
|
||||||
|
height: "100%",
|
||||||
|
// display: "flex",
|
||||||
|
|
||||||
|
"& .mantine-TableHeadCell-Content-Actions": {
|
||||||
|
alignSelf: "flex-start",
|
||||||
|
},
|
||||||
|
|
||||||
|
"& > .mantine-TableHeadCell-Content-Labels": {
|
||||||
|
width: "100%",
|
||||||
|
height: "100%",
|
||||||
|
|
||||||
|
"& > .mantine-TableHeadCell-Content-Wrapper": {
|
||||||
|
width: "100%",
|
||||||
|
height: "100%",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
// return <div>OutputsTable</div>;
|
|
||||||
}
|
}
|
||||||
|
|||||||
5
src/components/OutputsTable/types.ts
Normal file
5
src/components/OutputsTable/types.ts
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import { RouterOutputs } from "~/utils/api";
|
||||||
|
|
||||||
|
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
||||||
|
|
||||||
|
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
||||||
@@ -19,7 +19,7 @@ const useStyles = createStyles((theme) => ({
|
|||||||
paddingBottom: theme.spacing.md,
|
paddingBottom: theme.spacing.md,
|
||||||
marginBottom: `calc(${theme.spacing.md} * 1.5)`,
|
marginBottom: `calc(${theme.spacing.md} * 1.5)`,
|
||||||
borderBottom: `${rem(1)} solid ${
|
borderBottom: `${rem(1)} solid ${
|
||||||
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[2]
|
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[4]
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -27,12 +27,12 @@ const useStyles = createStyles((theme) => ({
|
|||||||
paddingTop: theme.spacing.md,
|
paddingTop: theme.spacing.md,
|
||||||
marginTop: theme.spacing.md,
|
marginTop: theme.spacing.md,
|
||||||
borderTop: `${rem(1)} solid ${
|
borderTop: `${rem(1)} solid ${
|
||||||
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[2]
|
theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[4]
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
|
||||||
link: {
|
link: {
|
||||||
...theme.fn.focusStyles(),
|
...(theme.fn.focusStyles() as Record<string, any>),
|
||||||
display: "flex",
|
display: "flex",
|
||||||
alignItems: "center",
|
alignItems: "center",
|
||||||
textDecoration: "none",
|
textDecoration: "none",
|
||||||
@@ -101,7 +101,7 @@ export default function AppNav(props: { children: React.ReactNode; title?: strin
|
|||||||
return (
|
return (
|
||||||
<Box mih="100vh" sx={{ display: "flex" }}>
|
<Box mih="100vh" sx={{ display: "flex" }}>
|
||||||
<Head>
|
<Head>
|
||||||
<title>{props.title && `${props.title} | `}Prompt Bench</title>
|
<title>{props.title ? `${props.title} | Prompt Bench` : "Prompt Bench"}</title>
|
||||||
</Head>
|
</Head>
|
||||||
<Navbar height="100vh" width={{ sm: 250 }} p="md" bg="gray.1">
|
<Navbar height="100vh" width={{ sm: 250 }} p="md" bg="gray.1">
|
||||||
<Navbar.Section grow>
|
<Navbar.Section grow>
|
||||||
|
|||||||
@@ -7,16 +7,15 @@ import { api } from "~/utils/api";
|
|||||||
export default function Experiment() {
|
export default function Experiment() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
|
|
||||||
console.log(router.query.id);
|
|
||||||
const experiment = api.experiments.get.useQuery(
|
const experiment = api.experiments.get.useQuery(
|
||||||
{ id: router.query.id as string },
|
{ id: router.query.id as string },
|
||||||
{ enabled: !!router.query.id }
|
{ enabled: !!router.query.id }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!experiment.data) {
|
if (!experiment.isLoading && !experiment.data) {
|
||||||
return (
|
return (
|
||||||
<AppNav title="Experiment not found">
|
<AppNav title="Experiment not found">
|
||||||
<Center>
|
<Center h="100vh">
|
||||||
<div>Experiment not found 😕</div>
|
<div>Experiment not found 😕</div>
|
||||||
</Center>
|
</Center>
|
||||||
</AppNav>
|
</AppNav>
|
||||||
@@ -24,7 +23,7 @@ export default function Experiment() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AppNav title={experiment.data.label}>
|
<AppNav title={experiment.data?.label}>
|
||||||
<Box sx={{ minHeight: "100vh" }}>
|
<Box sx={{ minHeight: "100vh" }}>
|
||||||
<OutputsTable experimentId={router.query.id as string | undefined} />
|
<OutputsTable experimentId={router.query.id as string | undefined} />
|
||||||
</Box>
|
</Box>
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
|
|||||||
import { createTRPCRouter } from "~/server/api/trpc";
|
import { createTRPCRouter } from "~/server/api/trpc";
|
||||||
import { experimentsRouter } from "./routers/experiments.router";
|
import { experimentsRouter } from "./routers/experiments.router";
|
||||||
import { scenariosRouter } from "./routers/scenarios.router";
|
import { scenariosRouter } from "./routers/scenarios.router";
|
||||||
|
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is the primary router for your server.
|
* This is the primary router for your server.
|
||||||
@@ -12,6 +13,7 @@ export const appRouter = createTRPCRouter({
|
|||||||
promptVariants: promptVariantsRouter,
|
promptVariants: promptVariantsRouter,
|
||||||
experiments: experimentsRouter,
|
experiments: experimentsRouter,
|
||||||
scenarios: scenariosRouter,
|
scenarios: scenariosRouter,
|
||||||
|
outputs: modelOutputsRouter,
|
||||||
});
|
});
|
||||||
|
|
||||||
// export type definition of API
|
// export type definition of API
|
||||||
|
|||||||
53
src/server/api/routers/modelOutputs.router.ts
Normal file
53
src/server/api/routers/modelOutputs.router.ts
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import fillTemplate, { JSONSerializable, VariableMap } from "~/server/utils/fillTemplate";
|
||||||
|
import { getChatCompletion } from "~/server/utils/openai";
|
||||||
|
|
||||||
|
export const modelOutputsRouter = createTRPCRouter({
|
||||||
|
get: publicProcedure
|
||||||
|
.input(z.object({ scenarioId: z.string(), variantId: z.string() }))
|
||||||
|
.query(async ({ input }) => {
|
||||||
|
const existing = await prisma.modelOutput.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (existing) return existing;
|
||||||
|
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: input.variantId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const scenario = await prisma.testScenario.findUnique({
|
||||||
|
where: {
|
||||||
|
id: input.scenarioId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!variant || !scenario) return null;
|
||||||
|
|
||||||
|
const filledTemplate = fillTemplate(
|
||||||
|
variant.config as JSONSerializable,
|
||||||
|
scenario.variableValues as VariableMap
|
||||||
|
);
|
||||||
|
|
||||||
|
const modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!);
|
||||||
|
|
||||||
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
output: modelResponse,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return modelOutput;
|
||||||
|
}),
|
||||||
|
});
|
||||||
27
src/server/utils/fillTemplate.ts
Normal file
27
src/server/utils/fillTemplate.ts
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
export type JSONSerializable =
|
||||||
|
| string
|
||||||
|
| number
|
||||||
|
| boolean
|
||||||
|
| null
|
||||||
|
| JSONSerializable[]
|
||||||
|
| { [key: string]: JSONSerializable };
|
||||||
|
|
||||||
|
export type VariableMap = Record<string, string>;
|
||||||
|
|
||||||
|
export default function fillTemplate<T extends JSONSerializable>(
|
||||||
|
template: T,
|
||||||
|
variables: VariableMap
|
||||||
|
): T {
|
||||||
|
if (typeof template === "string") {
|
||||||
|
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "") as T;
|
||||||
|
} else if (Array.isArray(template)) {
|
||||||
|
return template.map((item) => fillTemplate(item, variables)) as T;
|
||||||
|
} else if (typeof template === "object" && template !== null) {
|
||||||
|
return Object.keys(template).reduce((acc, key) => {
|
||||||
|
acc[key] = fillTemplate(template[key] as JSONSerializable, variables);
|
||||||
|
return acc;
|
||||||
|
}, {} as { [key: string]: JSONSerializable } & T);
|
||||||
|
} else {
|
||||||
|
return template;
|
||||||
|
}
|
||||||
|
}
|
||||||
19
src/server/utils/openai.ts
Normal file
19
src/server/utils/openai.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { JSONSerializable } from "./fillTemplate";
|
||||||
|
|
||||||
|
export async function getChatCompletion(payload: JSONSerializable, apiKey: string) {
|
||||||
|
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: `Bearer ${apiKey}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify(payload),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`OpenAI API request failed with status ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = (await response.json()) as JSONSerializable;
|
||||||
|
return data;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user