cache eval outputs and add gpt4 eval

This commit is contained in:
Kyle Corbitt
2023-07-17 17:55:36 -07:00
parent 011b12abb9
commit 7d41e94ca2
8 changed files with 168 additions and 115 deletions

View File

@@ -0,0 +1,39 @@
/*
Warnings:
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
*/
-- AlterEnum
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
-- DropTable
DROP TABLE "EvaluationResult";
-- CreateTable
CREATE TABLE "OutputEvaluation" (
"id" UUID NOT NULL,
"result" DOUBLE PRECISION NOT NULL,
"details" TEXT,
"modelOutputId" UUID NOT NULL,
"evaluationId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -132,6 +132,7 @@ model ModelOutput {
enum EvalType {
CONTAINS
DOES_NOT_CONTAIN
GPT4_EVAL
}
model Evaluation {

View File

@@ -4,7 +4,7 @@ import React from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement,
TextareaProps
TextareaProps & { minRows?: number }
> = (props, ref) => {
return (
<Textarea

View File

@@ -11,12 +11,14 @@ import {
FormLabel,
Select,
FormHelperText,
Code,
} from "@chakra-ui/react";
import { type Evaluation, EvalType } from "@prisma/client";
import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import AutoResizeTextArea from "../AutoResizeTextArea";
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
@@ -36,7 +38,7 @@ export function EvaluationEditor(props: {
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
<HStack w="100%">
<FormControl flex={1}>
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
<FormLabel fontSize="sm">Eval Name</FormLabel>
<Input
size="sm"
value={values.label}
@@ -44,7 +46,7 @@ export function EvaluationEditor(props: {
/>
</FormControl>
<FormControl flex={1}>
<FormLabel fontSize="sm">Match Type</FormLabel>
<FormLabel fontSize="sm">Eval Type</FormLabel>
<Select
size="sm"
value={values.evalType}
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
</Select>
</FormControl>
</HStack>
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<Input
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
/>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
</FormControl>
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<Input
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
/>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output. You
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
</FormHelperText>
</FormControl>
)}
{values.evalType === "GPT4_EVAL" && (
<FormControl pt={2}>
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
<AutoResizeTextArea
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
minRows={3}
/>
<FormHelperText>
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
access to the specific prompt variant, so be sure to be clear about the task you want it
to perform.
</FormHelperText>
</FormControl>
)}
<HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
Cancel

View File

@@ -1,4 +1,4 @@
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react";
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
import { useState } from "react";
import { BsCheck, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
<Heading size="sm">Scenario Variables</Heading>
<Stack spacing={2}>
<Text fontSize="sm">
Scenario variables can be used in your prompt variants as well as evaluations. Reference
them using <Code>{"{{curly_braces}}"}</Code>.
Scenario variables can be used in your prompt variants as well as evaluations.
</Text>
<HStack spacing={0}>
<Input

View File

@@ -2,7 +2,7 @@ import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types";
import { type RouterOutputs } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { HStack, Icon, Text } from "@chakra-ui/react";
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip";
@@ -36,14 +36,20 @@ export const OutputStats = ({
{modelOutput.outputEvaluation.map((evaluation) => {
const passed = evaluation.result > 0.5;
return (
<HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.evaluation.label}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
<Tooltip
isDisabled={!evaluation.details}
label={evaluation.details}
key={evaluation.id}
>
<HStack spacing={0}>
<Text>{evaluation.evaluation.label}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
</Tooltip>
);
})}
</HStack>

View File

@@ -4,7 +4,7 @@ import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
const result = runOneEval(evaluation, scenario, modelOutput);
const result = await runOneEval(evaluation, scenario, modelOutput);
return await prisma.outputEvaluation.upsert({
where: {
modelOutputId_evaluationId: {
@@ -15,10 +15,10 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu
create: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id,
result,
...result,
},
update: {
result,
...result,
},
});
};
@@ -35,43 +35,6 @@ export const runEvalsForOutput = async (
await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
);
// const cells = await prisma.scenarioVariantCell.findMany({
// where: {
// promptVariantId: variantId,
// retrievalStatus: "COMPLETE",
// testScenario: { visible: true },
// },
// include: { testScenario: true, modelOutput: { include: { OutputEvaluation: true } } },
// });
// await Promise.all(
// evaluations.map(async (evaluation) => {
// const passCount = cells.filter((cell) =>
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
// ).length;
// const failCount = cells.length - passCount;
// await prisma.evaluationResult.upsert({
// where: {
// evaluationId_promptVariantId: {
// evaluationId: evaluation.id,
// promptVariantId: variantId,
// },
// },
// create: {
// evaluationId: evaluation.id,
// promptVariantId: variantId,
// passCount,
// failCount,
// },
// update: {
// passCount,
// failCount,
// },
// });
// }),
// );
};
export const runAllEvals = async (experimentId: string) => {
@@ -113,42 +76,4 @@ export const runAllEvals = async (experimentId: string) => {
);
}),
);
// const cells = await prisma.scenarioVariantCell.findMany({
// where: {
// promptVariantId: { in: variants.map((v) => v.id) },
// testScenario: { visible: true },
// statusCode: { notIn: [429] },
// },
// include: { testScenario: true, modelOutput: true },
// });
// await Promise.all(
// variants.map(async (variant) => {
// const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
// const passCount = variantCells.filter((cell) =>
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
// ).length;
// const failCount = variantCells.length - passCount;
// await prisma.evaluationResult.upsert({
// where: {
// evaluationId_promptVariantId: {
// evaluationId: evaluation.id,
// promptVariantId: variant.id,
// },
// },
// create: {
// evaluationId: evaluation.id,
// promptVariantId: variant.id,
// passCount,
// failCount,
// },
// update: {
// passCount,
// failCount,
// },
// });
// }),
// );
};

View File

@@ -1,32 +1,93 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate";
import { openai } from "./openai";
import dedent from "dedent";
export const runOneEval = (
export const runGpt4Eval = async (
evaluation: Evaluation,
scenario: TestScenario,
message: ChatCompletion.Choice.Message,
): Promise<{ result: number; details: string }> => {
const output = await openai.chat.completions.create({
model: "gpt-4-0613",
messages: [
{
role: "system",
content: dedent`
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
---
${evaluation.value}
`,
},
{
role: "user",
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
},
{
role: "user",
content: `The full output of the simpler message:\n---\n${JSON.stringify(
message.content ?? message.function_call,
null,
2,
)}`,
},
],
function_call: {
name: "report_success",
},
functions: [
{
name: "report_success",
parameters: {
type: "object",
required: ["thoughts", "success"],
properties: {
thoughts: {
type: "string",
description: "Explain your reasoning for considering this a pass or fail",
},
success: {
type: "boolean",
description:
"Whether the simpler model successfully completed the task for this scenario",
},
},
},
},
],
});
try {
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
} catch (e) {
console.error(e);
return { result: 0, details: "Error parsing GPT-4 output" };
}
};
export const runOneEval = async (
evaluation: Evaluation,
scenario: TestScenario,
modelOutput: ModelOutput,
): number => {
): Promise<{ result: number; details?: string }> => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return 0;
if (!message) return { result: 0 };
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
let result;
switch (evaluation.evalType) {
case "CONTAINS":
result = stringifiedMessage.match(matchRegex) !== null ? 1 : 0;
break;
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
case "DOES_NOT_CONTAIN":
result = stringifiedMessage.match(matchRegex) === null ? 1 : 0;
break;
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
case "GPT4_EVAL":
return await runGpt4Eval(evaluation, scenario, message);
}
return result;
};