Compare commits

..

39 Commits

Author SHA1 Message Date
Kyle Corbitt
60765e51ac Remove model from promptVariant and add cost
Storing the model on promptVariant is problematic because it isn't always in sync with the actual prompt definition. I'm removing it for now to see if we can get away with that -- might have to add it back in later if this causes trouble.

Added `cost` to modelOutput as well so we can cache that, which is important given that the cost calculations won't be the same between different API providers.
2023-07-19 16:20:53 -07:00
arcticfly
4c97b9f147 Refine prompt (#63)
* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier

* Remove migration script

* Add refine modal

* Fix prettier

* Fix diff checker overflow

* Decrease diff height
2023-07-19 15:31:40 -07:00
arcticfly
58892d8b63 Remove unused fields, refine model translation (#62)
* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier
2023-07-19 13:59:11 -07:00
Kyle Corbitt
4fa2dffbcb styling tweaks for SelectModelModal 2023-07-19 07:17:56 -07:00
Kyle Corbitt
654f8c7cf2 Merge pull request #61 from OpenPipe/experiment-page
More visual tweaks
2023-07-19 06:56:58 -07:00
Kyle Corbitt
d02482468d more visual tweaks 2023-07-19 06:54:07 -07:00
Kyle Corbitt
5c6ed22f1d Merge pull request #60 from OpenPipe/experiment-page
experiment page visual tweaks
2023-07-18 22:26:05 -07:00
Kyle Corbitt
2cb623f332 experiment page visual tweaks 2023-07-18 22:22:58 -07:00
Kyle Corbitt
1c1cefe286 Merge pull request #59 from OpenPipe/auth
User accounts
2023-07-18 21:21:46 -07:00
Kyle Corbitt
b4aa95edca sidebar mobile styles 2023-07-18 21:19:06 -07:00
Kyle Corbitt
1dcdba04a6 User accounts
Allows for the creation of user accounts. A few notes on the specifics:

 - Experiments are the main access control objects. If you can view an experiment, you can view all its prompts/scenarios/evals. If you can edit it, you can edit or delete all of those as well.
 - Experiments are owned by Organizations in the database. Organizations can have multiple members and members can have roles of ADMIN, MEMBER or VIEWER.
 - Organizations can either be "personal" or general. Each user has a "personal" organization created as soon as they try to create an experiment. There's currently no UI support for creating general orgs or adding users to them; they're just in the database to future-proof all the ACL logic.
 - You can require that a user is signed-in to see a route using the `protectedProcedure` helper. When you use `protectedProcedure`, you also have to call `ctx.markAccessControlRun()` (or delegate to a function that does it for you; see accessControl.ts). This is to remind us to actually check for access control when we define a new endpoint.
2023-07-18 21:19:03 -07:00
arcticfly
e0e64c4207 Allow user to create a version of their current prompt with a new model (#58)
* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier

* Use env variable to restrict prisma logs

* Fix env.mjs

* Remove unnecessary scroll bar from function call output

* Properly record when 404 error occurs in queryLLM task

* Add SelectedModelInfo in SelectModelModal

* Add react-select

* Calculate new prompt after switching model

* Send newly selected model with creation request

* Get new prompt construction function back from GPT-4

* Fix prettier

* Fix prettier
2023-07-18 18:24:04 -07:00
arcticfly
fa5b1ab1c5 Allow user to duplicate prompt (#57)
* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier
2023-07-18 13:49:33 -07:00
David Corbitt
999a4c08fa Fix lint and prettier 2023-07-18 11:11:20 -07:00
arcticfly
374d0237ee Escape characters in Regex evaluations, minor UI fixes (#56)
* Fix ScenariosHeader stickiness

* Move meta tag from _app.tsx to _document.tsx

* Show spinner when saving variant

* Escape quotes and regex in evaluations
2023-07-18 11:07:04 -07:00
David Corbitt
b1f873623d Invalidate prompt stats after cell refetch 2023-07-18 09:45:11 -07:00
arcticfly
4131aa67d0 Continue polling VariantStats while LLM retrieval in progress, minor UI fixes (#54)
* Prevent zoom in on iOS

* Expand function return code background to fill cell

* Keep OutputStats on far right of cells

* Continue polling prompt stats while cells are retrieving from LLM

* Add comment to _document.tsx

* Fix prettier
2023-07-17 18:04:38 -07:00
Kyle Corbitt
8e7a6d3ae2 Merge pull request #55 from OpenPipe/more-eval
Add GPT4 Evals
2023-07-17 18:01:47 -07:00
Kyle Corbitt
7d41e94ca2 cache eval outputs and add gpt4 eval 2023-07-17 17:55:36 -07:00
Kyle Corbitt
011b12abb9 cache output evals 2023-07-17 17:52:30 -07:00
Kyle Corbitt
1ba18015bc Merge pull request #53 from OpenPipe/more-eval
Fix seeds and update eval field names
2023-07-17 14:26:29 -07:00
Kyle Corbitt
54369dba54 Fix seeds and update eval field names 2023-07-17 14:14:20 -07:00
arcticfly
6b84a59372 Properly catch completion errors (#51) 2023-07-17 10:50:25 -07:00
Kyle Corbitt
8db8aeacd3 Replace function chrome with comment
Use a block comment to explain the expected prompt formatting instead of function chrome. The advantage here is that once a user builds a mental model of how OpenPipe works they can just delete the comment, instead of the function chrome sitting around and taking up space in the UI forever.
2023-07-17 10:30:22 -07:00
Kyle Corbitt
64bd71e370 Merge pull request #50 from OpenPipe/remove-default
remove the default value for PromptVariant.model
2023-07-14 17:55:38 -07:00
Kyle Corbitt
ca21a7af06 Run checks on main
This will (1) make sure that anything we push directly passes CI, and also (2) cache the pnpm store on the main branch, which will make it available to PR runs as well and hopefully speed up CI a bit (see https://stackoverflow.com/a/75250061).``
2023-07-14 17:49:20 -07:00
Kyle Corbitt
3b99b7bd2b remove the default value for PromptVariant.model
We should be explicit about setting the appropriate model so it always matches the constructFn.
2023-07-14 17:43:52 -07:00
Kyle Corbitt
0c3bdbe4f2 Merge pull request #49 from OpenPipe/save-button
Make save button disappear on save
2023-07-14 17:39:06 -07:00
Kyle Corbitt
74c201d3a8 Make save button disappear on save
Fixes a bug where often the "Save" button wouldn't disappear as expected the first time you clicked it.
2023-07-14 17:35:57 -07:00
David Corbitt
ab9c721d09 Revert change to scenarios header 2023-07-14 17:51:12 -06:00
David Corbitt
0a2578a1d8 Update scenarios header negative margin 2023-07-14 17:41:51 -06:00
David Corbitt
1bebaff386 Merge branch 'main' of github.com:corbt/prompt-lab 2023-07-14 16:55:12 -06:00
David Corbitt
3bf5eaf4a2 Properly extract scenario id in new experiment creation 2023-07-14 16:55:09 -06:00
Kyle Corbitt
ded97f8bb9 fix lockfile 2023-07-14 15:55:01 -07:00
Kyle Corbitt
26ee8698be Make it so you can't delete the last prompt or scenario
No reason for an experiment to have 0 prompts or 0 scenarios and it makes the UI look bad.
2023-07-14 15:49:42 -07:00
arcticfly
b98eb9b729 Trigger llm output retrieval on server (#39)
* Rename tables, add graphile workers, update types

* Add dev:worker command

* Update pnpm-lock.yaml

* Remove sentry config import from worker.ts

* Stop generating new cells in cell router get query

* Generate new cells for new scenarios, variants, and experiments

* Remove most error throwing from queryLLM.task.ts

* Remove promptVariantId and testScenarioId from ModelOutput

* Remove duplicate index from ModelOutput

* Move inputHash from cell to output

* Add TODO

* Add todo

* Show cost and time for each cell

* Always show output stats if there is output

* Trigger LLM outputs when scenario variables are updated

* Add newlines to ends of files

* Add another newline

* Cascade ModelOutput deletion

* Fix linting and prettier

* Return instead of throwing for non-pending cell

* Remove pnpm dev:worker from pnpm:dev

* Update pnpm-lock.yaml
2023-07-14 16:38:46 -06:00
Kyle Corbitt
032c07ec65 Merge pull request #45 from OpenPipe/node-version
warn folks if they use a lower node version
2023-07-14 15:03:49 -07:00
Kyle Corbitt
80c0d13bb9 warn folks if they use a lower node version 2023-07-14 14:59:33 -07:00
Kyle Corbitt
f7c94be3f6 Merge pull request #44 from OpenPipe/strip-types
Strip types from prompt variants
2023-07-14 14:07:07 -07:00
95 changed files with 4134 additions and 2646 deletions

View File

@@ -18,3 +18,11 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
OPENAI_API_KEY=""
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
# Next Auth
NEXTAUTH_SECRET="your_secret"
NEXTAUTH_URL="http://localhost:3000"
# Next Auth Github Provider
GITHUB_CLIENT_ID="your_client_id"
GITHUB_CLIENT_SECRET="your_secret"

View File

@@ -3,6 +3,8 @@ name: CI checks
on:
pull_request:
branches: [main]
push:
branches: [main]
jobs:
run-checks:

1
.tool-versions Normal file
View File

@@ -0,0 +1 @@
nodejs 20.2.0

View File

@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
} from "next";
export type Route =
| StaticRoute<"/account/signin">
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
| DynamicRoute<"/experiments/[id]", { "id": string }>

View File

@@ -4,11 +4,18 @@
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
**Live Demo:** https://openpipe.ai
## Sample Experiments
These are simple experiments users have created that show how OpenPipe works.
- [Country Capitals](https://openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
- [Reddit User Needs](https://openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
- [OpenAI Function Calls](https://openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
- [Activity Classification](https://openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally).
You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
## High-Level Features
@@ -47,5 +54,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
5. Install the dependencies: `cd openpipe && pnpm install`
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
8. Start the app: `pnpm dev`.
9. Navigate to [http://localhost:3000](http://localhost:3000)
8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
9. Start the app: `pnpm dev`.
10. Navigate to [http://localhost:3000](http://localhost:3000)

View File

@@ -3,15 +3,22 @@
"type": "module",
"version": "0.1.0",
"license": "Apache-2.0",
"engines": {
"node": ">=20.0.0",
"pnpm": ">=8.6.1"
},
"scripts": {
"build": "next build",
"dev:next": "next dev",
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
"postinstall": "prisma generate",
"lint": "next lint",
"start": "next start",
"codegen": "tsx src/codegen/export-openai-types.ts"
"codegen": "tsx src/codegen/export-openai-types.ts",
"seed": "tsx prisma/seed.ts",
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
},
"dependencies": {
"@babel/preset-typescript": "^7.22.5",
@@ -21,6 +28,7 @@
"@emotion/react": "^11.11.1",
"@emotion/server": "^11.11.0",
"@emotion/styled": "^11.11.0",
"@fontsource/inconsolata": "^5.0.5",
"@monaco-editor/loader": "^1.3.3",
"@next-auth/prisma-adapter": "^1.0.5",
"@prisma/client": "^4.14.0",
@@ -40,10 +48,11 @@
"express": "^4.18.2",
"framer-motion": "^10.12.17",
"gpt-tokens": "^1.0.10",
"graphile-worker": "^0.13.0",
"immer": "^10.0.2",
"isolated-vm": "^4.5.0",
"json-stringify-pretty-compact": "^4.0.0",
"lodash": "^4.17.21",
"lodash-es": "^4.17.21",
"next": "^13.4.2",
"next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1",
@@ -51,9 +60,12 @@
"pluralize": "^8.0.0",
"posthog-js": "^1.68.4",
"prettier": "^3.0.0",
"prismjs": "^1.29.0",
"react": "18.2.0",
"react-diff-viewer": "^3.1.1",
"react-dom": "18.2.0",
"react-icons": "^4.10.1",
"react-select": "^5.7.4",
"react-syntax-highlighter": "^15.5.0",
"react-textarea-autosize": "^8.5.0",
"socket.io": "^4.7.1",
@@ -71,9 +83,10 @@
"@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0",
"@types/express": "^4.17.17",
"@types/lodash": "^4.14.195",
"@types/lodash-es": "^4.17.8",
"@types/node": "^18.16.0",
"@types/pluralize": "^0.0.30",
"@types/prismjs": "^1.26.0",
"@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7",
@@ -94,6 +107,6 @@
"initVersion": "7.14.0"
},
"prisma": {
"seed": "tsx prisma/seed.ts"
"seed": "pnpm seed"
}
}

973
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
-- Drop the foreign key constraints on the original ModelOutput
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
-- Rename the old table
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
-- Add the new fields to the renamed table
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
ALTER COLUMN "statusCode" DROP NOT NULL,
ALTER COLUMN "timeToComplete" DROP NOT NULL;
-- Create the new table
CREATE TABLE "ModelOutput" (
"id" UUID NOT NULL,
"inputHash" TEXT NOT NULL,
"output" JSONB NOT NULL,
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
"promptTokens" INTEGER,
"completionTokens" INTEGER,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"scenarioVariantCellId" UUID
);
-- Move inputHash index
DROP INDEX "ScenarioVariantCell_inputHash_idx";
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- CreateEnum
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
-- AlterTable
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;

View File

@@ -0,0 +1,24 @@
/*
Warnings:
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
*/
-- RenameEnum
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
-- AlterTable
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
-- AlterColumnType
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
-- SetNotNullConstraint
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;

View File

@@ -0,0 +1,39 @@
/*
Warnings:
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
*/
-- AlterEnum
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
-- DropTable
DROP TABLE "EvaluationResult";
-- CreateTable
CREATE TABLE "OutputEvaluation" (
"id" UUID NOT NULL,
"result" DOUBLE PRECISION NOT NULL,
"details" TEXT,
"modelOutputId" UUID NOT NULL,
"evaluationId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,124 @@
DROP TABLE "Account";
DROP TABLE "Session";
DROP TABLE "User";
DROP TABLE "VerificationToken";
CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
-- CreateTable
CREATE TABLE "Organization" (
"id" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"personalOrgUserId" UUID,
CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OrganizationUser" (
"id" UUID NOT NULL,
"role" "OrganizationUserRole" NOT NULL,
"organizationId" UUID NOT NULL,
"userId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Account" (
"id" UUID NOT NULL,
"userId" UUID NOT NULL,
"type" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"providerAccountId" TEXT NOT NULL,
"refresh_token" TEXT,
"refresh_token_expires_in" INTEGER,
"access_token" TEXT,
"expires_at" INTEGER,
"token_type" TEXT,
"scope" TEXT,
"id_token" TEXT,
"session_state" TEXT,
CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Session" (
"id" UUID NOT NULL,
"sessionToken" TEXT NOT NULL,
"userId" UUID NOT NULL,
"expires" TIMESTAMP(3) NOT NULL,
CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "User" (
"id" UUID NOT NULL,
"name" TEXT,
"email" TEXT,
"emailVerified" TIMESTAMP(3),
"image" TEXT,
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "VerificationToken" (
"identifier" TEXT NOT NULL,
"token" TEXT NOT NULL,
"expires" TIMESTAMP(3) NOT NULL
);
INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
-- AlterTable add organizationId as NULLABLE
ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
-- Set default organization for existing experiments
UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
-- AlterTable set organizationId as NOT NULL
ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
-- CreateIndex
CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
-- CreateIndex
CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
-- CreateIndex
CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
-- AddForeignKey
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,16 @@
/*
Warnings:
- You are about to drop the column `completionTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `inputHash` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `output` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `promptTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `timeToComplete` on the `ScenarioVariantCell` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "completionTokens",
DROP COLUMN "inputHash",
DROP COLUMN "output",
DROP COLUMN "promptTokens",
DROP COLUMN "timeToComplete";

View File

@@ -0,0 +1,8 @@
/*
Warnings:
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;

View File

@@ -2,8 +2,7 @@
// learn more about it in the docs: https://pris.ly/d/prisma-schema
generator client {
provider = "prisma-client-js"
previewFeatures = ["jsonProtocol"]
provider = "prisma-client-js"
}
datasource db {
@@ -17,8 +16,12 @@ model Experiment {
sortIndex Int @default(0)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
organizationId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
TemplateVariable TemplateVariable[]
PromptVariant PromptVariant[]
TestScenario TestScenario[]
@@ -30,7 +33,7 @@ model PromptVariant {
label String
constructFn String
model String @default("gpt-3.5-turbo")
model String
uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true)
@@ -39,10 +42,9 @@ model PromptVariant {
experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
ModelOutput ModelOutput[]
EvaluationResult EvaluationResult[]
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[]
@@index([uiId])
}
@@ -59,9 +61,9 @@ model TestScenario {
experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
ModelOutput ModelOutput[]
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[]
}
model TemplateVariable {
@@ -76,17 +78,23 @@ model TemplateVariable {
updatedAt DateTime @updatedAt
}
model ModelOutput {
enum CellRetrievalStatus {
PENDING
IN_PROGRESS
COMPLETE
ERROR
}
model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
statusCode Int
errorMessage String?
timeToComplete Int @default(0)
statusCode Int?
errorMessage String?
retryTime DateTime?
streamingChannel String?
retrievalStatus CellRetrievalStatus @default(COMPLETE)
promptTokens Int? // Added promptTokens field
completionTokens Int? // Added completionTokens field
modelOutput ModelOutput?
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -98,82 +106,140 @@ model ModelOutput {
updatedAt DateTime @updatedAt
@@unique([promptVariantId, testScenarioId])
}
model ModelOutput {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
timeToComplete Int @default(0)
cost Float?
promptTokens Int?
completionTokens Int?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluation OutputEvaluation[]
@@unique([scenarioVariantCellId])
@@index([inputHash])
}
enum EvaluationMatchType {
enum EvalType {
CONTAINS
DOES_NOT_CONTAIN
GPT4_EVAL
}
model Evaluation {
id String @id @default(uuid()) @db.Uuid
name String
matchString String
matchType EvaluationMatchType
label String
evalType EvalType
value String
experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
EvaluationResult EvaluationResult[]
OutputEvaluation OutputEvaluation[]
}
model EvaluationResult {
model OutputEvaluation {
id String @id @default(uuid()) @db.Uuid
passCount Int
failCount Int
// Number between 0 (fail) and 1 (pass)
result Float
details String?
modelOutputId String @db.Uuid
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([modelOutputId, evaluationId])
}
model Organization {
id String @id @default(uuid()) @db.Uuid
personalOrgUserId String? @unique @db.Uuid
PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
OrganizationUser OrganizationUser[]
Experiment Experiment[]
}
enum OrganizationUserRole {
ADMIN
MEMBER
VIEWER
}
model OrganizationUser {
id String @id @default(uuid()) @db.Uuid
role OrganizationUserRole
organizationId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
userId String @db.Uuid
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([evaluationId, promptVariantId])
@@unique([organizationId, userId])
}
// Necessary for Next auth
model Account {
id String @id @default(cuid())
userId String
type String
provider String
providerAccountId String
refresh_token String? // @db.Text
access_token String? // @db.Text
expires_at Int?
token_type String?
scope String?
id_token String? // @db.Text
session_state String?
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
id String @id @default(uuid()) @db.Uuid
userId String @db.Uuid
type String
provider String
providerAccountId String
refresh_token String? @db.Text
refresh_token_expires_in Int?
access_token String? @db.Text
expires_at Int?
token_type String?
scope String?
id_token String? @db.Text
session_state String?
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([provider, providerAccountId])
}
model Session {
id String @id @default(cuid())
id String @id @default(uuid()) @db.Uuid
sessionToken String @unique
userId String
userId String @db.Uuid
expires DateTime
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
model User {
id String @id @default(cuid())
name String?
email String? @unique
emailVerified DateTime?
image String?
accounts Account[]
sessions Session[]
id String @id @default(uuid()) @db.Uuid
name String?
email String? @unique
emailVerified DateTime?
image String?
accounts Account[]
sessions Session[]
OrganizationUser OrganizationUser[]
Organization Organization[]
}
model VerificationToken {

View File

@@ -1,76 +1,93 @@
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
const experimentId = "11111111-1111-1111-1111-111111111111";
const defaultId = "11111111-1111-1111-1111-111111111111";
await prisma.organization.deleteMany({
where: { id: defaultId },
});
await prisma.organization.create({
data: { id: defaultId },
});
// Delete the existing experiment
await prisma.experiment.deleteMany({
where: {
id: experimentId,
id: defaultId,
},
});
const experiment = await prisma.experiment.create({
await prisma.experiment.create({
data: {
id: experimentId,
id: defaultId,
label: "Country Capitals Example",
organizationId: defaultId,
},
});
await prisma.modelOutput.deleteMany({
await prisma.scenarioVariantCell.deleteMany({
where: {
promptVariant: {
experimentId,
experimentId: defaultId,
},
},
});
await prisma.promptVariant.deleteMany({
where: {
experimentId,
experimentId: defaultId,
},
});
await prisma.promptVariant.createMany({
data: [
{
experimentId,
experimentId: defaultId,
label: "Prompt Variant 1",
sortIndex: 0,
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
temperature: 0,
}`,
model: "gpt-3.5-turbo-0613",
constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: \`What is the capital of ${"$"}{scenario.country}?\`
}
],
temperature: 0,
}`,
},
{
experimentId,
experimentId: defaultId,
label: "Prompt Variant 2",
sortIndex: 1,
constructFn: `prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content:
"What is the capital of {{country}}? Return just the city name and nothing else.",
},
],
temperature: 0,
}`,
model: "gpt-3.5-turbo-0613",
constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
}
],
temperature: 0,
}`,
},
],
});
await prisma.templateVariable.deleteMany({
where: {
experimentId,
experimentId: defaultId,
},
});
await prisma.templateVariable.createMany({
data: [
{
experimentId,
experimentId: defaultId,
label: "country",
},
],
@@ -78,28 +95,28 @@ await prisma.templateVariable.createMany({
await prisma.testScenario.deleteMany({
where: {
experimentId,
experimentId: defaultId,
},
});
await prisma.testScenario.createMany({
data: [
{
experimentId,
experimentId: defaultId,
sortIndex: 0,
variableValues: {
country: "Spain",
},
},
{
experimentId,
experimentId: defaultId,
sortIndex: 1,
variableValues: {
country: "USA",
},
},
{
experimentId,
experimentId: defaultId,
sortIndex: 2,
variableValues: {
country: "Chile",
@@ -107,3 +124,26 @@ await prisma.testScenario.createMany({
},
],
});
const variants = await prisma.promptVariant.findMany({
where: {
experimentId: defaultId,
},
});
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: defaultId,
},
});
await Promise.all(
variants
.flatMap((variant) =>
scenarios.map((scenario) => ({
promptVariantId: variant.id,
testScenarioId: scenario.id,
})),
)
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
);

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@ services:
dockerContext: .
plan: standard
domains:
- openpipe.ai
- app.openpipe.ai
envVars:
- key: NODE_ENV
value: production

View File

@@ -2,7 +2,7 @@ import fs from "fs";
import path from "path";
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
import YAML from "yaml";
import _ from "lodash";
import { pick } from "lodash-es";
import assert from "assert";
const OPENAPI_URL =
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
delete schema["paths"];
assert(schema.components?.schemas);
schema.components.schemas = _.pick(schema.components?.schemas, [
schema.components.schemas = pick(schema.components?.schemas, [
"CreateChatCompletionRequest",
"ChatCompletionRequestMessage",
"ChatCompletionFunctions",

View File

@@ -4,7 +4,7 @@ import React from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement,
TextareaProps
TextareaProps & { minRows?: number }
> = (props, ref) => {
return (
<Textarea

View File

@@ -11,14 +11,16 @@ import {
FormLabel,
Select,
FormHelperText,
Code,
} from "@chakra-ui/react";
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
import { type Evaluation, EvalType } from "@prisma/client";
import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import AutoResizeTextArea from "../AutoResizeTextArea";
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
export function EvaluationEditor(props: {
evaluation: Evaluation | null;
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
onCancel: () => void;
}) {
const [values, setValues] = useState<EvalValues>({
name: props.evaluation?.name ?? props.defaultName ?? "",
matchString: props.evaluation?.matchString ?? "",
matchType: props.evaluation?.matchType ?? "CONTAINS",
label: props.evaluation?.label ?? props.defaultName ?? "",
value: props.evaluation?.value ?? "",
evalType: props.evaluation?.evalType ?? "CONTAINS",
});
return (
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
<HStack w="100%">
<FormControl flex={1}>
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
<FormLabel fontSize="sm">Eval Name</FormLabel>
<Input
size="sm"
value={values.name}
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
value={values.label}
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
/>
</FormControl>
<FormControl flex={1}>
<FormLabel fontSize="sm">Match Type</FormLabel>
<FormLabel fontSize="sm">Eval Type</FormLabel>
<Select
size="sm"
value={values.matchType}
value={values.evalType}
onChange={(e) =>
setValues((values) => ({
...values,
matchType: e.target.value as EvaluationMatchType,
evalType: e.target.value as EvalType,
}))
}
>
{Object.values(EvaluationMatchType).map((type) => (
{Object.values(EvalType).map((type) => (
<option key={type} value={type}>
{type}
</option>
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
</Select>
</FormControl>
</HStack>
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<Input
size="sm"
value={values.matchString}
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
/>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
</FormControl>
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<Input
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
/>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output. You
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
</FormHelperText>
</FormControl>
)}
{values.evalType === "GPT4_EVAL" && (
<FormControl pt={2}>
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
<AutoResizeTextArea
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
minRows={3}
/>
<FormHelperText>
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
access to the specific prompt variant, so be sure to be clear about the task you want it
to perform.
</FormHelperText>
</FormControl>
)}
<HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
Cancel
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
}
await utils.evaluations.list.invalidate();
await utils.promptVariants.stats.invalidate();
await utils.scenarioVariantCells.get.invalidate();
}, []);
const onCancel = useCallback(() => {
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
align="center"
key={evaluation.id}
>
<Text fontWeight="bold">{evaluation.name}</Text>
<Text fontWeight="bold">{evaluation.label}</Text>
<Text flex={1}>
{evaluation.matchType}: &quot;{evaluation.matchString}&quot;
{evaluation.evalType}: &quot;{evaluation.value}&quot;
</Text>
<Button
variant="unstyled"

View File

@@ -1,4 +1,4 @@
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react";
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
import { useState } from "react";
import { BsCheck, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
<Heading size="sm">Scenario Variables</Heading>
<Stack spacing={2}>
<Text fontSize="sm">
Scenario variables can be used in your prompt variants as well as evaluations. Reference
them using <Code>{"{{curly_braces}}"}</Code>.
Scenario variables can be used in your prompt variants as well as evaluations.
</Text>
<HStack spacing={0}>
<Input

View File

@@ -1,7 +1,7 @@
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
// Extracted Button styling into reusable component
const StyledButton = ({ children, onClick }: ButtonProps) => (
@@ -17,6 +17,8 @@ const StyledButton = ({ children, onClick }: ButtonProps) => (
);
export default function NewScenarioButton() {
const { canModify } = useExperimentAccess();
const experiment = useExperiment();
const mutation = api.scenarios.create.useMutation();
const utils = api.useContext();
@@ -38,6 +40,8 @@ export default function NewScenarioButton() {
await utils.scenarios.list.invalidate();
}, [mutation]);
if (!canModify) return null;
return (
<HStack spacing={2}>
<StyledButton onClick={onClick}>

View File

@@ -1,7 +1,7 @@
import { Button, Icon, Spinner } from "@chakra-ui/react";
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding, headerMinHeight } from "../constants";
export default function NewVariantButton() {
@@ -17,6 +17,9 @@ export default function NewVariantButton() {
await utils.promptVariants.list.invalidate();
}, [mutation]);
const { canModify } = useExperimentAccess();
if (!canModify) return <Box w={cellPadding.x} />;
return (
<Button
w="100%"
@@ -31,7 +34,7 @@ export default function NewVariantButton() {
minH={headerMinHeight}
>
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
Add Variant
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
</Button>
);
}

View File

@@ -0,0 +1,37 @@
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs";
import { useExperimentAccess } from "~/utils/hooks";
export const CellOptions = ({
refetchingOutput,
refetchOutput,
}: {
refetchingOutput: boolean;
refetchOutput: () => void;
}) => {
const { canModify } = useExperimentAccess();
return (
<HStack justifyContent="flex-end" w="full">
{!refetchingOutput && canModify && (
<Tooltip label="Refetch output" aria-label="refetch output">
<Button
size="xs"
w={4}
h={4}
py={4}
px={4}
minW={0}
borderRadius={8}
color="gray.500"
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={4} />
</Button>
</Tooltip>
)}
</HStack>
);
};

View File

@@ -1,29 +1,21 @@
import { type ModelOutput } from "@prisma/client";
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
import { type ScenarioVariantCell } from "@prisma/client";
import { VStack, Text } from "@chakra-ui/react";
import { useEffect, useState } from "react";
import { BsArrowClockwise } from "react-icons/bs";
import { rateLimitErrorMessage } from "~/sharedStrings";
import pluralize from "pluralize";
const MAX_AUTO_RETRIES = 3;
export const ErrorHandler = ({
output,
cell,
refetchOutput,
numPreviousTries,
}: {
output: ModelOutput;
cell: ScenarioVariantCell;
refetchOutput: () => void;
numPreviousTries: number;
}) => {
const [msToWait, setMsToWait] = useState(0);
const shouldAutoRetry =
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
useEffect(() => {
if (!shouldAutoRetry) return;
if (!cell.retryTime) return;
const initialWaitTime = calculateDelay(numPreviousTries);
const initialWaitTime = cell.retryTime.getTime() - Date.now();
const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime);
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
setMsToWait(remainingTime);
if (remainingTime <= 0) {
refetchOutput();
clearInterval(interval);
}
}, 1000);
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
clearInterval(interval);
clearTimeout(timeout);
};
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
return (
<VStack w="full">
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
<Text color="red.600" fontWeight="bold">
Error
</Text>
<Button
size="xs"
w={4}
h={4}
px={4}
py={4}
minW={0}
borderRadius={8}
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={6} />
</Button>
</HStack>
<Text color="red.600" wordBreak="break-word">
{output.errorMessage}
{cell.errorMessage}
</Text>
{msToWait > 0 && (
<Text color="red.600" fontSize="sm">
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
</VStack>
);
};
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 5000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}

View File

@@ -1,17 +1,16 @@
import { type RouterOutputs, api } from "~/utils/api";
import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
import { type ReactElement, useState, useEffect } from "react";
import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
export default function OutputCell({
scenario,
@@ -37,116 +36,114 @@ export default function OutputCell({
// if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output";
const outputMutation = api.outputs.get.useMutation();
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
const [channel, setChannel] = useState<string | undefined>(undefined);
const [numPreviousTries, setNumPreviousTries] = useState(0);
const fetchMutex = useRef(false);
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
async (forceRefetch?: boolean) => {
if (fetchMutex.current) return;
setNumPreviousTries((prev) => prev + 1);
fetchMutex.current = true;
setOutput(null);
const shouldStream =
isObject(variant) &&
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const channel = shouldStream ? generateChannel() : undefined;
setChannel(channel);
const output = await outputMutation.mutateAsync({
scenarioId: scenario.id,
variantId: variant.id,
channel,
forceRefetch,
});
setOutput(output);
await utils.promptVariants.stats.invalidate();
fetchMutex.current = false;
},
[outputMutation, scenario.id, variant.id],
const [refetchInterval, setRefetchInterval] = useState(0);
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
{ scenarioId: scenario.id, variantId: variant.id },
{ refetchInterval },
);
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
useEffect(fetchOutput, [scenario.id, variant.id]);
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
api.scenarioVariantCells.forceRefetch.useMutation();
const [hardRefetch] = useHandledAsyncCallback(async () => {
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
await utils.scenarioVariantCells.get.invalidate({
scenarioId: scenario.id,
variantId: variant.id,
});
await utils.promptVariants.stats.invalidate({
variantId: variant.id,
});
}, [hardRefetchMutate, scenario.id, variant.id]);
const fetchingOutput = queryLoading || refetchingOutput;
const awaitingOutput =
!cell ||
cell.retrievalStatus === "PENDING" ||
cell.retrievalStatus === "IN_PROGRESS" ||
refetchingOutput;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
const streamedMessage = useSocket(cell?.streamingChannel);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (fetchingOutput && !streamedMessage)
if (awaitingOutput && !streamedMessage)
return (
<Center h="100%" w="100%">
<Spinner />
</Center>
);
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (output && output.errorMessage) {
return (
<ErrorHandler
output={output}
refetchOutput={hardRefetch}
numPreviousTries={numPreviousTries}
/>
);
if (cell && cell.errorMessage) {
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
}
const response = output?.output as unknown as ChatCompletion;
const response = modelOutput?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message;
if (output && message?.function_call) {
if (modelOutput && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string;
try {
parsedArgs = JSON.parse(rawArgs);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
}
return (
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
<SyntaxHighlighter
customStyle={{ overflowX: "unset" }}
language="json"
style={docco}
lineProps={{
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
}}
wrapLines
>
{stringify(
{
function: message.function_call.name,
args: parsedArgs,
},
{ maxLength: 40 },
)}
</SyntaxHighlighter>
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} />
</Box>
<VStack
w="100%"
h="100%"
fontSize="xs"
flexWrap="wrap"
overflowX="hidden"
justifyContent="space-between"
>
<VStack w="full" flex={1} spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<SyntaxHighlighter
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
language="json"
style={docco}
lineProps={{
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
}}
wrapLines
>
{stringify(
{
function: message.function_call.name,
args: parsedArgs,
},
{ maxLength: 40 },
)}
</SyntaxHighlighter>
</VStack>
<OutputStats modelOutput={modelOutput} scenario={scenario} />
</VStack>
);
}
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
const contentToDisplay =
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay}
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />}
</Flex>
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
<VStack w="full" alignItems="flex-start" spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text>
</VStack>
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
</VStack>
);
}

View File

@@ -1,64 +1,56 @@
import { type ModelOutput } from "@prisma/client";
import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types";
import { useExperiment } from "~/utils/hooks";
import { api } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { evaluateOutput } from "~/server/utils/evaluateOutput";
import { HStack, Icon, Text } from "@chakra-ui/react";
import { type RouterOutputs } from "~/utils/api";
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_COST = false;
const SHOW_TIME = false;
const SHOW_TIME = true;
export const OutputStats = ({
model,
modelOutput,
scenario,
}: {
model: SupportedModel | string | null;
modelOutput: ModelOutput;
modelOutput: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
>;
scenario: Scenario;
}) => {
const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
const completionCost =
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
const cost = promptCost + completionCost;
if (!evals.length) return null;
return (
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}>
{evals.map((evaluation) => {
const passed = evaluateOutput(modelOutput, scenario, evaluation);
{modelOutput.outputEvaluation.map((evaluation) => {
const passed = evaluation.result > 0.5;
return (
<HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.name}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
<Tooltip
isDisabled={!evaluation.details}
label={evaluation.details}
key={evaluation.id}
>
<HStack spacing={0}>
<Text>{evaluation.evaluation.label}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
</Tooltip>
);
})}
</HStack>
{SHOW_COST && (
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}>
{modelOutput.cost && (
<CostTooltip
promptTokens={promptTokens}
completionTokens={completionTokens}
cost={modelOutput.cost}
>
<HStack spacing={0}>
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{cost.toFixed(3)}</Text>
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
</HStack>
</CostTooltip>
)}

View File

@@ -1,8 +1,8 @@
import { type DragEvent } from "react";
import { api } from "~/utils/api";
import { isEqual } from "lodash";
import { isEqual } from "lodash-es";
import { type Scenario } from "./types";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react";
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
@@ -13,11 +13,14 @@ import AutoResizeTextArea from "../AutoResizeTextArea";
export default function ScenarioEditor({
scenario,
hovered,
...props
}: {
scenario: Scenario;
hovered: boolean;
canHide: boolean;
}) {
const { canModify } = useExperimentAccess();
const savedValues = scenario.variableValues as Record<string, string>;
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
@@ -73,6 +76,7 @@ export default function ScenarioEditor({
alignItems="flex-start"
pr={cellPadding.x}
py={cellPadding.y}
pl={canModify ? 0 : cellPadding.x}
height="100%"
draggable={!variableInputHovered}
onDragStart={(e) => {
@@ -92,31 +96,38 @@ export default function ScenarioEditor({
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}>
<Tooltip label="Hide scenario" hasArrow>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
<Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
</Stack>
{canModify && (
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
{props.canHide && (
<>
<Tooltip label="Hide scenario" hasArrow>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
<Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
</>
)}
</Stack>
)}
{variableLabels.length === 0 ? (
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
) : (
@@ -150,6 +161,8 @@ export default function ScenarioEditor({
fontSize="sm"
lineHeight={1.2}
value={value}
isDisabled={!canModify}
_disabled={{ opacity: 1, cursor: "default" }}
onChange={(e) => {
setValues((prev) => ({ ...prev, [key]: e.target.value }));
}}

View File

@@ -5,7 +5,11 @@ import OutputCell from "./OutputCell/OutputCell";
import ScenarioEditor from "./ScenarioEditor";
import type { PromptVariant, Scenario } from "./types";
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
const ScenarioRow = (props: {
scenario: Scenario;
variants: PromptVariant[];
canHide: boolean;
}) => {
const [isHovered, setIsHovered] = useState(false);
const highlightStyle = { backgroundColor: "gray.50" };
@@ -18,7 +22,7 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
sx={isHovered ? highlightStyle : undefined}
borderLeftWidth={1}
>
<ScenarioEditor scenario={props.scenario} hovered={isHovered} />
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
</GridItem>
{props.variants.map((variant) => (
<GridItem

View File

@@ -0,0 +1,52 @@
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
import { cellPadding } from "../constants";
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
import { stickyHeaderStyle } from "./styles";
import { BsPencil } from "react-icons/bs";
import { useAppStore } from "~/state/store";
export const ScenariosHeader = ({
headerRows,
numScenarios,
}: {
headerRows: number;
numScenarios: number;
}) => {
const openDrawer = useAppStore((s) => s.openDrawer);
const { canModify } = useExperimentAccess();
const [ref, dimensions] = useElementDimensions();
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
return (
<GridItem
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ref={ref as any}
display="flex"
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// Only display the part of the grid item that has content
sx={{ ...stickyHeaderStyle, top: topValue }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({numScenarios})
</Heading>
{canModify && (
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
)}
</HStack>
</GridItem>
);
};

View File

@@ -1,11 +1,12 @@
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react";
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import { editorBackground } from "~/state/sharedVariantEditor.slice";
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
@@ -17,15 +18,23 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
const checkForChanges = useCallback(() => {
if (!editorRef.current) return;
const currentConfig = editorRef.current.getValue();
setIsChanged(currentConfig !== lastSavedFn);
const currentFn = editorRef.current.getValue();
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
}, [lastSavedFn]);
const matchUpdatedSavedFn = useCallback(() => {
if (!editorRef.current) return;
editorRef.current.setValue(lastSavedFn);
setIsChanged(false);
}, [lastSavedFn]);
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext();
const toast = useToast();
const [onSave] = useHandledAsyncCallback(async () => {
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
if (!editorRef.current) return;
await editorRef.current.getAction("editor.action.formatDocument")?.run();
@@ -38,18 +47,6 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
const model = editorRef.current.getModel();
if (!model) return;
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
if (hasErrors) {
toast({
title: "Invalid TypeScript",
description: "Please fix the TypeScript errors before saving.",
status: "error",
});
return;
}
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
const promptRegex = /prompt\s*=/;
if (!promptRegex.test(currentFn)) {
@@ -75,9 +72,9 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
});
}
await utils.promptVariants.list.invalidate();
setIsChanged(false);
checkForChanges();
await utils.promptVariants.list.invalidate();
}, [checkForChanges]);
useEffect(() => {
@@ -101,6 +98,7 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "",
quickSuggestions: true,
readOnly: !canModify,
});
editorRef.current.onDidFocusEditorText(() => {
@@ -128,21 +126,16 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
/* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]);
useEffect(() => {
if (!editorRef.current) return;
editorRef.current.updateOptions({
readOnly: !canModify,
});
}, [canModify]);
return (
<Box w="100%" pos="relative">
<VStack
spacing={0}
align="stretch"
fontSize="xs"
fontWeight="bold"
color="gray.600"
py={2}
bgColor={editorBackground}
>
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
<code>{`return prompt; }`}</code>
</VStack>
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
{isChanged && (
<HStack pos="absolute" bottom={2} right={2}>
<Button
@@ -156,8 +149,8 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
Reset
</Button>
<Tooltip label={`${modifierKey} + Enter`}>
<Button size="sm" onClick={onSave} colorScheme="blue">
Save
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
</Button>
</Tooltip>
</HStack>

View File

@@ -1,105 +0,0 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
import { BsX } from "react-icons/bs";
import { RiDraggable } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
export default function VariantHeader(props: { variant: PromptVariant }) {
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
const [label, setLabel] = useState(props.variant.label);
const updateMutation = api.promptVariants.update.useMutation();
const [onSaveLabel] = useHandledAsyncCallback(async () => {
if (label && label !== props.variant.label) {
await updateMutation.mutateAsync({
id: props.variant.id,
updates: { label: label },
});
}
}, [updateMutation, props.variant.id, props.variant.label, label]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: props.variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, props.variant.id]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
async (e: DragEvent<HTMLDivElement>) => {
e.preventDefault();
setIsDragTarget(false);
const draggedId = e.dataTransfer.getData("text/plain");
const droppedId = props.variant.id;
if (!draggedId || !droppedId || draggedId === droppedId) return;
await reorderMutation.mutateAsync({
draggedId,
droppedId,
});
await utils.promptVariants.list.invalidate();
},
[reorderMutation, props.variant.id],
);
return (
<HStack
spacing={4}
alignItems="center"
minH={headerMinHeight}
draggable={!isInputHovered}
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", props.variant.id);
e.currentTarget.style.opacity = "0.4";
}}
onDragEnd={(e) => {
e.currentTarget.style.opacity = "1";
}}
onDragOver={(e) => {
e.preventDefault();
setIsDragTarget(true);
}}
onDragLeave={() => {
setIsDragTarget(false);
}}
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
<AutoResizeTextArea // Changed to Input
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
placeholder="Variant Name"
borderWidth={1}
borderColor="transparent"
fontWeight="bold"
fontSize={16}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
flex={1}
px={cellPadding.x}
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
<Tooltip label="Hide Variant" hasArrow>
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
<Icon as={BsX} boxSize={6} />
</Button>
</Tooltip>
</HStack>
);
}

View File

@@ -5,8 +5,10 @@ import { api } from "~/utils/api";
import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs";
import { CostTooltip } from "../tooltip/CostTooltip";
import { useEffect, useState } from "react";
export default function VariantStats(props: { variant: PromptVariant }) {
const [refetchInterval, setRefetchInterval] = useState(0);
const { data } = api.promptVariants.stats.useQuery(
{
variantId: props.variant.id,
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
completionTokens: 0,
scenarioCount: 0,
outputCount: 0,
awaitingRetrievals: false,
},
refetchInterval,
},
);
// Poll every two seconds while we are waiting for LLM retrievals to finish
useEffect(
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
[data.awaitingRetrievals],
);
const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500",
"gray.500",
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
return (
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
<HStack
justifyContent="space-between"
alignItems="center"
mx="2"
fontSize="xs"
py={cellPadding.y}
>
{showNumFinished && (
<Text>
{data.outputCount} / {data.scenarioCount}
</Text>
)}
<HStack px={cellPadding.x} py={cellPadding.y}>
<HStack px={cellPadding.x}>
{data.evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount);
const passedFrac = result.passCount / result.totalCount;
return (
<HStack key={result.id}>
<Text>{result.evaluation.name}</Text>
<Text>{result.label}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}%
</Text>
@@ -55,13 +69,13 @@ export default function VariantStats(props: { variant: PromptVariant }) {
);
})}
</HStack>
{data.overallCost && (
{data.overallCost && !data.awaitingRetrievals && (
<CostTooltip
promptTokens={data.promptTokens}
completionTokens={data.completionTokens}
cost={data.overallCost}
>
<HStack spacing={0} align="center" color="gray.500" my="2">
<HStack spacing={0} align="center" color="gray.500">
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack>

View File

@@ -1,28 +1,19 @@
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
import { Grid, GridItem } from "@chakra-ui/react";
import { api } from "~/utils/api";
import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton";
import ScenarioRow from "./ScenarioRow";
import VariantConfigEditor from "./VariantEditor";
import VariantHeader from "./VariantHeader";
import { cellPadding } from "../constants";
import { BsPencil } from "react-icons/bs";
import VariantEditor from "./VariantEditor";
import VariantHeader from "../VariantHeader/VariantHeader";
import VariantStats from "./VariantStats";
import { useAppStore } from "~/state/store";
const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "-1px",
backgroundColor: "#fff",
zIndex: 1,
};
import { ScenariosHeader } from "./ScenariosHeader";
import { stickyHeaderStyle } from "./styles";
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const variants = api.promptVariants.list.useQuery(
{ experimentId: experimentId as string },
{ enabled: !!experimentId },
);
const openDrawer = useAppStore((s) => s.openDrawer);
const scenarios = api.scenarios.list.useQuery(
{ experimentId: experimentId as string },
@@ -49,37 +40,10 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
}}
fontSize="sm"
>
<GridItem
display="flex"
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
// so if the header height changes, this will need to be updated.
sx={{ ...stickyHeaderStyle, top: "-337px" }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({scenarios.data.length})
</Heading>
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
</HStack>
</GridItem>
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
{variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
<VariantHeader variant={variant} />
</GridItem>
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
))}
<GridItem
rowSpan={scenarios.data.length + headerRows}
@@ -94,7 +58,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
{variants.data.map((variant) => (
<GridItem key={variant.uiId}>
<VariantConfigEditor variant={variant} />
<VariantEditor variant={variant} />
</GridItem>
))}
{variants.data.map((variant) => (
@@ -103,7 +67,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
</GridItem>
))}
{scenarios.data.map((scenario) => (
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
<ScenarioRow
key={scenario.uiId}
scenario={scenario}
variants={variants.data}
canHide={scenarios.data.length > 1}
/>
))}
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
<NewScenarioButton />

View File

@@ -0,0 +1,8 @@
import { type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "0",
backgroundColor: "#fff",
zIndex: 1,
};

View File

@@ -1,21 +0,0 @@
import { Flex, Icon, Link, Text } from "@chakra-ui/react";
import { BsExclamationTriangleFill } from "react-icons/bs";
import { env } from "~/env.mjs";
export default function PublicPlaygroundWarning() {
if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
return (
<Flex bgColor="red.600" color="whiteAlpha.900" p={2} align="center">
<Icon boxSize={4} mr={2} as={BsExclamationTriangleFill} />
<Text>
Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
private use,{" "}
<Link textDecor="underline" href="https://github.com/openpipe/openpipe" target="_blank">
run a local copy
</Link>
.
</Text>
</Flex>
);
}

View File

@@ -0,0 +1,45 @@
import { HStack, VStack } from "@chakra-ui/react";
import React from "react";
import DiffViewer, { DiffMethod } from "react-diff-viewer";
import Prism from "prismjs";
import "prismjs/components/prism-javascript";
import "prismjs/themes/prism.css"; // choose a theme you like
const CompareFunctions = ({
originalFunction,
newFunction = "",
}: {
originalFunction: string;
newFunction?: string;
}) => {
console.log("newFunction", newFunction);
const highlightSyntax = (str: string) => {
let highlighted;
try {
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
} catch (e) {
console.error("Error highlighting:", e);
highlighted = str;
}
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
};
return (
<HStack w="full" spacing={5}>
<VStack w="full" spacing={4} maxH="65vh" fontSize={12} lineHeight={1} overflowY="auto">
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={true}
hideLineNumbers={true}
leftTitle="Original"
rightTitle={newFunction ? "Modified" : "Unmodified"}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
/>
</VStack>
</HStack>
);
};
export default CompareFunctions;

View File

@@ -0,0 +1,103 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
} from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import AutoResizeTextArea from "../AutoResizeTextArea";
import CompareFunctions from "./CompareFunctions";
export const RefinePromptModal = ({
variant,
onClose,
}: {
variant: PromptVariant;
onClose: () => void;
}) => {
const utils = api.useContext();
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getRefinedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(async () => {
if (!variant.experimentId) return;
await getRefinedPromptMutateAsync({
id: variant.id,
instructions,
});
}, [getRefinedPromptMutateAsync, onClose, variant, instructions]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (!variant.experimentId || !refinedPromptFn) return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: refinedPromptFn,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>Refine Your Prompt</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<HStack w="full">
<AutoResizeTextArea
value={instructions}
onChange={(e) => setInstructions(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
e.preventDefault();
e.currentTarget.blur();
getRefinedPromptFn();
}
}}
placeholder="Use chain of thought"
/>
<Button onClick={getRefinedPromptFn}>
{refiningInProgress ? <Spinner boxSize={4} /> : <Text>Submit</Text>}
</Button>
</HStack>
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={refinedPromptFn}
/>
</VStack>
</ModalBody>
<ModalFooter>
<HStack spacing={4}>
<Button onClick={onClose}>Cancel</Button>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
disabled={!refinedPromptFn}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,89 @@
import {
VStack,
Text,
HStack,
type StackProps,
GridItem,
SimpleGrid,
Link,
} from "@chakra-ui/react";
import { modelStats } from "~/server/modelStats";
import { type SupportedModel } from "~/server/types";
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
const stats = modelStats[model];
return (
<VStack w="full" align="start">
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
{label}
</Text>
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
<HStack w="full" align="flex-start">
<Text flex={1} fontSize="lg">
<Text as="span" color="gray.600">
{stats.provider} /{" "}
</Text>
<Text as="span" fontWeight="bold" color="gray.900">
{model}
</Text>
</Text>
<Link
href={stats.learnMoreUrl}
isExternal
color="blue.500"
fontWeight="bold"
fontSize="sm"
ml={2}
>
Learn More
</Link>
</HStack>
<SimpleGrid
w="full"
justifyContent="space-between"
alignItems="flex-start"
fontSize="sm"
columns={{ base: 2, md: 4 }}
>
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
</SimpleGrid>
</VStack>
</VStack>
);
};
const SelectedModelLabeledInfo = ({
label,
info,
...props
}: {
label: string;
info: string | number | React.ReactElement;
} & StackProps) => (
<GridItem>
<VStack alignItems="flex-start" {...props}>
<Text fontWeight="bold">{label}</Text>
<Text>{info}</Text>
</VStack>
</GridItem>
);

View File

@@ -0,0 +1,77 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
} from "@chakra-ui/react";
import { useState } from "react";
import { type SupportedModel } from "~/server/types";
import { ModelStatsCard } from "./ModelStatsCard";
import { SelectModelSearch } from "./SelectModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const SelectModelModal = ({
originalModel,
variantId,
onClose,
}: {
originalModel: SupportedModel;
variantId: string;
onClose: () => void;
}) => {
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
const utils = api.useContext();
const experiment = useExperiment();
const createMutation = api.promptVariants.create.useMutation();
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment?.data?.id) return;
await createMutation.mutateAsync({
experimentId: experiment?.data?.id,
variantId,
newModel: selectedModel,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [createMutation, experiment?.data?.id, variantId, onClose]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>Select a New Model</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} />
{originalModel !== selectedModel && (
<ModelStatsCard label="New Model" model={selectedModel} />
)}
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
</VStack>
</ModalBody>
<ModalFooter>
<Button
colorScheme="blue"
onClick={createNewVariant}
minW={24}
disabled={originalModel === selectedModel}
>
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,47 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { type SupportedModel } from "~/server/types";
import { useElementDimensions } from "~/utils/hooks";
const modelOptions: { value: SupportedModel; label: string }[] = [
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
{ value: "gpt-4", label: "gpt-4" },
{ value: "gpt-4-0613", label: "gpt-4-0613" },
{ value: "gpt-4-32k", label: "gpt-4-32k" },
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
];
export const SelectModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: SupportedModel;
setSelectedModel: (model: SupportedModel) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -0,0 +1,123 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { RiDraggable } from "react-icons/ri";
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "../OutputsTable/styles";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const { canModify } = useExperimentAccess();
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
const [label, setLabel] = useState(props.variant.label);
const updateMutation = api.promptVariants.update.useMutation();
const [onSaveLabel] = useHandledAsyncCallback(async () => {
if (label && label !== props.variant.label) {
await updateMutation.mutateAsync({
id: props.variant.id,
updates: { label: label },
});
}
}, [updateMutation, props.variant.id, props.variant.label, label]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
async (e: DragEvent<HTMLDivElement>) => {
e.preventDefault();
setIsDragTarget(false);
const draggedId = e.dataTransfer.getData("text/plain");
const droppedId = props.variant.id;
if (!draggedId || !droppedId || draggedId === droppedId) return;
await reorderMutation.mutateAsync({
draggedId,
droppedId,
});
await utils.promptVariants.list.invalidate();
},
[reorderMutation, props.variant.id],
);
const [menuOpen, setMenuOpen] = useState(false);
if (!canModify) {
return (
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
{props.variant.label}
</Text>
</GridItem>
);
}
return (
<GridItem
padding={0}
sx={{
...stickyHeaderStyle,
// Ensure that the menu always appears above the sticky header of other variants
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
}}
borderTopWidth={1}
>
<HStack
spacing={4}
alignItems="flex-start"
minH={headerMinHeight}
draggable={!isInputHovered}
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", props.variant.id);
e.currentTarget.style.opacity = "0.4";
}}
onDragEnd={(e) => {
e.currentTarget.style.opacity = "1";
}}
onDragOver={(e) => {
e.preventDefault();
setIsDragTarget(true);
}}
onDragLeave={() => {
setIsDragTarget(false);
}}
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Icon
as={RiDraggable}
boxSize={6}
mt={2}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
<AutoResizeTextArea
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
placeholder="Variant Name"
borderWidth={1}
borderColor="transparent"
fontWeight="bold"
fontSize={16}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
flex={1}
px={cellPadding.x}
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
<VariantHeaderMenuButton
variant={props.variant}
canHide={props.canHide}
menuOpen={menuOpen}
setMenuOpen={setMenuOpen}
/>
</HStack>
</GridItem>
);
}

View File

@@ -0,0 +1,114 @@
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import {
Button,
Icon,
Menu,
MenuButton,
MenuItem,
MenuList,
MenuDivider,
Text,
Spinner,
} from "@chakra-ui/react";
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { AiOutlineDiff } from "react-icons/ai";
import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
export default function VariantHeaderMenuButton({
variant,
canHide,
menuOpen,
setMenuOpen,
}: {
variant: PromptVariant;
canHide: boolean;
menuOpen: boolean;
setMenuOpen: (open: boolean) => void;
}) {
const utils = api.useContext();
const duplicateMutation = api.promptVariants.create.useMutation();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({
experimentId: variant.experimentId,
variantId: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [duplicateMutation, variant.experimentId, variant.id]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, variant.id]);
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
return (
<>
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
{duplicationInProgress ? (
<Spinner boxSize={4} mx={3} my={3} />
) : (
<MenuButton>
<Button variant="ghost">
<Icon as={BsGear} />
</Button>
</MenuButton>
)}
<MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
Duplicate
</MenuItem>
<MenuItem
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
onClick={() => setSelectModelModalOpen(true)}
>
Change Model
</MenuItem>
<MenuItem
icon={<Icon as={AiOutlineDiff} boxSize={5} />}
onClick={() => setRefinePromptModalOpen(true)}
>
Refine
</MenuItem>
{canHide && (
<>
<MenuDivider />
<MenuItem
onClick={onHide}
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
color="red.600"
_hover={{ backgroundColor: "red.50" }}
>
<Text>Hide</Text>
</MenuItem>
</>
)}
</MenuList>
</Menu>
{selectModelModalOpen && (
<SelectModelModal
originalModel={variant.model as SupportedModel}
variantId={variant.id}
onClose={() => setSelectModelModalOpen(false)}
/>
)}
{refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
)}
</>
);
}

View File

@@ -1,17 +1,11 @@
import {
Card,
CardBody,
HStack,
Icon,
VStack,
Text,
CardHeader,
Divider,
Box,
} from "@chakra-ui/react";
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri";
import { formatTimePast } from "~/utils/dayjs";
import Link from "next/link";
import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
type ExperimentData = {
testScenarioCount: number;
@@ -24,47 +18,42 @@ type ExperimentData = {
};
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
const router = useRouter();
return (
<Box
as={Card}
variant="elevated"
bg="gray.50"
_hover={{ bg: "gray.100" }}
transition="background 0.2s"
cursor="pointer"
onClick={(e) => {
e.preventDefault();
void router.push({ pathname: "/experiments/[id]", query: { id: exp.id } }, undefined, {
shallow: true,
});
}}
>
<CardHeader>
<HStack w="full" color="gray.700">
<AspectRatio ratio={1.2} w="full">
<VStack
as={Link}
href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
bg="gray.50"
_hover={{ bg: "gray.100" }}
transition="background 0.2s"
cursor="pointer"
borderColor="gray.200"
borderWidth={1}
p={4}
justify="space-between"
>
<HStack w="full" color="gray.700" justify="center">
<Icon as={RiFlaskLine} boxSize={4} />
<Text fontWeight="bold">{exp.label}</Text>
</HStack>
</CardHeader>
<CardBody>
<HStack w="full" mb={8} spacing={4}>
<HStack h="full" spacing={4} flex={1} align="center">
<CountLabel label="Variants" count={exp.promptVariantCount} />
<Divider h={12} orientation="vertical" />
<CountLabel label="Scenarios" count={exp.testScenarioCount} />
</HStack>
<HStack w="full" color="gray.500" fontSize="xs">
<Text>Created {formatTimePast(exp.createdAt)}</Text>
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
<Text flex={1}>Created {formatTimePast(exp.createdAt)}</Text>
<Divider h={4} orientation="vertical" />
<Text>Updated {formatTimePast(exp.updatedAt)}</Text>
<Text flex={1}>Updated {formatTimePast(exp.updatedAt)}</Text>
</HStack>
</CardBody>
</Box>
</VStack>
</AspectRatio>
);
};
const CountLabel = ({ label, count }: { label: string; count: number }) => {
return (
<VStack alignItems="flex-start">
<VStack alignItems="center" flex={1}>
<Text color="gray.500" fontWeight="bold">
{label}
</Text>
@@ -74,3 +63,33 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
</VStack>
);
};
export const NewExperimentCard = () => {
const router = useRouter();
const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
}, [createMutation, router]);
return (
<AspectRatio ratio={1.2} w="full">
<VStack
align="center"
justify="center"
_hover={{ cursor: "pointer", bg: "gray.50" }}
transition="background 0.2s"
cursor="pointer"
borderColor="gray.200"
borderWidth={1}
p={4}
onClick={createExperiment}
>
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
<Text display={{ base: "none", md: "block" }} ml={2}>
New Experiment
</Text>
</VStack>
</AspectRatio>
);
};

View File

@@ -1,31 +0,0 @@
import { Icon, Button, Spinner, Text, type ButtonProps } from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs";
import { useHandledAsyncCallback } from "~/utils/hooks";
export const NewExperimentButton = (props: ButtonProps) => {
const router = useRouter();
const utils = api.useContext();
const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
}, [createMutation, router]);
return (
<Button
onClick={createExperiment}
display="flex"
alignItems="center"
variant={{ base: "solid", md: "ghost" }}
{...props}
>
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={4} />
<Text display={{ base: "none", md: "block" }} ml={2}>
New Experiment
</Text>
</Button>
);
};

View File

@@ -1,84 +1,100 @@
import { useState, useEffect } from "react";
import {
Heading,
VStack,
Icon,
HStack,
Image,
Grid,
GridItem,
Divider,
Text,
Box,
type BoxProps,
type LinkProps,
Link,
Flex,
} from "@chakra-ui/react";
import Head from "next/head";
import { BsGithub, BsTwitter } from "react-icons/bs";
import { BsGithub, BsPersonCircle } from "react-icons/bs";
import { useRouter } from "next/router";
import PublicPlaygroundWarning from "../PublicPlaygroundWarning";
import { type IconType } from "react-icons";
import { RiFlaskLine } from "react-icons/ri";
import { useState, useEffect } from "react";
import { signIn, useSession } from "next-auth/react";
import UserMenu from "./UserMenu";
type IconLinkProps = BoxProps & LinkProps & { label: string; icon: IconType; href: string };
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
const isActive = useRouter().pathname.startsWith(href);
const router = useRouter();
const isActive = href && router.pathname.startsWith(href);
return (
<Box
<HStack
w="full"
p={4}
color={color}
as={Link}
href={href}
target={target}
w="full"
bgColor={isActive ? "gray.300" : "transparent"}
_hover={{ bgColor: "gray.300" }}
py={4}
bgColor={isActive ? "gray.200" : "transparent"}
_hover={{ bgColor: "gray.200", textDecoration: "none" }}
justifyContent="start"
cursor="pointer"
{...props}
>
<HStack w="full" px={4} color={color}>
<Icon as={icon} boxSize={6} mr={2} />
<Text fontWeight="bold">{label}</Text>
</HStack>
</Box>
<Icon as={icon} boxSize={6} mr={2} />
<Text fontWeight="bold" fontSize="sm">
{label}
</Text>
</HStack>
);
};
const Divider = () => <Box h="1px" bgColor="gray.200" />;
const NavSidebar = () => {
const user = useSession().data;
return (
<VStack align="stretch" bgColor="gray.100" py={2} pb={0} height="100%">
<Link href="/" w="full" _hover={{ textDecoration: "none" }}>
<HStack spacing={0} pl="3">
<Image src="/logo.svg" alt="" w={8} h={8} />
<Heading size="md" p={2} pl={{ base: 16, md: 2 }}>
OpenPipe
</Heading>
</HStack>
</Link>
<Divider />
<VStack
align="stretch"
bgColor="gray.100"
py={2}
pb={0}
height="100%"
w={{ base: "56px", md: "200px" }}
overflow="hidden"
>
<HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={4} py={2}>
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
<Heading size="md" fontFamily="inconsolata, monospace">
OpenPipe
</Heading>
</HStack>
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
{user != null && (
<>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
</>
)}
{user === null && (
<IconLink
icon={BsPersonCircle}
label="Sign In"
onClick={() => {
signIn("github").catch(console.error);
}}
/>
)}
</VStack>
<Divider />
<VStack w="full" spacing={0} pb={2}>
<IconLink
icon={BsGithub}
label="GitHub"
{user ? <UserMenu user={user} /> : <Divider />}
<VStack spacing={0} align="center">
<Link
href="https://github.com/openpipe/openpipe"
target="_blank"
color="gray.500"
_hover={{ color: "gray.800" }}
/>
<IconLink
icon={BsTwitter}
label="Twitter"
href="https://twitter.com/corbtt"
target="_blank"
color="gray.500"
_hover={{ color: "gray.800" }}
/>
p={2}
>
<Icon as={BsGithub} boxSize={6} />
</Link>
</VStack>
</VStack>
);
@@ -105,25 +121,14 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
}, []);
return (
<Grid
h={vh}
w="100vw"
templateColumns={{ base: "56px minmax(0, 1fr)", md: "200px minmax(0, 1fr)" }}
templateRows="max-content 1fr"
templateAreas={'"warning warning"\n"sidebar main"'}
>
<Flex h={vh} w="100vw">
<Head>
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
</Head>
<GridItem area="warning">
<PublicPlaygroundWarning />
</GridItem>
<GridItem area="sidebar" overflow="hidden">
<NavSidebar />
</GridItem>
<GridItem area="main" overflowY="auto">
<NavSidebar />
<Box h="100%" flex={1} overflowY="auto">
{props.children}
</GridItem>
</Grid>
</Box>
</Flex>
);
}

View File

@@ -0,0 +1,74 @@
import {
HStack,
Icon,
Image,
VStack,
Text,
Popover,
PopoverTrigger,
PopoverContent,
Link,
} from "@chakra-ui/react";
import { type Session } from "next-auth";
import { signOut } from "next-auth/react";
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
export default function UserMenu({ user }: { user: Session }) {
const profileImage = user.user.image ? (
<Image src={user.user.image} alt="profile picture" boxSize={8} borderRadius="50%" />
) : (
<Icon as={BsPersonCircle} boxSize={6} />
);
return (
<>
<Popover placement="right">
<PopoverTrigger>
<HStack
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
px={3}
spacing={3}
py={2}
borderColor={"gray.200"}
borderTopWidth={1}
borderBottomWidth={1}
cursor="pointer"
_hover={{
bgColor: "gray.200",
}}
>
{profileImage}
<VStack spacing={0} align="start" flex={1} flexShrink={1}>
<Text fontWeight="bold" fontSize="sm">
{user.user.name}
</Text>
<Text color="gray.500" fontSize="xs">
{user.user.email}
</Text>
</VStack>
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
</HStack>
</PopoverTrigger>
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
<VStack align="stretch" spacing={0}>
{/* sign out */}
<HStack
as={Link}
onClick={() => {
signOut().catch(console.error);
}}
px={4}
py={2}
spacing={4}
color="gray.500"
fontSize="sm"
>
<Icon as={BsBoxArrowRight} boxSize={6} />
<Text>Sign out</Text>
</HStack>
</VStack>
</PopoverContent>
</Popover>
</>
);
}

View File

@@ -20,7 +20,6 @@ export const CostTooltip = ({
color="gray.800"
bgColor="gray.50"
borderWidth={1}
py={2}
hasArrow
shouldWrapChildren
label={

View File

@@ -10,6 +10,13 @@ export const env = createEnv({
DATABASE_URL: z.string().url(),
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
OPENAI_API_KEY: z.string().min(1),
RESTRICT_PRISMA_LOGS: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
GITHUB_CLIENT_ID: z.string().min(1),
GITHUB_CLIENT_SECRET: z.string().min(1),
},
/**
@@ -19,11 +26,6 @@ export const env = createEnv({
*/
client: {
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
},
@@ -35,9 +37,11 @@ export const env = createEnv({
DATABASE_URL: process.env.DATABASE_URL,
NODE_ENV: process.env.NODE_ENV,
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
},
/**
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.

View File

@@ -6,18 +6,27 @@ import { ChakraProvider } from "@chakra-ui/react";
import theme from "~/utils/theme";
import Favicon from "~/components/Favicon";
import "~/utils/analytics";
import Head from "next/head";
const MyApp: AppType<{ session: Session | null }> = ({
Component,
pageProps: { session, ...pageProps },
}) => {
return (
<SessionProvider session={session}>
<Favicon />
<ChakraProvider theme={theme}>
<Component {...pageProps} />
</ChakraProvider>
</SessionProvider>
<>
<Head>
<meta
name="viewport"
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
/>
</Head>
<SessionProvider session={session}>
<Favicon />
<ChakraProvider theme={theme}>
<Component {...pageProps} />
</ChakraProvider>
</SessionProvider>
</>
);
};

View File

@@ -0,0 +1,23 @@
import { signIn, useSession } from "next-auth/react";
import { useRouter } from "next/router";
import { useEffect } from "react";
import AppShell from "~/components/nav/AppShell";
export default function SignIn() {
const session = useSession().data;
const router = useRouter();
useEffect(() => {
if (session) {
router.push("/experiments").catch(console.error);
} else if (session === null) {
signIn("github").catch(console.error);
}
}, [session, router]);
return (
<AppShell>
<div />
</AppShell>
);
}

View File

@@ -124,6 +124,8 @@ export default function Experiment() {
);
}
const canModify = experiment.data?.access.canModify ?? false;
return (
<AppShell title={experiment.data?.label}>
<VStack h="full">
@@ -143,37 +145,45 @@ export default function Experiment() {
</Link>
</BreadcrumbItem>
<BreadcrumbItem isCurrentPage>
<Input
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
borderWidth={1}
borderColor="transparent"
fontSize={16}
px={0}
minW={{ base: 100, lg: 300 }}
flex={1}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
/>
{canModify ? (
<Input
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
borderWidth={1}
borderColor="transparent"
fontSize={16}
px={0}
minW={{ base: 100, lg: 300 }}
flex={1}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
/>
) : (
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
{experiment.data?.label}
</Text>
)}
</BreadcrumbItem>
</Breadcrumb>
<HStack>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={openDrawer}
>
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Edit Vars & Evals
</Text>
</Button>
<DeleteButton />
</HStack>
{canModify && (
<HStack>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={openDrawer}
>
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Edit Vars & Evals
</Text>
</Button>
<DeleteButton />
</HStack>
)}
</Flex>
<SettingsDrawer />
<Box w="100%" overflowX="auto" flex={1}>

View File

@@ -1,25 +1,50 @@
import {
SimpleGrid,
HStack,
Icon,
VStack,
Breadcrumb,
BreadcrumbItem,
Flex,
Center,
Text,
Link,
HStack,
} from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton";
import { ExperimentCard } from "~/components/experiments/ExperimentCard";
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
import { signIn, useSession } from "next-auth/react";
export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery();
const user = useSession().data;
if (user === null) {
return (
<AppShell title="Experiments">
<Center h="100%">
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
</Center>
</AppShell>
);
}
return (
<AppShell>
<VStack alignItems={"flex-start"} m={4} mt={1}>
<HStack w="full" justifyContent="space-between" mb={4}>
<AppShell title="Experiments">
<VStack alignItems={"flex-start"} px={4} py={2}>
<HStack minH={8} align="center">
<Breadcrumb flex={1}>
<BreadcrumbItem>
<Flex alignItems="center">
@@ -27,9 +52,9 @@ export default function ExperimentsPage() {
</Flex>
</BreadcrumbItem>
</Breadcrumb>
<NewExperimentButton mr={4} borderRadius={8} />
</HStack>
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
<NewExperimentCard />
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
</SimpleGrid>
</VStack>

View File

@@ -1,7 +1,7 @@
import { type GetServerSideProps } from "next";
// eslint-disable-next-line @typescript-eslint/require-await
export const getServerSideProps: GetServerSideProps = async (context) => {
export const getServerSideProps: GetServerSideProps = async () => {
return {
redirect: {
destination: "/experiments",

View File

@@ -1,11 +1,7 @@
import { type CompletionCreateParams } from "openai/resources/chat";
import { prisma } from "../db";
import { openai } from "../utils/openai";
import { pick } from "lodash";
function promptHasVariable(prompt: string, variableName: string) {
return prompt.includes(`{{${variableName}}}`);
}
import { pick } from "lodash-es";
type AxiosError = {
response?: {

View File

@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
import { createTRPCRouter } from "~/server/api/trpc";
import { experimentsRouter } from "./routers/experiments.router";
import { scenariosRouter } from "./routers/scenarios.router";
import { modelOutputsRouter } from "./routers/modelOutputs.router";
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
import { templateVarsRouter } from "./routers/templateVariables.router";
import { evaluationsRouter } from "./routers/evaluations.router";
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
promptVariants: promptVariantsRouter,
experiments: experimentsRouter,
scenarios: scenariosRouter,
outputs: modelOutputsRouter,
scenarioVariantCells: scenarioVariantCellsRouter,
templateVars: templateVarsRouter,
evaluations: evaluationsRouter,
});

View File

@@ -1,68 +1,96 @@
import { EvaluationMatchType } from "@prisma/client";
import { EvalType } from "@prisma/client";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { reevaluateEvaluation } from "~/server/utils/evaluations";
import { runAllEvals } from "~/server/utils/evaluations";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.evaluation.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: { createdAt: "asc" },
});
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
create: publicProcedure
return await prisma.evaluation.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: { createdAt: "asc" },
});
}),
create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
name: z.string(),
matchString: z.string(),
matchType: z.nativeEnum(EvaluationMatchType),
label: z.string(),
value: z.string(),
evalType: z.nativeEnum(EvalType),
}),
)
.mutation(async ({ input }) => {
const evaluation = await prisma.evaluation.create({
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.evaluation.create({
data: {
experimentId: input.experimentId,
name: input.name,
matchString: input.matchString,
matchType: input.matchType,
label: input.label,
value: input.value,
evalType: input.evalType,
},
});
await reevaluateEvaluation(evaluation);
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
// to kick off a background job or something instead
await runAllEvals(input.experimentId);
}),
update: publicProcedure
update: protectedProcedure
.input(
z.object({
id: z.string(),
updates: z.object({
name: z.string().optional(),
matchString: z.string().optional(),
matchType: z.nativeEnum(EvaluationMatchType).optional(),
label: z.string().optional(),
value: z.string().optional(),
evalType: z.nativeEnum(EvalType).optional(),
}),
}),
)
.mutation(async ({ input }) => {
await prisma.evaluation.update({
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const evaluation = await prisma.evaluation.update({
where: { id: input.id },
data: {
name: input.updates.name,
matchString: input.updates.matchString,
matchType: input.updates.matchType,
label: input.updates.label,
value: input.updates.value,
evalType: input.updates.evalType,
},
});
await reevaluateEvaluation(
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
);
await prisma.outputEvaluation.deleteMany({
where: {
evaluationId: evaluation.id,
},
});
// Re-run all evals. Other eval results will already be cached, so this
// should only re-run the updated one.
await runAllEvals(evaluation.experimentId);
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
await prisma.evaluation.delete({
where: { id: input.id },
});
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
await prisma.evaluation.delete({
where: { id: input.id },
});
}),
});

View File

@@ -1,13 +1,31 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
import {
canModifyExperiment,
requireCanModifyExperiment,
requireCanViewExperiment,
requireNothing,
} from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
export const experimentsRouter = createTRPCRouter({
list: publicProcedure.query(async () => {
list: protectedProcedure.query(async ({ ctx }) => {
// Anyone can list experiments
requireNothing(ctx);
const experiments = await prisma.experiment.findMany({
where: {
organization: {
OrganizationUser: {
some: { userId: ctx.session.user.id },
},
},
},
orderBy: {
sortIndex: "asc",
sortIndex: "desc",
},
});
@@ -39,15 +57,29 @@ export const experimentsRouter = createTRPCRouter({
return experimentsWithCounts;
}),
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => {
return await prisma.experiment.findFirst({
where: {
id: input.id,
},
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx);
const experiment = await prisma.experiment.findFirstOrThrow({
where: { id: input.id },
});
const canModify = ctx.session?.user.id
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
: false;
return {
...experiment,
access: {
canView: true,
canModify,
},
};
}),
create: publicProcedure.input(z.object({})).mutation(async () => {
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
// Anyone can create an experiment
requireNothing(ctx);
const maxSortIndex =
(
await prisma.experiment.aggregate({
@@ -61,37 +93,66 @@ export const experimentsRouter = createTRPCRouter({
data: {
sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`,
organizationId: (await userOrg(ctx.session.user.id)).id,
},
});
await prisma.$transaction([
const [variant, _, scenario] = await prisma.$transaction([
prisma.promptVariant.create({
data: {
experimentId: exp.id,
label: "Prompt Variant 1",
sortIndex: 0,
constructFn: dedent`prompt = {
// The interpolated $ is necessary until dedent incorporates
// https://github.com/dmnd/dedent/pull/46
constructFn: dedent`
/**
* Use Javascript to define an OpenAI chat completion
* (https://platform.openai.com/docs/api-reference/chat/create) and
* assign it to the \`prompt\` variable.
*
* You have access to the current scenario in the \`scenario\`
* variable.
*/
prompt = {
model: "gpt-3.5-turbo-0613",
stream: true,
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
}`,
messages: [
{
role: "system",
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
},
],
};`,
model: "gpt-3.5-turbo-0613",
},
}),
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "text",
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {},
variableValues: {
text: "This is a test scenario.",
},
},
}),
]);
await generateNewCell(variant.id, scenario.id);
return exp;
}),
update: publicProcedure
update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.id, ctx);
return await prisma.experiment.update({
where: {
id: input.id,
@@ -102,11 +163,15 @@ export const experimentsRouter = createTRPCRouter({
});
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
await prisma.experiment.delete({
where: {
id: input.id,
},
});
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.id, ctx);
await prisma.experiment.delete({
where: {
id: input.id,
},
});
}),
});

View File

@@ -1,101 +0,0 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import crypto from "crypto";
import type { Prisma } from "@prisma/client";
import { reevaluateVariant } from "~/server/utils/evaluations";
import { getCompletion } from "~/server/utils/getCompletion";
import { constructPrompt } from "~/server/utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
channel: z.string().optional(),
forceRefetch: z.boolean().optional(),
}),
)
.mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
});
if (existing && !input.forceRefetch) return existing;
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: input.scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
// TODO: we should probably only use this if temperature=0
const existingResponse = await prisma.modelOutput.findFirst({
where: { inputHash, errorMessage: null },
});
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
if (existingResponse) {
modelResponse = {
output: existingResponse.output as Prisma.InputJsonValue,
statusCode: existingResponse.statusCode,
errorMessage: existingResponse.errorMessage,
timeToComplete: existingResponse.timeToComplete,
promptTokens: existingResponse.promptTokens ?? undefined,
completionTokens: existingResponse.completionTokens ?? undefined,
};
} else {
try {
modelResponse = await getCompletion(
prompt as unknown as CompletionCreateParams,
input.channel,
);
} catch (e) {
console.error(e);
throw e;
}
}
const modelOutput = await prisma.modelOutput.upsert({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
create: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
inputHash,
...modelResponse,
},
update: {
...modelResponse,
},
});
await reevaluateVariant(input.variantId);
return modelOutput;
}),
});

View File

@@ -1,93 +1,175 @@
import { isObject } from "lodash";
import { isObject } from "lodash-es";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { OpenAIChatModel } from "~/server/types";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel, type SupportedModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: { sortIndex: "asc" },
});
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
if (!variant) {
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
const evalResults = await prisma.evaluationResult.findMany({
where: {
promptVariantId: input.variantId,
},
include: { evaluation: true },
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.modelOutput.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
},
_sum: {
promptTokens: true,
completionTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
const overallCost = overallPromptCost + overallCompletionCost;
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
}),
create: publicProcedure
.input(
z.object({
experimentId: z.string(),
}),
)
.mutation(async ({ input }) => {
const lastVariant = await prisma.promptVariant.findFirst({
return await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "desc",
orderBy: { sortIndex: "asc" },
});
}),
stats: publicProcedure
.input(z.object({ variantId: z.string() }))
.query(async ({ input, ctx }) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
if (!variant) {
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
await requireCanViewExperiment(variant.experimentId, ctx);
const outputEvals = await prisma.outputEvaluation.groupBy({
by: ["evaluationId"],
_sum: {
result: true,
},
_count: {
id: true,
},
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find(
(outputEval) => outputEval.evaluationId === evalItem.id,
);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
},
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: {
visible: true,
},
},
},
_sum: {
cost: true,
promptTokens: true,
completionTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
// Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
},
},
}));
return {
evalResults,
promptTokens,
completionTokens,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}),
create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
variantId: z.string().optional(),
newModel: z.string().optional(),
}),
)
.mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
let originalVariant: PromptVariant | null = null;
if (input.variantId) {
originalVariant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
} else {
originalVariant = await prisma.promptVariant.findFirst({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "desc",
},
});
}
const largestSortIndex =
(
await prisma.promptVariant.aggregate({
@@ -100,13 +182,23 @@ export const promptVariantsRouter = createTRPCRouter({
})
)._max?.sortIndex ?? 0;
const newVariantLabel =
input.variantId && originalVariant
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const newConstructFn = await deriveNewConstructFn(
originalVariant,
input.newModel as SupportedModel,
);
const createNewVariantAction = prisma.promptVariant.create({
data: {
experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
constructFn: lastVariant?.constructFn ?? "",
model: lastVariant?.model ?? "gpt-3.5-turbo",
label: newVariantLabel,
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn: newConstructFn,
model: originalVariant?.model ?? "gpt-3.5-turbo",
},
});
@@ -115,10 +207,26 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId),
]);
if (originalVariant) {
// Insert new variant to right of original variant
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
}
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return newVariant;
}),
update: publicProcedure
update: protectedProcedure
.input(
z.object({
id: z.string(),
@@ -127,7 +235,7 @@ export const promptVariantsRouter = createTRPCRouter({
}),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUnique({
where: {
id: input.id,
@@ -138,6 +246,8 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
}
await requireCanModifyExperiment(existing.experimentId, ctx);
const updatePromptVariantAction = prisma.promptVariant.update({
where: {
id: input.id,
@@ -153,13 +263,18 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
hide: publicProcedure
hide: protectedProcedure
.input(
z.object({
id: z.string(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const updatedPromptVariant = await prisma.promptVariant.update({
where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -168,19 +283,50 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
replaceVariant: publicProcedure
getRefinedPromptFn: protectedProcedure
.input(
z.object({
id: z.string(),
instructions: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUniqueOrThrow({
where: {
id: input.id,
},
});
await requireCanModifyExperiment(existing.experimentId, ctx);
const constructedPrompt = await constructPrompt({ constructFn: existing.constructFn }, null);
const promptConstructionFn = await deriveNewConstructFn(
existing,
// @ts-expect-error TODO clean this up
constructedPrompt?.model as SupportedModel,
input.instructions,
);
// TODO: Validate promptConstructionFn
// TODO: Record in some sort of history
return promptConstructionFn;
}),
replaceVariant: protectedProcedure
.input(
z.object({
id: z.string(),
constructFn: z.string(),
}),
)
.mutation(async ({ input }) => {
const existing = await prisma.promptVariant.findUnique({
.mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUniqueOrThrow({
where: {
id: input.id,
},
});
await requireCanModifyExperiment(existing.experimentId, ctx);
if (!existing) {
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
@@ -234,75 +380,33 @@ export const promptVariantsRouter = createTRPCRouter({
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: newVariant.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return { status: "ok" } as const;
}),
reorder: publicProcedure
reorder: protectedProcedure
.input(
z.object({
draggedId: z.string(),
droppedId: z.string(),
}),
)
.mutation(async ({ input }) => {
const dragged = await prisma.promptVariant.findUnique({
where: {
id: input.draggedId,
},
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { id: input.draggedId },
});
await requireCanModifyExperiment(experimentId, ctx);
const dropped = await prisma.promptVariant.findUnique({
where: {
id: input.droppedId,
},
});
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
throw new Error(
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: dragged.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the dragged item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
// Find the index of the dragged item and the dropped item
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
// Determine the new index for the dragged item
let newIndex;
if (dragIndex < dropIndex) {
newIndex = dropIndex + 1; // Insert after the dropped item
} else {
newIndex = dropIndex; // Insert before the dropped item
}
// Insert the dragged item at the new position
orderedItems.splice(newIndex, 0, dragged);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
await reorderPromptVariants(input.draggedId, input.droppedId);
}),
});

View File

@@ -0,0 +1,90 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.query(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanViewExperiment(experimentId, ctx);
return await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: {
include: {
outputEvaluation: {
include: {
evaluation: {
select: { label: true },
},
},
},
},
},
},
});
}),
forceRefetch: protectedProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanModifyExperiment(experimentId, ctx);
const cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (!cell) {
await generateNewCell(input.variantId, input.scenarioId);
return true;
}
if (cell.modelOutput) {
// TODO: Maybe keep these around to show previous generations?
await prisma.modelOutput.delete({
where: { id: cell.modelOutput.id },
});
}
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "PENDING" },
});
await queueLLMRetrievalTask(cell.id);
return true;
}),
});

View File

@@ -1,31 +1,39 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reevaluateAll } from "~/server/utils/evaluations";
import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
create: publicProcedure
return await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
autogenerate: z.boolean().optional(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
const maxSortIndex =
(
await prisma.testScenario.aggregate({
@@ -48,32 +56,50 @@ export const scenariosRouter = createTRPCRouter({
},
});
await prisma.$transaction([
const [scenario] = await prisma.$transaction([
createNewScenarioAction,
recordExperimentUpdated(input.experimentId),
]);
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, scenario.id);
}
}),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
const experimentId = (
await prisma.testScenario.findUniqueOrThrow({
where: { id: input.id },
})
).experimentId;
await requireCanModifyExperiment(experimentId, ctx);
const hiddenScenario = await prisma.testScenario.update({
where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
});
// Reevaluate all evaluations now that this scenario is hidden
await reevaluateAll(hiddenScenario.experimentId);
await runAllEvals(hiddenScenario.experimentId);
return hiddenScenario;
}),
reorder: publicProcedure
reorder: protectedProcedure
.input(
z.object({
draggedId: z.string(),
droppedId: z.string(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const dragged = await prisma.testScenario.findUnique({
where: {
id: input.draggedId,
@@ -92,6 +118,8 @@ export const scenariosRouter = createTRPCRouter({
);
}
await requireCanModifyExperiment(dragged.experimentId, ctx);
const visibleItems = await prisma.testScenario.findMany({
where: {
experimentId: dragged.experimentId,
@@ -135,14 +163,14 @@ export const scenariosRouter = createTRPCRouter({
);
}),
replaceWithValues: publicProcedure
replaceWithValues: protectedProcedure
.input(
z.object({
id: z.string(),
values: z.record(z.string()),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const existing = await prisma.testScenario.findUnique({
where: {
id: input.id,
@@ -153,6 +181,8 @@ export const scenariosRouter = createTRPCRouter({
throw new Error(`Scenario with id ${input.id} does not exist`);
}
await requireCanModifyExperiment(existing.experimentId, ctx);
const newScenario = await prisma.testScenario.create({
data: {
experimentId: existing.experimentId,
@@ -175,6 +205,17 @@ export const scenariosRouter = createTRPCRouter({
},
});
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: newScenario.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, newScenario.id);
}
return newScenario;
}),
});

View File

@@ -1,11 +1,14 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const templateVarsRouter = createTRPCRouter({
create: publicProcedure
create: protectedProcedure
.input(z.object({ experimentId: z.string(), label: z.string() }))
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.templateVariable.create({
data: {
experimentId: input.experimentId,
@@ -14,22 +17,33 @@ export const templateVarsRouter = createTRPCRouter({
});
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
await prisma.templateVariable.delete({ where: { id: input.id } });
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
where: { id: input.id },
});
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.templateVariable.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
await requireCanModifyExperiment(experimentId, ctx);
await prisma.templateVariable.delete({ where: { id: input.id } });
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
return await prisma.templateVariable.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
});

View File

@@ -27,6 +27,9 @@ type CreateContextOptions = {
session: Session | null;
};
// eslint-disable-next-line @typescript-eslint/no-empty-function
const noOp = () => {};
/**
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
* it from here.
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
return {
session: opts.session,
prisma,
markAccessControlRun: noOp,
};
};
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
* errors on the backend.
*/
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
const t = initTRPC.context<typeof createTRPCContext>().create({
transformer: superjson,
errorFormatter({ shape, error }) {
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
export const publicProcedure = t.procedure;
/** Reusable middleware that enforces users are logged in before running the procedure. */
const enforceUserIsAuthed = t.middleware(({ ctx, next }) => {
const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
if (!ctx.session || !ctx.session.user) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return next({
let accessControlRun = false;
const resp = await next({
ctx: {
// infers the `session` as non-nullable
session: { ...ctx.session, user: ctx.session.user },
markAccessControlRun: () => {
accessControlRun = true;
},
},
});
if (!accessControlRun)
throw new TRPCError({
code: "INTERNAL_SERVER_ERROR",
message:
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
});
return resp;
});
/**

View File

@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
import { type GetServerSidePropsContext } from "next";
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
import { prisma } from "~/server/db";
import GitHubProvider from "next-auth/providers/github";
import { env } from "~/env.mjs";
/**
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
},
adapter: PrismaAdapter(prisma),
providers: [
// DiscordProvider({
// clientId: env.DISCORD_CLIENT_ID,
// clientSecret: env.DISCORD_CLIENT_SECRET,
// }),
/**
* ...add more providers here.
*
* Most other providers require a bit more work than the Discord provider. For example, the
* GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
* model. Refer to the NextAuth.js docs for the provider you want to use. Example:
*
* @see https://next-auth.js.org/providers/github
*/
GitHubProvider({
clientId: env.GITHUB_CLIENT_ID,
clientSecret: env.GITHUB_CLIENT_SECRET,
}),
],
theme: {
logo: "/logo.svg",
brandColor: "#ff5733",
},
};
/**

View File

@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
export const prisma =
globalForPrisma.prisma ??
new PrismaClient({
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"],
log:
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
? ["query", "error", "warn"]
: ["error"],
});
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;

77
src/server/modelStats.ts Normal file
View File

@@ -0,0 +1,77 @@
import { type SupportedModel } from "./types";
interface ModelStats {
contextLength: number;
promptTokenPrice: number;
completionTokenPrice: number;
speed: "fast" | "medium" | "slow";
provider: "OpenAI";
learnMoreUrl: string;
}
export const modelStats: Record<SupportedModel, ModelStats> = {
"gpt-4": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-0613": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-0613": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
};

View File

@@ -0,0 +1,26 @@
// /* eslint-disable */
// import "dotenv/config";
// import Replicate from "replicate";
// const replicate = new Replicate({
// auth: process.env.REPLICATE_API_TOKEN || "",
// });
// console.log("going to run");
// const prediction = await replicate.predictions.create({
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
// input: {
// prompt: "...",
// },
// });
// console.log("waiting");
// setInterval(() => {
// replicate.predictions.get(prediction.id).then((prediction) => {
// console.log(prediction.output);
// });
// }, 500);
// // const output = await replicate.wait(prediction, {});
// // console.log(output);

View File

@@ -0,0 +1,31 @@
// Import necessary dependencies
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
import { env } from "~/env.mjs";
// Define the defineTask function
function defineTask<TPayload>(
taskIdentifier: string,
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
) {
const enqueue = async (payload: TPayload) => {
console.log("Enqueuing task", taskIdentifier, payload);
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
};
const handler = (payload: TPayload, helpers: Helpers) => {
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
return taskHandler(payload, helpers);
};
const task = {
identifier: taskIdentifier,
handler: handler as Task,
};
return {
enqueue,
task,
};
}
export default defineTask;

View File

@@ -0,0 +1,182 @@
import crypto from "crypto";
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { type CompletionResponse, getOpenAIChatCompletion } from "../utils/getCompletion";
import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel";
import { runEvalsForOutput } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
import { type Prisma } from "@prisma/client";
const MAX_AUTO_RETRIES = 10;
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 15000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}
const getCompletionWithRetries = async (
cellId: string,
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> => {
let modelResponse: CompletionResponse | null = null;
try {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
modelResponse = await getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
channel,
);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse;
}
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
errorMessage: "Rate limit exceeded",
statusCode: 429,
retryTime: new Date(Date.now() + delay),
},
});
// TODO: Maybe requeue the job so other jobs can run in the future?
await sleep(delay);
}
throw new Error("Max retries limit reached");
} catch (error: unknown) {
return {
statusCode: modelResponse?.statusCode ?? 500,
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
output: null,
timeToComplete: 0,
};
}
};
export type queryLLMJob = {
scenarioVariantCellId: string;
};
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const { scenarioVariantCellId } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: scenarioVariantCellId },
include: { modelOutput: true },
});
if (!cell) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Cell not found",
retrievalStatus: "ERROR",
},
});
return;
}
// If cell is not pending, then some other job is already processing it
if (cell.retrievalStatus !== "PENDING") {
return;
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
retrievalStatus: "IN_PROGRESS",
},
});
const variant = await prisma.promptVariant.findUnique({
where: { id: cell.promptVariantId },
});
if (!variant) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR",
},
});
return;
}
const scenario = await prisma.testScenario.findUnique({
where: { id: cell.testScenarioId },
});
if (!scenario) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Scenario not found",
retrievalStatus: "ERROR",
},
});
return;
}
const prompt = await constructPrompt(variant, scenario.variableValues);
const streamingEnabled = shouldStream(prompt);
let streamingChannel;
if (streamingEnabled) {
streamingChannel = generateChannel();
// Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
streamingChannel,
},
});
}
const modelResponse = await getCompletionWithRetries(
scenarioVariantCellId,
prompt,
streamingChannel,
);
let modelOutput = null;
if (modelResponse.statusCode === 200) {
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
modelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId,
inputHash,
output: modelResponse.output as unknown as Prisma.InputJsonObject,
timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens,
cost: modelResponse.cost,
},
});
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: modelResponse.statusCode,
errorMessage: modelResponse.errorMessage,
streamingChannel: null,
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
modelOutput: {
connect: {
id: modelOutput?.id,
},
},
},
});
if (modelOutput) {
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
}
});

View File

@@ -0,0 +1,40 @@
import { type TaskList, run } from "graphile-worker";
import "dotenv/config";
import { env } from "~/env.mjs";
import { queryLLM } from "./queryLLM.task";
const registeredTasks = [queryLLM];
const taskList = registeredTasks.reduce((acc, task) => {
acc[task.task.identifier] = task.task.handler;
return acc;
}, {} as TaskList);
async function main() {
// Run a worker to execute jobs:
const runner = await run({
connectionString: env.DATABASE_URL,
concurrency: 20,
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
noHandleSignals: false,
pollInterval: 1000,
// you can set the taskList or taskDirectory but not both
taskList,
// or:
// taskDirectory: `${__dirname}/tasks`,
});
// Immediately await (or otherwise handled) the resulting promise, to avoid
// "unhandled rejection" errors causing a process crash in the event of
// something going wrong.
await runner.promise;
// If the worker exits (whether through fatal error or otherwise), the above
// promise will resolve/reject.
}
main().catch((err) => {
console.error("Unhandled error occurred running worker: ", err);
process.exit(1);
});

View File

@@ -9,7 +9,7 @@ export async function constructPrompt(
scenario: TestScenario["variableValues"],
): Promise<JSONSerializable> {
const code = `
const scenario = ${JSON.stringify(scenario, null, 2)};
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
let prompt
${variant.constructFn}

View File

@@ -0,0 +1,123 @@
import { type PromptVariant } from "@prisma/client";
import { type SupportedModel } from "../types";
import ivm from "isolated-vm";
import dedent from "dedent";
import { openai } from "./openai";
import { getApiShapeForModel } from "./getTypesForModel";
import { isObject } from "lodash-es";
import { type CompletionCreateParams } from "openai/resources/chat/completions";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn(
originalVariant: PromptVariant | null,
newModel?: SupportedModel,
instructions?: string,
) {
if (originalVariant && !newModel && !instructions) {
return originalVariant.constructFn;
}
if (originalVariant && (newModel || instructions)) {
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
}
return dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`;
}
const NUM_RETRIES = 5;
const requestUpdatedPromptFunction = async (
originalVariant: PromptVariant,
newModel?: SupportedModel,
instructions?: string,
) => {
const originalModel = originalVariant.model as SupportedModel;
let newContructionFn = "";
for (let i = 0; i < NUM_RETRIES; i++) {
try {
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming.Message[] = [
{
role: "system",
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
getApiShapeForModel(originalModel),
null,
2,
)}`,
},
];
if (newModel) {
messages.push({
role: "user",
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
});
}
if (instructions) {
messages.push({
role: "user",
content: `Follow these instructions: ${instructions}`,
});
}
messages.push({
role: "user",
content:
"The prompt variable has already been declared, so do not declare it again. Rewrite the entire prompt constructor function.",
});
const completion = await openai.chat.completions.create({
model: "gpt-4",
messages,
functions: [
{
name: "update_prompt_constructor_function",
parameters: {
type: "object",
properties: {
new_prompt_function: {
type: "string",
description: "The new prompt function, runnable in typescript",
},
},
},
},
],
function_call: {
name: "update_prompt_constructor_function",
},
});
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
const code = `
global.contructPromptFunctionArgs = ${argString};
`;
const context = await isolate.createContext();
const jail = context.global;
await jail.set("global", jail.derefInto());
const script = await isolate.compileScript(code);
await script.run(context);
const contructPromptFunctionArgs = (await context.global.get(
"contructPromptFunctionArgs",
)) as ivm.Reference;
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
if (args && isObject(args) && "new_prompt_function" in args) {
newContructionFn = args.new_prompt_function as string;
break;
}
} catch (e) {
console.error(e);
}
}
return newContructionFn;
};

View File

@@ -1,31 +0,0 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate";
export const evaluateOutput = (
modelOutput: ModelOutput,
scenario: TestScenario,
evaluation: Evaluation,
): boolean => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return false;
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
let match;
switch (evaluation.matchType) {
case "CONTAINS":
match = stringifiedMessage.match(matchRegex) !== null;
break;
case "DOES_NOT_CONTAIN":
match = stringifiedMessage.match(matchRegex) === null;
break;
}
return match;
};

View File

@@ -1,103 +1,79 @@
import { type Evaluation } from "@prisma/client";
import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput";
import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
export const reevaluateVariant = async (variantId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: { id: variantId },
});
if (!variant) return;
const evaluations = await prisma.evaluation.findMany({
where: { experimentId: variant.experimentId },
});
const modelOutputs = await prisma.modelOutput.findMany({
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
const result = await runOneEval(evaluation, scenario, modelOutput);
return await prisma.outputEvaluation.upsert({
where: {
promptVariantId: variantId,
statusCode: { notIn: [429] },
testScenario: { visible: true },
modelOutputId_evaluationId: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id,
},
},
create: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id,
...result,
},
update: {
...result,
},
include: { testScenario: true },
});
await Promise.all(
evaluations.map(async (evaluation) => {
const passCount = modelOutputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation),
).length;
const failCount = modelOutputs.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variantId,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variantId,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
};
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
const variants = await prisma.promptVariant.findMany({
where: { experimentId: evaluation.experimentId, visible: true },
});
const modelOutputs = await prisma.modelOutput.findMany({
where: {
promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true },
statusCode: { notIn: [429] },
},
include: { testScenario: true },
});
await Promise.all(
variants.map(async (variant) => {
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
const passCount = outputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation),
).length;
const failCount = outputs.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
};
export const reevaluateAll = async (experimentId: string) => {
export const runEvalsForOutput = async (
experimentId: string,
scenario: Scenario,
modelOutput: ModelOutput,
) => {
const evaluations = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(evaluations.map(reevaluateEvaluation));
await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
);
};
export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelOutput.findMany({
where: {
scenarioVariantCell: {
promptVariant: {
experimentId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
include: {
scenarioVariantCell: {
include: {
testScenario: true,
},
},
outputEvaluation: true,
},
});
const evals = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
outputs.map(async (output) => {
const unrunEvals = evals.filter(
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
);
await Promise.all(
unrunEvals.map(async (evaluation) => {
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
}),
);
}),
);
};

View File

@@ -2,6 +2,16 @@ import { type JSONSerializable } from "../types";
export type VariableMap = Record<string, string>;
// Escape quotes to match the way we encode JSON
export function escapeQuotes(str: string) {
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
}
// Escape regex special characters
export function escapeRegExp(str: string) {
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
}
export function fillTemplate(template: string, variables: VariableMap): string {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
}

View File

@@ -0,0 +1,77 @@
import crypto from "crypto";
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import { constructPrompt } from "./constructPrompt";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (cell) return cell;
cell = await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
include: {
modelOutput: true,
},
});
const matchingModelOutput = await prisma.modelOutput.findFirst({
where: {
inputHash,
},
});
let newModelOutput;
if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
cost: matchingModelOutput.cost,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
},
});
} else {
cell = await queueLLMRetrievalTask(cell.id);
}
return { ...cell, modelOutput: newModelOutput };
};

View File

@@ -1,24 +1,25 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash";
import { Prisma } from "@prisma/client";
import { isObject } from "lodash-es";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type OpenAIChatModel } from "../types";
import { type SupportedModel, type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings";
import { modelStats } from "../modelStats";
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
export type CompletionResponse = {
output: ChatCompletion | null;
statusCode: number;
errorMessage: string | null;
timeToComplete: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
};
export async function getCompletion(
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
channel?: string,
): Promise<CompletionResponse> {
@@ -35,7 +36,7 @@ export async function getCompletion(
});
const resp: CompletionResponse = {
output: Prisma.JsonNull,
output: null,
errorMessage: null,
statusCode: response.status,
timeToComplete: 0,
@@ -52,7 +53,7 @@ export async function getCompletion(
}
})().catch((err) => console.error(err));
if (finalOutput) {
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
resp.output = finalOutput;
resp.timeToComplete = Date.now() - start;
}
} else {
@@ -88,6 +89,13 @@ export async function getCompletion(
resp.completionTokens = countOpenAIChatTokens(model, messages);
}
}
const stats = modelStats[resp.output?.model as SupportedModel];
if (stats && resp.promptTokens && resp.completionTokens) {
resp.cost =
resp.promptTokens * stats.promptTokenPrice +
resp.completionTokens * stats.completionTokenPrice;
}
} catch (e) {
console.error(e);
if (response.ok) {

View File

@@ -0,0 +1,7 @@
import { OpenAIChatModel, type SupportedModel } from "../types";
import openAIChatApiShape from "~/codegen/openai.types.ts.txt";
export const getApiShapeForModel = (model: SupportedModel) => {
if (model in OpenAIChatModel) return openAIChatApiShape;
return "";
};

View File

@@ -1,4 +1,4 @@
import { omit } from "lodash";
import { omit } from "lodash-es";
import { env } from "~/env.mjs";
import OpenAI from "openai";

View File

@@ -0,0 +1,22 @@
import { prisma } from "../db";
import { queryLLM } from "../tasks/queryLLM.task";
export const queueLLMRetrievalTask = async (cellId: string) => {
const updatedCell = await prisma.scenarioVariantCell.update({
where: {
id: cellId,
},
data: {
retrievalStatus: "PENDING",
errorMessage: null,
},
include: {
modelOutput: true,
},
});
// @ts-expect-error we aren't passing the helpers but that's ok
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
return updatedCell;
};

View File

@@ -0,0 +1,65 @@
import { prisma } from "~/server/db";
export const reorderPromptVariants = async (
movedId: string,
stationaryTargetId: string,
alwaysInsertRight?: boolean,
) => {
const moved = await prisma.promptVariant.findUnique({
where: {
id: movedId,
},
});
const target = await prisma.promptVariant.findUnique({
where: {
id: stationaryTargetId,
},
});
if (!moved || !target || moved.experimentId !== target.experimentId) {
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: moved.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the moved item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
// Find the index of the moved item and the target item
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
// Determine the new index for the moved item
let newIndex;
if (movedIndex < targetIndex || alwaysInsertRight) {
newIndex = targetIndex + 1; // Insert after the target item
} else {
newIndex = targetIndex; // Insert before the target item
}
// Insert the moved item at the new position
orderedItems.splice(newIndex, 0, moved);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
};

View File

@@ -0,0 +1,95 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai";
import dedent from "dedent";
export const runGpt4Eval = async (
evaluation: Evaluation,
scenario: TestScenario,
message: ChatCompletion.Choice.Message,
): Promise<{ result: number; details: string }> => {
const output = await openai.chat.completions.create({
model: "gpt-4-0613",
messages: [
{
role: "system",
content: dedent`
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
---
${evaluation.value}
`,
},
{
role: "user",
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
},
{
role: "user",
content: `The full output of the simpler message:\n---\n${JSON.stringify(
message.content ?? message.function_call,
null,
2,
)}`,
},
],
function_call: {
name: "report_success",
},
functions: [
{
name: "report_success",
parameters: {
type: "object",
required: ["thoughts", "success"],
properties: {
thoughts: {
type: "string",
description: "Explain your reasoning for considering this a pass or fail",
},
success: {
type: "boolean",
description:
"Whether the simpler model successfully completed the task for this scenario",
},
},
},
},
],
});
try {
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
} catch (e) {
console.error(e);
return { result: 0, details: "Error parsing GPT-4 output" };
}
};
export const runOneEval = async (
evaluation: Evaluation,
scenario: TestScenario,
modelOutput: ModelOutput,
): Promise<{ result: number; details?: string }> => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return { result: 0 };
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = escapeRegExp(
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
);
switch (evaluation.evalType) {
case "CONTAINS":
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
case "DOES_NOT_CONTAIN":
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
case "GPT4_EVAL":
return await runGpt4Eval(evaluation, scenario, message);
}
};

View File

@@ -0,0 +1,7 @@
import { isObject } from "lodash-es";
import { type JSONSerializable } from "../types";
export const shouldStream = (config: JSONSerializable): boolean => {
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
return shouldStream;
};

View File

@@ -0,0 +1 @@
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

View File

@@ -0,0 +1,19 @@
import { prisma } from "~/server/db";
export default async function userOrg(userId: string) {
return await prisma.organization.upsert({
where: {
personalOrgUserId: userId,
},
update: {},
create: {
personalOrgUserId: userId,
OrganizationUser: {
create: {
userId: userId,
role: "ADMIN",
},
},
},
});
}

View File

@@ -33,6 +33,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
allowNonTsExtensions: true,
strictNullChecks: true,
lib: ["esnext"],
});
@@ -84,7 +85,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
)} as const;
type Scenario = typeof scenarios[number];
declare var scenario: Scenario | null;
declare var scenario: Scenario | { [key: string]: string };
`;
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));

View File

@@ -0,0 +1,49 @@
import { OrganizationUserRole } from "@prisma/client";
import { TRPCError } from "@trpc/server";
import { type TRPCContext } from "~/server/api/trpc";
import { prisma } from "~/server/db";
// No-op method for protected routes that really should be accessible to anyone.
export const requireNothing = (ctx: TRPCContext) => {
ctx.markAccessControlRun();
};
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
await prisma.experiment.findFirst({
where: { id: experimentId },
});
// Right now all experiments are publicly viewable, so this is a no-op.
ctx.markAccessControlRun();
};
export const canModifyExperiment = async (experimentId: string, userId: string) => {
const experiment = await prisma.experiment.findFirst({
where: {
id: experimentId,
organization: {
OrganizationUser: {
some: {
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
userId,
},
},
},
},
});
return !!experiment;
};
export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if (!(await canModifyExperiment(experimentId, userId))) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};

View File

@@ -1,46 +0,0 @@
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00003,
"gpt-4-0613": 0.00003,
"gpt-4-32k": 0.00006,
"gpt-4-32k-0613": 0.00006,
"gpt-3.5-turbo": 0.0000015,
"gpt-3.5-turbo-0613": 0.0000015,
"gpt-3.5-turbo-16k": 0.000003,
"gpt-3.5-turbo-16k-0613": 0.000003,
};
const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00006,
"gpt-4-0613": 0.00006,
"gpt-4-32k": 0.00012,
"gpt-4-32k-0613": 0.00012,
"gpt-3.5-turbo": 0.000002,
"gpt-3.5-turbo-0613": 0.000002,
"gpt-3.5-turbo-16k": 0.000004,
"gpt-3.5-turbo-16k-0613": 0.000004,
};
export const calculateTokenCost = (
model: SupportedModel | string | null,
numTokens: number,
isCompletion = false,
) => {
if (!model) return 0;
if (model in OpenAIChatModel) {
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
}
return 0;
};
const calculateOpenAIChatTokenCost = (
model: OpenAIChatModel,
numTokens: number,
isCompletion: boolean,
) => {
const tokensToDollars = isCompletion
? openAICompletionTokensToDollars[model]
: openAIPromptTokensToDollars[model];
return tokensToDollars * numTokens;
};

View File

@@ -5,16 +5,5 @@ import relativeTime from "dayjs/plugin/relativeTime";
dayjs.extend(duration);
dayjs.extend(relativeTime);
export const formatTimePast = (date: Date) => {
const now = dayjs();
const dayDiff = Math.floor(now.diff(date, "day"));
if (dayDiff > 0) return dayjs.duration(-dayDiff, "days").humanize(true);
const hourDiff = Math.floor(now.diff(date, "hour"));
if (hourDiff > 0) return dayjs.duration(-hourDiff, "hours").humanize(true);
const minuteDiff = Math.floor(now.diff(date, "minute"));
if (minuteDiff > 0) return dayjs.duration(-minuteDiff, "minutes").humanize(true);
return "a few seconds ago";
};
export const formatTimePast = (date: Date) =>
dayjs.duration(dayjs(date).diff(dayjs())).humanize(true);

View File

@@ -1,5 +1,5 @@
import { useRouter } from "next/router";
import { useCallback, useEffect, useState } from "react";
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
import { api } from "~/utils/api";
export const useExperiment = () => {
@@ -12,6 +12,10 @@ export const useExperiment = () => {
return experiment;
};
export const useExperimentAccess = () => {
return useExperiment().data?.access ?? { canView: false, canModify: false };
};
type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>;
export function useHandledAsyncCallback<T extends unknown[], U>(
@@ -49,3 +53,43 @@ export const useModifierKeyLabel = () => {
}, []);
return label;
};
interface Dimensions {
left: number;
top: number;
right: number;
bottom: number;
width: number;
height: number;
x: number;
y: number;
}
// get dimensions of an element
export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | undefined] => {
const ref = useRef<HTMLElement>(null);
const [dimensions, setDimensions] = useState<Dimensions | undefined>();
useEffect(() => {
if (ref.current) {
const observer = new ResizeObserver((entries) => {
entries.forEach((entry) => {
setDimensions(entry.contentRect);
});
});
const observedRef = ref.current;
observer.observe(observedRef);
// Cleanup the observer on component unmount
return () => {
if (observedRef) {
observer.unobserve(observedRef);
}
};
}
}, []);
return [ref, dimensions];
};

View File

@@ -1,4 +1,5 @@
import { extendTheme } from "@chakra-ui/react";
import "@fontsource/inconsolata";
const systemFont =
'ui-sans-serif, -apple-system, "system-ui", "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol"';

View File

@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
const url = env.NEXT_PUBLIC_SOCKET_URL;
export default function useSocket(channel?: string) {
export default function useSocket(channel?: string | null) {
const socketRef = useRef<Socket>();
const [message, setMessage] = useState<ChatCompletion | null>(null);