Compare commits
15 Commits
save-butto
...
fix-pretti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
999a4c08fa | ||
|
|
374d0237ee | ||
|
|
b1f873623d | ||
|
|
4131aa67d0 | ||
|
|
8e7a6d3ae2 | ||
|
|
7d41e94ca2 | ||
|
|
011b12abb9 | ||
|
|
1ba18015bc | ||
|
|
54369dba54 | ||
|
|
6b84a59372 | ||
|
|
8db8aeacd3 | ||
|
|
64bd71e370 | ||
|
|
ca21a7af06 | ||
|
|
3b99b7bd2b | ||
|
|
0c3bdbe4f2 |
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@@ -3,6 +3,8 @@ name: CI checks
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
run-checks:
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
"postinstall": "prisma generate",
|
||||
"lint": "next lint",
|
||||
"start": "next start",
|
||||
"codegen": "tsx src/codegen/export-openai-types.ts"
|
||||
"codegen": "tsx src/codegen/export-openai-types.ts",
|
||||
"seed": "tsx prisma/seed.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@babel/preset-typescript": "^7.22.5",
|
||||
@@ -49,7 +50,7 @@
|
||||
"immer": "^10.0.2",
|
||||
"isolated-vm": "^4.5.0",
|
||||
"json-stringify-pretty-compact": "^4.0.0",
|
||||
"lodash": "^4.17.21",
|
||||
"lodash-es": "^4.17.21",
|
||||
"next": "^13.4.2",
|
||||
"next-auth": "^4.22.1",
|
||||
"nextjs-routes": "^2.0.1",
|
||||
@@ -77,7 +78,7 @@
|
||||
"@types/cors": "^2.8.13",
|
||||
"@types/eslint": "^8.37.0",
|
||||
"@types/express": "^4.17.17",
|
||||
"@types/lodash": "^4.14.195",
|
||||
"@types/lodash-es": "^4.17.8",
|
||||
"@types/node": "^18.16.0",
|
||||
"@types/pluralize": "^0.0.30",
|
||||
"@types/react": "^18.2.6",
|
||||
@@ -100,6 +101,6 @@
|
||||
"initVersion": "7.14.0"
|
||||
},
|
||||
"prisma": {
|
||||
"seed": "tsx prisma/seed.ts"
|
||||
"seed": "pnpm seed"
|
||||
}
|
||||
}
|
||||
|
||||
18
pnpm-lock.yaml
generated
18
pnpm-lock.yaml
generated
@@ -95,7 +95,7 @@ dependencies:
|
||||
json-stringify-pretty-compact:
|
||||
specifier: ^4.0.0
|
||||
version: 4.0.0
|
||||
lodash:
|
||||
lodash-es:
|
||||
specifier: ^4.17.21
|
||||
version: 4.17.21
|
||||
next:
|
||||
@@ -175,9 +175,9 @@ devDependencies:
|
||||
'@types/express':
|
||||
specifier: ^4.17.17
|
||||
version: 4.17.17
|
||||
'@types/lodash':
|
||||
specifier: ^4.14.195
|
||||
version: 4.14.195
|
||||
'@types/lodash-es':
|
||||
specifier: ^4.17.8
|
||||
version: 4.17.8
|
||||
'@types/node':
|
||||
specifier: ^18.16.0
|
||||
version: 18.16.0
|
||||
@@ -2753,6 +2753,12 @@ packages:
|
||||
resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==}
|
||||
dev: true
|
||||
|
||||
/@types/lodash-es@4.17.8:
|
||||
resolution: {integrity: sha512-euY3XQcZmIzSy7YH5+Unb3b2X12Wtk54YWINBvvGQ5SmMvwb11JQskGsfkH/5HXK77Kr8GF0wkVDIxzAisWtog==}
|
||||
dependencies:
|
||||
'@types/lodash': 4.14.195
|
||||
dev: true
|
||||
|
||||
/@types/lodash.mergewith@4.6.7:
|
||||
resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==}
|
||||
dependencies:
|
||||
@@ -5379,6 +5385,10 @@ packages:
|
||||
p-locate: 5.0.0
|
||||
dev: true
|
||||
|
||||
/lodash-es@4.17.21:
|
||||
resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==}
|
||||
dev: false
|
||||
|
||||
/lodash.merge@4.6.2:
|
||||
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
|
||||
dev: true
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;
|
||||
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
|
||||
*/
|
||||
|
||||
-- RenameEnum
|
||||
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
|
||||
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
|
||||
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
|
||||
|
||||
-- AlterColumnType
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
|
||||
|
||||
-- SetNotNullConstraint
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;
|
||||
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
|
||||
|
||||
*/
|
||||
-- AlterEnum
|
||||
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
|
||||
|
||||
-- DropTable
|
||||
DROP TABLE "EvaluationResult";
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OutputEvaluation" (
|
||||
"id" UUID NOT NULL,
|
||||
"result" DOUBLE PRECISION NOT NULL,
|
||||
"details" TEXT,
|
||||
"modelOutputId" UUID NOT NULL,
|
||||
"evaluationId" UUID NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-js"
|
||||
previewFeatures = ["jsonProtocol"]
|
||||
}
|
||||
|
||||
datasource db {
|
||||
@@ -30,7 +29,7 @@ model PromptVariant {
|
||||
|
||||
label String
|
||||
constructFn String
|
||||
model String @default("gpt-3.5-turbo")
|
||||
model String
|
||||
|
||||
uiId String @default(uuid()) @db.Uuid
|
||||
visible Boolean @default(true)
|
||||
@@ -42,7 +41,6 @@ model PromptVariant {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
scenarioVariantCells ScenarioVariantCell[]
|
||||
EvaluationResult EvaluationResult[]
|
||||
|
||||
@@index([uiId])
|
||||
}
|
||||
@@ -125,47 +123,50 @@ model ModelOutput {
|
||||
|
||||
scenarioVariantCellId String @db.Uuid
|
||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||
outputEvaluation OutputEvaluation[]
|
||||
|
||||
@@unique([scenarioVariantCellId])
|
||||
@@index([inputHash])
|
||||
}
|
||||
|
||||
enum EvaluationMatchType {
|
||||
enum EvalType {
|
||||
CONTAINS
|
||||
DOES_NOT_CONTAIN
|
||||
GPT4_EVAL
|
||||
}
|
||||
|
||||
model Evaluation {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
name String
|
||||
matchString String
|
||||
matchType EvaluationMatchType
|
||||
label String
|
||||
evalType EvalType
|
||||
value String
|
||||
|
||||
experimentId String @db.Uuid
|
||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
EvaluationResult EvaluationResult[]
|
||||
OutputEvaluation OutputEvaluation[]
|
||||
}
|
||||
|
||||
model EvaluationResult {
|
||||
model OutputEvaluation {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
passCount Int
|
||||
failCount Int
|
||||
// Number between 0 (fail) and 1 (pass)
|
||||
result Float
|
||||
details String?
|
||||
|
||||
modelOutputId String @db.Uuid
|
||||
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
||||
|
||||
evaluationId String @db.Uuid
|
||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||
|
||||
promptVariantId String @db.Uuid
|
||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@unique([evaluationId, promptVariantId])
|
||||
@@unique([modelOutputId, evaluationId])
|
||||
}
|
||||
|
||||
// Necessary for Next auth
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
|
||||
const experimentId = "11111111-1111-1111-1111-111111111111";
|
||||
|
||||
@@ -9,7 +11,7 @@ await prisma.experiment.deleteMany({
|
||||
},
|
||||
});
|
||||
|
||||
const experiment = await prisma.experiment.create({
|
||||
await prisma.experiment.create({
|
||||
data: {
|
||||
id: experimentId,
|
||||
label: "Country Capitals Example",
|
||||
@@ -36,9 +38,16 @@ await prisma.promptVariant.createMany({
|
||||
experimentId,
|
||||
label: "Prompt Variant 1",
|
||||
sortIndex: 0,
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
|
||||
constructFn: dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`What is the capital of ${"$"}{scenario.country}?\`
|
||||
}
|
||||
],
|
||||
temperature: 0,
|
||||
}`,
|
||||
},
|
||||
@@ -46,14 +55,15 @@ await prisma.promptVariant.createMany({
|
||||
experimentId,
|
||||
label: "Prompt Variant 2",
|
||||
sortIndex: 1,
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
constructFn: dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"What is the capital of {{country}}? Return just the city name and nothing else.",
|
||||
},
|
||||
content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
|
||||
}
|
||||
],
|
||||
temperature: 0,
|
||||
}`,
|
||||
@@ -107,3 +117,26 @@ await prisma.testScenario.createMany({
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const variants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
variants
|
||||
.flatMap((variant) =>
|
||||
scenarios.map((scenario) => ({
|
||||
promptVariantId: variant.id,
|
||||
testScenarioId: scenario.id,
|
||||
})),
|
||||
)
|
||||
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
||||
);
|
||||
|
||||
@@ -12,6 +12,7 @@ await prisma.promptVariant.createMany({
|
||||
{
|
||||
experimentId: functionCallsExperiment.id,
|
||||
label: "No Fn Calls",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
@@ -30,6 +31,7 @@ await prisma.promptVariant.createMany({
|
||||
{
|
||||
experimentId: functionCallsExperiment.id,
|
||||
label: "Fn Calls",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
@@ -92,6 +94,7 @@ await prisma.promptVariant.createMany({
|
||||
experimentId: redditExperiment.id,
|
||||
label: "3.5 Base",
|
||||
sortIndex: 0,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
@@ -107,6 +110,7 @@ await prisma.promptVariant.createMany({
|
||||
experimentId: redditExperiment.id,
|
||||
label: "4 Base",
|
||||
sortIndex: 1,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-4-0613",
|
||||
messages: [
|
||||
@@ -122,6 +126,7 @@ await prisma.promptVariant.createMany({
|
||||
experimentId: redditExperiment.id,
|
||||
label: "3.5 CoT + Functions",
|
||||
sortIndex: 2,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
@@ -178,9 +183,9 @@ await prisma.templateVariable.createMany({
|
||||
await prisma.evaluation.create({
|
||||
data: {
|
||||
experimentId: redditExperiment.id,
|
||||
name: "Relevance Accuracy",
|
||||
matchType: "CONTAINS",
|
||||
matchString: '"{{relevance}}"',
|
||||
label: "Relevance Accuracy",
|
||||
evalType: "CONTAINS",
|
||||
value: '"{{relevance}}"',
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1119,12 +1124,3 @@ await prisma.testScenario.createMany({
|
||||
variableValues: vars,
|
||||
})),
|
||||
});
|
||||
|
||||
// await prisma.evaluation.create({
|
||||
// data: {
|
||||
// experimentId: redditExperiment.id,
|
||||
// name: "Scores Match",
|
||||
// matchType: "CONTAINS",
|
||||
// matchString: "{{score}}",
|
||||
// },
|
||||
// });
|
||||
|
||||
@@ -2,7 +2,7 @@ import fs from "fs";
|
||||
import path from "path";
|
||||
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
||||
import YAML from "yaml";
|
||||
import _ from "lodash";
|
||||
import { pick } from "lodash-es";
|
||||
import assert from "assert";
|
||||
|
||||
const OPENAPI_URL =
|
||||
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
|
||||
|
||||
delete schema["paths"];
|
||||
assert(schema.components?.schemas);
|
||||
schema.components.schemas = _.pick(schema.components?.schemas, [
|
||||
schema.components.schemas = pick(schema.components?.schemas, [
|
||||
"CreateChatCompletionRequest",
|
||||
"ChatCompletionRequestMessage",
|
||||
"ChatCompletionFunctions",
|
||||
|
||||
@@ -4,7 +4,7 @@ import React from "react";
|
||||
|
||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||
HTMLTextAreaElement,
|
||||
TextareaProps
|
||||
TextareaProps & { minRows?: number }
|
||||
> = (props, ref) => {
|
||||
return (
|
||||
<Textarea
|
||||
|
||||
@@ -11,14 +11,16 @@ import {
|
||||
FormLabel,
|
||||
Select,
|
||||
FormHelperText,
|
||||
Code,
|
||||
} from "@chakra-ui/react";
|
||||
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
|
||||
import { type Evaluation, EvalType } from "@prisma/client";
|
||||
import { useCallback, useState } from "react";
|
||||
import { BsPencil, BsX } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
|
||||
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
|
||||
|
||||
export function EvaluationEditor(props: {
|
||||
evaluation: Evaluation | null;
|
||||
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
|
||||
onCancel: () => void;
|
||||
}) {
|
||||
const [values, setValues] = useState<EvalValues>({
|
||||
name: props.evaluation?.name ?? props.defaultName ?? "",
|
||||
matchString: props.evaluation?.matchString ?? "",
|
||||
matchType: props.evaluation?.matchType ?? "CONTAINS",
|
||||
label: props.evaluation?.label ?? props.defaultName ?? "",
|
||||
value: props.evaluation?.value ?? "",
|
||||
evalType: props.evaluation?.evalType ?? "CONTAINS",
|
||||
});
|
||||
|
||||
return (
|
||||
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
||||
<HStack w="100%">
|
||||
<FormControl flex={1}>
|
||||
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
|
||||
<FormLabel fontSize="sm">Eval Name</FormLabel>
|
||||
<Input
|
||||
size="sm"
|
||||
value={values.name}
|
||||
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
||||
value={values.label}
|
||||
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flex={1}>
|
||||
<FormLabel fontSize="sm">Match Type</FormLabel>
|
||||
<FormLabel fontSize="sm">Eval Type</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={values.matchType}
|
||||
value={values.evalType}
|
||||
onChange={(e) =>
|
||||
setValues((values) => ({
|
||||
...values,
|
||||
matchType: e.target.value as EvaluationMatchType,
|
||||
evalType: e.target.value as EvalType,
|
||||
}))
|
||||
}
|
||||
>
|
||||
{Object.values(EvaluationMatchType).map((type) => (
|
||||
{Object.values(EvalType).map((type) => (
|
||||
<option key={type} value={type}>
|
||||
{type}
|
||||
</option>
|
||||
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
|
||||
</Select>
|
||||
</FormControl>
|
||||
</HStack>
|
||||
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
|
||||
<FormControl>
|
||||
<FormLabel fontSize="sm">Match String</FormLabel>
|
||||
<Input
|
||||
size="sm"
|
||||
value={values.matchString}
|
||||
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
|
||||
value={values.value}
|
||||
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||
/>
|
||||
<FormHelperText>
|
||||
This string will be interpreted as a regex and checked against each model output.
|
||||
This string will be interpreted as a regex and checked against each model output. You
|
||||
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
|
||||
</FormHelperText>
|
||||
</FormControl>
|
||||
)}
|
||||
{values.evalType === "GPT4_EVAL" && (
|
||||
<FormControl pt={2}>
|
||||
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
|
||||
<AutoResizeTextArea
|
||||
size="sm"
|
||||
value={values.value}
|
||||
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||
minRows={3}
|
||||
/>
|
||||
<FormHelperText>
|
||||
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
|
||||
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
|
||||
access to the specific prompt variant, so be sure to be clear about the task you want it
|
||||
to perform.
|
||||
</FormHelperText>
|
||||
</FormControl>
|
||||
)}
|
||||
<HStack alignSelf="flex-end">
|
||||
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
||||
Cancel
|
||||
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
|
||||
}
|
||||
await utils.evaluations.list.invalidate();
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
await utils.scenarioVariantCells.get.invalidate();
|
||||
}, []);
|
||||
|
||||
const onCancel = useCallback(() => {
|
||||
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
|
||||
align="center"
|
||||
key={evaluation.id}
|
||||
>
|
||||
<Text fontWeight="bold">{evaluation.name}</Text>
|
||||
<Text fontWeight="bold">{evaluation.label}</Text>
|
||||
<Text flex={1}>
|
||||
{evaluation.matchType}: "{evaluation.matchString}"
|
||||
{evaluation.evalType}: "{evaluation.value}"
|
||||
</Text>
|
||||
<Button
|
||||
variant="unstyled"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react";
|
||||
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { BsCheck, BsX } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
|
||||
<Heading size="sm">Scenario Variables</Heading>
|
||||
<Stack spacing={2}>
|
||||
<Text fontSize="sm">
|
||||
Scenario variables can be used in your prompt variants as well as evaluations. Reference
|
||||
them using <Code>{"{{curly_braces}}"}</Code>.
|
||||
Scenario variables can be used in your prompt variants as well as evaluations.
|
||||
</Text>
|
||||
<HStack spacing={0}>
|
||||
<Input
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "../types";
|
||||
import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react";
|
||||
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
@@ -50,12 +50,18 @@ export default function OutputCell({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
});
|
||||
await utils.promptVariants.stats.invalidate({
|
||||
variantId: variant.id,
|
||||
});
|
||||
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||
|
||||
const fetchingOutput = queryLoading || refetchingOutput;
|
||||
|
||||
const awaitingOutput =
|
||||
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS";
|
||||
!cell ||
|
||||
cell.retrievalStatus === "PENDING" ||
|
||||
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||
refetchingOutput;
|
||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||
|
||||
const modelOutput = cell?.modelOutput;
|
||||
@@ -95,11 +101,18 @@ export default function OutputCell({
|
||||
}
|
||||
|
||||
return (
|
||||
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
||||
<VStack w="full" spacing={0}>
|
||||
<VStack
|
||||
w="100%"
|
||||
h="100%"
|
||||
fontSize="xs"
|
||||
flexWrap="wrap"
|
||||
overflowX="auto"
|
||||
justifyContent="space-between"
|
||||
>
|
||||
<VStack w="full" flex={1} spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset" }}
|
||||
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||
language="json"
|
||||
style={docco}
|
||||
lineProps={{
|
||||
@@ -117,7 +130,7 @@ export default function OutputCell({
|
||||
</SyntaxHighlighter>
|
||||
</VStack>
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
</Box>
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -125,7 +138,7 @@ export default function OutputCell({
|
||||
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
|
||||
|
||||
return (
|
||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<Text>{contentToDisplay}</Text>
|
||||
@@ -133,6 +146,6 @@ export default function OutputCell({
|
||||
{modelOutput && (
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
)}
|
||||
</Flex>
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { type Scenario } from "../types";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||
import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
|
||||
@@ -15,16 +12,14 @@ const SHOW_TIME = true;
|
||||
export const OutputStats = ({
|
||||
model,
|
||||
modelOutput,
|
||||
scenario,
|
||||
}: {
|
||||
model: SupportedModel | string | null;
|
||||
modelOutput: ModelOutput;
|
||||
modelOutput: NonNullable<
|
||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||
>;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
const timeToComplete = modelOutput.timeToComplete;
|
||||
const experiment = useExperiment();
|
||||
const evals =
|
||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||
|
||||
const promptTokens = modelOutput.promptTokens;
|
||||
const completionTokens = modelOutput.completionTokens;
|
||||
@@ -36,19 +31,25 @@ export const OutputStats = ({
|
||||
const cost = promptCost + completionCost;
|
||||
|
||||
return (
|
||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
{evals.map((evaluation) => {
|
||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||
{modelOutput.outputEvaluation.map((evaluation) => {
|
||||
const passed = evaluation.result > 0.5;
|
||||
return (
|
||||
<HStack spacing={0} key={evaluation.id}>
|
||||
<Text>{evaluation.name}</Text>
|
||||
<Tooltip
|
||||
isDisabled={!evaluation.details}
|
||||
label={evaluation.details}
|
||||
key={evaluation.id}
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Text>{evaluation.evaluation.label}</Text>
|
||||
<Icon
|
||||
as={passed ? BsCheck : BsX}
|
||||
color={passed ? "green.500" : "red.500"}
|
||||
boxSize={6}
|
||||
/>
|
||||
</HStack>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type DragEvent } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { isEqual } from "lodash";
|
||||
import { isEqual } from "lodash-es";
|
||||
import { type Scenario } from "./types";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useState } from "react";
|
||||
|
||||
49
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
49
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
@@ -0,0 +1,49 @@
|
||||
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
||||
import { cellPadding } from "../constants";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
import { BsPencil } from "react-icons/bs";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
export const ScenariosHeader = ({
|
||||
headerRows,
|
||||
numScenarios,
|
||||
}: {
|
||||
headerRows: number;
|
||||
numScenarios: number;
|
||||
}) => {
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
|
||||
const [ref, dimensions] = useElementDimensions();
|
||||
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
||||
|
||||
return (
|
||||
<GridItem
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
ref={ref as any}
|
||||
display="flex"
|
||||
alignItems="flex-end"
|
||||
rowSpan={headerRows}
|
||||
px={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
// Only display the part of the grid item that has content
|
||||
sx={{ ...stickyHeaderStyle, top: topValue }}
|
||||
>
|
||||
<HStack w="100%">
|
||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||
Scenarios ({numScenarios})
|
||||
</Heading>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
color="gray.500"
|
||||
aria-label="Edit"
|
||||
leftIcon={<BsPencil />}
|
||||
onClick={openDrawer}
|
||||
>
|
||||
Edit Vars
|
||||
</Button>
|
||||
</HStack>
|
||||
</GridItem>
|
||||
);
|
||||
};
|
||||
@@ -1,10 +1,9 @@
|
||||
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react";
|
||||
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
||||
import { useRef, useEffect, useState, useCallback } from "react";
|
||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||
|
||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||
@@ -28,7 +27,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const utils = api.useContext();
|
||||
const toast = useToast();
|
||||
|
||||
const [onSave] = useHandledAsyncCallback(async () => {
|
||||
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!editorRef.current) return;
|
||||
|
||||
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
||||
@@ -133,19 +132,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
|
||||
return (
|
||||
<Box w="100%" pos="relative">
|
||||
<VStack
|
||||
spacing={0}
|
||||
align="stretch"
|
||||
fontSize="xs"
|
||||
fontWeight="bold"
|
||||
color="gray.600"
|
||||
py={2}
|
||||
bgColor={editorBackground}
|
||||
>
|
||||
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
|
||||
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
||||
<code>{`return prompt; }`}</code>
|
||||
</VStack>
|
||||
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
||||
{isChanged && (
|
||||
<HStack pos="absolute" bottom={2} right={2}>
|
||||
<Button
|
||||
@@ -159,8 +146,8 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
Reset
|
||||
</Button>
|
||||
<Tooltip label={`${modifierKey} + Enter`}>
|
||||
<Button size="sm" onClick={onSave} colorScheme="blue">
|
||||
Save
|
||||
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
|
||||
import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { cellPadding } from "../constants";
|
||||
import { api } from "~/utils/api";
|
||||
import chroma from "chroma-js";
|
||||
import { BsCurrencyDollar } from "react-icons/bs";
|
||||
import { CostTooltip } from "../tooltip/CostTooltip";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||
const { data } = api.promptVariants.stats.useQuery(
|
||||
{
|
||||
variantId: props.variant.id,
|
||||
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
completionTokens: 0,
|
||||
scenarioCount: 0,
|
||||
outputCount: 0,
|
||||
awaitingRetrievals: false,
|
||||
},
|
||||
refetchInterval,
|
||||
},
|
||||
);
|
||||
|
||||
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||
useEffect(
|
||||
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
|
||||
[data.awaitingRetrievals],
|
||||
);
|
||||
|
||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||
"green.500",
|
||||
"gray.500",
|
||||
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
|
||||
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
||||
|
||||
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
|
||||
|
||||
return (
|
||||
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
|
||||
<HStack
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
mx="2"
|
||||
fontSize="xs"
|
||||
py={cellPadding.y}
|
||||
>
|
||||
{showNumFinished && (
|
||||
<Text>
|
||||
{data.outputCount} / {data.scenarioCount}
|
||||
</Text>
|
||||
)}
|
||||
<HStack px={cellPadding.x} py={cellPadding.y}>
|
||||
<HStack px={cellPadding.x}>
|
||||
{data.evalResults.map((result) => {
|
||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
||||
const passedFrac = result.passCount / result.totalCount;
|
||||
return (
|
||||
<HStack key={result.id}>
|
||||
<Text>{result.evaluation.name}</Text>
|
||||
<Text>{result.label}</Text>
|
||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||
{(passedFrac * 100).toFixed(1)}%
|
||||
</Text>
|
||||
@@ -55,17 +69,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{data.overallCost && (
|
||||
{data.overallCost && !data.awaitingRetrievals ? (
|
||||
<CostTooltip
|
||||
promptTokens={data.promptTokens}
|
||||
completionTokens={data.completionTokens}
|
||||
cost={data.overallCost}
|
||||
>
|
||||
<HStack spacing={0} align="center" color="gray.500" my="2">
|
||||
<HStack spacing={0} align="center" color="gray.500">
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
</CostTooltip>
|
||||
) : (
|
||||
<Skeleton height={4} width={12} mr={1} />
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
|
||||
import { Grid, GridItem } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import NewScenarioButton from "./NewScenarioButton";
|
||||
import NewVariantButton from "./NewVariantButton";
|
||||
import ScenarioRow from "./ScenarioRow";
|
||||
import VariantEditor from "./VariantEditor";
|
||||
import VariantHeader from "./VariantHeader";
|
||||
import { cellPadding } from "../constants";
|
||||
import { BsPencil } from "react-icons/bs";
|
||||
import VariantStats from "./VariantStats";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
top: "-1px",
|
||||
backgroundColor: "#fff",
|
||||
zIndex: 1,
|
||||
};
|
||||
import { ScenariosHeader } from "./ScenariosHeader";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
|
||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||
const variants = api.promptVariants.list.useQuery(
|
||||
{ experimentId: experimentId as string },
|
||||
{ enabled: !!experimentId },
|
||||
);
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
|
||||
const scenarios = api.scenarios.list.useQuery(
|
||||
{ experimentId: experimentId as string },
|
||||
@@ -49,32 +40,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
}}
|
||||
fontSize="sm"
|
||||
>
|
||||
<GridItem
|
||||
display="flex"
|
||||
alignItems="flex-end"
|
||||
rowSpan={headerRows}
|
||||
px={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
|
||||
// so if the header height changes, this will need to be updated.
|
||||
sx={{ ...stickyHeaderStyle, top: "-337px" }}
|
||||
>
|
||||
<HStack w="100%">
|
||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||
Scenarios ({scenarios.data.length})
|
||||
</Heading>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
color="gray.500"
|
||||
aria-label="Edit"
|
||||
leftIcon={<BsPencil />}
|
||||
onClick={openDrawer}
|
||||
>
|
||||
Edit Vars
|
||||
</Button>
|
||||
</HStack>
|
||||
</GridItem>
|
||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||
|
||||
8
src/components/OutputsTable/styles.ts
Normal file
8
src/components/OutputsTable/styles.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { type SystemStyleObject } from "@chakra-ui/react";
|
||||
|
||||
export const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
top: "-1px",
|
||||
backgroundColor: "#fff",
|
||||
zIndex: 1,
|
||||
};
|
||||
@@ -20,7 +20,6 @@ export const CostTooltip = ({
|
||||
color="gray.800"
|
||||
bgColor="gray.50"
|
||||
borderWidth={1}
|
||||
py={2}
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
label={
|
||||
|
||||
@@ -6,18 +6,27 @@ import { ChakraProvider } from "@chakra-ui/react";
|
||||
import theme from "~/utils/theme";
|
||||
import Favicon from "~/components/Favicon";
|
||||
import "~/utils/analytics";
|
||||
import Head from "next/head";
|
||||
|
||||
const MyApp: AppType<{ session: Session | null }> = ({
|
||||
Component,
|
||||
pageProps: { session, ...pageProps },
|
||||
}) => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
|
||||
/>
|
||||
</Head>
|
||||
<SessionProvider session={session}>
|
||||
<Favicon />
|
||||
<ChakraProvider theme={theme}>
|
||||
<Component {...pageProps} />
|
||||
</ChakraProvider>
|
||||
</SessionProvider>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { prisma } from "../db";
|
||||
import { openai } from "../utils/openai";
|
||||
import { pick } from "lodash";
|
||||
import { pick } from "lodash-es";
|
||||
|
||||
type AxiosError = {
|
||||
response?: {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { EvaluationMatchType } from "@prisma/client";
|
||||
import { EvalType } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { reevaluateEvaluation } from "~/server/utils/evaluations";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
|
||||
export const evaluationsRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
@@ -18,21 +18,24 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
name: z.string(),
|
||||
matchString: z.string(),
|
||||
matchType: z.nativeEnum(EvaluationMatchType),
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
evalType: z.nativeEnum(EvalType),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const evaluation = await prisma.evaluation.create({
|
||||
await prisma.evaluation.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
name: input.name,
|
||||
matchString: input.matchString,
|
||||
matchType: input.matchType,
|
||||
label: input.label,
|
||||
value: input.value,
|
||||
evalType: input.evalType,
|
||||
},
|
||||
});
|
||||
await reevaluateEvaluation(evaluation);
|
||||
|
||||
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
||||
// to kick off a background job or something instead
|
||||
await runAllEvals(input.experimentId);
|
||||
}),
|
||||
|
||||
update: publicProcedure
|
||||
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
name: z.string().optional(),
|
||||
matchString: z.string().optional(),
|
||||
matchType: z.nativeEnum(EvaluationMatchType).optional(),
|
||||
label: z.string().optional(),
|
||||
value: z.string().optional(),
|
||||
evalType: z.nativeEnum(EvalType).optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
await prisma.evaluation.update({
|
||||
const evaluation = await prisma.evaluation.update({
|
||||
where: { id: input.id },
|
||||
data: {
|
||||
name: input.updates.name,
|
||||
matchString: input.updates.matchString,
|
||||
matchType: input.updates.matchType,
|
||||
label: input.updates.label,
|
||||
value: input.updates.value,
|
||||
evalType: input.updates.evalType,
|
||||
},
|
||||
});
|
||||
await reevaluateEvaluation(
|
||||
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
|
||||
);
|
||||
|
||||
await prisma.outputEvaluation.deleteMany({
|
||||
where: {
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
});
|
||||
// Re-run all evals. Other eval results will already be cached, so this
|
||||
// should only re-run the updated one.
|
||||
await runAllEvals(evaluation.experimentId);
|
||||
}),
|
||||
|
||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||
|
||||
@@ -71,11 +71,28 @@ export const experimentsRouter = createTRPCRouter({
|
||||
experimentId: exp.id,
|
||||
label: "Prompt Variant 1",
|
||||
sortIndex: 0,
|
||||
constructFn: dedent`prompt = {
|
||||
// The interpolated $ is necessary until dedent incorporates
|
||||
// https://github.com/dmnd/dedent/pull/46
|
||||
constructFn: dedent`
|
||||
/**
|
||||
* Use Javascript to define an OpenAI chat completion
|
||||
* (https://platform.openai.com/docs/api-reference/chat/create) and
|
||||
* assign it to the \`prompt\` variable.
|
||||
*
|
||||
* You have access to the current scenario in the \`scenario\`
|
||||
* variable.
|
||||
*/
|
||||
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
stream: true,
|
||||
messages: [{ role: "system", content: ${"`Return '${scenario.text}'`"} }],
|
||||
}`,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
||||
},
|
||||
],
|
||||
};`,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
},
|
||||
}),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import dedent from "dedent";
|
||||
import { isObject } from "lodash";
|
||||
import { isObject } from "lodash-es";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
@@ -32,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||
}
|
||||
|
||||
const evalResults = await prisma.evaluationResult.findMany({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||
by: ["evaluationId"],
|
||||
_sum: {
|
||||
result: true,
|
||||
},
|
||||
include: { evaluation: true },
|
||||
_count: {
|
||||
id: true,
|
||||
},
|
||||
where: {
|
||||
modelOutput: {
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
id: input.variantId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
const evalResults = evals.map((evalItem) => {
|
||||
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
|
||||
return {
|
||||
id: evalItem.id,
|
||||
label: evalItem.label,
|
||||
passCount: evalResult?._sum?.result ?? 0,
|
||||
totalCount: evalResult?._count?.id ?? 1,
|
||||
};
|
||||
});
|
||||
|
||||
const scenarioCount = await prisma.testScenario.count({
|
||||
@@ -50,7 +82,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelOutput: {
|
||||
isNot: null,
|
||||
is: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -77,7 +109,26 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
|
||||
const overallCost = overallPromptCost + overallCompletionCost;
|
||||
|
||||
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
|
||||
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
// Check if is PENDING or IN_PROGRESS
|
||||
retrievalStatus: {
|
||||
in: ["PENDING", "IN_PROGRESS"],
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
return {
|
||||
evalResults,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
overallCost,
|
||||
scenarioCount,
|
||||
outputCount,
|
||||
awaitingRetrievals,
|
||||
};
|
||||
}),
|
||||
|
||||
create: publicProcedure
|
||||
|
||||
@@ -21,7 +21,17 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
modelOutput: {
|
||||
include: {
|
||||
outputEvaluation: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { autogenerateScenarioValues } from "../autogen";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
@@ -73,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
|
||||
});
|
||||
|
||||
// Reevaluate all evaluations now that this scenario is hidden
|
||||
await reevaluateAll(hiddenScenario.experimentId);
|
||||
await runAllEvals(hiddenScenario.experimentId);
|
||||
|
||||
return hiddenScenario;
|
||||
}),
|
||||
|
||||
@@ -6,9 +6,10 @@ import { type JSONSerializable } from "../types";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import { shouldStream } from "../utils/shouldStream";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { reevaluateVariant } from "../utils/evaluations";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import { constructPrompt } from "../utils/constructPrompt";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
|
||||
const MAX_AUTO_RETRIES = 10;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
@@ -25,11 +26,10 @@ const getCompletionWithRetries = async (
|
||||
payload: JSONSerializable,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> => {
|
||||
let modelResponse: CompletionResponse | null = null;
|
||||
try {
|
||||
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
|
||||
const modelResponse = await getCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
channel,
|
||||
);
|
||||
modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel);
|
||||
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
|
||||
return modelResponse;
|
||||
}
|
||||
@@ -46,6 +46,14 @@ const getCompletionWithRetries = async (
|
||||
await sleep(delay);
|
||||
}
|
||||
throw new Error("Max retries limit reached");
|
||||
} catch (error: unknown) {
|
||||
return {
|
||||
statusCode: modelResponse?.statusCode ?? 500,
|
||||
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
|
||||
output: null as unknown as Prisma.InputJsonValue,
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export type queryLLMJob = {
|
||||
@@ -140,5 +148,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
},
|
||||
});
|
||||
|
||||
await reevaluateVariant(cell.promptVariantId);
|
||||
if (modelOutput) {
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate } from "./fillTemplate";
|
||||
|
||||
export const evaluateOutput = (
|
||||
modelOutput: ModelOutput,
|
||||
scenario: TestScenario,
|
||||
evaluation: Evaluation,
|
||||
): boolean => {
|
||||
const output = modelOutput.output as unknown as ChatCompletion;
|
||||
const message = output?.choices?.[0]?.message;
|
||||
|
||||
if (!message) return false;
|
||||
|
||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||
|
||||
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
|
||||
|
||||
let match;
|
||||
|
||||
switch (evaluation.matchType) {
|
||||
case "CONTAINS":
|
||||
match = stringifiedMessage.match(matchRegex) !== null;
|
||||
break;
|
||||
case "DOES_NOT_CONTAIN":
|
||||
match = stringifiedMessage.match(matchRegex) === null;
|
||||
break;
|
||||
}
|
||||
|
||||
return match;
|
||||
};
|
||||
@@ -1,105 +1,79 @@
|
||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { evaluateOutput } from "./evaluateOutput";
|
||||
import { runOneEval } from "./runOneEval";
|
||||
import { type Scenario } from "~/components/OutputsTable/types";
|
||||
|
||||
export const reevaluateVariant = async (variantId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: { id: variantId },
|
||||
});
|
||||
if (!variant) return;
|
||||
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId: variant.experimentId },
|
||||
});
|
||||
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelOutput);
|
||||
return await prisma.outputEvaluation.upsert({
|
||||
where: {
|
||||
promptVariantId: variantId,
|
||||
retrievalStatus: "COMPLETE",
|
||||
testScenario: { visible: true },
|
||||
modelOutput: { isNot: null },
|
||||
},
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => {
|
||||
const passCount = cells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = cells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
evaluationId_promptVariantId: {
|
||||
modelOutputId_evaluationId: {
|
||||
modelOutputId: modelOutput.id,
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variantId,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
modelOutputId: modelOutput.id,
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variantId,
|
||||
passCount,
|
||||
failCount,
|
||||
...result,
|
||||
},
|
||||
update: {
|
||||
passCount,
|
||||
failCount,
|
||||
...result,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
||||
const variants = await prisma.promptVariant.findMany({
|
||||
where: { experimentId: evaluation.experimentId, visible: true },
|
||||
});
|
||||
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
promptVariantId: { in: variants.map((v) => v.id) },
|
||||
testScenario: { visible: true },
|
||||
statusCode: { notIn: [429] },
|
||||
modelOutput: { isNot: null },
|
||||
},
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
variants.map(async (variant) => {
|
||||
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
||||
const passCount = variantCells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = variantCells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
evaluationId_promptVariantId: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variant.id,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variant.id,
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
update: {
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
export const reevaluateAll = async (experimentId: string) => {
|
||||
export const runEvalsForOutput = async (
|
||||
experimentId: string,
|
||||
scenario: Scenario,
|
||||
modelOutput: ModelOutput,
|
||||
) => {
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(evaluations.map(reevaluateEvaluation));
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
||||
);
|
||||
};
|
||||
|
||||
export const runAllEvals = async (experimentId: string) => {
|
||||
const outputs = await prisma.modelOutput.findMany({
|
||||
where: {
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
include: {
|
||||
scenarioVariantCell: {
|
||||
include: {
|
||||
testScenario: true,
|
||||
},
|
||||
},
|
||||
outputEvaluation: true,
|
||||
},
|
||||
});
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
outputs.map(async (output) => {
|
||||
const unrunEvals = evals.filter(
|
||||
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
unrunEvals.map(async (evaluation) => {
|
||||
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
||||
}),
|
||||
);
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,6 +2,16 @@ import { type JSONSerializable } from "../types";
|
||||
|
||||
export type VariableMap = Record<string, string>;
|
||||
|
||||
// Escape quotes to match the way we encode JSON
|
||||
export function escapeQuotes(str: string) {
|
||||
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
|
||||
}
|
||||
|
||||
// Escape regex special characters
|
||||
export function escapeRegExp(str: string) {
|
||||
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
|
||||
}
|
||||
|
||||
export function fillTemplate(template: string, variables: VariableMap): string {
|
||||
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||
import { isObject } from "lodash";
|
||||
import { isObject } from "lodash-es";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { omit } from "lodash";
|
||||
import { omit } from "lodash-es";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
import OpenAI from "openai";
|
||||
|
||||
95
src/server/utils/runOneEval.ts
Normal file
95
src/server/utils/runOneEval.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||
import { openai } from "./openai";
|
||||
import dedent from "dedent";
|
||||
|
||||
export const runGpt4Eval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
message: ChatCompletion.Choice.Message,
|
||||
): Promise<{ result: number; details: string }> => {
|
||||
const output = await openai.chat.completions.create({
|
||||
model: "gpt-4-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: dedent`
|
||||
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
|
||||
---
|
||||
${evaluation.value}
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `The full output of the simpler message:\n---\n${JSON.stringify(
|
||||
message.content ?? message.function_call,
|
||||
null,
|
||||
2,
|
||||
)}`,
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "report_success",
|
||||
},
|
||||
functions: [
|
||||
{
|
||||
name: "report_success",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["thoughts", "success"],
|
||||
properties: {
|
||||
thoughts: {
|
||||
type: "string",
|
||||
description: "Explain your reasoning for considering this a pass or fail",
|
||||
},
|
||||
success: {
|
||||
type: "boolean",
|
||||
description:
|
||||
"Whether the simpler model successfully completed the task for this scenario",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
try {
|
||||
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
|
||||
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return { result: 0, details: "Error parsing GPT-4 output" };
|
||||
}
|
||||
};
|
||||
|
||||
export const runOneEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
modelOutput: ModelOutput,
|
||||
): Promise<{ result: number; details?: string }> => {
|
||||
const output = modelOutput.output as unknown as ChatCompletion;
|
||||
|
||||
const message = output?.choices?.[0]?.message;
|
||||
|
||||
if (!message) return { result: 0 };
|
||||
|
||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||
|
||||
const matchRegex = escapeRegExp(
|
||||
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
||||
);
|
||||
|
||||
switch (evaluation.evalType) {
|
||||
case "CONTAINS":
|
||||
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
|
||||
case "DOES_NOT_CONTAIN":
|
||||
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
|
||||
case "GPT4_EVAL":
|
||||
return await runGpt4Eval(evaluation, scenario, message);
|
||||
}
|
||||
};
|
||||
@@ -1,4 +1,4 @@
|
||||
import { isObject } from "lodash";
|
||||
import { isObject } from "lodash-es";
|
||||
import { type JSONSerializable } from "../types";
|
||||
|
||||
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useRouter } from "next/router";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
|
||||
export const useExperiment = () => {
|
||||
@@ -49,3 +49,43 @@ export const useModifierKeyLabel = () => {
|
||||
}, []);
|
||||
return label;
|
||||
};
|
||||
|
||||
interface Dimensions {
|
||||
left: number;
|
||||
top: number;
|
||||
right: number;
|
||||
bottom: number;
|
||||
width: number;
|
||||
height: number;
|
||||
x: number;
|
||||
y: number;
|
||||
}
|
||||
|
||||
// get dimensions of an element
|
||||
export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | undefined] => {
|
||||
const ref = useRef<HTMLElement>(null);
|
||||
const [dimensions, setDimensions] = useState<Dimensions | undefined>();
|
||||
|
||||
useEffect(() => {
|
||||
if (ref.current) {
|
||||
const observer = new ResizeObserver((entries) => {
|
||||
entries.forEach((entry) => {
|
||||
setDimensions(entry.contentRect);
|
||||
});
|
||||
});
|
||||
|
||||
const observedRef = ref.current;
|
||||
|
||||
observer.observe(observedRef);
|
||||
|
||||
// Cleanup the observer on component unmount
|
||||
return () => {
|
||||
if (observedRef) {
|
||||
observer.unobserve(observedRef);
|
||||
}
|
||||
};
|
||||
}
|
||||
}, []);
|
||||
|
||||
return [ref, dimensions];
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user