This commit is contained in:
Kyle Corbitt
2023-07-06 15:20:45 -07:00
parent 0844fb0da7
commit 1fa0d7bc62
5 changed files with 34 additions and 28 deletions

View File

@@ -65,14 +65,14 @@ export function EvaluationEditor(props: {
</HStack>
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
<Input
size="sm"
value={values.matchString}
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
/>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
</FormControl>
<HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray">

View File

@@ -1,11 +1,11 @@
import { api } from "~/utils/api";
import { type RouterOutputs, api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "./types";
import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react";
import { useExperiment } from "~/utils/hooks";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { useMemo, type ReactElement } from "react";
import { useMemo, type ReactElement, useState, useEffect } from "react";
import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
@@ -24,6 +24,7 @@ export default function OutputCell({
scenario: Scenario;
variant: PromptVariant;
}): ReactElement | null {
const utils = api.useContext();
const experiment = useExperiment();
const vars = api.templateVars.list.useQuery({
experimentId: experiment.data?.id ?? "",
@@ -53,41 +54,47 @@ export default function OutputCell({
return generateChannel();
}, [shouldStream]);
const output = api.outputs.get.useQuery(
{
const outputMutation = api.outputs.get.useMutation();
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => {
const output = await outputMutation.mutateAsync({
scenarioId: scenario.id,
variantId: variant.id,
channel,
},
{ enabled: disabledReason === null }
);
});
setOutput(output);
await utils.evaluations.results.invalidate();
}, [outputMutation, scenario.id, variant.id, channel]);
useEffect(fetchOutput, []);
// Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(output.isLoading ? channel : undefined);
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (output.isLoading && !streamedMessage)
if (fetchingOutput && !streamedMessage)
return (
<Center h="100%" w="100%">
<Spinner />
</Center>
);
if (!output.data && !output.isLoading)
return <Text color="gray.500">Error retrieving output</Text>;
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (output.data && output.data.errorMessage) {
return <Text color="red.600">Error: {output.data.errorMessage}</Text>;
if (output && output.errorMessage) {
return <Text color="red.600">Error: {output.errorMessage}</Text>;
}
const response = output.data?.output as unknown as ChatCompletion;
const response = output?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message;
if (output.data && message?.function_call) {
if (output && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string;
try {
@@ -115,18 +122,17 @@ export default function OutputCell({
{ maxLength: 40 }
)}
</SyntaxHighlighter>
<OutputStats model={model} modelOutput={output.data} scenario={scenario} />
<OutputStats model={model} modelOutput={output} scenario={scenario} />
</Box>
);
}
const contentToDisplay =
message?.content ?? streamedContent ?? JSON.stringify(output.data?.output);
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay}
{output.data && <OutputStats model={model} modelOutput={output.data} scenario={scenario} />}
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
</Flex>
);
}

View File

@@ -26,9 +26,9 @@ export default function ScenarioEditor({
const [values, setValues] = useState<Record<string, string>>(savedValues);
const experiment = useExperiment();
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data;
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" });
const variableLabels = vars?.map((v) => v.label) ?? [];
const variableLabels = vars.data?.map((v) => v.label) ?? [];
const hasChanged = !isEqual(savedValues, values);
@@ -117,7 +117,7 @@ export default function ScenarioEditor({
/>
</Stack>
{variableLabels.length === 0 ? (
<Box color="gray.500">No scenario variables configured</Box>
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
) : (
<Stack>
{variableLabels.map((key) => {

View File

@@ -13,7 +13,7 @@ export const modelOutputsRouter = createTRPCRouter({
.input(
z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })
)
.query(async ({ input }) => {
.mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
where: {
promptVariantId_testScenarioId: {

View File

@@ -8,7 +8,7 @@ export const evaluateOutput = (
evaluation: Evaluation
): boolean => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output.choices?.[0]?.message;
const message = output?.choices?.[0]?.message;
if (!message) return false;