Compare commits

..

26 Commits

Author SHA1 Message Date
David Corbitt
999a4c08fa Fix lint and prettier 2023-07-18 11:11:20 -07:00
arcticfly
374d0237ee Escape characters in Regex evaluations, minor UI fixes (#56)
* Fix ScenariosHeader stickiness

* Move meta tag from _app.tsx to _document.tsx

* Show spinner when saving variant

* Escape quotes and regex in evaluations
2023-07-18 11:07:04 -07:00
David Corbitt
b1f873623d Invalidate prompt stats after cell refetch 2023-07-18 09:45:11 -07:00
arcticfly
4131aa67d0 Continue polling VariantStats while LLM retrieval in progress, minor UI fixes (#54)
* Prevent zoom in on iOS

* Expand function return code background to fill cell

* Keep OutputStats on far right of cells

* Continue polling prompt stats while cells are retrieving from LLM

* Add comment to _document.tsx

* Fix prettier
2023-07-17 18:04:38 -07:00
Kyle Corbitt
8e7a6d3ae2 Merge pull request #55 from OpenPipe/more-eval
Add GPT4 Evals
2023-07-17 18:01:47 -07:00
Kyle Corbitt
7d41e94ca2 cache eval outputs and add gpt4 eval 2023-07-17 17:55:36 -07:00
Kyle Corbitt
011b12abb9 cache output evals 2023-07-17 17:52:30 -07:00
Kyle Corbitt
1ba18015bc Merge pull request #53 from OpenPipe/more-eval
Fix seeds and update eval field names
2023-07-17 14:26:29 -07:00
Kyle Corbitt
54369dba54 Fix seeds and update eval field names 2023-07-17 14:14:20 -07:00
arcticfly
6b84a59372 Properly catch completion errors (#51) 2023-07-17 10:50:25 -07:00
Kyle Corbitt
8db8aeacd3 Replace function chrome with comment
Use a block comment to explain the expected prompt formatting instead of function chrome. The advantage here is that once a user builds a mental model of how OpenPipe works they can just delete the comment, instead of the function chrome sitting around and taking up space in the UI forever.
2023-07-17 10:30:22 -07:00
Kyle Corbitt
64bd71e370 Merge pull request #50 from OpenPipe/remove-default
remove the default value for PromptVariant.model
2023-07-14 17:55:38 -07:00
Kyle Corbitt
ca21a7af06 Run checks on main
This will (1) make sure that anything we push directly passes CI, and also (2) cache the pnpm store on the main branch, which will make it available to PR runs as well and hopefully speed up CI a bit (see https://stackoverflow.com/a/75250061).``
2023-07-14 17:49:20 -07:00
Kyle Corbitt
3b99b7bd2b remove the default value for PromptVariant.model
We should be explicit about setting the appropriate model so it always matches the constructFn.
2023-07-14 17:43:52 -07:00
Kyle Corbitt
0c3bdbe4f2 Merge pull request #49 from OpenPipe/save-button
Make save button disappear on save
2023-07-14 17:39:06 -07:00
Kyle Corbitt
74c201d3a8 Make save button disappear on save
Fixes a bug where often the "Save" button wouldn't disappear as expected the first time you clicked it.
2023-07-14 17:35:57 -07:00
David Corbitt
ab9c721d09 Revert change to scenarios header 2023-07-14 17:51:12 -06:00
David Corbitt
0a2578a1d8 Update scenarios header negative margin 2023-07-14 17:41:51 -06:00
David Corbitt
1bebaff386 Merge branch 'main' of github.com:corbt/prompt-lab 2023-07-14 16:55:12 -06:00
David Corbitt
3bf5eaf4a2 Properly extract scenario id in new experiment creation 2023-07-14 16:55:09 -06:00
Kyle Corbitt
ded97f8bb9 fix lockfile 2023-07-14 15:55:01 -07:00
Kyle Corbitt
26ee8698be Make it so you can't delete the last prompt or scenario
No reason for an experiment to have 0 prompts or 0 scenarios and it makes the UI look bad.
2023-07-14 15:49:42 -07:00
arcticfly
b98eb9b729 Trigger llm output retrieval on server (#39)
* Rename tables, add graphile workers, update types

* Add dev:worker command

* Update pnpm-lock.yaml

* Remove sentry config import from worker.ts

* Stop generating new cells in cell router get query

* Generate new cells for new scenarios, variants, and experiments

* Remove most error throwing from queryLLM.task.ts

* Remove promptVariantId and testScenarioId from ModelOutput

* Remove duplicate index from ModelOutput

* Move inputHash from cell to output

* Add TODO

* Add todo

* Show cost and time for each cell

* Always show output stats if there is output

* Trigger LLM outputs when scenario variables are updated

* Add newlines to ends of files

* Add another newline

* Cascade ModelOutput deletion

* Fix linting and prettier

* Return instead of throwing for non-pending cell

* Remove pnpm dev:worker from pnpm:dev

* Update pnpm-lock.yaml
2023-07-14 16:38:46 -06:00
Kyle Corbitt
032c07ec65 Merge pull request #45 from OpenPipe/node-version
warn folks if they use a lower node version
2023-07-14 15:03:49 -07:00
Kyle Corbitt
80c0d13bb9 warn folks if they use a lower node version 2023-07-14 14:59:33 -07:00
Kyle Corbitt
f7c94be3f6 Merge pull request #44 from OpenPipe/strip-types
Strip types from prompt variants
2023-07-14 14:07:07 -07:00
57 changed files with 1872 additions and 776 deletions

View File

@@ -3,6 +3,8 @@ name: CI checks
on:
pull_request:
branches: [main]
push:
branches: [main]
jobs:
run-checks:

1
.tool-versions Normal file
View File

@@ -0,0 +1 @@
nodejs 20.2.0

View File

@@ -3,15 +3,21 @@
"type": "module",
"version": "0.1.0",
"license": "Apache-2.0",
"engines": {
"node": ">=20.0.0",
"pnpm": ">=8.6.1"
},
"scripts": {
"build": "next build",
"dev:next": "next dev",
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
"postinstall": "prisma generate",
"lint": "next lint",
"start": "next start",
"codegen": "tsx src/codegen/export-openai-types.ts"
"codegen": "tsx src/codegen/export-openai-types.ts",
"seed": "tsx prisma/seed.ts"
},
"dependencies": {
"@babel/preset-typescript": "^7.22.5",
@@ -40,10 +46,11 @@
"express": "^4.18.2",
"framer-motion": "^10.12.17",
"gpt-tokens": "^1.0.10",
"graphile-worker": "^0.13.0",
"immer": "^10.0.2",
"isolated-vm": "^4.5.0",
"json-stringify-pretty-compact": "^4.0.0",
"lodash": "^4.17.21",
"lodash-es": "^4.17.21",
"next": "^13.4.2",
"next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1",
@@ -71,7 +78,7 @@
"@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0",
"@types/express": "^4.17.17",
"@types/lodash": "^4.14.195",
"@types/lodash-es": "^4.17.8",
"@types/node": "^18.16.0",
"@types/pluralize": "^0.0.30",
"@types/react": "^18.2.6",
@@ -94,6 +101,6 @@
"initVersion": "7.14.0"
},
"prisma": {
"seed": "tsx prisma/seed.ts"
"seed": "pnpm seed"
}
}

524
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
-- Drop the foreign key constraints on the original ModelOutput
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
-- Rename the old table
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
-- Add the new fields to the renamed table
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
ALTER COLUMN "statusCode" DROP NOT NULL,
ALTER COLUMN "timeToComplete" DROP NOT NULL;
-- Create the new table
CREATE TABLE "ModelOutput" (
"id" UUID NOT NULL,
"inputHash" TEXT NOT NULL,
"output" JSONB NOT NULL,
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
"promptTokens" INTEGER,
"completionTokens" INTEGER,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"scenarioVariantCellId" UUID
);
-- Move inputHash index
DROP INDEX "ScenarioVariantCell_inputHash_idx";
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- CreateEnum
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
-- AlterTable
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;

View File

@@ -0,0 +1,24 @@
/*
Warnings:
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
*/
-- RenameEnum
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
-- AlterTable
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
-- AlterColumnType
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
-- SetNotNullConstraint
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;

View File

@@ -0,0 +1,39 @@
/*
Warnings:
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
*/
-- AlterEnum
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
-- DropTable
DROP TABLE "EvaluationResult";
-- CreateTable
CREATE TABLE "OutputEvaluation" (
"id" UUID NOT NULL,
"result" DOUBLE PRECISION NOT NULL,
"details" TEXT,
"modelOutputId" UUID NOT NULL,
"evaluationId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -2,8 +2,7 @@
// learn more about it in the docs: https://pris.ly/d/prisma-schema
generator client {
provider = "prisma-client-js"
previewFeatures = ["jsonProtocol"]
provider = "prisma-client-js"
}
datasource db {
@@ -30,7 +29,7 @@ model PromptVariant {
label String
constructFn String
model String @default("gpt-3.5-turbo")
model String
uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true)
@@ -39,10 +38,9 @@ model PromptVariant {
experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
ModelOutput ModelOutput[]
EvaluationResult EvaluationResult[]
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[]
@@index([uiId])
}
@@ -59,9 +57,9 @@ model TestScenario {
experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
ModelOutput ModelOutput[]
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[]
}
model TemplateVariable {
@@ -76,17 +74,28 @@ model TemplateVariable {
updatedAt DateTime @updatedAt
}
model ModelOutput {
enum CellRetrievalStatus {
PENDING
IN_PROGRESS
COMPLETE
ERROR
}
model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
statusCode Int
errorMessage String?
timeToComplete Int @default(0)
inputHash String? // TODO: Remove once migration is complete
output Json? // TODO: Remove once migration is complete
statusCode Int?
errorMessage String?
timeToComplete Int? @default(0) // TODO: Remove once migration is complete
retryTime DateTime?
streamingChannel String?
retrievalStatus CellRetrievalStatus @default(COMPLETE)
promptTokens Int? // Added promptTokens field
completionTokens Int? // Added completionTokens field
promptTokens Int? // TODO: Remove once migration is complete
completionTokens Int? // TODO: Remove once migration is complete
modelOutput ModelOutput?
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -98,45 +107,66 @@ model ModelOutput {
updatedAt DateTime @updatedAt
@@unique([promptVariantId, testScenarioId])
}
model ModelOutput {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
timeToComplete Int @default(0)
promptTokens Int?
completionTokens Int?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluation OutputEvaluation[]
@@unique([scenarioVariantCellId])
@@index([inputHash])
}
enum EvaluationMatchType {
enum EvalType {
CONTAINS
DOES_NOT_CONTAIN
GPT4_EVAL
}
model Evaluation {
id String @id @default(uuid()) @db.Uuid
name String
matchString String
matchType EvaluationMatchType
label String
evalType EvalType
value String
experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
EvaluationResult EvaluationResult[]
OutputEvaluation OutputEvaluation[]
}
model EvaluationResult {
model OutputEvaluation {
id String @id @default(uuid()) @db.Uuid
passCount Int
failCount Int
// Number between 0 (fail) and 1 (pass)
result Float
details String?
modelOutputId String @db.Uuid
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([evaluationId, promptVariantId])
@@unique([modelOutputId, evaluationId])
}
// Necessary for Next auth

View File

@@ -1,4 +1,6 @@
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
const experimentId = "11111111-1111-1111-1111-111111111111";
@@ -9,14 +11,14 @@ await prisma.experiment.deleteMany({
},
});
const experiment = await prisma.experiment.create({
await prisma.experiment.create({
data: {
id: experimentId,
label: "Country Capitals Example",
},
});
await prisma.modelOutput.deleteMany({
await prisma.scenarioVariantCell.deleteMany({
where: {
promptVariant: {
experimentId,
@@ -36,27 +38,35 @@ await prisma.promptVariant.createMany({
experimentId,
label: "Prompt Variant 1",
sortIndex: 0,
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
temperature: 0,
}`,
model: "gpt-3.5-turbo-0613",
constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: \`What is the capital of ${"$"}{scenario.country}?\`
}
],
temperature: 0,
}`,
},
{
experimentId,
label: "Prompt Variant 2",
sortIndex: 1,
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content:
"What is the capital of {{country}}? Return just the city name and nothing else.",
},
],
temperature: 0,
}`,
model: "gpt-3.5-turbo-0613",
constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
}
],
temperature: 0,
}`,
},
],
});
@@ -107,3 +117,26 @@ await prisma.testScenario.createMany({
},
],
});
const variants = await prisma.promptVariant.findMany({
where: {
experimentId,
},
});
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId,
},
});
await Promise.all(
variants
.flatMap((variant) =>
scenarios.map((scenario) => ({
promptVariantId: variant.id,
testScenarioId: scenario.id,
})),
)
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
);

View File

@@ -12,6 +12,7 @@ await prisma.promptVariant.createMany({
{
experimentId: functionCallsExperiment.id,
label: "No Fn Calls",
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
@@ -30,6 +31,7 @@ await prisma.promptVariant.createMany({
{
experimentId: functionCallsExperiment.id,
label: "Fn Calls",
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
@@ -92,6 +94,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id,
label: "3.5 Base",
sortIndex: 0,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
@@ -107,6 +110,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id,
label: "4 Base",
sortIndex: 1,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = {
model: "gpt-4-0613",
messages: [
@@ -122,6 +126,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id,
label: "3.5 CoT + Functions",
sortIndex: 2,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
@@ -178,9 +183,9 @@ await prisma.templateVariable.createMany({
await prisma.evaluation.create({
data: {
experimentId: redditExperiment.id,
name: "Relevance Accuracy",
matchType: "CONTAINS",
matchString: '"{{relevance}}"',
label: "Relevance Accuracy",
evalType: "CONTAINS",
value: '"{{relevance}}"',
},
});
@@ -1119,12 +1124,3 @@ await prisma.testScenario.createMany({
variableValues: vars,
})),
});
// await prisma.evaluation.create({
// data: {
// experimentId: redditExperiment.id,
// name: "Scores Match",
// matchType: "CONTAINS",
// matchString: "{{score}}",
// },
// });

View File

@@ -2,7 +2,7 @@ import fs from "fs";
import path from "path";
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
import YAML from "yaml";
import _ from "lodash";
import { pick } from "lodash-es";
import assert from "assert";
const OPENAPI_URL =
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
delete schema["paths"];
assert(schema.components?.schemas);
schema.components.schemas = _.pick(schema.components?.schemas, [
schema.components.schemas = pick(schema.components?.schemas, [
"CreateChatCompletionRequest",
"ChatCompletionRequestMessage",
"ChatCompletionFunctions",

View File

@@ -4,7 +4,7 @@ import React from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement,
TextareaProps
TextareaProps & { minRows?: number }
> = (props, ref) => {
return (
<Textarea

View File

@@ -11,14 +11,16 @@ import {
FormLabel,
Select,
FormHelperText,
Code,
} from "@chakra-ui/react";
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
import { type Evaluation, EvalType } from "@prisma/client";
import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import AutoResizeTextArea from "../AutoResizeTextArea";
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
export function EvaluationEditor(props: {
evaluation: Evaluation | null;
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
onCancel: () => void;
}) {
const [values, setValues] = useState<EvalValues>({
name: props.evaluation?.name ?? props.defaultName ?? "",
matchString: props.evaluation?.matchString ?? "",
matchType: props.evaluation?.matchType ?? "CONTAINS",
label: props.evaluation?.label ?? props.defaultName ?? "",
value: props.evaluation?.value ?? "",
evalType: props.evaluation?.evalType ?? "CONTAINS",
});
return (
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
<HStack w="100%">
<FormControl flex={1}>
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
<FormLabel fontSize="sm">Eval Name</FormLabel>
<Input
size="sm"
value={values.name}
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
value={values.label}
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
/>
</FormControl>
<FormControl flex={1}>
<FormLabel fontSize="sm">Match Type</FormLabel>
<FormLabel fontSize="sm">Eval Type</FormLabel>
<Select
size="sm"
value={values.matchType}
value={values.evalType}
onChange={(e) =>
setValues((values) => ({
...values,
matchType: e.target.value as EvaluationMatchType,
evalType: e.target.value as EvalType,
}))
}
>
{Object.values(EvaluationMatchType).map((type) => (
{Object.values(EvalType).map((type) => (
<option key={type} value={type}>
{type}
</option>
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
</Select>
</FormControl>
</HStack>
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<Input
size="sm"
value={values.matchString}
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
/>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
</FormControl>
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<Input
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
/>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output. You
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
</FormHelperText>
</FormControl>
)}
{values.evalType === "GPT4_EVAL" && (
<FormControl pt={2}>
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
<AutoResizeTextArea
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
minRows={3}
/>
<FormHelperText>
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
access to the specific prompt variant, so be sure to be clear about the task you want it
to perform.
</FormHelperText>
</FormControl>
)}
<HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
Cancel
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
}
await utils.evaluations.list.invalidate();
await utils.promptVariants.stats.invalidate();
await utils.scenarioVariantCells.get.invalidate();
}, []);
const onCancel = useCallback(() => {
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
align="center"
key={evaluation.id}
>
<Text fontWeight="bold">{evaluation.name}</Text>
<Text fontWeight="bold">{evaluation.label}</Text>
<Text flex={1}>
{evaluation.matchType}: &quot;{evaluation.matchString}&quot;
{evaluation.evalType}: &quot;{evaluation.value}&quot;
</Text>
<Button
variant="unstyled"

View File

@@ -1,4 +1,4 @@
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react";
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
import { useState } from "react";
import { BsCheck, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
<Heading size="sm">Scenario Variables</Heading>
<Stack spacing={2}>
<Text fontSize="sm">
Scenario variables can be used in your prompt variants as well as evaluations. Reference
them using <Code>{"{{curly_braces}}"}</Code>.
Scenario variables can be used in your prompt variants as well as evaluations.
</Text>
<HStack spacing={0}>
<Input

View File

@@ -0,0 +1,33 @@
import { Button, HStack, Icon } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs";
export const CellOptions = ({
refetchingOutput,
refetchOutput,
}: {
refetchingOutput: boolean;
refetchOutput: () => void;
}) => {
return (
<HStack justifyContent="flex-end" w="full">
{!refetchingOutput && (
<Button
size="xs"
w={4}
h={4}
py={4}
px={4}
minW={0}
borderRadius={8}
color="gray.500"
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={4} />
</Button>
)}
</HStack>
);
};

View File

@@ -1,29 +1,21 @@
import { type ModelOutput } from "@prisma/client";
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
import { type ScenarioVariantCell } from "@prisma/client";
import { VStack, Text } from "@chakra-ui/react";
import { useEffect, useState } from "react";
import { BsArrowClockwise } from "react-icons/bs";
import { rateLimitErrorMessage } from "~/sharedStrings";
import pluralize from "pluralize";
const MAX_AUTO_RETRIES = 3;
export const ErrorHandler = ({
output,
cell,
refetchOutput,
numPreviousTries,
}: {
output: ModelOutput;
cell: ScenarioVariantCell;
refetchOutput: () => void;
numPreviousTries: number;
}) => {
const [msToWait, setMsToWait] = useState(0);
const shouldAutoRetry =
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
useEffect(() => {
if (!shouldAutoRetry) return;
if (!cell.retryTime) return;
const initialWaitTime = calculateDelay(numPreviousTries);
const initialWaitTime = cell.retryTime.getTime() - Date.now();
const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime);
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
setMsToWait(remainingTime);
if (remainingTime <= 0) {
refetchOutput();
clearInterval(interval);
}
}, 1000);
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
clearInterval(interval);
clearTimeout(timeout);
};
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
return (
<VStack w="full">
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
<Text color="red.600" fontWeight="bold">
Error
</Text>
<Button
size="xs"
w={4}
h={4}
px={4}
py={4}
minW={0}
borderRadius={8}
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={6} />
</Button>
</HStack>
<Text color="red.600" wordBreak="break-word">
{output.errorMessage}
{cell.errorMessage}
</Text>
{msToWait > 0 && (
<Text color="red.600" fontSize="sm">
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
</VStack>
);
};
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 5000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}

View File

@@ -1,17 +1,16 @@
import { type RouterOutputs, api } from "~/utils/api";
import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
import { type ReactElement, useState, useEffect } from "react";
import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
export default function OutputCell({
scenario,
@@ -37,116 +36,116 @@ export default function OutputCell({
// if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output";
const outputMutation = api.outputs.get.useMutation();
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
const [channel, setChannel] = useState<string | undefined>(undefined);
const [numPreviousTries, setNumPreviousTries] = useState(0);
const fetchMutex = useRef(false);
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
async (forceRefetch?: boolean) => {
if (fetchMutex.current) return;
setNumPreviousTries((prev) => prev + 1);
fetchMutex.current = true;
setOutput(null);
const shouldStream =
isObject(variant) &&
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const channel = shouldStream ? generateChannel() : undefined;
setChannel(channel);
const output = await outputMutation.mutateAsync({
scenarioId: scenario.id,
variantId: variant.id,
channel,
forceRefetch,
});
setOutput(output);
await utils.promptVariants.stats.invalidate();
fetchMutex.current = false;
},
[outputMutation, scenario.id, variant.id],
const [refetchInterval, setRefetchInterval] = useState(0);
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
{ scenarioId: scenario.id, variantId: variant.id },
{ refetchInterval },
);
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
useEffect(fetchOutput, [scenario.id, variant.id]);
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
api.scenarioVariantCells.forceRefetch.useMutation();
const [hardRefetch] = useHandledAsyncCallback(async () => {
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
await utils.scenarioVariantCells.get.invalidate({
scenarioId: scenario.id,
variantId: variant.id,
});
await utils.promptVariants.stats.invalidate({
variantId: variant.id,
});
}, [hardRefetchMutate, scenario.id, variant.id]);
const fetchingOutput = queryLoading || refetchingOutput;
const awaitingOutput =
!cell ||
cell.retrievalStatus === "PENDING" ||
cell.retrievalStatus === "IN_PROGRESS" ||
refetchingOutput;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
const streamedMessage = useSocket(cell?.streamingChannel);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (fetchingOutput && !streamedMessage)
if (awaitingOutput && !streamedMessage)
return (
<Center h="100%" w="100%">
<Spinner />
</Center>
);
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (output && output.errorMessage) {
return (
<ErrorHandler
output={output}
refetchOutput={hardRefetch}
numPreviousTries={numPreviousTries}
/>
);
if (cell && cell.errorMessage) {
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
}
const response = output?.output as unknown as ChatCompletion;
const response = modelOutput?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message;
if (output && message?.function_call) {
if (modelOutput && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string;
try {
parsedArgs = JSON.parse(rawArgs);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
}
return (
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
<SyntaxHighlighter
customStyle={{ overflowX: "unset" }}
language="json"
style={docco}
lineProps={{
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
}}
wrapLines
>
{stringify(
{
function: message.function_call.name,
args: parsedArgs,
},
{ maxLength: 40 },
)}
</SyntaxHighlighter>
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} />
</Box>
<VStack
w="100%"
h="100%"
fontSize="xs"
flexWrap="wrap"
overflowX="auto"
justifyContent="space-between"
>
<VStack w="full" flex={1} spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<SyntaxHighlighter
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
language="json"
style={docco}
lineProps={{
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
}}
wrapLines
>
{stringify(
{
function: message.function_call.name,
args: parsedArgs,
},
{ maxLength: 40 },
)}
</SyntaxHighlighter>
</VStack>
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
</VStack>
);
}
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
const contentToDisplay =
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay}
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />}
</Flex>
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
<VStack w="full" alignItems="flex-start" spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text>
</VStack>
{modelOutput && (
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
)}
</VStack>
);
}

View File

@@ -1,30 +1,25 @@
import { type ModelOutput } from "@prisma/client";
import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types";
import { useExperiment } from "~/utils/hooks";
import { api } from "~/utils/api";
import { type RouterOutputs } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { evaluateOutput } from "~/server/utils/evaluateOutput";
import { HStack, Icon, Text } from "@chakra-ui/react";
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_COST = false;
const SHOW_TIME = false;
const SHOW_COST = true;
const SHOW_TIME = true;
export const OutputStats = ({
model,
modelOutput,
scenario,
}: {
model: SupportedModel | string | null;
modelOutput: ModelOutput;
modelOutput: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
>;
scenario: Scenario;
}) => {
const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
@@ -35,22 +30,26 @@ export const OutputStats = ({
const cost = promptCost + completionCost;
if (!evals.length) return null;
return (
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}>
{evals.map((evaluation) => {
const passed = evaluateOutput(modelOutput, scenario, evaluation);
{modelOutput.outputEvaluation.map((evaluation) => {
const passed = evaluation.result > 0.5;
return (
<HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.name}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
<Tooltip
isDisabled={!evaluation.details}
label={evaluation.details}
key={evaluation.id}
>
<HStack spacing={0}>
<Text>{evaluation.evaluation.label}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
</Tooltip>
);
})}
</HStack>

View File

@@ -1,6 +1,6 @@
import { type DragEvent } from "react";
import { api } from "~/utils/api";
import { isEqual } from "lodash";
import { isEqual } from "lodash-es";
import { type Scenario } from "./types";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react";
@@ -13,10 +13,11 @@ import AutoResizeTextArea from "../AutoResizeTextArea";
export default function ScenarioEditor({
scenario,
hovered,
...props
}: {
scenario: Scenario;
hovered: boolean;
canHide: boolean;
}) {
const savedValues = scenario.variableValues as Record<string, string>;
const utils = api.useContext();
@@ -92,30 +93,34 @@ export default function ScenarioEditor({
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}>
<Tooltip label="Hide scenario" hasArrow>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
<Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
{props.canHide && (
<>
<Tooltip label="Hide scenario" hasArrow>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
<Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
</>
)}
</Stack>
{variableLabels.length === 0 ? (
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>

View File

@@ -5,7 +5,11 @@ import OutputCell from "./OutputCell/OutputCell";
import ScenarioEditor from "./ScenarioEditor";
import type { PromptVariant, Scenario } from "./types";
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
const ScenarioRow = (props: {
scenario: Scenario;
variants: PromptVariant[];
canHide: boolean;
}) => {
const [isHovered, setIsHovered] = useState(false);
const highlightStyle = { backgroundColor: "gray.50" };
@@ -18,7 +22,7 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
sx={isHovered ? highlightStyle : undefined}
borderLeftWidth={1}
>
<ScenarioEditor scenario={props.scenario} hovered={isHovered} />
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
</GridItem>
{props.variants.map((variant) => (
<GridItem

View File

@@ -0,0 +1,49 @@
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
import { cellPadding } from "../constants";
import { useElementDimensions } from "~/utils/hooks";
import { stickyHeaderStyle } from "./styles";
import { BsPencil } from "react-icons/bs";
import { useAppStore } from "~/state/store";
export const ScenariosHeader = ({
headerRows,
numScenarios,
}: {
headerRows: number;
numScenarios: number;
}) => {
const openDrawer = useAppStore((s) => s.openDrawer);
const [ref, dimensions] = useElementDimensions();
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
return (
<GridItem
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ref={ref as any}
display="flex"
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// Only display the part of the grid item that has content
sx={{ ...stickyHeaderStyle, top: topValue }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({numScenarios})
</Heading>
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
</HStack>
</GridItem>
);
};

View File

@@ -1,11 +1,11 @@
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react";
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import { editorBackground } from "~/state/sharedVariantEditor.slice";
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
export default function VariantEditor(props: { variant: PromptVariant }) {
const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
@@ -17,15 +17,17 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
const checkForChanges = useCallback(() => {
if (!editorRef.current) return;
const currentConfig = editorRef.current.getValue();
setIsChanged(currentConfig !== lastSavedFn);
const currentFn = editorRef.current.getValue();
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
}, [lastSavedFn]);
useEffect(checkForChanges, [checkForChanges, lastSavedFn]);
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext();
const toast = useToast();
const [onSave] = useHandledAsyncCallback(async () => {
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
if (!editorRef.current) return;
await editorRef.current.getAction("editor.action.formatDocument")?.run();
@@ -75,9 +77,9 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
});
}
await utils.promptVariants.list.invalidate();
setIsChanged(false);
checkForChanges();
await utils.promptVariants.list.invalidate();
}, [checkForChanges]);
useEffect(() => {
@@ -130,19 +132,7 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
return (
<Box w="100%" pos="relative">
<VStack
spacing={0}
align="stretch"
fontSize="xs"
fontWeight="bold"
color="gray.600"
py={2}
bgColor={editorBackground}
>
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
<code>{`return prompt; }`}</code>
</VStack>
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
{isChanged && (
<HStack pos="absolute" bottom={2} right={2}>
<Button
@@ -156,8 +146,8 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
Reset
</Button>
<Tooltip label={`${modifierKey} + Enter`}>
<Button size="sm" onClick={onSave} colorScheme="blue">
Save
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
</Button>
</Tooltip>
</HStack>

View File

@@ -8,7 +8,7 @@ import { RiDraggable } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
export default function VariantHeader(props: { variant: PromptVariant }) {
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
@@ -95,11 +95,13 @@ export default function VariantHeader(props: { variant: PromptVariant }) {
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
<Tooltip label="Hide Variant" hasArrow>
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
<Icon as={BsX} boxSize={6} />
</Button>
</Tooltip>
{props.canHide && (
<Tooltip label="Remove Variant" hasArrow>
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
<Icon as={BsX} boxSize={6} />
</Button>
</Tooltip>
)}
</HStack>
);
}

View File

@@ -1,12 +1,14 @@
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types";
import { cellPadding } from "../constants";
import { api } from "~/utils/api";
import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs";
import { CostTooltip } from "../tooltip/CostTooltip";
import { useEffect, useState } from "react";
export default function VariantStats(props: { variant: PromptVariant }) {
const [refetchInterval, setRefetchInterval] = useState(0);
const { data } = api.promptVariants.stats.useQuery(
{
variantId: props.variant.id,
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
completionTokens: 0,
scenarioCount: 0,
outputCount: 0,
awaitingRetrievals: false,
},
refetchInterval,
},
);
// Poll every two seconds while we are waiting for LLM retrievals to finish
useEffect(
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
[data.awaitingRetrievals],
);
const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500",
"gray.500",
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
return (
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
<HStack
justifyContent="space-between"
alignItems="center"
mx="2"
fontSize="xs"
py={cellPadding.y}
>
{showNumFinished && (
<Text>
{data.outputCount} / {data.scenarioCount}
</Text>
)}
<HStack px={cellPadding.x} py={cellPadding.y}>
<HStack px={cellPadding.x}>
{data.evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount);
const passedFrac = result.passCount / result.totalCount;
return (
<HStack key={result.id}>
<Text>{result.evaluation.name}</Text>
<Text>{result.label}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}%
</Text>
@@ -55,17 +69,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
);
})}
</HStack>
{data.overallCost && (
{data.overallCost && !data.awaitingRetrievals ? (
<CostTooltip
promptTokens={data.promptTokens}
completionTokens={data.completionTokens}
cost={data.overallCost}
>
<HStack spacing={0} align="center" color="gray.500" my="2">
<HStack spacing={0} align="center" color="gray.500">
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack>
</CostTooltip>
) : (
<Skeleton height={4} width={12} mr={1} />
)}
</HStack>
);

View File

@@ -1,28 +1,19 @@
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
import { Grid, GridItem } from "@chakra-ui/react";
import { api } from "~/utils/api";
import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton";
import ScenarioRow from "./ScenarioRow";
import VariantConfigEditor from "./VariantEditor";
import VariantEditor from "./VariantEditor";
import VariantHeader from "./VariantHeader";
import { cellPadding } from "../constants";
import { BsPencil } from "react-icons/bs";
import VariantStats from "./VariantStats";
import { useAppStore } from "~/state/store";
const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "-1px",
backgroundColor: "#fff",
zIndex: 1,
};
import { ScenariosHeader } from "./ScenariosHeader";
import { stickyHeaderStyle } from "./styles";
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const variants = api.promptVariants.list.useQuery(
{ experimentId: experimentId as string },
{ enabled: !!experimentId },
);
const openDrawer = useAppStore((s) => s.openDrawer);
const scenarios = api.scenarios.list.useQuery(
{ experimentId: experimentId as string },
@@ -49,36 +40,11 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
}}
fontSize="sm"
>
<GridItem
display="flex"
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
// so if the header height changes, this will need to be updated.
sx={{ ...stickyHeaderStyle, top: "-337px" }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({scenarios.data.length})
</Heading>
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
</HStack>
</GridItem>
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
{variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
<VariantHeader variant={variant} />
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
</GridItem>
))}
<GridItem
@@ -94,7 +60,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
{variants.data.map((variant) => (
<GridItem key={variant.uiId}>
<VariantConfigEditor variant={variant} />
<VariantEditor variant={variant} />
</GridItem>
))}
{variants.data.map((variant) => (
@@ -103,7 +69,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
</GridItem>
))}
{scenarios.data.map((scenario) => (
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
<ScenarioRow
key={scenario.uiId}
scenario={scenario}
variants={variants.data}
canHide={scenarios.data.length > 1}
/>
))}
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
<NewScenarioButton />

View File

@@ -0,0 +1,8 @@
import { type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "-1px",
backgroundColor: "#fff",
zIndex: 1,
};

View File

@@ -20,7 +20,6 @@ export const CostTooltip = ({
color="gray.800"
bgColor="gray.50"
borderWidth={1}
py={2}
hasArrow
shouldWrapChildren
label={

View File

@@ -6,18 +6,27 @@ import { ChakraProvider } from "@chakra-ui/react";
import theme from "~/utils/theme";
import Favicon from "~/components/Favicon";
import "~/utils/analytics";
import Head from "next/head";
const MyApp: AppType<{ session: Session | null }> = ({
Component,
pageProps: { session, ...pageProps },
}) => {
return (
<SessionProvider session={session}>
<Favicon />
<ChakraProvider theme={theme}>
<Component {...pageProps} />
</ChakraProvider>
</SessionProvider>
<>
<Head>
<meta
name="viewport"
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
/>
</Head>
<SessionProvider session={session}>
<Favicon />
<ChakraProvider theme={theme}>
<Component {...pageProps} />
</ChakraProvider>
</SessionProvider>
</>
);
};

View File

@@ -1,7 +1,7 @@
import { type GetServerSideProps } from "next";
// eslint-disable-next-line @typescript-eslint/require-await
export const getServerSideProps: GetServerSideProps = async (context) => {
export const getServerSideProps: GetServerSideProps = async () => {
return {
redirect: {
destination: "/experiments",

View File

@@ -1,11 +1,7 @@
import { type CompletionCreateParams } from "openai/resources/chat";
import { prisma } from "../db";
import { openai } from "../utils/openai";
import { pick } from "lodash";
function promptHasVariable(prompt: string, variableName: string) {
return prompt.includes(`{{${variableName}}}`);
}
import { pick } from "lodash-es";
type AxiosError = {
response?: {

View File

@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
import { createTRPCRouter } from "~/server/api/trpc";
import { experimentsRouter } from "./routers/experiments.router";
import { scenariosRouter } from "./routers/scenarios.router";
import { modelOutputsRouter } from "./routers/modelOutputs.router";
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
import { templateVarsRouter } from "./routers/templateVariables.router";
import { evaluationsRouter } from "./routers/evaluations.router";
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
promptVariants: promptVariantsRouter,
experiments: experimentsRouter,
scenarios: scenariosRouter,
outputs: modelOutputsRouter,
scenarioVariantCells: scenarioVariantCellsRouter,
templateVars: templateVarsRouter,
evaluations: evaluationsRouter,
});

View File

@@ -1,8 +1,8 @@
import { EvaluationMatchType } from "@prisma/client";
import { EvalType } from "@prisma/client";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { reevaluateEvaluation } from "~/server/utils/evaluations";
import { runAllEvals } from "~/server/utils/evaluations";
export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -18,21 +18,24 @@ export const evaluationsRouter = createTRPCRouter({
.input(
z.object({
experimentId: z.string(),
name: z.string(),
matchString: z.string(),
matchType: z.nativeEnum(EvaluationMatchType),
label: z.string(),
value: z.string(),
evalType: z.nativeEnum(EvalType),
}),
)
.mutation(async ({ input }) => {
const evaluation = await prisma.evaluation.create({
await prisma.evaluation.create({
data: {
experimentId: input.experimentId,
name: input.name,
matchString: input.matchString,
matchType: input.matchType,
label: input.label,
value: input.value,
evalType: input.evalType,
},
});
await reevaluateEvaluation(evaluation);
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
// to kick off a background job or something instead
await runAllEvals(input.experimentId);
}),
update: publicProcedure
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
z.object({
id: z.string(),
updates: z.object({
name: z.string().optional(),
matchString: z.string().optional(),
matchType: z.nativeEnum(EvaluationMatchType).optional(),
label: z.string().optional(),
value: z.string().optional(),
evalType: z.nativeEnum(EvalType).optional(),
}),
}),
)
.mutation(async ({ input }) => {
await prisma.evaluation.update({
const evaluation = await prisma.evaluation.update({
where: { id: input.id },
data: {
name: input.updates.name,
matchString: input.updates.matchString,
matchType: input.updates.matchType,
label: input.updates.label,
value: input.updates.value,
evalType: input.updates.evalType,
},
});
await reevaluateEvaluation(
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
);
await prisma.outputEvaluation.deleteMany({
where: {
evaluationId: evaluation.id,
},
});
// Re-run all evals. Other eval results will already be cached, so this
// should only re-run the updated one.
await runAllEvals(evaluation.experimentId);
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {

View File

@@ -2,6 +2,7 @@ import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
export const experimentsRouter = createTRPCRouter({
list: publicProcedure.query(async () => {
@@ -64,28 +65,55 @@ export const experimentsRouter = createTRPCRouter({
},
});
await prisma.$transaction([
const [variant, _, scenario] = await prisma.$transaction([
prisma.promptVariant.create({
data: {
experimentId: exp.id,
label: "Prompt Variant 1",
sortIndex: 0,
constructFn: dedent`prompt = {
// The interpolated $ is necessary until dedent incorporates
// https://github.com/dmnd/dedent/pull/46
constructFn: dedent`
/**
* Use Javascript to define an OpenAI chat completion
* (https://platform.openai.com/docs/api-reference/chat/create) and
* assign it to the \`prompt\` variable.
*
* You have access to the current scenario in the \`scenario\`
* variable.
*/
prompt = {
model: "gpt-3.5-turbo-0613",
stream: true,
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
}`,
messages: [
{
role: "system",
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
},
],
};`,
model: "gpt-3.5-turbo-0613",
},
}),
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "text",
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {},
variableValues: {
text: "This is a test scenario.",
},
},
}),
]);
await generateNewCell(variant.id, scenario.id);
return exp;
}),

View File

@@ -1,101 +0,0 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import crypto from "crypto";
import type { Prisma } from "@prisma/client";
import { reevaluateVariant } from "~/server/utils/evaluations";
import { getCompletion } from "~/server/utils/getCompletion";
import { constructPrompt } from "~/server/utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
channel: z.string().optional(),
forceRefetch: z.boolean().optional(),
}),
)
.mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
});
if (existing && !input.forceRefetch) return existing;
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: input.scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
// TODO: we should probably only use this if temperature=0
const existingResponse = await prisma.modelOutput.findFirst({
where: { inputHash, errorMessage: null },
});
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
if (existingResponse) {
modelResponse = {
output: existingResponse.output as Prisma.InputJsonValue,
statusCode: existingResponse.statusCode,
errorMessage: existingResponse.errorMessage,
timeToComplete: existingResponse.timeToComplete,
promptTokens: existingResponse.promptTokens ?? undefined,
completionTokens: existingResponse.completionTokens ?? undefined,
};
} else {
try {
modelResponse = await getCompletion(
prompt as unknown as CompletionCreateParams,
input.channel,
);
} catch (e) {
console.error(e);
throw e;
}
}
const modelOutput = await prisma.modelOutput.upsert({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
create: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
inputHash,
...modelResponse,
},
update: {
...modelResponse,
},
});
await reevaluateVariant(input.variantId);
return modelOutput;
}),
});

View File

@@ -1,7 +1,9 @@
import { isObject } from "lodash";
import dedent from "dedent";
import { isObject } from "lodash-es";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error";
@@ -30,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
const evalResults = await prisma.evaluationResult.findMany({
where: {
promptVariantId: input.variantId,
const outputEvals = await prisma.outputEvaluation.groupBy({
by: ["evaluationId"],
_sum: {
result: true,
},
include: { evaluation: true },
_count: {
id: true,
},
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
});
const scenarioCount = await prisma.testScenario.count({
@@ -43,17 +77,24 @@ export const promptVariantsRouter = createTRPCRouter({
visible: true,
},
});
const outputCount = await prisma.modelOutput.count({
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
},
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: {
visible: true,
},
},
},
_sum: {
promptTokens: true,
@@ -68,7 +109,26 @@ export const promptVariantsRouter = createTRPCRouter({
const overallCost = overallPromptCost + overallCompletionCost;
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
// Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
},
},
}));
return {
evalResults,
promptTokens,
completionTokens,
overallCost,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}),
create: publicProcedure
@@ -105,7 +165,18 @@ export const promptVariantsRouter = createTRPCRouter({
experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
constructFn: lastVariant?.constructFn ?? "",
constructFn:
lastVariant?.constructFn ??
dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`,
model: lastVariant?.model ?? "gpt-3.5-turbo",
},
});
@@ -115,6 +186,17 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId),
]);
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return newVariant;
}),
@@ -234,6 +316,17 @@ export const promptVariantsRouter = createTRPCRouter({
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: newVariant.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return { status: "ok" } as const;
}),

View File

@@ -0,0 +1,78 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.query(async ({ input }) => {
return await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: {
include: {
outputEvaluation: {
include: {
evaluation: {
select: { label: true },
},
},
},
},
},
},
});
}),
forceRefetch: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.mutation(async ({ input }) => {
const cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (!cell) {
await generateNewCell(input.variantId, input.scenarioId);
return true;
}
if (cell.modelOutput) {
// TODO: Maybe keep these around to show previous generations?
await prisma.modelOutput.delete({
where: { id: cell.modelOutput.id },
});
}
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "PENDING" },
});
await queueLLMRetrievalTask(cell.id);
return true;
}),
});

View File

@@ -3,7 +3,8 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reevaluateAll } from "~/server/utils/evaluations";
import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -48,10 +49,21 @@ export const scenariosRouter = createTRPCRouter({
},
});
await prisma.$transaction([
const [scenario] = await prisma.$transaction([
createNewScenarioAction,
recordExperimentUpdated(input.experimentId),
]);
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, scenario.id);
}
}),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
@@ -61,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
});
// Reevaluate all evaluations now that this scenario is hidden
await reevaluateAll(hiddenScenario.experimentId);
await runAllEvals(hiddenScenario.experimentId);
return hiddenScenario;
}),
@@ -175,6 +187,17 @@ export const scenariosRouter = createTRPCRouter({
},
});
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: newScenario.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, newScenario.id);
}
return newScenario;
}),
});

View File

@@ -0,0 +1,47 @@
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
async function migrateScenarioVariantOutputData() {
// Get all ScenarioVariantCells
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
console.log(`Found ${cells.length} records`);
let updatedCount = 0;
// Loop through all scenarioVariants
for (const cell of cells) {
// Create a new ModelOutput for each ScenarioVariant with an existing output
if (cell.output && !cell.modelOutput) {
updatedCount++;
await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash: cell.inputHash || "",
output: cell.output as Prisma.InputJsonValue,
timeToComplete: cell.timeToComplete ?? undefined,
promptTokens: cell.promptTokens,
completionTokens: cell.completionTokens,
createdAt: cell.createdAt,
updatedAt: cell.updatedAt,
},
});
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
updatedCount++;
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: {
retrievalStatus: "ERROR",
},
});
}
}
console.log("Data migration completed");
console.log(`Updated ${updatedCount} records`);
}
// Execute the function
migrateScenarioVariantOutputData().catch((error) => {
console.error("An error occurred while migrating data: ", error);
process.exit(1);
});

View File

@@ -0,0 +1,31 @@
// Import necessary dependencies
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
import { env } from "~/env.mjs";
// Define the defineTask function
function defineTask<TPayload>(
taskIdentifier: string,
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
) {
const enqueue = async (payload: TPayload) => {
console.log("Enqueuing task", taskIdentifier, payload);
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
};
const handler = (payload: TPayload, helpers: Helpers) => {
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
return taskHandler(payload, helpers);
};
const task = {
identifier: taskIdentifier,
handler: handler as Task,
};
return {
enqueue,
task,
};
}
export default defineTask;

View File

@@ -0,0 +1,154 @@
import crypto from "crypto";
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel";
import { runEvalsForOutput } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
import { type Prisma } from "@prisma/client";
const MAX_AUTO_RETRIES = 10;
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 15000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}
const getCompletionWithRetries = async (
cellId: string,
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> => {
let modelResponse: CompletionResponse | null = null;
try {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse;
}
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
errorMessage: "Rate limit exceeded",
statusCode: 429,
retryTime: new Date(Date.now() + delay),
},
});
// TODO: Maybe requeue the job so other jobs can run in the future?
await sleep(delay);
}
throw new Error("Max retries limit reached");
} catch (error: unknown) {
return {
statusCode: modelResponse?.statusCode ?? 500,
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
output: null as unknown as Prisma.InputJsonValue,
timeToComplete: 0,
};
}
};
export type queryLLMJob = {
scenarioVariantCellId: string;
};
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const { scenarioVariantCellId } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: scenarioVariantCellId },
include: { modelOutput: true },
});
if (!cell) {
return;
}
// If cell is not pending, then some other job is already processing it
if (cell.retrievalStatus !== "PENDING") {
return;
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
retrievalStatus: "IN_PROGRESS",
},
});
const variant = await prisma.promptVariant.findUnique({
where: { id: cell.promptVariantId },
});
if (!variant) {
return;
}
const scenario = await prisma.testScenario.findUnique({
where: { id: cell.testScenarioId },
});
if (!scenario) {
return;
}
const prompt = await constructPrompt(variant, scenario.variableValues);
const streamingEnabled = shouldStream(prompt);
let streamingChannel;
if (streamingEnabled) {
streamingChannel = generateChannel();
// Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
streamingChannel,
},
});
}
const modelResponse = await getCompletionWithRetries(
scenarioVariantCellId,
prompt,
streamingChannel,
);
let modelOutput = null;
if (modelResponse.statusCode === 200) {
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
modelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId,
inputHash,
output: modelResponse.output,
timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens,
},
});
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: modelResponse.statusCode,
errorMessage: modelResponse.errorMessage,
streamingChannel: null,
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
modelOutput: {
connect: {
id: modelOutput?.id,
},
},
},
});
if (modelOutput) {
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
}
});

View File

@@ -0,0 +1,40 @@
import { type TaskList, run } from "graphile-worker";
import "dotenv/config";
import { env } from "~/env.mjs";
import { queryLLM } from "./queryLLM.task";
const registeredTasks = [queryLLM];
const taskList = registeredTasks.reduce((acc, task) => {
acc[task.task.identifier] = task.task.handler;
return acc;
}, {} as TaskList);
async function main() {
// Run a worker to execute jobs:
const runner = await run({
connectionString: env.DATABASE_URL,
concurrency: 20,
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
noHandleSignals: false,
pollInterval: 1000,
// you can set the taskList or taskDirectory but not both
taskList,
// or:
// taskDirectory: `${__dirname}/tasks`,
});
// Immediately await (or otherwise handled) the resulting promise, to avoid
// "unhandled rejection" errors causing a process crash in the event of
// something going wrong.
await runner.promise;
// If the worker exits (whether through fatal error or otherwise), the above
// promise will resolve/reject.
}
main().catch((err) => {
console.error("Unhandled error occurred running worker: ", err);
process.exit(1);
});

View File

@@ -9,7 +9,7 @@ export async function constructPrompt(
scenario: TestScenario["variableValues"],
): Promise<JSONSerializable> {
const code = `
const scenario = ${JSON.stringify(scenario, null, 2)};
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
let prompt
${variant.constructFn}

View File

@@ -1,31 +0,0 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate";
export const evaluateOutput = (
modelOutput: ModelOutput,
scenario: TestScenario,
evaluation: Evaluation,
): boolean => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return false;
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
let match;
switch (evaluation.matchType) {
case "CONTAINS":
match = stringifiedMessage.match(matchRegex) !== null;
break;
case "DOES_NOT_CONTAIN":
match = stringifiedMessage.match(matchRegex) === null;
break;
}
return match;
};

View File

@@ -1,103 +1,79 @@
import { type Evaluation } from "@prisma/client";
import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput";
import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
export const reevaluateVariant = async (variantId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: { id: variantId },
});
if (!variant) return;
const evaluations = await prisma.evaluation.findMany({
where: { experimentId: variant.experimentId },
});
const modelOutputs = await prisma.modelOutput.findMany({
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
const result = await runOneEval(evaluation, scenario, modelOutput);
return await prisma.outputEvaluation.upsert({
where: {
promptVariantId: variantId,
statusCode: { notIn: [429] },
testScenario: { visible: true },
modelOutputId_evaluationId: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id,
},
},
create: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id,
...result,
},
update: {
...result,
},
include: { testScenario: true },
});
await Promise.all(
evaluations.map(async (evaluation) => {
const passCount = modelOutputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation),
).length;
const failCount = modelOutputs.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variantId,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variantId,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
};
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
const variants = await prisma.promptVariant.findMany({
where: { experimentId: evaluation.experimentId, visible: true },
});
const modelOutputs = await prisma.modelOutput.findMany({
where: {
promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true },
statusCode: { notIn: [429] },
},
include: { testScenario: true },
});
await Promise.all(
variants.map(async (variant) => {
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
const passCount = outputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation),
).length;
const failCount = outputs.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
};
export const reevaluateAll = async (experimentId: string) => {
export const runEvalsForOutput = async (
experimentId: string,
scenario: Scenario,
modelOutput: ModelOutput,
) => {
const evaluations = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(evaluations.map(reevaluateEvaluation));
await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
);
};
export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelOutput.findMany({
where: {
scenarioVariantCell: {
promptVariant: {
experimentId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
include: {
scenarioVariantCell: {
include: {
testScenario: true,
},
},
outputEvaluation: true,
},
});
const evals = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
outputs.map(async (output) => {
const unrunEvals = evals.filter(
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
);
await Promise.all(
unrunEvals.map(async (evaluation) => {
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
}),
);
}),
);
};

View File

@@ -2,6 +2,16 @@ import { type JSONSerializable } from "../types";
export type VariableMap = Record<string, string>;
// Escape quotes to match the way we encode JSON
export function escapeQuotes(str: string) {
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
}
// Escape regex special characters
export function escapeRegExp(str: string) {
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
}
export function fillTemplate(template: string, variables: VariableMap): string {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
}

View File

@@ -0,0 +1,76 @@
import crypto from "crypto";
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import { constructPrompt } from "./constructPrompt";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (cell) return cell;
cell = await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
include: {
modelOutput: true,
},
});
const matchingModelOutput = await prisma.modelOutput.findFirst({
where: {
inputHash,
},
});
let newModelOutput;
if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
},
});
} else {
cell = await queueLLMRetrievalTask(cell.id);
}
return { ...cell, modelOutput: newModelOutput };
};

View File

@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash";
import { isObject } from "lodash-es";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
@@ -9,7 +9,7 @@ import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings";
type CompletionResponse = {
export type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number;
errorMessage: string | null;

View File

@@ -1,4 +1,4 @@
import { omit } from "lodash";
import { omit } from "lodash-es";
import { env } from "~/env.mjs";
import OpenAI from "openai";

View File

@@ -0,0 +1,22 @@
import { prisma } from "../db";
import { queryLLM } from "../tasks/queryLLM.task";
export const queueLLMRetrievalTask = async (cellId: string) => {
const updatedCell = await prisma.scenarioVariantCell.update({
where: {
id: cellId,
},
data: {
retrievalStatus: "PENDING",
errorMessage: null,
},
include: {
modelOutput: true,
},
});
// @ts-expect-error we aren't passing the helpers but that's ok
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
return updatedCell;
};

View File

@@ -0,0 +1,95 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai";
import dedent from "dedent";
export const runGpt4Eval = async (
evaluation: Evaluation,
scenario: TestScenario,
message: ChatCompletion.Choice.Message,
): Promise<{ result: number; details: string }> => {
const output = await openai.chat.completions.create({
model: "gpt-4-0613",
messages: [
{
role: "system",
content: dedent`
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
---
${evaluation.value}
`,
},
{
role: "user",
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
},
{
role: "user",
content: `The full output of the simpler message:\n---\n${JSON.stringify(
message.content ?? message.function_call,
null,
2,
)}`,
},
],
function_call: {
name: "report_success",
},
functions: [
{
name: "report_success",
parameters: {
type: "object",
required: ["thoughts", "success"],
properties: {
thoughts: {
type: "string",
description: "Explain your reasoning for considering this a pass or fail",
},
success: {
type: "boolean",
description:
"Whether the simpler model successfully completed the task for this scenario",
},
},
},
},
],
});
try {
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
} catch (e) {
console.error(e);
return { result: 0, details: "Error parsing GPT-4 output" };
}
};
export const runOneEval = async (
evaluation: Evaluation,
scenario: TestScenario,
modelOutput: ModelOutput,
): Promise<{ result: number; details?: string }> => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return { result: 0 };
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = escapeRegExp(
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
);
switch (evaluation.evalType) {
case "CONTAINS":
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
case "DOES_NOT_CONTAIN":
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
case "GPT4_EVAL":
return await runGpt4Eval(evaluation, scenario, message);
}
};

View File

@@ -0,0 +1,7 @@
import { isObject } from "lodash-es";
import { type JSONSerializable } from "../types";
export const shouldStream = (config: JSONSerializable): boolean => {
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
return shouldStream;
};

View File

@@ -0,0 +1 @@
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

View File

@@ -33,6 +33,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
allowNonTsExtensions: true,
strictNullChecks: true,
lib: ["esnext"],
});
@@ -84,7 +85,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
)} as const;
type Scenario = typeof scenarios[number];
declare var scenario: Scenario | null;
declare var scenario: Scenario | { [key: string]: string };
`;
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));

View File

@@ -1,5 +1,5 @@
import { useRouter } from "next/router";
import { useCallback, useEffect, useState } from "react";
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
import { api } from "~/utils/api";
export const useExperiment = () => {
@@ -49,3 +49,43 @@ export const useModifierKeyLabel = () => {
}, []);
return label;
};
interface Dimensions {
left: number;
top: number;
right: number;
bottom: number;
width: number;
height: number;
x: number;
y: number;
}
// get dimensions of an element
export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | undefined] => {
const ref = useRef<HTMLElement>(null);
const [dimensions, setDimensions] = useState<Dimensions | undefined>();
useEffect(() => {
if (ref.current) {
const observer = new ResizeObserver((entries) => {
entries.forEach((entry) => {
setDimensions(entry.contentRect);
});
});
const observedRef = ref.current;
observer.observe(observedRef);
// Cleanup the observer on component unmount
return () => {
if (observedRef) {
observer.unobserve(observedRef);
}
};
}
}, []);
return [ref, dimensions];
};

View File

@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
const url = env.NEXT_PUBLIC_SOCKET_URL;
export default function useSocket(channel?: string) {
export default function useSocket(channel?: string | null) {
const socketRef = useRef<Socket>();
const [message, setMessage] = useState<ChatCompletion | null>(null);