autogen scenarios
This commit is contained in:
@@ -1,12 +1,27 @@
|
||||
import { Button } from "@chakra-ui/react";
|
||||
import { Button, type ButtonProps, Fade, HStack } from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
// Extracted Button styling into reusable component
|
||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
||||
<Button
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={2}
|
||||
onClick={onClick}
|
||||
>
|
||||
{children}
|
||||
</Button>
|
||||
);
|
||||
|
||||
export default function NewScenarioButton() {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.scenarios.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
const [hovering, setHovering] = useState(false);
|
||||
|
||||
const [onClick] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
@@ -16,19 +31,31 @@ export default function NewScenarioButton() {
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const [onAutogenerate] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
autogenerate: true,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
return (
|
||||
<Button
|
||||
w="100%"
|
||||
alignItems="center"
|
||||
justifyContent="flex-start"
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={2}
|
||||
onClick={onClick}
|
||||
<HStack
|
||||
spacing={2}
|
||||
onMouseEnter={() => setHovering(true)}
|
||||
onMouseLeave={() => setHovering(false)}
|
||||
>
|
||||
<BsPlus size={24} />
|
||||
Add Scenario
|
||||
</Button>
|
||||
<StyledButton onClick={onClick}>
|
||||
<BsPlus size={24} />
|
||||
Add Scenario
|
||||
</StyledButton>
|
||||
<Fade in={hovering}>
|
||||
<StyledButton onClick={onAutogenerate}>
|
||||
<BsPlus size={24} />
|
||||
Autogenerate Scenario
|
||||
</StyledButton>
|
||||
</Fade>
|
||||
</HStack>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import { api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "./types";
|
||||
import { Center, Spinner, Text } from "@chakra-ui/react";
|
||||
import { Spinner, Text, Box } from "@chakra-ui/react";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { cellPadding } from "../constants";
|
||||
|
||||
const CellShell = ({ children }: { children: React.ReactNode }) => (
|
||||
<Center h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
||||
{children}
|
||||
</Center>
|
||||
);
|
||||
import { type CreateChatCompletionResponse } from "openai";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
import stringify from "json-stringify-pretty-compact";
|
||||
import { type ReactElement } from "react";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -16,7 +14,7 @@ export default function OutputCell({
|
||||
}: {
|
||||
scenario: Scenario;
|
||||
variant: PromptVariant;
|
||||
}) {
|
||||
}): ReactElement | null {
|
||||
const experiment = useExperiment();
|
||||
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data;
|
||||
|
||||
@@ -41,37 +39,34 @@ export default function OutputCell({
|
||||
|
||||
if (!vars) return null;
|
||||
|
||||
if (disabledReason)
|
||||
return (
|
||||
<CellShell>
|
||||
<Text color="gray.500">{disabledReason}</Text>
|
||||
</CellShell>
|
||||
);
|
||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||
|
||||
if (output.isLoading)
|
||||
return (
|
||||
<CellShell>
|
||||
<Spinner />
|
||||
</CellShell>
|
||||
);
|
||||
if (output.isLoading) return <Spinner />;
|
||||
|
||||
if (!output.data)
|
||||
return (
|
||||
<CellShell>
|
||||
<Text color="gray.500">Error retrieving output</Text>
|
||||
</CellShell>
|
||||
);
|
||||
if (!output.data) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
|
||||
if (output.data.errorMessage) {
|
||||
return <Text color="red.600">Error: {output.data.errorMessage}</Text>;
|
||||
}
|
||||
|
||||
const response = output.data?.output as unknown as CreateChatCompletionResponse;
|
||||
const message = response?.choices?.[0]?.message;
|
||||
|
||||
if (message?.function_call) {
|
||||
return (
|
||||
<CellShell>
|
||||
<Text color="red.600">Error: {output.data.errorMessage}</Text>
|
||||
</CellShell>
|
||||
<Box fontSize="xs">
|
||||
<SyntaxHighlighter language="json" style={docco}>
|
||||
{stringify(
|
||||
{
|
||||
function: message.function_call.name,
|
||||
args: JSON.parse(message.function_call.arguments ?? "null"),
|
||||
},
|
||||
{ maxLength: 40 }
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
// @ts-expect-error TODO proper typing and error checks
|
||||
<CellShell>{output.data.output.choices[0].message.content}</CellShell>
|
||||
);
|
||||
return <Box>{message?.content ?? JSON.stringify(output.data.output)}</Box>;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Grid, GridItem, type SystemStyleObject } from "@chakra-ui/react";
|
||||
import { Center, Grid, GridItem, type SystemStyleObject } from "@chakra-ui/react";
|
||||
import React, { useState } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import NewScenarioButton from "./NewScenarioButton";
|
||||
@@ -9,6 +9,7 @@ import VariantConfigEditor from "./VariantConfigEditor";
|
||||
import VariantHeader from "./VariantHeader";
|
||||
import type { Scenario, PromptVariant } from "./types";
|
||||
import ScenarioHeader from "~/server/ScenarioHeader";
|
||||
import { cellPadding } from "../constants";
|
||||
|
||||
const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
@@ -41,7 +42,9 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
>
|
||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||
<Center h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||
</Center>
|
||||
</GridItem>
|
||||
))}
|
||||
</React.Fragment>
|
||||
|
||||
157
src/server/api/autogen.ts
Normal file
157
src/server/api/autogen.ts
Normal file
@@ -0,0 +1,157 @@
|
||||
import { type CreateChatCompletionRequest } from "openai";
|
||||
import { prisma } from "../db";
|
||||
import { openai } from "../utils/openai";
|
||||
import { pick } from "lodash";
|
||||
|
||||
function promptHasVariable(prompt: string, variableName: string) {
|
||||
return prompt.includes(`{{${variableName}}}`);
|
||||
}
|
||||
|
||||
type AxiosError = {
|
||||
response?: {
|
||||
data?: {
|
||||
error?: {
|
||||
message?: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
function isAxiosError(error: unknown): error is AxiosError {
|
||||
if (typeof error === "object" && error !== null) {
|
||||
// Initial check
|
||||
const err = error as AxiosError;
|
||||
return err.response?.data?.error?.message !== undefined; // Check structure
|
||||
}
|
||||
return false;
|
||||
}
|
||||
export const autogenerateScenarioValues = async (
|
||||
experimentId: string
|
||||
): Promise<Record<string, string>> => {
|
||||
const [experiment, variables, existingScenarios, prompt] = await Promise.all([
|
||||
prisma.experiment.findUnique({
|
||||
where: {
|
||||
id: experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
take: 10,
|
||||
}),
|
||||
prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!experiment || !(variables?.length > 0) || !prompt) return {};
|
||||
|
||||
const messages: CreateChatCompletionRequest["messages"] = [
|
||||
{
|
||||
role: "system",
|
||||
content:
|
||||
"The user is testing multiple scenarios against the same prompt. Attempt to generate a new scenario that is different from the others.",
|
||||
},
|
||||
];
|
||||
|
||||
const promptText = JSON.stringify(prompt.config);
|
||||
if (variables.some((variable) => promptHasVariable(promptText, variable.label))) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Prompt template:\n---\n${promptText}`,
|
||||
});
|
||||
}
|
||||
|
||||
existingScenarios
|
||||
.map(
|
||||
(scenario) =>
|
||||
pick(
|
||||
scenario.variableValues,
|
||||
variables.map((variable) => variable.label)
|
||||
) as Record<string, string>
|
||||
)
|
||||
.filter((vals) => Object.keys(vals ?? {}).length > 0)
|
||||
.forEach((vals) => {
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
// @ts-expect-error the openai type definition is wrong, the content field is required
|
||||
content: null,
|
||||
function_call: {
|
||||
name: "add_scenario",
|
||||
arguments: JSON.stringify(vals),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const variableProperties = variables.reduce((acc, variable) => {
|
||||
acc[variable.label] = { type: "string" };
|
||||
return acc;
|
||||
}, {} as Record<string, { type: "string" }>);
|
||||
|
||||
console.log({
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages,
|
||||
functions: [
|
||||
{
|
||||
name: "add_scenario",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: variableProperties,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_scenario" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
try {
|
||||
const completion = await openai.createChatCompletion({
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages,
|
||||
functions: [
|
||||
{
|
||||
name: "add_scenario",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: variableProperties,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_scenario" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(
|
||||
completion.data.choices[0]?.message?.function_call?.arguments ?? "{}"
|
||||
) as Record<string, string>;
|
||||
return parsed;
|
||||
} catch (e) {
|
||||
// If it's an axios error, try to get the error message
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
console.error(e);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
return {};
|
||||
};
|
||||
@@ -1,6 +1,7 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { autogenerateScenarioValues } from "../autogen";
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
@@ -19,6 +20,7 @@ export const scenariosRouter = createTRPCRouter({
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
autogenerate: z.boolean().optional(),
|
||||
})
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
@@ -38,7 +40,9 @@ export const scenariosRouter = createTRPCRouter({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
sortIndex: maxSortIndex + 1,
|
||||
variableValues: {},
|
||||
variableValues: input.autogenerate
|
||||
? await autogenerateScenarioValues(input.experimentId)
|
||||
: {},
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
8
src/server/utils/openai.ts
Normal file
8
src/server/utils/openai.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { Configuration, OpenAIApi } from "openai";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
const configuration = new Configuration({
|
||||
apiKey: env.OPENAI_API_KEY,
|
||||
});
|
||||
|
||||
export const openai = new OpenAIApi(configuration);
|
||||
Reference in New Issue
Block a user