Compare commits

..

13 Commits

Author SHA1 Message Date
David Corbitt
999a4c08fa Fix lint and prettier 2023-07-18 11:11:20 -07:00
arcticfly
374d0237ee Escape characters in Regex evaluations, minor UI fixes (#56)
* Fix ScenariosHeader stickiness

* Move meta tag from _app.tsx to _document.tsx

* Show spinner when saving variant

* Escape quotes and regex in evaluations
2023-07-18 11:07:04 -07:00
David Corbitt
b1f873623d Invalidate prompt stats after cell refetch 2023-07-18 09:45:11 -07:00
arcticfly
4131aa67d0 Continue polling VariantStats while LLM retrieval in progress, minor UI fixes (#54)
* Prevent zoom in on iOS

* Expand function return code background to fill cell

* Keep OutputStats on far right of cells

* Continue polling prompt stats while cells are retrieving from LLM

* Add comment to _document.tsx

* Fix prettier
2023-07-17 18:04:38 -07:00
Kyle Corbitt
8e7a6d3ae2 Merge pull request #55 from OpenPipe/more-eval
Add GPT4 Evals
2023-07-17 18:01:47 -07:00
Kyle Corbitt
7d41e94ca2 cache eval outputs and add gpt4 eval 2023-07-17 17:55:36 -07:00
Kyle Corbitt
011b12abb9 cache output evals 2023-07-17 17:52:30 -07:00
Kyle Corbitt
1ba18015bc Merge pull request #53 from OpenPipe/more-eval
Fix seeds and update eval field names
2023-07-17 14:26:29 -07:00
Kyle Corbitt
54369dba54 Fix seeds and update eval field names 2023-07-17 14:14:20 -07:00
arcticfly
6b84a59372 Properly catch completion errors (#51) 2023-07-17 10:50:25 -07:00
Kyle Corbitt
8db8aeacd3 Replace function chrome with comment
Use a block comment to explain the expected prompt formatting instead of function chrome. The advantage here is that once a user builds a mental model of how OpenPipe works they can just delete the comment, instead of the function chrome sitting around and taking up space in the UI forever.
2023-07-17 10:30:22 -07:00
Kyle Corbitt
64bd71e370 Merge pull request #50 from OpenPipe/remove-default
remove the default value for PromptVariant.model
2023-07-14 17:55:38 -07:00
Kyle Corbitt
3b99b7bd2b remove the default value for PromptVariant.model
We should be explicit about setting the appropriate model so it always matches the constructFn.
2023-07-14 17:43:52 -07:00
37 changed files with 737 additions and 376 deletions

View File

@@ -16,7 +16,8 @@
"postinstall": "prisma generate", "postinstall": "prisma generate",
"lint": "next lint", "lint": "next lint",
"start": "next start", "start": "next start",
"codegen": "tsx src/codegen/export-openai-types.ts" "codegen": "tsx src/codegen/export-openai-types.ts",
"seed": "tsx prisma/seed.ts"
}, },
"dependencies": { "dependencies": {
"@babel/preset-typescript": "^7.22.5", "@babel/preset-typescript": "^7.22.5",
@@ -49,7 +50,7 @@
"immer": "^10.0.2", "immer": "^10.0.2",
"isolated-vm": "^4.5.0", "isolated-vm": "^4.5.0",
"json-stringify-pretty-compact": "^4.0.0", "json-stringify-pretty-compact": "^4.0.0",
"lodash": "^4.17.21", "lodash-es": "^4.17.21",
"next": "^13.4.2", "next": "^13.4.2",
"next-auth": "^4.22.1", "next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1", "nextjs-routes": "^2.0.1",
@@ -77,7 +78,7 @@
"@types/cors": "^2.8.13", "@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0", "@types/eslint": "^8.37.0",
"@types/express": "^4.17.17", "@types/express": "^4.17.17",
"@types/lodash": "^4.14.195", "@types/lodash-es": "^4.17.8",
"@types/node": "^18.16.0", "@types/node": "^18.16.0",
"@types/pluralize": "^0.0.30", "@types/pluralize": "^0.0.30",
"@types/react": "^18.2.6", "@types/react": "^18.2.6",
@@ -100,6 +101,6 @@
"initVersion": "7.14.0" "initVersion": "7.14.0"
}, },
"prisma": { "prisma": {
"seed": "tsx prisma/seed.ts" "seed": "pnpm seed"
} }
} }

18
pnpm-lock.yaml generated
View File

@@ -95,7 +95,7 @@ dependencies:
json-stringify-pretty-compact: json-stringify-pretty-compact:
specifier: ^4.0.0 specifier: ^4.0.0
version: 4.0.0 version: 4.0.0
lodash: lodash-es:
specifier: ^4.17.21 specifier: ^4.17.21
version: 4.17.21 version: 4.17.21
next: next:
@@ -175,9 +175,9 @@ devDependencies:
'@types/express': '@types/express':
specifier: ^4.17.17 specifier: ^4.17.17
version: 4.17.17 version: 4.17.17
'@types/lodash': '@types/lodash-es':
specifier: ^4.14.195 specifier: ^4.17.8
version: 4.14.195 version: 4.17.8
'@types/node': '@types/node':
specifier: ^18.16.0 specifier: ^18.16.0
version: 18.16.0 version: 18.16.0
@@ -2753,6 +2753,12 @@ packages:
resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==} resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==}
dev: true dev: true
/@types/lodash-es@4.17.8:
resolution: {integrity: sha512-euY3XQcZmIzSy7YH5+Unb3b2X12Wtk54YWINBvvGQ5SmMvwb11JQskGsfkH/5HXK77Kr8GF0wkVDIxzAisWtog==}
dependencies:
'@types/lodash': 4.14.195
dev: true
/@types/lodash.mergewith@4.6.7: /@types/lodash.mergewith@4.6.7:
resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==} resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==}
dependencies: dependencies:
@@ -5379,6 +5385,10 @@ packages:
p-locate: 5.0.0 p-locate: 5.0.0
dev: true dev: true
/lodash-es@4.17.21:
resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==}
dev: false
/lodash.merge@4.6.2: /lodash.merge@4.6.2:
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
dev: true dev: true

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;

View File

@@ -0,0 +1,24 @@
/*
Warnings:
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
*/
-- RenameEnum
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
-- AlterTable
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
-- AlterColumnType
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
-- SetNotNullConstraint
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;

View File

@@ -0,0 +1,39 @@
/*
Warnings:
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
*/
-- AlterEnum
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
-- DropTable
DROP TABLE "EvaluationResult";
-- CreateTable
CREATE TABLE "OutputEvaluation" (
"id" UUID NOT NULL,
"result" DOUBLE PRECISION NOT NULL,
"details" TEXT,
"modelOutputId" UUID NOT NULL,
"evaluationId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -3,7 +3,6 @@
generator client { generator client {
provider = "prisma-client-js" provider = "prisma-client-js"
previewFeatures = ["jsonProtocol"]
} }
datasource db { datasource db {
@@ -30,7 +29,7 @@ model PromptVariant {
label String label String
constructFn String constructFn String
model String @default("gpt-3.5-turbo") model String
uiId String @default(uuid()) @db.Uuid uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true) visible Boolean @default(true)
@@ -42,7 +41,6 @@ model PromptVariant {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[] scenarioVariantCells ScenarioVariantCell[]
EvaluationResult EvaluationResult[]
@@index([uiId]) @@index([uiId])
} }
@@ -125,47 +123,50 @@ model ModelOutput {
scenarioVariantCellId String @db.Uuid scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluation OutputEvaluation[]
@@unique([scenarioVariantCellId]) @@unique([scenarioVariantCellId])
@@index([inputHash]) @@index([inputHash])
} }
enum EvaluationMatchType { enum EvalType {
CONTAINS CONTAINS
DOES_NOT_CONTAIN DOES_NOT_CONTAIN
GPT4_EVAL
} }
model Evaluation { model Evaluation {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
name String label String
matchString String evalType EvalType
matchType EvaluationMatchType value String
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
EvaluationResult EvaluationResult[] OutputEvaluation OutputEvaluation[]
} }
model EvaluationResult { model OutputEvaluation {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
passCount Int // Number between 0 (fail) and 1 (pass)
failCount Int result Float
details String?
modelOutputId String @db.Uuid
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade) evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([evaluationId, promptVariantId]) @@unique([modelOutputId, evaluationId])
} }
// Necessary for Next auth // Necessary for Next auth

View File

@@ -1,4 +1,6 @@
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
const experimentId = "11111111-1111-1111-1111-111111111111"; const experimentId = "11111111-1111-1111-1111-111111111111";
@@ -9,7 +11,7 @@ await prisma.experiment.deleteMany({
}, },
}); });
const experiment = await prisma.experiment.create({ await prisma.experiment.create({
data: { data: {
id: experimentId, id: experimentId,
label: "Country Capitals Example", label: "Country Capitals Example",
@@ -36,9 +38,16 @@ await prisma.promptVariant.createMany({
experimentId, experimentId,
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }], constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: \`What is the capital of ${"$"}{scenario.country}?\`
}
],
temperature: 0, temperature: 0,
}`, }`,
}, },
@@ -46,14 +55,15 @@ await prisma.promptVariant.createMany({
experimentId, experimentId,
label: "Prompt Variant 2", label: "Prompt Variant 2",
sortIndex: 1, sortIndex: 1,
constructFn: `prompt = { model: "gpt-3.5-turbo-0613",
constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
{ {
role: "user", role: "user",
content: content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
"What is the capital of {{country}}? Return just the city name and nothing else.", }
},
], ],
temperature: 0, temperature: 0,
}`, }`,
@@ -107,3 +117,26 @@ await prisma.testScenario.createMany({
}, },
], ],
}); });
const variants = await prisma.promptVariant.findMany({
where: {
experimentId,
},
});
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId,
},
});
await Promise.all(
variants
.flatMap((variant) =>
scenarios.map((scenario) => ({
promptVariantId: variant.id,
testScenarioId: scenario.id,
})),
)
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
);

View File

@@ -12,6 +12,7 @@ await prisma.promptVariant.createMany({
{ {
experimentId: functionCallsExperiment.id, experimentId: functionCallsExperiment.id,
label: "No Fn Calls", label: "No Fn Calls",
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
@@ -30,6 +31,7 @@ await prisma.promptVariant.createMany({
{ {
experimentId: functionCallsExperiment.id, experimentId: functionCallsExperiment.id,
label: "Fn Calls", label: "Fn Calls",
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
@@ -92,6 +94,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
label: "3.5 Base", label: "3.5 Base",
sortIndex: 0, sortIndex: 0,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
@@ -107,6 +110,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
label: "4 Base", label: "4 Base",
sortIndex: 1, sortIndex: 1,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-4-0613", model: "gpt-4-0613",
messages: [ messages: [
@@ -122,6 +126,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
label: "3.5 CoT + Functions", label: "3.5 CoT + Functions",
sortIndex: 2, sortIndex: 2,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
@@ -178,9 +183,9 @@ await prisma.templateVariable.createMany({
await prisma.evaluation.create({ await prisma.evaluation.create({
data: { data: {
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
name: "Relevance Accuracy", label: "Relevance Accuracy",
matchType: "CONTAINS", evalType: "CONTAINS",
matchString: '"{{relevance}}"', value: '"{{relevance}}"',
}, },
}); });
@@ -1119,12 +1124,3 @@ await prisma.testScenario.createMany({
variableValues: vars, variableValues: vars,
})), })),
}); });
// await prisma.evaluation.create({
// data: {
// experimentId: redditExperiment.id,
// name: "Scores Match",
// matchType: "CONTAINS",
// matchString: "{{score}}",
// },
// });

View File

@@ -2,7 +2,7 @@ import fs from "fs";
import path from "path"; import path from "path";
import openapiTS, { type OpenAPI3 } from "openapi-typescript"; import openapiTS, { type OpenAPI3 } from "openapi-typescript";
import YAML from "yaml"; import YAML from "yaml";
import _ from "lodash"; import { pick } from "lodash-es";
import assert from "assert"; import assert from "assert";
const OPENAPI_URL = const OPENAPI_URL =
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
delete schema["paths"]; delete schema["paths"];
assert(schema.components?.schemas); assert(schema.components?.schemas);
schema.components.schemas = _.pick(schema.components?.schemas, [ schema.components.schemas = pick(schema.components?.schemas, [
"CreateChatCompletionRequest", "CreateChatCompletionRequest",
"ChatCompletionRequestMessage", "ChatCompletionRequestMessage",
"ChatCompletionFunctions", "ChatCompletionFunctions",

View File

@@ -4,7 +4,7 @@ import React from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction< export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement, HTMLTextAreaElement,
TextareaProps TextareaProps & { minRows?: number }
> = (props, ref) => { > = (props, ref) => {
return ( return (
<Textarea <Textarea

View File

@@ -11,14 +11,16 @@ import {
FormLabel, FormLabel,
Select, Select,
FormHelperText, FormHelperText,
Code,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { type Evaluation, EvaluationMatchType } from "@prisma/client"; import { type Evaluation, EvalType } from "@prisma/client";
import { useCallback, useState } from "react"; import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs"; import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import AutoResizeTextArea from "../AutoResizeTextArea";
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">; type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
export function EvaluationEditor(props: { export function EvaluationEditor(props: {
evaluation: Evaluation | null; evaluation: Evaluation | null;
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
onCancel: () => void; onCancel: () => void;
}) { }) {
const [values, setValues] = useState<EvalValues>({ const [values, setValues] = useState<EvalValues>({
name: props.evaluation?.name ?? props.defaultName ?? "", label: props.evaluation?.label ?? props.defaultName ?? "",
matchString: props.evaluation?.matchString ?? "", value: props.evaluation?.value ?? "",
matchType: props.evaluation?.matchType ?? "CONTAINS", evalType: props.evaluation?.evalType ?? "CONTAINS",
}); });
return ( return (
<VStack borderTopWidth={1} borderColor="gray.200" py={4}> <VStack borderTopWidth={1} borderColor="gray.200" py={4}>
<HStack w="100%"> <HStack w="100%">
<FormControl flex={1}> <FormControl flex={1}>
<FormLabel fontSize="sm">Evaluation Name</FormLabel> <FormLabel fontSize="sm">Eval Name</FormLabel>
<Input <Input
size="sm" size="sm"
value={values.name} value={values.label}
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))} onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
/> />
</FormControl> </FormControl>
<FormControl flex={1}> <FormControl flex={1}>
<FormLabel fontSize="sm">Match Type</FormLabel> <FormLabel fontSize="sm">Eval Type</FormLabel>
<Select <Select
size="sm" size="sm"
value={values.matchType} value={values.evalType}
onChange={(e) => onChange={(e) =>
setValues((values) => ({ setValues((values) => ({
...values, ...values,
matchType: e.target.value as EvaluationMatchType, evalType: e.target.value as EvalType,
})) }))
} }
> >
{Object.values(EvaluationMatchType).map((type) => ( {Object.values(EvalType).map((type) => (
<option key={type} value={type}> <option key={type} value={type}>
{type} {type}
</option> </option>
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
</Select> </Select>
</FormControl> </FormControl>
</HStack> </HStack>
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
<FormControl> <FormControl>
<FormLabel fontSize="sm">Match String</FormLabel> <FormLabel fontSize="sm">Match String</FormLabel>
<Input <Input
size="sm" size="sm"
value={values.matchString} value={values.value}
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))} onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
/> />
<FormHelperText> <FormHelperText>
This string will be interpreted as a regex and checked against each model output. This string will be interpreted as a regex and checked against each model output. You
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
</FormHelperText> </FormHelperText>
</FormControl> </FormControl>
)}
{values.evalType === "GPT4_EVAL" && (
<FormControl pt={2}>
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
<AutoResizeTextArea
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
minRows={3}
/>
<FormHelperText>
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
access to the specific prompt variant, so be sure to be clear about the task you want it
to perform.
</FormHelperText>
</FormControl>
)}
<HStack alignSelf="flex-end"> <HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray"> <Button size="sm" onClick={props.onCancel} colorScheme="gray">
Cancel Cancel
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
} }
await utils.evaluations.list.invalidate(); await utils.evaluations.list.invalidate();
await utils.promptVariants.stats.invalidate(); await utils.promptVariants.stats.invalidate();
await utils.scenarioVariantCells.get.invalidate();
}, []); }, []);
const onCancel = useCallback(() => { const onCancel = useCallback(() => {
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
align="center" align="center"
key={evaluation.id} key={evaluation.id}
> >
<Text fontWeight="bold">{evaluation.name}</Text> <Text fontWeight="bold">{evaluation.label}</Text>
<Text flex={1}> <Text flex={1}>
{evaluation.matchType}: &quot;{evaluation.matchString}&quot; {evaluation.evalType}: &quot;{evaluation.value}&quot;
</Text> </Text>
<Button <Button
variant="unstyled" variant="unstyled"

View File

@@ -1,4 +1,4 @@
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react"; import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
import { useState } from "react"; import { useState } from "react";
import { BsCheck, BsX } from "react-icons/bs"; import { BsCheck, BsX } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
<Heading size="sm">Scenario Variables</Heading> <Heading size="sm">Scenario Variables</Heading>
<Stack spacing={2}> <Stack spacing={2}>
<Text fontSize="sm"> <Text fontSize="sm">
Scenario variables can be used in your prompt variants as well as evaluations. Reference Scenario variables can be used in your prompt variants as well as evaluations.
them using <Code>{"{{curly_braces}}"}</Code>.
</Text> </Text>
<HStack spacing={0}> <HStack spacing={0}>
<Input <Input

View File

@@ -1,6 +1,6 @@
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types"; import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react"; import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
@@ -50,12 +50,18 @@ export default function OutputCell({
scenarioId: scenario.id, scenarioId: scenario.id,
variantId: variant.id, variantId: variant.id,
}); });
await utils.promptVariants.stats.invalidate({
variantId: variant.id,
});
}, [hardRefetchMutate, scenario.id, variant.id]); }, [hardRefetchMutate, scenario.id, variant.id]);
const fetchingOutput = queryLoading || refetchingOutput; const fetchingOutput = queryLoading || refetchingOutput;
const awaitingOutput = const awaitingOutput =
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS"; !cell ||
cell.retrievalStatus === "PENDING" ||
cell.retrievalStatus === "IN_PROGRESS" ||
refetchingOutput;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]); useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput; const modelOutput = cell?.modelOutput;
@@ -95,11 +101,18 @@ export default function OutputCell({
} }
return ( return (
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto"> <VStack
<VStack w="full" spacing={0}> w="100%"
h="100%"
fontSize="xs"
flexWrap="wrap"
overflowX="auto"
justifyContent="space-between"
>
<VStack w="full" flex={1} spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} /> <CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<SyntaxHighlighter <SyntaxHighlighter
customStyle={{ overflowX: "unset" }} customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
language="json" language="json"
style={docco} style={docco}
lineProps={{ lineProps={{
@@ -117,7 +130,7 @@ export default function OutputCell({
</SyntaxHighlighter> </SyntaxHighlighter>
</VStack> </VStack>
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} /> <OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
</Box> </VStack>
); );
} }
@@ -125,7 +138,7 @@ export default function OutputCell({
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output); message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return ( return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap"> <VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
<VStack w="full" alignItems="flex-start" spacing={0}> <VStack w="full" alignItems="flex-start" spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} /> <CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text> <Text>{contentToDisplay}</Text>
@@ -133,6 +146,6 @@ export default function OutputCell({
{modelOutput && ( {modelOutput && (
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} /> <OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
)} )}
</Flex> </VStack>
); );
} }

View File

@@ -1,11 +1,8 @@
import { type ModelOutput } from "@prisma/client";
import { type SupportedModel } from "~/server/types"; import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types"; import { type Scenario } from "../types";
import { useExperiment } from "~/utils/hooks"; import { type RouterOutputs } from "~/utils/api";
import { api } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { evaluateOutput } from "~/server/utils/evaluateOutput"; import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
import { HStack, Icon, Text } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip"; import { CostTooltip } from "~/components/tooltip/CostTooltip";
@@ -15,16 +12,14 @@ const SHOW_TIME = true;
export const OutputStats = ({ export const OutputStats = ({
model, model,
modelOutput, modelOutput,
scenario,
}: { }: {
model: SupportedModel | string | null; model: SupportedModel | string | null;
modelOutput: ModelOutput; modelOutput: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
>;
scenario: Scenario; scenario: Scenario;
}) => { }) => {
const timeToComplete = modelOutput.timeToComplete; const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens; const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens; const completionTokens = modelOutput.completionTokens;
@@ -36,19 +31,25 @@ export const OutputStats = ({
const cost = promptCost + completionCost; const cost = promptCost + completionCost;
return ( return (
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>
{evals.map((evaluation) => { {modelOutput.outputEvaluation.map((evaluation) => {
const passed = evaluateOutput(modelOutput, scenario, evaluation); const passed = evaluation.result > 0.5;
return ( return (
<HStack spacing={0} key={evaluation.id}> <Tooltip
<Text>{evaluation.name}</Text> isDisabled={!evaluation.details}
label={evaluation.details}
key={evaluation.id}
>
<HStack spacing={0}>
<Text>{evaluation.evaluation.label}</Text>
<Icon <Icon
as={passed ? BsCheck : BsX} as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"} color={passed ? "green.500" : "red.500"}
boxSize={6} boxSize={6}
/> />
</HStack> </HStack>
</Tooltip>
); );
})} })}
</HStack> </HStack>

View File

@@ -1,6 +1,6 @@
import { type DragEvent } from "react"; import { type DragEvent } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { isEqual } from "lodash"; import { isEqual } from "lodash-es";
import { type Scenario } from "./types"; import { type Scenario } from "./types";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react"; import { useState } from "react";

View File

@@ -0,0 +1,49 @@
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
import { cellPadding } from "../constants";
import { useElementDimensions } from "~/utils/hooks";
import { stickyHeaderStyle } from "./styles";
import { BsPencil } from "react-icons/bs";
import { useAppStore } from "~/state/store";
export const ScenariosHeader = ({
headerRows,
numScenarios,
}: {
headerRows: number;
numScenarios: number;
}) => {
const openDrawer = useAppStore((s) => s.openDrawer);
const [ref, dimensions] = useElementDimensions();
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
return (
<GridItem
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ref={ref as any}
display="flex"
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// Only display the part of the grid item that has content
sx={{ ...stickyHeaderStyle, top: topValue }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({numScenarios})
</Heading>
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
</HStack>
</GridItem>
);
};

View File

@@ -1,10 +1,9 @@
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react"; import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react"; import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
import { editorBackground } from "~/state/sharedVariantEditor.slice";
export default function VariantEditor(props: { variant: PromptVariant }) { export default function VariantEditor(props: { variant: PromptVariant }) {
const monaco = useAppStore.use.sharedVariantEditor.monaco(); const monaco = useAppStore.use.sharedVariantEditor.monaco();
@@ -28,7 +27,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const utils = api.useContext(); const utils = api.useContext();
const toast = useToast(); const toast = useToast();
const [onSave] = useHandledAsyncCallback(async () => { const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
if (!editorRef.current) return; if (!editorRef.current) return;
await editorRef.current.getAction("editor.action.formatDocument")?.run(); await editorRef.current.getAction("editor.action.formatDocument")?.run();
@@ -133,19 +132,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
return ( return (
<Box w="100%" pos="relative"> <Box w="100%" pos="relative">
<VStack <div id={editorId} style={{ height: "400px", width: "100%" }}></div>
spacing={0}
align="stretch"
fontSize="xs"
fontWeight="bold"
color="gray.600"
py={2}
bgColor={editorBackground}
>
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
<code>{`return prompt; }`}</code>
</VStack>
{isChanged && ( {isChanged && (
<HStack pos="absolute" bottom={2} right={2}> <HStack pos="absolute" bottom={2} right={2}>
<Button <Button
@@ -159,8 +146,8 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
Reset Reset
</Button> </Button>
<Tooltip label={`${modifierKey} + Enter`}> <Tooltip label={`${modifierKey} + Enter`}>
<Button size="sm" onClick={onSave} colorScheme="blue"> <Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
Save {saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
</Button> </Button>
</Tooltip> </Tooltip>
</HStack> </HStack>

View File

@@ -1,12 +1,14 @@
import { HStack, Icon, Text, useToken } from "@chakra-ui/react"; import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import chroma from "chroma-js"; import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs"; import { BsCurrencyDollar } from "react-icons/bs";
import { CostTooltip } from "../tooltip/CostTooltip"; import { CostTooltip } from "../tooltip/CostTooltip";
import { useEffect, useState } from "react";
export default function VariantStats(props: { variant: PromptVariant }) { export default function VariantStats(props: { variant: PromptVariant }) {
const [refetchInterval, setRefetchInterval] = useState(0);
const { data } = api.promptVariants.stats.useQuery( const { data } = api.promptVariants.stats.useQuery(
{ {
variantId: props.variant.id, variantId: props.variant.id,
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
completionTokens: 0, completionTokens: 0,
scenarioCount: 0, scenarioCount: 0,
outputCount: 0, outputCount: 0,
awaitingRetrievals: false,
}, },
refetchInterval,
}, },
); );
// Poll every two seconds while we are waiting for LLM retrievals to finish
useEffect(
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
[data.awaitingRetrievals],
);
const [passColor, neutralColor, failColor] = useToken("colors", [ const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500", "green.500",
"gray.500", "gray.500",
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount; const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
return ( return (
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs"> <HStack
justifyContent="space-between"
alignItems="center"
mx="2"
fontSize="xs"
py={cellPadding.y}
>
{showNumFinished && ( {showNumFinished && (
<Text> <Text>
{data.outputCount} / {data.scenarioCount} {data.outputCount} / {data.scenarioCount}
</Text> </Text>
)} )}
<HStack px={cellPadding.x} py={cellPadding.y}> <HStack px={cellPadding.x}>
{data.evalResults.map((result) => { {data.evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount); const passedFrac = result.passCount / result.totalCount;
return ( return (
<HStack key={result.id}> <HStack key={result.id}>
<Text>{result.evaluation.name}</Text> <Text>{result.label}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold"> <Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}% {(passedFrac * 100).toFixed(1)}%
</Text> </Text>
@@ -55,17 +69,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
); );
})} })}
</HStack> </HStack>
{data.overallCost && ( {data.overallCost && !data.awaitingRetrievals ? (
<CostTooltip <CostTooltip
promptTokens={data.promptTokens} promptTokens={data.promptTokens}
completionTokens={data.completionTokens} completionTokens={data.completionTokens}
cost={data.overallCost} cost={data.overallCost}
> >
<HStack spacing={0} align="center" color="gray.500" my="2"> <HStack spacing={0} align="center" color="gray.500">
<Icon as={BsCurrencyDollar} /> <Icon as={BsCurrencyDollar} />
<Text mr={1}>{data.overallCost.toFixed(3)}</Text> <Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack> </HStack>
</CostTooltip> </CostTooltip>
) : (
<Skeleton height={4} width={12} mr={1} />
)} )}
</HStack> </HStack>
); );

View File

@@ -1,28 +1,19 @@
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react"; import { Grid, GridItem } from "@chakra-ui/react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import NewScenarioButton from "./NewScenarioButton"; import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton"; import NewVariantButton from "./NewVariantButton";
import ScenarioRow from "./ScenarioRow"; import ScenarioRow from "./ScenarioRow";
import VariantEditor from "./VariantEditor"; import VariantEditor from "./VariantEditor";
import VariantHeader from "./VariantHeader"; import VariantHeader from "./VariantHeader";
import { cellPadding } from "../constants";
import { BsPencil } from "react-icons/bs";
import VariantStats from "./VariantStats"; import VariantStats from "./VariantStats";
import { useAppStore } from "~/state/store"; import { ScenariosHeader } from "./ScenariosHeader";
import { stickyHeaderStyle } from "./styles";
const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "-1px",
backgroundColor: "#fff",
zIndex: 1,
};
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) { export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const variants = api.promptVariants.list.useQuery( const variants = api.promptVariants.list.useQuery(
{ experimentId: experimentId as string }, { experimentId: experimentId as string },
{ enabled: !!experimentId }, { enabled: !!experimentId },
); );
const openDrawer = useAppStore((s) => s.openDrawer);
const scenarios = api.scenarios.list.useQuery( const scenarios = api.scenarios.list.useQuery(
{ experimentId: experimentId as string }, { experimentId: experimentId as string },
@@ -49,32 +40,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
}} }}
fontSize="sm" fontSize="sm"
> >
<GridItem <ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
display="flex"
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
// so if the header height changes, this will need to be updated.
sx={{ ...stickyHeaderStyle, top: "-337px" }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({scenarios.data.length})
</Heading>
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
</HStack>
</GridItem>
{variants.data.map((variant) => ( {variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}> <GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>

View File

@@ -0,0 +1,8 @@
import { type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "-1px",
backgroundColor: "#fff",
zIndex: 1,
};

View File

@@ -20,7 +20,6 @@ export const CostTooltip = ({
color="gray.800" color="gray.800"
bgColor="gray.50" bgColor="gray.50"
borderWidth={1} borderWidth={1}
py={2}
hasArrow hasArrow
shouldWrapChildren shouldWrapChildren
label={ label={

View File

@@ -6,18 +6,27 @@ import { ChakraProvider } from "@chakra-ui/react";
import theme from "~/utils/theme"; import theme from "~/utils/theme";
import Favicon from "~/components/Favicon"; import Favicon from "~/components/Favicon";
import "~/utils/analytics"; import "~/utils/analytics";
import Head from "next/head";
const MyApp: AppType<{ session: Session | null }> = ({ const MyApp: AppType<{ session: Session | null }> = ({
Component, Component,
pageProps: { session, ...pageProps }, pageProps: { session, ...pageProps },
}) => { }) => {
return ( return (
<>
<Head>
<meta
name="viewport"
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
/>
</Head>
<SessionProvider session={session}> <SessionProvider session={session}>
<Favicon /> <Favicon />
<ChakraProvider theme={theme}> <ChakraProvider theme={theme}>
<Component {...pageProps} /> <Component {...pageProps} />
</ChakraProvider> </ChakraProvider>
</SessionProvider> </SessionProvider>
</>
); );
}; };

View File

@@ -1,7 +1,7 @@
import { type CompletionCreateParams } from "openai/resources/chat"; import { type CompletionCreateParams } from "openai/resources/chat";
import { prisma } from "../db"; import { prisma } from "../db";
import { openai } from "../utils/openai"; import { openai } from "../utils/openai";
import { pick } from "lodash"; import { pick } from "lodash-es";
type AxiosError = { type AxiosError = {
response?: { response?: {

View File

@@ -1,8 +1,8 @@
import { EvaluationMatchType } from "@prisma/client"; import { EvalType } from "@prisma/client";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { reevaluateEvaluation } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
export const evaluationsRouter = createTRPCRouter({ export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -18,21 +18,24 @@ export const evaluationsRouter = createTRPCRouter({
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
name: z.string(), label: z.string(),
matchString: z.string(), value: z.string(),
matchType: z.nativeEnum(EvaluationMatchType), evalType: z.nativeEnum(EvalType),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const evaluation = await prisma.evaluation.create({ await prisma.evaluation.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
name: input.name, label: input.label,
matchString: input.matchString, value: input.value,
matchType: input.matchType, evalType: input.evalType,
}, },
}); });
await reevaluateEvaluation(evaluation);
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
// to kick off a background job or something instead
await runAllEvals(input.experimentId);
}), }),
update: publicProcedure update: publicProcedure
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
z.object({ z.object({
id: z.string(), id: z.string(),
updates: z.object({ updates: z.object({
name: z.string().optional(), label: z.string().optional(),
matchString: z.string().optional(), value: z.string().optional(),
matchType: z.nativeEnum(EvaluationMatchType).optional(), evalType: z.nativeEnum(EvalType).optional(),
}), }),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
await prisma.evaluation.update({ const evaluation = await prisma.evaluation.update({
where: { id: input.id }, where: { id: input.id },
data: { data: {
name: input.updates.name, label: input.updates.label,
matchString: input.updates.matchString, value: input.updates.value,
matchType: input.updates.matchType, evalType: input.updates.evalType,
}, },
}); });
await reevaluateEvaluation(
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }), await prisma.outputEvaluation.deleteMany({
); where: {
evaluationId: evaluation.id,
},
});
// Re-run all evals. Other eval results will already be cached, so this
// should only re-run the updated one.
await runAllEvals(evaluation.experimentId);
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {

View File

@@ -71,11 +71,28 @@ export const experimentsRouter = createTRPCRouter({
experimentId: exp.id, experimentId: exp.id,
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
constructFn: dedent`prompt = { // The interpolated $ is necessary until dedent incorporates
// https://github.com/dmnd/dedent/pull/46
constructFn: dedent`
/**
* Use Javascript to define an OpenAI chat completion
* (https://platform.openai.com/docs/api-reference/chat/create) and
* assign it to the \`prompt\` variable.
*
* You have access to the current scenario in the \`scenario\`
* variable.
*/
prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
stream: true, stream: true,
messages: [{ role: "system", content: ${"`Return '${scenario.text}'`"} }], messages: [
}`, {
role: "system",
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
},
],
};`,
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
}, },
}), }),

View File

@@ -1,5 +1,5 @@
import dedent from "dedent"; import dedent from "dedent";
import { isObject } from "lodash"; import { isObject } from "lodash-es";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
@@ -32,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`); throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
} }
const evalResults = await prisma.evaluationResult.findMany({ const outputEvals = await prisma.outputEvaluation.groupBy({
where: { by: ["evaluationId"],
promptVariantId: input.variantId, _sum: {
result: true,
}, },
include: { evaluation: true }, _count: {
id: true,
},
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
}); });
const scenarioCount = await prisma.testScenario.count({ const scenarioCount = await prisma.testScenario.count({
@@ -50,7 +82,7 @@ export const promptVariantsRouter = createTRPCRouter({
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { visible: true }, testScenario: { visible: true },
modelOutput: { modelOutput: {
isNot: null, is: {},
}, },
}, },
}); });
@@ -77,7 +109,26 @@ export const promptVariantsRouter = createTRPCRouter({
const overallCost = overallPromptCost + overallCompletionCost; const overallCost = overallPromptCost + overallCompletionCost;
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount }; const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
// Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
},
},
}));
return {
evalResults,
promptTokens,
completionTokens,
overallCost,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}), }),
create: publicProcedure create: publicProcedure

View File

@@ -21,7 +21,17 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
}, },
}, },
include: { include: {
modelOutput: true, modelOutput: {
include: {
outputEvaluation: {
include: {
evaluation: {
select: { label: true },
},
},
},
},
},
}, },
}); });
}), }),

View File

@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen"; import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reevaluateAll } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
export const scenariosRouter = createTRPCRouter({ export const scenariosRouter = createTRPCRouter({
@@ -73,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
}); });
// Reevaluate all evaluations now that this scenario is hidden // Reevaluate all evaluations now that this scenario is hidden
await reevaluateAll(hiddenScenario.experimentId); await runAllEvals(hiddenScenario.experimentId);
return hiddenScenario; return hiddenScenario;
}), }),

View File

@@ -6,9 +6,10 @@ import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep"; import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream"; import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel"; import { generateChannel } from "~/utils/generateChannel";
import { reevaluateVariant } from "../utils/evaluations"; import { runEvalsForOutput } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt"; import { constructPrompt } from "../utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat"; import { type CompletionCreateParams } from "openai/resources/chat";
import { type Prisma } from "@prisma/client";
const MAX_AUTO_RETRIES = 10; const MAX_AUTO_RETRIES = 10;
const MIN_DELAY = 500; // milliseconds const MIN_DELAY = 500; // milliseconds
@@ -25,11 +26,10 @@ const getCompletionWithRetries = async (
payload: JSONSerializable, payload: JSONSerializable,
channel?: string, channel?: string,
): Promise<CompletionResponse> => { ): Promise<CompletionResponse> => {
let modelResponse: CompletionResponse | null = null;
try {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) { for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
const modelResponse = await getCompletion( modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel);
payload as unknown as CompletionCreateParams,
channel,
);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) { if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse; return modelResponse;
} }
@@ -46,6 +46,14 @@ const getCompletionWithRetries = async (
await sleep(delay); await sleep(delay);
} }
throw new Error("Max retries limit reached"); throw new Error("Max retries limit reached");
} catch (error: unknown) {
return {
statusCode: modelResponse?.statusCode ?? 500,
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
output: null as unknown as Prisma.InputJsonValue,
timeToComplete: 0,
};
}
}; };
export type queryLLMJob = { export type queryLLMJob = {
@@ -140,5 +148,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
}, },
}); });
await reevaluateVariant(cell.promptVariantId); if (modelOutput) {
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
}
}); });

View File

@@ -1,31 +0,0 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate";
export const evaluateOutput = (
modelOutput: ModelOutput,
scenario: TestScenario,
evaluation: Evaluation,
): boolean => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return false;
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
let match;
switch (evaluation.matchType) {
case "CONTAINS":
match = stringifiedMessage.match(matchRegex) !== null;
break;
case "DOES_NOT_CONTAIN":
match = stringifiedMessage.match(matchRegex) === null;
break;
}
return match;
};

View File

@@ -1,105 +1,79 @@
import { type ModelOutput, type Evaluation } from "@prisma/client"; import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput"; import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
export const reevaluateVariant = async (variantId: string) => { const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
const variant = await prisma.promptVariant.findUnique({ const result = await runOneEval(evaluation, scenario, modelOutput);
where: { id: variantId }, return await prisma.outputEvaluation.upsert({
});
if (!variant) return;
const evaluations = await prisma.evaluation.findMany({
where: { experimentId: variant.experimentId },
});
const cells = await prisma.scenarioVariantCell.findMany({
where: { where: {
promptVariantId: variantId, modelOutputId_evaluationId: {
retrievalStatus: "COMPLETE", modelOutputId: modelOutput.id,
testScenario: { visible: true },
modelOutput: { isNot: null },
},
include: { testScenario: true, modelOutput: true },
});
await Promise.all(
evaluations.map(async (evaluation) => {
const passCount = cells.filter((cell) =>
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length;
const failCount = cells.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id, evaluationId: evaluation.id,
promptVariantId: variantId,
}, },
}, },
create: { create: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id, evaluationId: evaluation.id,
promptVariantId: variantId, ...result,
passCount,
failCount,
}, },
update: { update: {
passCount, ...result,
failCount,
}, },
}); });
}),
);
}; };
export const reevaluateEvaluation = async (evaluation: Evaluation) => { export const runEvalsForOutput = async (
const variants = await prisma.promptVariant.findMany({ experimentId: string,
where: { experimentId: evaluation.experimentId, visible: true }, scenario: Scenario,
}); modelOutput: ModelOutput,
) => {
const cells = await prisma.scenarioVariantCell.findMany({
where: {
promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true },
statusCode: { notIn: [429] },
modelOutput: { isNot: null },
},
include: { testScenario: true, modelOutput: true },
});
await Promise.all(
variants.map(async (variant) => {
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
const passCount = variantCells.filter((cell) =>
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length;
const failCount = variantCells.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
};
export const reevaluateAll = async (experimentId: string) => {
const evaluations = await prisma.evaluation.findMany({ const evaluations = await prisma.evaluation.findMany({
where: { experimentId }, where: { experimentId },
}); });
await Promise.all(evaluations.map(reevaluateEvaluation)); await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
);
};
export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelOutput.findMany({
where: {
scenarioVariantCell: {
promptVariant: {
experimentId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
include: {
scenarioVariantCell: {
include: {
testScenario: true,
},
},
outputEvaluation: true,
},
});
const evals = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
outputs.map(async (output) => {
const unrunEvals = evals.filter(
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
);
await Promise.all(
unrunEvals.map(async (evaluation) => {
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
}),
);
}),
);
}; };

View File

@@ -2,6 +2,16 @@ import { type JSONSerializable } from "../types";
export type VariableMap = Record<string, string>; export type VariableMap = Record<string, string>;
// Escape quotes to match the way we encode JSON
export function escapeQuotes(str: string) {
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
}
// Escape regex special characters
export function escapeRegExp(str: string) {
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
}
export function fillTemplate(template: string, variables: VariableMap): string { export function fillTemplate(template: string, variables: VariableMap): string {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || ""); return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
} }

View File

@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */ /* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash"; import { isObject } from "lodash-es";
import { Prisma } from "@prisma/client"; import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai"; import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection"; import { wsConnection } from "~/utils/wsConnection";

View File

@@ -1,4 +1,4 @@
import { omit } from "lodash"; import { omit } from "lodash-es";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import OpenAI from "openai"; import OpenAI from "openai";

View File

@@ -0,0 +1,95 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai";
import dedent from "dedent";
export const runGpt4Eval = async (
evaluation: Evaluation,
scenario: TestScenario,
message: ChatCompletion.Choice.Message,
): Promise<{ result: number; details: string }> => {
const output = await openai.chat.completions.create({
model: "gpt-4-0613",
messages: [
{
role: "system",
content: dedent`
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
---
${evaluation.value}
`,
},
{
role: "user",
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
},
{
role: "user",
content: `The full output of the simpler message:\n---\n${JSON.stringify(
message.content ?? message.function_call,
null,
2,
)}`,
},
],
function_call: {
name: "report_success",
},
functions: [
{
name: "report_success",
parameters: {
type: "object",
required: ["thoughts", "success"],
properties: {
thoughts: {
type: "string",
description: "Explain your reasoning for considering this a pass or fail",
},
success: {
type: "boolean",
description:
"Whether the simpler model successfully completed the task for this scenario",
},
},
},
},
],
});
try {
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
} catch (e) {
console.error(e);
return { result: 0, details: "Error parsing GPT-4 output" };
}
};
export const runOneEval = async (
evaluation: Evaluation,
scenario: TestScenario,
modelOutput: ModelOutput,
): Promise<{ result: number; details?: string }> => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return { result: 0 };
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = escapeRegExp(
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
);
switch (evaluation.evalType) {
case "CONTAINS":
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
case "DOES_NOT_CONTAIN":
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
case "GPT4_EVAL":
return await runGpt4Eval(evaluation, scenario, message);
}
};

View File

@@ -1,4 +1,4 @@
import { isObject } from "lodash"; import { isObject } from "lodash-es";
import { type JSONSerializable } from "../types"; import { type JSONSerializable } from "../types";
export const shouldStream = (config: JSONSerializable): boolean => { export const shouldStream = (config: JSONSerializable): boolean => {

View File

@@ -1,5 +1,5 @@
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import { useCallback, useEffect, useState } from "react"; import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
export const useExperiment = () => { export const useExperiment = () => {
@@ -49,3 +49,43 @@ export const useModifierKeyLabel = () => {
}, []); }, []);
return label; return label;
}; };
interface Dimensions {
left: number;
top: number;
right: number;
bottom: number;
width: number;
height: number;
x: number;
y: number;
}
// get dimensions of an element
export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | undefined] => {
const ref = useRef<HTMLElement>(null);
const [dimensions, setDimensions] = useState<Dimensions | undefined>();
useEffect(() => {
if (ref.current) {
const observer = new ResizeObserver((entries) => {
entries.forEach((entry) => {
setDimensions(entry.contentRect);
});
});
const observedRef = ref.current;
observer.observe(observedRef);
// Cleanup the observer on component unmount
return () => {
if (observedRef) {
observer.unobserve(observedRef);
}
};
}
}, []);
return [ref, dimensions];
};