This commit is contained in:
Kyle Corbitt
2023-07-06 15:20:45 -07:00
parent 0844fb0da7
commit 1fa0d7bc62
5 changed files with 34 additions and 28 deletions

View File

@@ -65,14 +65,14 @@ export function EvaluationEditor(props: {
</HStack> </HStack>
<FormControl> <FormControl>
<FormLabel fontSize="sm">Match String</FormLabel> <FormLabel fontSize="sm">Match String</FormLabel>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
<Input <Input
size="sm" size="sm"
value={values.matchString} value={values.matchString}
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))} onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
/> />
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
</FormControl> </FormControl>
<HStack alignSelf="flex-end"> <HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray"> <Button size="sm" onClick={props.onCancel} colorScheme="gray">

View File

@@ -1,11 +1,11 @@
import { api } from "~/utils/api"; import { type RouterOutputs, api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "./types"; import { type PromptVariant, type Scenario } from "./types";
import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react"; import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react";
import { useExperiment } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { useMemo, type ReactElement } from "react"; import { useMemo, type ReactElement, useState, useEffect } from "react";
import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs"; import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client"; import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
@@ -24,6 +24,7 @@ export default function OutputCell({
scenario: Scenario; scenario: Scenario;
variant: PromptVariant; variant: PromptVariant;
}): ReactElement | null { }): ReactElement | null {
const utils = api.useContext();
const experiment = useExperiment(); const experiment = useExperiment();
const vars = api.templateVars.list.useQuery({ const vars = api.templateVars.list.useQuery({
experimentId: experiment.data?.id ?? "", experimentId: experiment.data?.id ?? "",
@@ -53,41 +54,47 @@ export default function OutputCell({
return generateChannel(); return generateChannel();
}, [shouldStream]); }, [shouldStream]);
const output = api.outputs.get.useQuery( const outputMutation = api.outputs.get.useMutation();
{
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => {
const output = await outputMutation.mutateAsync({
scenarioId: scenario.id, scenarioId: scenario.id,
variantId: variant.id, variantId: variant.id,
channel, channel,
}, });
{ enabled: disabledReason === null } setOutput(output);
); await utils.evaluations.results.invalidate();
}, [outputMutation, scenario.id, variant.id, channel]);
useEffect(fetchOutput, []);
// Disconnect from socket if we're not streaming anymore // Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(output.isLoading ? channel : undefined); const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content; const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null; if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>; if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (output.isLoading && !streamedMessage) if (fetchingOutput && !streamedMessage)
return ( return (
<Center h="100%" w="100%"> <Center h="100%" w="100%">
<Spinner /> <Spinner />
</Center> </Center>
); );
if (!output.data && !output.isLoading) if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
return <Text color="gray.500">Error retrieving output</Text>;
if (output.data && output.data.errorMessage) { if (output && output.errorMessage) {
return <Text color="red.600">Error: {output.data.errorMessage}</Text>; return <Text color="red.600">Error: {output.errorMessage}</Text>;
} }
const response = output.data?.output as unknown as ChatCompletion; const response = output?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message; const message = response?.choices?.[0]?.message;
if (output.data && message?.function_call) { if (output && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null"; const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string; let parsedArgs: string;
try { try {
@@ -115,18 +122,17 @@ export default function OutputCell({
{ maxLength: 40 } { maxLength: 40 }
)} )}
</SyntaxHighlighter> </SyntaxHighlighter>
<OutputStats model={model} modelOutput={output.data} scenario={scenario} /> <OutputStats model={model} modelOutput={output} scenario={scenario} />
</Box> </Box>
); );
} }
const contentToDisplay = const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
message?.content ?? streamedContent ?? JSON.stringify(output.data?.output);
return ( return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap"> <Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay} {contentToDisplay}
{output.data && <OutputStats model={model} modelOutput={output.data} scenario={scenario} />} {output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
</Flex> </Flex>
); );
} }

View File

@@ -26,9 +26,9 @@ export default function ScenarioEditor({
const [values, setValues] = useState<Record<string, string>>(savedValues); const [values, setValues] = useState<Record<string, string>>(savedValues);
const experiment = useExperiment(); const experiment = useExperiment();
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data; const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" });
const variableLabels = vars?.map((v) => v.label) ?? []; const variableLabels = vars.data?.map((v) => v.label) ?? [];
const hasChanged = !isEqual(savedValues, values); const hasChanged = !isEqual(savedValues, values);
@@ -117,7 +117,7 @@ export default function ScenarioEditor({
/> />
</Stack> </Stack>
{variableLabels.length === 0 ? ( {variableLabels.length === 0 ? (
<Box color="gray.500">No scenario variables configured</Box> <Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
) : ( ) : (
<Stack> <Stack>
{variableLabels.map((key) => { {variableLabels.map((key) => {

View File

@@ -13,7 +13,7 @@ export const modelOutputsRouter = createTRPCRouter({
.input( .input(
z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() }) z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })
) )
.query(async ({ input }) => { .mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({ const existing = await prisma.modelOutput.findUnique({
where: { where: {
promptVariantId_testScenarioId: { promptVariantId_testScenarioId: {

View File

@@ -8,7 +8,7 @@ export const evaluateOutput = (
evaluation: Evaluation evaluation: Evaluation
): boolean => { ): boolean => {
const output = modelOutput.output as unknown as ChatCompletion; const output = modelOutput.output as unknown as ChatCompletion;
const message = output.choices?.[0]?.message; const message = output?.choices?.[0]?.message;
if (!message) return false; if (!message) return false;