editing scenarios is kinda working

This commit is contained in:
Kyle Corbitt
2023-06-23 16:18:28 -07:00
parent bf41069442
commit 2b0c2ad603
12 changed files with 221 additions and 35 deletions

View File

@@ -31,6 +31,7 @@
"@trpc/server": "^10.26.0", "@trpc/server": "^10.26.0",
"dayjs": "^1.11.8", "dayjs": "^1.11.8",
"dotenv": "^16.3.1", "dotenv": "^16.3.1",
"lodash": "^4.17.21",
"mantine-react-table": "1.0.0-beta.13", "mantine-react-table": "1.0.0-beta.13",
"next": "^13.4.2", "next": "^13.4.2",
"next-auth": "^4.22.1", "next-auth": "^4.22.1",
@@ -42,6 +43,7 @@
}, },
"devDependencies": { "devDependencies": {
"@types/eslint": "^8.37.0", "@types/eslint": "^8.37.0",
"@types/lodash": "^4.14.195",
"@types/node": "^18.16.0", "@types/node": "^18.16.0",
"@types/react": "^18.2.6", "@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4", "@types/react-dom": "^18.2.4",

14
pnpm-lock.yaml generated
View File

@@ -65,6 +65,9 @@ dependencies:
dotenv: dotenv:
specifier: ^16.3.1 specifier: ^16.3.1
version: 16.3.1 version: 16.3.1
lodash:
specifier: ^4.17.21
version: 4.17.21
mantine-react-table: mantine-react-table:
specifier: 1.0.0-beta.13 specifier: 1.0.0-beta.13
version: 1.0.0-beta.13(@emotion/react@11.11.1)(@mantine/core@6.0.14)(@mantine/dates@6.0.14)(@mantine/hooks@6.0.14)(@tabler/icons-react@2.22.0)(react-dom@18.2.0)(react@18.2.0) version: 1.0.0-beta.13(@emotion/react@11.11.1)(@mantine/core@6.0.14)(@mantine/dates@6.0.14)(@mantine/hooks@6.0.14)(@tabler/icons-react@2.22.0)(react-dom@18.2.0)(react@18.2.0)
@@ -94,6 +97,9 @@ devDependencies:
'@types/eslint': '@types/eslint':
specifier: ^8.37.0 specifier: ^8.37.0
version: 8.37.0 version: 8.37.0
'@types/lodash':
specifier: ^4.14.195
version: 4.14.195
'@types/node': '@types/node':
specifier: ^18.16.0 specifier: ^18.16.0
version: 18.16.0 version: 18.16.0
@@ -1176,6 +1182,10 @@ packages:
resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==} resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==}
dev: true dev: true
/@types/lodash@4.14.195:
resolution: {integrity: sha512-Hwx9EUgdwf2GLarOjQp5ZH8ZmblzcbTBC2wtQWNKARBSxM9ezRIAUpeDTgoQRAFB0+8CNWXVA9+MaSOzOF3nPg==}
dev: true
/@types/node@18.16.0: /@types/node@18.16.0:
resolution: {integrity: sha512-BsAaKhB+7X+H4GnSjGhJG9Qi8Tw+inU9nJDwmD5CgOmBLEI6ArdhikpLX7DjbjDRDTbqZzU2LSQNZg8WGPiSZQ==} resolution: {integrity: sha512-BsAaKhB+7X+H4GnSjGhJG9Qi8Tw+inU9nJDwmD5CgOmBLEI6ArdhikpLX7DjbjDRDTbqZzU2LSQNZg8WGPiSZQ==}
dev: true dev: true
@@ -2860,6 +2870,10 @@ packages:
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
dev: true dev: true
/lodash@4.17.21:
resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==}
dev: false
/loose-envify@1.4.0: /loose-envify@1.4.0:
resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==}
hasBin: true hasBin: true

View File

@@ -26,13 +26,12 @@ model PromptVariant {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
label String label String
uiId String @default(uuid()) @db.Uuid config Json
uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true) visible Boolean @default(true)
sortIndex Int @default(0) sortIndex Int @default(0)
config Json
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id]) experiment Experiment @relation(fields: [experimentId], references: [id])
@@ -46,11 +45,12 @@ model PromptVariant {
model TestScenario { model TestScenario {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
variableValues Json
uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true) visible Boolean @default(true)
sortIndex Int @default(0) sortIndex Int @default(0)
variableValues Json
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id]) experiment Experiment @relation(fields: [experimentId], references: [id])

View File

@@ -61,7 +61,7 @@ await prisma.templateVariable.createMany({
data: [ data: [
{ {
experimentId, experimentId,
label: "input", label: "state",
}, },
], ],
}); });
@@ -83,7 +83,13 @@ await prisma.testScenario.createMany({
{ {
experimentId, experimentId,
variableValues: { variableValues: {
state: "Georgia", state: "California",
},
},
{
experimentId,
variableValues: {
state: "Utah",
}, },
}, },
], ],

View File

@@ -0,0 +1,27 @@
import { useRef } from "react";
import { Title } from "@mantine/core";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
export default function EditableVariantLabel(props: { variant: PromptVariant }) {
const labelRef = useRef<HTMLHeadingElement | null>(null);
const mutation = api.promptVariants.update.useMutation();
const [onBlur] = useHandledAsyncCallback(async () => {
const newLabel = labelRef.current?.innerText;
if (newLabel && newLabel !== props.variant.label) {
await mutation.mutateAsync({
id: props.variant.id,
updates: { label: newLabel },
});
}
}, [mutation, props.variant.id, props.variant.label]);
return (
<Title order={4} ref={labelRef} contentEditable suppressContentEditableWarning onBlur={onBlur}>
{props.variant.label}
</Title>
);
}

View File

@@ -0,0 +1,65 @@
import { api } from "~/utils/api";
import { isEqual } from "lodash";
import { PromptVariant, Scenario } from "./types";
import { Badge, Button, Group, Stack, TextInput, Textarea, Tooltip } from "@mantine/core";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react";
export default function ScenarioHeader({ scenario }: { scenario: Scenario }) {
const savedValues = scenario.variableValues as Record<string, string>;
const utils = api.useContext();
const [values, setValues] = useState<Record<string, string>>(savedValues);
const experiment = useExperiment();
const variableLabels = experiment.data?.TemplateVariable.map((v) => v.label) ?? [];
const hasChanged = !isEqual(savedValues, values);
const mutation = api.scenarios.replaceWithValues.useMutation();
const [onSave] = useHandledAsyncCallback(async () => {
await mutation.mutateAsync({
id: scenario.id,
values,
});
await utils.scenarios.list.invalidate();
}, [mutation, values]);
return (
<Stack>
{variableLabels.map((key) => {
return (
<Textarea
key={key}
label={key}
value={values[key] ?? ""}
onChange={(e) => {
setValues((prev) => ({ ...prev, [key]: e.target.value }));
}}
autosize
rows={1}
maxRows={20}
/>
);
})}
{hasChanged && (
<Group spacing={4} position="right">
<Button
size="xs"
onClick={() => {
setValues(savedValues);
}}
color="gray"
>
Reset
</Button>
<Button size="xs" onClick={onSave}>
Save
</Button>
</Group>
)}
</Stack>
);
}

View File

@@ -1,8 +1,8 @@
import { Box, Button, Group, Stack, Title } from "@mantine/core"; import { Box, Button, Group, Stack, Title, Tooltip } from "@mantine/core";
import { useMonaco } from "@monaco-editor/react"; import { useMonaco } from "@monaco-editor/react";
import { useRef, useEffect, useState, useCallback } from "react"; import { useRef, useEffect, useState, useCallback } from "react";
import { set } from "zod"; import { set } from "zod";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
let isThemeDefined = false; let isThemeDefined = false;
@@ -16,6 +16,8 @@ export default function VariantConfigEditor(props: {
const [isChanged, setIsChanged] = useState(false); const [isChanged, setIsChanged] = useState(false);
const savedConfigRef = useRef(props.savedConfig); const savedConfigRef = useRef(props.savedConfig);
const modifierKey = useModifierKeyLabel();
const checkForChanges = useCallback(() => { const checkForChanges = useCallback(() => {
if (!editorRef.current) return; if (!editorRef.current) return;
const currentConfig = editorRef.current.getValue(); const currentConfig = editorRef.current.getValue();
@@ -104,9 +106,11 @@ export default function VariantConfigEditor(props: {
> >
Reset Reset
</Button> </Button>
<Button size="xs" onClick={onSave}> <Tooltip label={`${modifierKey} + Enter`} withArrow>
Save <Button size="xs" onClick={onSave}>
</Button> Save
</Button>
</Tooltip>
</Group> </Group>
)} )}
</Box> </Box>

View File

@@ -5,6 +5,7 @@ import { api } from "~/utils/api";
import { notifications } from "@mantine/notifications"; import { notifications } from "@mantine/notifications";
import { type JSONSerializable } from "~/server/types"; import { type JSONSerializable } from "~/server/types";
import VariantConfigEditor from "./VariantConfigEditor"; import VariantConfigEditor from "./VariantConfigEditor";
import EditableVariantLabel from "./EditableVariantLabel";
export default function VariantHeader({ variant }: { variant: PromptVariant }) { export default function VariantHeader({ variant }: { variant: PromptVariant }) {
const replaceWithConfig = api.promptVariants.replaceWithConfig.useMutation(); const replaceWithConfig = api.promptVariants.replaceWithConfig.useMutation();
@@ -39,15 +40,14 @@ export default function VariantHeader({ variant }: { variant: PromptVariant }) {
}); });
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
// TODO: invalidate the variants query
}, },
[variant.id, replaceWithConfig, utils.promptVariants.list] [variant.id, replaceWithConfig, utils.promptVariants.list]
); );
return ( return (
<Stack w="100%"> // title="" to hide the title text that mantine-react-table likes to add
<Title order={4}>{variant.label}</Title> <Stack w="100%" title="">
<EditableVariantLabel variant={variant} />
<VariantConfigEditor savedConfig={JSON.stringify(variant.config, null, 2)} onSave={onSave} /> <VariantConfigEditor savedConfig={JSON.stringify(variant.config, null, 2)} onSave={onSave} />
</Stack> </Stack>
); );

View File

@@ -4,6 +4,7 @@ import { RouterOutputs, api } from "~/utils/api";
import { PromptVariant } from "./types"; import { PromptVariant } from "./types";
import VariantHeader from "./VariantHeader"; import VariantHeader from "./VariantHeader";
import OutputCell from "./OutputCell"; import OutputCell from "./OutputCell";
import ScenarioHeader from "./ScenarioHeader";
type CellData = { type CellData = {
variant: PromptVariant; variant: PromptVariant;
@@ -15,11 +16,6 @@ type TableRow = {
} & Record<string, CellData>; } & Record<string, CellData>;
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) { export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const experiment = api.experiments.get.useQuery(
{ id: experimentId as string },
{ enabled: !!experimentId }
);
const variants = api.promptVariants.list.useQuery( const variants = api.promptVariants.list.useQuery(
{ experimentId: experimentId as string }, { experimentId: experimentId as string },
{ enabled: !!experimentId } { enabled: !!experimentId }
@@ -37,9 +33,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
header: "Scenario", header: "Scenario",
enableColumnDragging: false, enableColumnDragging: false,
size: 200, size: 200,
Cell: ({ row }) => { Cell: ({ row }) => <ScenarioHeader scenario={row.original.scenario} />,
return <div>{JSON.stringify(row.original.scenario.variableValues)}</div>;
},
}, },
...(variants.data?.map( ...(variants.data?.map(
(variant): MRT_ColumnDef<TableRow> => ({ (variant): MRT_ColumnDef<TableRow> => ({
@@ -91,7 +85,6 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
"& .mantine-TableHeadCell-Content": { "& .mantine-TableHeadCell-Content": {
width: "100%", width: "100%",
height: "100%", height: "100%",
// display: "flex",
"& .mantine-TableHeadCell-Content-Actions": { "& .mantine-TableHeadCell-Content-Actions": {
alignSelf: "flex-start", alignSelf: "flex-start",

View File

@@ -8,6 +8,17 @@ export const experimentsRouter = createTRPCRouter({
where: { where: {
id: input.id, id: input.id,
}, },
include: {
TemplateVariable: {
orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
},
},
}); });
}), }),
}); });

View File

@@ -1,7 +1,7 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { JSONSerializable, OpenAIChatConfig } from "~/server/types"; import { OpenAIChatConfig } from "~/server/types";
export const promptVariantsRouter = createTRPCRouter({ export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -16,6 +16,34 @@ export const promptVariantsRouter = createTRPCRouter({
}); });
}), }),
update: publicProcedure
.input(
z.object({
id: z.string(),
updates: z.object({
label: z.string().optional(),
}),
})
)
.mutation(async ({ input }) => {
const existing = await prisma.promptVariant.findUnique({
where: {
id: input.id,
},
});
if (!existing) {
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
}
return await prisma.promptVariant.update({
where: {
id: input.id,
},
data: input.updates,
});
}),
replaceWithConfig: publicProcedure replaceWithConfig: publicProcedure
.input( .input(
z.object({ z.object({
@@ -41,14 +69,6 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`); throw new Error(`Prompt Variant with id ${input.id} does not exist`);
} }
console.log("new config", {
experimentId: existing.experimentId,
label: existing.label,
sortIndex: existing.sortIndex,
uiId: existing.uiId,
config: parsedConfig,
});
// Create a duplicate with only the config changed // Create a duplicate with only the config changed
const newVariant = await prisma.promptVariant.create({ const newVariant = await prisma.promptVariant.create({
data: { data: {

View File

@@ -7,7 +7,51 @@ export const scenariosRouter = createTRPCRouter({
return await prisma.testScenario.findMany({ return await prisma.testScenario.findMany({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
visible: true,
}, },
}); });
}), }),
replaceWithValues: publicProcedure
.input(
z.object({
id: z.string(),
values: z.record(z.string()),
})
)
.mutation(async ({ input }) => {
const existing = await prisma.testScenario.findUnique({
where: {
id: input.id,
},
});
if (!existing) {
throw new Error(`Scenario with id ${input.id} does not exist`);
}
const newScenario = await prisma.testScenario.create({
data: {
experimentId: existing.experimentId,
sortIndex: existing.sortIndex,
variableValues: input.values,
uiId: existing.uiId,
},
});
// Hide the old scenario
await prisma.testScenario.updateMany({
where: {
uiId: existing.uiId,
id: {
not: newScenario.id,
},
},
data: {
visible: false,
},
});
return newScenario;
}),
}); });