Compare commits
47 Commits
autoformat
...
no-model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60765e51ac | ||
|
|
4c97b9f147 | ||
|
|
58892d8b63 | ||
|
|
4fa2dffbcb | ||
|
|
654f8c7cf2 | ||
|
|
d02482468d | ||
|
|
5c6ed22f1d | ||
|
|
2cb623f332 | ||
|
|
1c1cefe286 | ||
|
|
b4aa95edca | ||
|
|
1dcdba04a6 | ||
|
|
e0e64c4207 | ||
|
|
fa5b1ab1c5 | ||
|
|
999a4c08fa | ||
|
|
374d0237ee | ||
|
|
b1f873623d | ||
|
|
4131aa67d0 | ||
|
|
8e7a6d3ae2 | ||
|
|
7d41e94ca2 | ||
|
|
011b12abb9 | ||
|
|
1ba18015bc | ||
|
|
54369dba54 | ||
|
|
6b84a59372 | ||
|
|
8db8aeacd3 | ||
|
|
64bd71e370 | ||
|
|
ca21a7af06 | ||
|
|
3b99b7bd2b | ||
|
|
0c3bdbe4f2 | ||
|
|
74c201d3a8 | ||
|
|
ab9c721d09 | ||
|
|
0a2578a1d8 | ||
|
|
1bebaff386 | ||
|
|
3bf5eaf4a2 | ||
|
|
ded97f8bb9 | ||
|
|
26ee8698be | ||
|
|
b98eb9b729 | ||
|
|
032c07ec65 | ||
|
|
80c0d13bb9 | ||
|
|
f7c94be3f6 | ||
|
|
c3e85607e0 | ||
|
|
cd5927b8f5 | ||
|
|
731406d1f4 | ||
|
|
3c59e4b774 | ||
|
|
972b1f2333 | ||
|
|
7321f3deda | ||
|
|
2bd41fdfbf | ||
|
|
a5378b106b |
@@ -18,3 +18,11 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
|
||||
OPENAI_API_KEY=""
|
||||
|
||||
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||
|
||||
# Next Auth
|
||||
NEXTAUTH_SECRET="your_secret"
|
||||
NEXTAUTH_URL="http://localhost:3000"
|
||||
|
||||
# Next Auth Github Provider
|
||||
GITHUB_CLIENT_ID="your_client_id"
|
||||
GITHUB_CLIENT_SECRET="your_secret"
|
||||
|
||||
53
.github/workflows/ci.yaml
vendored
Normal file
53
.github/workflows/ci.yaml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
name: CI checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
run-checks:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: "20"
|
||||
|
||||
- uses: pnpm/action-setup@v2
|
||||
name: Install pnpm
|
||||
id: pnpm-install
|
||||
with:
|
||||
version: 8.6.1
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||
|
||||
- uses: actions/cache@v3
|
||||
name: Setup pnpm cache
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Check types
|
||||
run: pnpm tsc
|
||||
|
||||
- name: Lint
|
||||
run: SKIP_ENV_VALIDATION=1 pnpm lint
|
||||
|
||||
- name: Check prettier
|
||||
run: pnpm prettier . --check
|
||||
1
.tool-versions
Normal file
1
.tool-versions
Normal file
@@ -0,0 +1 @@
|
||||
nodejs 20.2.0
|
||||
1
@types/nextjs-routes.d.ts
vendored
1
@types/nextjs-routes.d.ts
vendored
@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
|
||||
} from "next";
|
||||
|
||||
export type Route =
|
||||
| StaticRoute<"/account/signin">
|
||||
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
|
||||
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
|
||||
| DynamicRoute<"/experiments/[id]", { "id": string }>
|
||||
|
||||
16
README.md
16
README.md
@@ -4,11 +4,18 @@
|
||||
|
||||
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
|
||||
|
||||
**Live Demo:** https://openpipe.ai
|
||||
## Sample Experiments
|
||||
|
||||
These are simple experiments users have created that show how OpenPipe works.
|
||||
|
||||
- [Country Capitals](https://openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
|
||||
- [Reddit User Needs](https://openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
|
||||
- [OpenAI Function Calls](https://openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
|
||||
- [Activity Classification](https://openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
|
||||
|
||||
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
|
||||
|
||||
Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally).
|
||||
You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
|
||||
|
||||
## High-Level Features
|
||||
|
||||
@@ -47,5 +54,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
|
||||
5. Install the dependencies: `cd openpipe && pnpm install`
|
||||
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
|
||||
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
|
||||
8. Start the app: `pnpm dev`.
|
||||
9. Navigate to [http://localhost:3000](http://localhost:3000)
|
||||
8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
|
||||
9. Start the app: `pnpm dev`.
|
||||
10. Navigate to [http://localhost:3000](http://localhost:3000)
|
||||
|
||||
25
package.json
25
package.json
@@ -3,22 +3,32 @@
|
||||
"type": "module",
|
||||
"version": "0.1.0",
|
||||
"license": "Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">=20.0.0",
|
||||
"pnpm": ">=8.6.1"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "next build",
|
||||
"dev:next": "next dev",
|
||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
||||
"postinstall": "prisma generate",
|
||||
"lint": "next lint",
|
||||
"start": "next start",
|
||||
"codegen": "tsx src/codegen/export-openai-types.ts"
|
||||
"codegen": "tsx src/codegen/export-openai-types.ts",
|
||||
"seed": "tsx prisma/seed.ts",
|
||||
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
|
||||
},
|
||||
"dependencies": {
|
||||
"@babel/preset-typescript": "^7.22.5",
|
||||
"@babel/standalone": "^7.22.9",
|
||||
"@chakra-ui/next-js": "^2.1.4",
|
||||
"@chakra-ui/react": "^2.7.1",
|
||||
"@emotion/react": "^11.11.1",
|
||||
"@emotion/server": "^11.11.0",
|
||||
"@emotion/styled": "^11.11.0",
|
||||
"@fontsource/inconsolata": "^5.0.5",
|
||||
"@monaco-editor/loader": "^1.3.3",
|
||||
"@next-auth/prisma-adapter": "^1.0.5",
|
||||
"@prisma/client": "^4.14.0",
|
||||
@@ -38,10 +48,11 @@
|
||||
"express": "^4.18.2",
|
||||
"framer-motion": "^10.12.17",
|
||||
"gpt-tokens": "^1.0.10",
|
||||
"graphile-worker": "^0.13.0",
|
||||
"immer": "^10.0.2",
|
||||
"isolated-vm": "^4.5.0",
|
||||
"json-stringify-pretty-compact": "^4.0.0",
|
||||
"lodash": "^4.17.21",
|
||||
"lodash-es": "^4.17.21",
|
||||
"next": "^13.4.2",
|
||||
"next-auth": "^4.22.1",
|
||||
"nextjs-routes": "^2.0.1",
|
||||
@@ -49,9 +60,12 @@
|
||||
"pluralize": "^8.0.0",
|
||||
"posthog-js": "^1.68.4",
|
||||
"prettier": "^3.0.0",
|
||||
"prismjs": "^1.29.0",
|
||||
"react": "18.2.0",
|
||||
"react-diff-viewer": "^3.1.1",
|
||||
"react-dom": "18.2.0",
|
||||
"react-icons": "^4.10.1",
|
||||
"react-select": "^5.7.4",
|
||||
"react-syntax-highlighter": "^15.5.0",
|
||||
"react-textarea-autosize": "^8.5.0",
|
||||
"socket.io": "^4.7.1",
|
||||
@@ -63,13 +77,16 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
||||
"@types/babel__core": "^7.20.1",
|
||||
"@types/babel__standalone": "^7.1.4",
|
||||
"@types/chroma-js": "^2.4.0",
|
||||
"@types/cors": "^2.8.13",
|
||||
"@types/eslint": "^8.37.0",
|
||||
"@types/express": "^4.17.17",
|
||||
"@types/lodash": "^4.14.195",
|
||||
"@types/lodash-es": "^4.17.8",
|
||||
"@types/node": "^18.16.0",
|
||||
"@types/pluralize": "^0.0.30",
|
||||
"@types/prismjs": "^1.26.0",
|
||||
"@types/react": "^18.2.6",
|
||||
"@types/react-dom": "^18.2.4",
|
||||
"@types/react-syntax-highlighter": "^15.5.7",
|
||||
@@ -90,6 +107,6 @@
|
||||
"initVersion": "7.14.0"
|
||||
},
|
||||
"prisma": {
|
||||
"seed": "tsx prisma/seed.ts"
|
||||
"seed": "pnpm seed"
|
||||
}
|
||||
}
|
||||
|
||||
1361
pnpm-lock.yaml
generated
1361
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
||||
-- Drop the foreign key constraints on the original ModelOutput
|
||||
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
|
||||
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
|
||||
|
||||
-- Rename the old table
|
||||
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
|
||||
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
|
||||
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
|
||||
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
|
||||
|
||||
-- Add the new fields to the renamed table
|
||||
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
|
||||
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
|
||||
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
|
||||
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
|
||||
ALTER COLUMN "statusCode" DROP NOT NULL,
|
||||
ALTER COLUMN "timeToComplete" DROP NOT NULL;
|
||||
|
||||
-- Create the new table
|
||||
CREATE TABLE "ModelOutput" (
|
||||
"id" UUID NOT NULL,
|
||||
"inputHash" TEXT NOT NULL,
|
||||
"output" JSONB NOT NULL,
|
||||
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
|
||||
"promptTokens" INTEGER,
|
||||
"completionTokens" INTEGER,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"scenarioVariantCellId" UUID
|
||||
);
|
||||
|
||||
-- Move inputHash index
|
||||
DROP INDEX "ScenarioVariantCell_inputHash_idx";
|
||||
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
|
||||
|
||||
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
|
||||
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
|
||||
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
|
||||
|
||||
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;
|
||||
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
|
||||
*/
|
||||
|
||||
-- RenameEnum
|
||||
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
|
||||
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
|
||||
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
|
||||
|
||||
-- AlterColumnType
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
|
||||
|
||||
-- SetNotNullConstraint
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
|
||||
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;
|
||||
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
|
||||
|
||||
*/
|
||||
-- AlterEnum
|
||||
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
|
||||
|
||||
-- DropTable
|
||||
DROP TABLE "EvaluationResult";
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OutputEvaluation" (
|
||||
"id" UUID NOT NULL,
|
||||
"result" DOUBLE PRECISION NOT NULL,
|
||||
"details" TEXT,
|
||||
"modelOutputId" UUID NOT NULL,
|
||||
"evaluationId" UUID NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,124 @@
|
||||
DROP TABLE "Account";
|
||||
DROP TABLE "Session";
|
||||
DROP TABLE "User";
|
||||
DROP TABLE "VerificationToken";
|
||||
|
||||
CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "Organization" (
|
||||
"id" UUID NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"personalOrgUserId" UUID,
|
||||
|
||||
CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "OrganizationUser" (
|
||||
"id" UUID NOT NULL,
|
||||
"role" "OrganizationUserRole" NOT NULL,
|
||||
"organizationId" UUID NOT NULL,
|
||||
"userId" UUID NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "Account" (
|
||||
"id" UUID NOT NULL,
|
||||
"userId" UUID NOT NULL,
|
||||
"type" TEXT NOT NULL,
|
||||
"provider" TEXT NOT NULL,
|
||||
"providerAccountId" TEXT NOT NULL,
|
||||
"refresh_token" TEXT,
|
||||
"refresh_token_expires_in" INTEGER,
|
||||
"access_token" TEXT,
|
||||
"expires_at" INTEGER,
|
||||
"token_type" TEXT,
|
||||
"scope" TEXT,
|
||||
"id_token" TEXT,
|
||||
"session_state" TEXT,
|
||||
|
||||
CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "Session" (
|
||||
"id" UUID NOT NULL,
|
||||
"sessionToken" TEXT NOT NULL,
|
||||
"userId" UUID NOT NULL,
|
||||
"expires" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "User" (
|
||||
"id" UUID NOT NULL,
|
||||
"name" TEXT,
|
||||
"email" TEXT,
|
||||
"emailVerified" TIMESTAMP(3),
|
||||
"image" TEXT,
|
||||
|
||||
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "VerificationToken" (
|
||||
"identifier" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"expires" TIMESTAMP(3) NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
|
||||
|
||||
-- AlterTable add organizationId as NULLABLE
|
||||
ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
|
||||
|
||||
-- Set default organization for existing experiments
|
||||
UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
|
||||
|
||||
-- AlterTable set organizationId as NOT NULL
|
||||
ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
|
||||
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
|
||||
|
||||
ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,16 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `completionTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `inputHash` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `output` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `promptTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `timeToComplete` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "completionTokens",
|
||||
DROP COLUMN "inputHash",
|
||||
DROP COLUMN "output",
|
||||
DROP COLUMN "promptTokens",
|
||||
DROP COLUMN "timeToComplete";
|
||||
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-js"
|
||||
previewFeatures = ["jsonProtocol"]
|
||||
}
|
||||
|
||||
datasource db {
|
||||
@@ -17,8 +16,12 @@ model Experiment {
|
||||
|
||||
sortIndex Int @default(0)
|
||||
|
||||
organizationId String @db.Uuid
|
||||
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
TemplateVariable TemplateVariable[]
|
||||
PromptVariant PromptVariant[]
|
||||
TestScenario TestScenario[]
|
||||
@@ -27,9 +30,10 @@ model Experiment {
|
||||
|
||||
model PromptVariant {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
label String
|
||||
|
||||
label String
|
||||
constructFn String
|
||||
model String
|
||||
|
||||
uiId String @default(uuid()) @db.Uuid
|
||||
visible Boolean @default(true)
|
||||
@@ -40,8 +44,7 @@ model PromptVariant {
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
ModelOutput ModelOutput[]
|
||||
EvaluationResult EvaluationResult[]
|
||||
scenarioVariantCells ScenarioVariantCell[]
|
||||
|
||||
@@index([uiId])
|
||||
}
|
||||
@@ -60,7 +63,7 @@ model TestScenario {
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
ModelOutput ModelOutput[]
|
||||
scenarioVariantCells ScenarioVariantCell[]
|
||||
}
|
||||
|
||||
model TemplateVariable {
|
||||
@@ -75,17 +78,23 @@ model TemplateVariable {
|
||||
updatedAt DateTime @updatedAt
|
||||
}
|
||||
|
||||
model ModelOutput {
|
||||
enum CellRetrievalStatus {
|
||||
PENDING
|
||||
IN_PROGRESS
|
||||
COMPLETE
|
||||
ERROR
|
||||
}
|
||||
|
||||
model ScenarioVariantCell {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
inputHash String
|
||||
output Json
|
||||
statusCode Int
|
||||
statusCode Int?
|
||||
errorMessage String?
|
||||
timeToComplete Int @default(0)
|
||||
retryTime DateTime?
|
||||
streamingChannel String?
|
||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||
|
||||
promptTokens Int? // Added promptTokens field
|
||||
completionTokens Int? // Added completionTokens field
|
||||
modelOutput ModelOutput?
|
||||
|
||||
promptVariantId String @db.Uuid
|
||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||
@@ -97,60 +106,116 @@ model ModelOutput {
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@unique([promptVariantId, testScenarioId])
|
||||
}
|
||||
|
||||
model ModelOutput {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
inputHash String
|
||||
output Json
|
||||
timeToComplete Int @default(0)
|
||||
cost Float?
|
||||
promptTokens Int?
|
||||
completionTokens Int?
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
scenarioVariantCellId String @db.Uuid
|
||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||
outputEvaluation OutputEvaluation[]
|
||||
|
||||
@@unique([scenarioVariantCellId])
|
||||
@@index([inputHash])
|
||||
}
|
||||
|
||||
enum EvaluationMatchType {
|
||||
enum EvalType {
|
||||
CONTAINS
|
||||
DOES_NOT_CONTAIN
|
||||
GPT4_EVAL
|
||||
}
|
||||
|
||||
model Evaluation {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
name String
|
||||
matchString String
|
||||
matchType EvaluationMatchType
|
||||
label String
|
||||
evalType EvalType
|
||||
value String
|
||||
|
||||
experimentId String @db.Uuid
|
||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
EvaluationResult EvaluationResult[]
|
||||
OutputEvaluation OutputEvaluation[]
|
||||
}
|
||||
|
||||
model EvaluationResult {
|
||||
model OutputEvaluation {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
passCount Int
|
||||
failCount Int
|
||||
// Number between 0 (fail) and 1 (pass)
|
||||
result Float
|
||||
details String?
|
||||
|
||||
modelOutputId String @db.Uuid
|
||||
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
||||
|
||||
evaluationId String @db.Uuid
|
||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||
|
||||
promptVariantId String @db.Uuid
|
||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@unique([modelOutputId, evaluationId])
|
||||
}
|
||||
|
||||
model Organization {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
personalOrgUserId String? @unique @db.Uuid
|
||||
PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
OrganizationUser OrganizationUser[]
|
||||
Experiment Experiment[]
|
||||
}
|
||||
|
||||
enum OrganizationUserRole {
|
||||
ADMIN
|
||||
MEMBER
|
||||
VIEWER
|
||||
}
|
||||
|
||||
model OrganizationUser {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
role OrganizationUserRole
|
||||
|
||||
organizationId String @db.Uuid
|
||||
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
|
||||
|
||||
userId String @db.Uuid
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@unique([evaluationId, promptVariantId])
|
||||
@@unique([organizationId, userId])
|
||||
}
|
||||
|
||||
// Necessary for Next auth
|
||||
model Account {
|
||||
id String @id @default(cuid())
|
||||
userId String
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
userId String @db.Uuid
|
||||
type String
|
||||
provider String
|
||||
providerAccountId String
|
||||
refresh_token String? // @db.Text
|
||||
access_token String? // @db.Text
|
||||
refresh_token String? @db.Text
|
||||
refresh_token_expires_in Int?
|
||||
access_token String? @db.Text
|
||||
expires_at Int?
|
||||
token_type String?
|
||||
scope String?
|
||||
id_token String? // @db.Text
|
||||
id_token String? @db.Text
|
||||
session_state String?
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@ -158,21 +223,23 @@ model Account {
|
||||
}
|
||||
|
||||
model Session {
|
||||
id String @id @default(cuid())
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
sessionToken String @unique
|
||||
userId String
|
||||
userId String @db.Uuid
|
||||
expires DateTime
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
}
|
||||
|
||||
model User {
|
||||
id String @id @default(cuid())
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
name String?
|
||||
email String? @unique
|
||||
emailVerified DateTime?
|
||||
image String?
|
||||
accounts Account[]
|
||||
sessions Session[]
|
||||
OrganizationUser OrganizationUser[]
|
||||
Organization Organization[]
|
||||
}
|
||||
|
||||
model VerificationToken {
|
||||
|
||||
@@ -1,59 +1,76 @@
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
|
||||
const experimentId = "11111111-1111-1111-1111-111111111111";
|
||||
const defaultId = "11111111-1111-1111-1111-111111111111";
|
||||
|
||||
await prisma.organization.deleteMany({
|
||||
where: { id: defaultId },
|
||||
});
|
||||
await prisma.organization.create({
|
||||
data: { id: defaultId },
|
||||
});
|
||||
|
||||
// Delete the existing experiment
|
||||
await prisma.experiment.deleteMany({
|
||||
where: {
|
||||
id: experimentId,
|
||||
id: defaultId,
|
||||
},
|
||||
});
|
||||
|
||||
const experiment = await prisma.experiment.create({
|
||||
await prisma.experiment.create({
|
||||
data: {
|
||||
id: experimentId,
|
||||
id: defaultId,
|
||||
label: "Country Capitals Example",
|
||||
organizationId: defaultId,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.modelOutput.deleteMany({
|
||||
await prisma.scenarioVariantCell.deleteMany({
|
||||
where: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.promptVariant.deleteMany({
|
||||
where: {
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.promptVariant.createMany({
|
||||
data: [
|
||||
{
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
label: "Prompt Variant 1",
|
||||
sortIndex: 0,
|
||||
constructFn: `prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
|
||||
temperature: 0,
|
||||
}`,
|
||||
},
|
||||
{
|
||||
experimentId,
|
||||
label: "Prompt Variant 2",
|
||||
sortIndex: 1,
|
||||
constructFn: `prompt = {
|
||||
constructFn: dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"What is the capital of {{country}}? Return just the city name and nothing else.",
|
||||
content: \`What is the capital of ${"$"}{scenario.country}?\`
|
||||
}
|
||||
],
|
||||
temperature: 0,
|
||||
}`,
|
||||
},
|
||||
{
|
||||
experimentId: defaultId,
|
||||
label: "Prompt Variant 2",
|
||||
sortIndex: 1,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
constructFn: dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
|
||||
}
|
||||
],
|
||||
temperature: 0,
|
||||
}`,
|
||||
@@ -63,14 +80,14 @@ await prisma.promptVariant.createMany({
|
||||
|
||||
await prisma.templateVariable.deleteMany({
|
||||
where: {
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.templateVariable.createMany({
|
||||
data: [
|
||||
{
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
label: "country",
|
||||
},
|
||||
],
|
||||
@@ -78,28 +95,28 @@ await prisma.templateVariable.createMany({
|
||||
|
||||
await prisma.testScenario.deleteMany({
|
||||
where: {
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.testScenario.createMany({
|
||||
data: [
|
||||
{
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
sortIndex: 0,
|
||||
variableValues: {
|
||||
country: "Spain",
|
||||
},
|
||||
},
|
||||
{
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
sortIndex: 1,
|
||||
variableValues: {
|
||||
country: "USA",
|
||||
},
|
||||
},
|
||||
{
|
||||
experimentId,
|
||||
experimentId: defaultId,
|
||||
sortIndex: 2,
|
||||
variableValues: {
|
||||
country: "Chile",
|
||||
@@ -107,3 +124,26 @@ await prisma.testScenario.createMany({
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const variants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: defaultId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: defaultId,
|
||||
},
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
variants
|
||||
.flatMap((variant) =>
|
||||
scenarios.map((scenario) => ({
|
||||
promptVariantId: variant.id,
|
||||
testScenarioId: scenario.id,
|
||||
})),
|
||||
)
|
||||
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
||||
);
|
||||
|
||||
1130
prisma/seedDemo.ts
1130
prisma/seedDemo.ts
File diff suppressed because one or more lines are too long
@@ -12,7 +12,7 @@ services:
|
||||
dockerContext: .
|
||||
plan: standard
|
||||
domains:
|
||||
- openpipe.ai
|
||||
- app.openpipe.ai
|
||||
envVars:
|
||||
- key: NODE_ENV
|
||||
value: production
|
||||
|
||||
@@ -2,7 +2,7 @@ import fs from "fs";
|
||||
import path from "path";
|
||||
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
||||
import YAML from "yaml";
|
||||
import _ from "lodash";
|
||||
import { pick } from "lodash-es";
|
||||
import assert from "assert";
|
||||
|
||||
const OPENAPI_URL =
|
||||
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
|
||||
|
||||
delete schema["paths"];
|
||||
assert(schema.components?.schemas);
|
||||
schema.components.schemas = _.pick(schema.components?.schemas, [
|
||||
schema.components.schemas = pick(schema.components?.schemas, [
|
||||
"CreateChatCompletionRequest",
|
||||
"ChatCompletionRequestMessage",
|
||||
"ChatCompletionFunctions",
|
||||
|
||||
@@ -4,7 +4,7 @@ import React from "react";
|
||||
|
||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||
HTMLTextAreaElement,
|
||||
TextareaProps
|
||||
TextareaProps & { minRows?: number }
|
||||
> = (props, ref) => {
|
||||
return (
|
||||
<Textarea
|
||||
|
||||
@@ -11,14 +11,16 @@ import {
|
||||
FormLabel,
|
||||
Select,
|
||||
FormHelperText,
|
||||
Code,
|
||||
} from "@chakra-ui/react";
|
||||
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
|
||||
import { type Evaluation, EvalType } from "@prisma/client";
|
||||
import { useCallback, useState } from "react";
|
||||
import { BsPencil, BsX } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
|
||||
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
|
||||
|
||||
export function EvaluationEditor(props: {
|
||||
evaluation: Evaluation | null;
|
||||
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
|
||||
onCancel: () => void;
|
||||
}) {
|
||||
const [values, setValues] = useState<EvalValues>({
|
||||
name: props.evaluation?.name ?? props.defaultName ?? "",
|
||||
matchString: props.evaluation?.matchString ?? "",
|
||||
matchType: props.evaluation?.matchType ?? "CONTAINS",
|
||||
label: props.evaluation?.label ?? props.defaultName ?? "",
|
||||
value: props.evaluation?.value ?? "",
|
||||
evalType: props.evaluation?.evalType ?? "CONTAINS",
|
||||
});
|
||||
|
||||
return (
|
||||
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
||||
<HStack w="100%">
|
||||
<FormControl flex={1}>
|
||||
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
|
||||
<FormLabel fontSize="sm">Eval Name</FormLabel>
|
||||
<Input
|
||||
size="sm"
|
||||
value={values.name}
|
||||
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
||||
value={values.label}
|
||||
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flex={1}>
|
||||
<FormLabel fontSize="sm">Match Type</FormLabel>
|
||||
<FormLabel fontSize="sm">Eval Type</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={values.matchType}
|
||||
value={values.evalType}
|
||||
onChange={(e) =>
|
||||
setValues((values) => ({
|
||||
...values,
|
||||
matchType: e.target.value as EvaluationMatchType,
|
||||
evalType: e.target.value as EvalType,
|
||||
}))
|
||||
}
|
||||
>
|
||||
{Object.values(EvaluationMatchType).map((type) => (
|
||||
{Object.values(EvalType).map((type) => (
|
||||
<option key={type} value={type}>
|
||||
{type}
|
||||
</option>
|
||||
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
|
||||
</Select>
|
||||
</FormControl>
|
||||
</HStack>
|
||||
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
|
||||
<FormControl>
|
||||
<FormLabel fontSize="sm">Match String</FormLabel>
|
||||
<Input
|
||||
size="sm"
|
||||
value={values.matchString}
|
||||
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
|
||||
value={values.value}
|
||||
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||
/>
|
||||
<FormHelperText>
|
||||
This string will be interpreted as a regex and checked against each model output.
|
||||
This string will be interpreted as a regex and checked against each model output. You
|
||||
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
|
||||
</FormHelperText>
|
||||
</FormControl>
|
||||
)}
|
||||
{values.evalType === "GPT4_EVAL" && (
|
||||
<FormControl pt={2}>
|
||||
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
|
||||
<AutoResizeTextArea
|
||||
size="sm"
|
||||
value={values.value}
|
||||
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||
minRows={3}
|
||||
/>
|
||||
<FormHelperText>
|
||||
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
|
||||
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
|
||||
access to the specific prompt variant, so be sure to be clear about the task you want it
|
||||
to perform.
|
||||
</FormHelperText>
|
||||
</FormControl>
|
||||
)}
|
||||
<HStack alignSelf="flex-end">
|
||||
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
||||
Cancel
|
||||
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
|
||||
}
|
||||
await utils.evaluations.list.invalidate();
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
await utils.scenarioVariantCells.get.invalidate();
|
||||
}, []);
|
||||
|
||||
const onCancel = useCallback(() => {
|
||||
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
|
||||
align="center"
|
||||
key={evaluation.id}
|
||||
>
|
||||
<Text fontWeight="bold">{evaluation.name}</Text>
|
||||
<Text fontWeight="bold">{evaluation.label}</Text>
|
||||
<Text flex={1}>
|
||||
{evaluation.matchType}: "{evaluation.matchString}"
|
||||
{evaluation.evalType}: "{evaluation.value}"
|
||||
</Text>
|
||||
<Button
|
||||
variant="unstyled"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react";
|
||||
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { BsCheck, BsX } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
|
||||
<Heading size="sm">Scenario Variables</Heading>
|
||||
<Stack spacing={2}>
|
||||
<Text fontSize="sm">
|
||||
Scenario variables can be used in your prompt variants as well as evaluations. Reference
|
||||
them using <Code>{"{{curly_braces}}"}</Code>.
|
||||
Scenario variables can be used in your prompt variants as well as evaluations.
|
||||
</Text>
|
||||
<HStack spacing={0}>
|
||||
<Input
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
// Extracted Button styling into reusable component
|
||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
||||
@@ -17,6 +17,8 @@ const StyledButton = ({ children, onClick }: ButtonProps) => (
|
||||
);
|
||||
|
||||
export default function NewScenarioButton() {
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.scenarios.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
@@ -38,6 +40,8 @@ export default function NewScenarioButton() {
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
if (!canModify) return null;
|
||||
|
||||
return (
|
||||
<HStack spacing={2}>
|
||||
<StyledButton onClick={onClick}>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Button, Icon, Spinner } from "@chakra-ui/react";
|
||||
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
|
||||
export default function NewVariantButton() {
|
||||
@@ -17,6 +17,9 @@ export default function NewVariantButton() {
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const { canModify } = useExperimentAccess();
|
||||
if (!canModify) return <Box w={cellPadding.x} />;
|
||||
|
||||
return (
|
||||
<Button
|
||||
w="100%"
|
||||
@@ -31,7 +34,7 @@ export default function NewVariantButton() {
|
||||
minH={headerMinHeight}
|
||||
>
|
||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||
Add Variant
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
37
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
37
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
@@ -0,0 +1,37 @@
|
||||
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
import { useExperimentAccess } from "~/utils/hooks";
|
||||
|
||||
export const CellOptions = ({
|
||||
refetchingOutput,
|
||||
refetchOutput,
|
||||
}: {
|
||||
refetchingOutput: boolean;
|
||||
refetchOutput: () => void;
|
||||
}) => {
|
||||
const { canModify } = useExperimentAccess();
|
||||
return (
|
||||
<HStack justifyContent="flex-end" w="full">
|
||||
{!refetchingOutput && canModify && (
|
||||
<Tooltip label="Refetch output" aria-label="refetch output">
|
||||
<Button
|
||||
size="xs"
|
||||
w={4}
|
||||
h={4}
|
||||
py={4}
|
||||
px={4}
|
||||
minW={0}
|
||||
borderRadius={8}
|
||||
color="gray.500"
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
@@ -1,29 +1,21 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
||||
import { type ScenarioVariantCell } from "@prisma/client";
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import pluralize from "pluralize";
|
||||
|
||||
const MAX_AUTO_RETRIES = 3;
|
||||
|
||||
export const ErrorHandler = ({
|
||||
output,
|
||||
cell,
|
||||
refetchOutput,
|
||||
numPreviousTries,
|
||||
}: {
|
||||
output: ModelOutput;
|
||||
cell: ScenarioVariantCell;
|
||||
refetchOutput: () => void;
|
||||
numPreviousTries: number;
|
||||
}) => {
|
||||
const [msToWait, setMsToWait] = useState(0);
|
||||
const shouldAutoRetry =
|
||||
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
||||
|
||||
useEffect(() => {
|
||||
if (!shouldAutoRetry) return;
|
||||
if (!cell.retryTime) return;
|
||||
|
||||
const initialWaitTime = calculateDelay(numPreviousTries);
|
||||
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
||||
const msModuloOneSecond = initialWaitTime % 1000;
|
||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||
setMsToWait(remainingTime);
|
||||
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
|
||||
setMsToWait(remainingTime);
|
||||
|
||||
if (remainingTime <= 0) {
|
||||
refetchOutput();
|
||||
clearInterval(interval);
|
||||
}
|
||||
}, 1000);
|
||||
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
|
||||
clearInterval(interval);
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
||||
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
||||
|
||||
return (
|
||||
<VStack w="full">
|
||||
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
||||
<Text color="red.600" fontWeight="bold">
|
||||
Error
|
||||
</Text>
|
||||
<Button
|
||||
size="xs"
|
||||
w={4}
|
||||
h={4}
|
||||
px={4}
|
||||
py={4}
|
||||
minW={0}
|
||||
borderRadius={8}
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={6} />
|
||||
</Button>
|
||||
</HStack>
|
||||
<Text color="red.600" wordBreak="break-word">
|
||||
{output.errorMessage}
|
||||
{cell.errorMessage}
|
||||
</Text>
|
||||
{msToWait > 0 && (
|
||||
<Text color="red.600" fontSize="sm">
|
||||
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 5000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import { type RouterOutputs, api } from "~/utils/api";
|
||||
import { api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "../types";
|
||||
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
|
||||
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
import stringify from "json-stringify-pretty-compact";
|
||||
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
||||
import { type ReactElement, useState, useEffect } from "react";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { isObject } from "lodash";
|
||||
import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -37,94 +36,83 @@ export default function OutputCell({
|
||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||
// disabledReason = "Save your prompt variant to see output";
|
||||
|
||||
// const model = getModelName(variant.config as JSONSerializable);
|
||||
// TODO: Temporarily hardcoding this while we get other stuff working
|
||||
const model = "gpt-3.5-turbo";
|
||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||
{ scenarioId: scenario.id, variantId: variant.id },
|
||||
{ refetchInterval },
|
||||
);
|
||||
|
||||
const outputMutation = api.outputs.get.useMutation();
|
||||
|
||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
||||
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
||||
|
||||
const fetchMutex = useRef(false);
|
||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
||||
async (forceRefetch?: boolean) => {
|
||||
if (fetchMutex.current) return;
|
||||
setNumPreviousTries((prev) => prev + 1);
|
||||
|
||||
fetchMutex.current = true;
|
||||
setOutput(null);
|
||||
|
||||
const shouldStream =
|
||||
isObject(variant) &&
|
||||
"config" in variant &&
|
||||
isObject(variant.config) &&
|
||||
"stream" in variant.config &&
|
||||
variant.config.stream === true;
|
||||
|
||||
const channel = shouldStream ? generateChannel() : undefined;
|
||||
setChannel(channel);
|
||||
|
||||
const output = await outputMutation.mutateAsync({
|
||||
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
|
||||
api.scenarioVariantCells.forceRefetch.useMutation();
|
||||
const [hardRefetch] = useHandledAsyncCallback(async () => {
|
||||
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||
await utils.scenarioVariantCells.get.invalidate({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
channel,
|
||||
forceRefetch,
|
||||
});
|
||||
setOutput(output);
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
fetchMutex.current = false;
|
||||
},
|
||||
[outputMutation, scenario.id, variant.id],
|
||||
);
|
||||
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
||||
await utils.promptVariants.stats.invalidate({
|
||||
variantId: variant.id,
|
||||
});
|
||||
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||
|
||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
||||
const fetchingOutput = queryLoading || refetchingOutput;
|
||||
|
||||
const awaitingOutput =
|
||||
!cell ||
|
||||
cell.retrievalStatus === "PENDING" ||
|
||||
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||
refetchingOutput;
|
||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||
|
||||
const modelOutput = cell?.modelOutput;
|
||||
|
||||
// Disconnect from socket if we're not streaming anymore
|
||||
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
|
||||
const streamedMessage = useSocket(cell?.streamingChannel);
|
||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
||||
|
||||
if (!vars) return null;
|
||||
|
||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||
|
||||
if (fetchingOutput && !streamedMessage)
|
||||
if (awaitingOutput && !streamedMessage)
|
||||
return (
|
||||
<Center h="100%" w="100%">
|
||||
<Spinner />
|
||||
</Center>
|
||||
);
|
||||
|
||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
|
||||
if (output && output.errorMessage) {
|
||||
return (
|
||||
<ErrorHandler
|
||||
output={output}
|
||||
refetchOutput={hardRefetch}
|
||||
numPreviousTries={numPreviousTries}
|
||||
/>
|
||||
);
|
||||
if (cell && cell.errorMessage) {
|
||||
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||
}
|
||||
|
||||
const response = output?.output as unknown as ChatCompletion;
|
||||
const response = modelOutput?.output as unknown as ChatCompletion;
|
||||
const message = response?.choices?.[0]?.message;
|
||||
|
||||
if (output && message?.function_call) {
|
||||
if (modelOutput && message?.function_call) {
|
||||
const rawArgs = message.function_call.arguments ?? "null";
|
||||
let parsedArgs: string;
|
||||
try {
|
||||
parsedArgs = JSON.parse(rawArgs);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} catch (e: any) {
|
||||
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
||||
<VStack
|
||||
w="100%"
|
||||
h="100%"
|
||||
fontSize="xs"
|
||||
flexWrap="wrap"
|
||||
overflowX="hidden"
|
||||
justifyContent="space-between"
|
||||
>
|
||||
<VStack w="full" flex={1} spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset" }}
|
||||
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||
language="json"
|
||||
style={docco}
|
||||
lineProps={{
|
||||
@@ -140,17 +128,22 @@ export default function OutputCell({
|
||||
{ maxLength: 40 },
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
<OutputStats model={model} modelOutput={output} scenario={scenario} />
|
||||
</Box>
|
||||
</VStack>
|
||||
<OutputStats modelOutput={modelOutput} scenario={scenario} />
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
|
||||
const contentToDisplay =
|
||||
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
|
||||
|
||||
return (
|
||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
{contentToDisplay}
|
||||
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
|
||||
</Flex>
|
||||
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<Text>{contentToDisplay}</Text>
|
||||
</VStack>
|
||||
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,64 +1,56 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { type Scenario } from "../types";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||
import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
|
||||
const SHOW_COST = false;
|
||||
const SHOW_TIME = false;
|
||||
const SHOW_TIME = true;
|
||||
|
||||
export const OutputStats = ({
|
||||
model,
|
||||
modelOutput,
|
||||
scenario,
|
||||
}: {
|
||||
model: SupportedModel | null;
|
||||
modelOutput: ModelOutput;
|
||||
modelOutput: NonNullable<
|
||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||
>;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
const timeToComplete = modelOutput.timeToComplete;
|
||||
const experiment = useExperiment();
|
||||
const evals =
|
||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||
|
||||
const promptTokens = modelOutput.promptTokens;
|
||||
const completionTokens = modelOutput.completionTokens;
|
||||
|
||||
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
||||
const completionCost =
|
||||
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
||||
|
||||
const cost = promptCost + completionCost;
|
||||
|
||||
if (!evals.length) return null;
|
||||
|
||||
return (
|
||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
{evals.map((evaluation) => {
|
||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||
{modelOutput.outputEvaluation.map((evaluation) => {
|
||||
const passed = evaluation.result > 0.5;
|
||||
return (
|
||||
<HStack spacing={0} key={evaluation.id}>
|
||||
<Text>{evaluation.name}</Text>
|
||||
<Tooltip
|
||||
isDisabled={!evaluation.details}
|
||||
label={evaluation.details}
|
||||
key={evaluation.id}
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Text>{evaluation.evaluation.label}</Text>
|
||||
<Icon
|
||||
as={passed ? BsCheck : BsX}
|
||||
color={passed ? "green.500" : "red.500"}
|
||||
boxSize={6}
|
||||
/>
|
||||
</HStack>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{SHOW_COST && (
|
||||
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}>
|
||||
{modelOutput.cost && (
|
||||
<CostTooltip
|
||||
promptTokens={promptTokens}
|
||||
completionTokens={completionTokens}
|
||||
cost={modelOutput.cost}
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{cost.toFixed(3)}</Text>
|
||||
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
</CostTooltip>
|
||||
)}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { type DragEvent } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { isEqual } from "lodash";
|
||||
import { isEqual } from "lodash-es";
|
||||
import { type Scenario } from "./types";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useState } from "react";
|
||||
|
||||
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
|
||||
@@ -13,11 +13,14 @@ import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
export default function ScenarioEditor({
|
||||
scenario,
|
||||
hovered,
|
||||
...props
|
||||
}: {
|
||||
scenario: Scenario;
|
||||
hovered: boolean;
|
||||
canHide: boolean;
|
||||
}) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const savedValues = scenario.variableValues as Record<string, string>;
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
@@ -73,6 +76,7 @@ export default function ScenarioEditor({
|
||||
alignItems="flex-start"
|
||||
pr={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
pl={canModify ? 0 : cellPadding.x}
|
||||
height="100%"
|
||||
draggable={!variableInputHovered}
|
||||
onDragStart={(e) => {
|
||||
@@ -92,7 +96,10 @@ export default function ScenarioEditor({
|
||||
onDrop={onReorder}
|
||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||
>
|
||||
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}>
|
||||
{canModify && (
|
||||
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
|
||||
{props.canHide && (
|
||||
<>
|
||||
<Tooltip label="Hide scenario" hasArrow>
|
||||
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
||||
<Button
|
||||
@@ -116,7 +123,11 @@ export default function ScenarioEditor({
|
||||
color="gray.400"
|
||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</Stack>
|
||||
)}
|
||||
|
||||
{variableLabels.length === 0 ? (
|
||||
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
||||
) : (
|
||||
@@ -150,6 +161,8 @@ export default function ScenarioEditor({
|
||||
fontSize="sm"
|
||||
lineHeight={1.2}
|
||||
value={value}
|
||||
isDisabled={!canModify}
|
||||
_disabled={{ opacity: 1, cursor: "default" }}
|
||||
onChange={(e) => {
|
||||
setValues((prev) => ({ ...prev, [key]: e.target.value }));
|
||||
}}
|
||||
|
||||
@@ -5,7 +5,11 @@ import OutputCell from "./OutputCell/OutputCell";
|
||||
import ScenarioEditor from "./ScenarioEditor";
|
||||
import type { PromptVariant, Scenario } from "./types";
|
||||
|
||||
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
|
||||
const ScenarioRow = (props: {
|
||||
scenario: Scenario;
|
||||
variants: PromptVariant[];
|
||||
canHide: boolean;
|
||||
}) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const highlightStyle = { backgroundColor: "gray.50" };
|
||||
@@ -18,7 +22,7 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
borderLeftWidth={1}
|
||||
>
|
||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} />
|
||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||
</GridItem>
|
||||
{props.variants.map((variant) => (
|
||||
<GridItem
|
||||
|
||||
52
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
52
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
@@ -0,0 +1,52 @@
|
||||
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
||||
import { cellPadding } from "../constants";
|
||||
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
import { BsPencil } from "react-icons/bs";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
export const ScenariosHeader = ({
|
||||
headerRows,
|
||||
numScenarios,
|
||||
}: {
|
||||
headerRows: number;
|
||||
numScenarios: number;
|
||||
}) => {
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const [ref, dimensions] = useElementDimensions();
|
||||
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
||||
|
||||
return (
|
||||
<GridItem
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
ref={ref as any}
|
||||
display="flex"
|
||||
alignItems="flex-end"
|
||||
rowSpan={headerRows}
|
||||
px={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
// Only display the part of the grid item that has content
|
||||
sx={{ ...stickyHeaderStyle, top: topValue }}
|
||||
>
|
||||
<HStack w="100%">
|
||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||
Scenarios ({numScenarios})
|
||||
</Heading>
|
||||
{canModify && (
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
color="gray.500"
|
||||
aria-label="Edit"
|
||||
leftIcon={<BsPencil />}
|
||||
onClick={openDrawer}
|
||||
>
|
||||
Edit Vars
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
</GridItem>
|
||||
);
|
||||
};
|
||||
@@ -1,12 +1,12 @@
|
||||
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
|
||||
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
||||
import { useRef, useEffect, useState, useCallback } from "react";
|
||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useAppStore } from "~/state/store";
|
||||
// import openAITypes from "~/codegen/openai.types.ts.txt";
|
||||
|
||||
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||
@@ -18,15 +18,23 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
|
||||
const checkForChanges = useCallback(() => {
|
||||
if (!editorRef.current) return;
|
||||
const currentConfig = editorRef.current.getValue();
|
||||
setIsChanged(currentConfig !== lastSavedFn);
|
||||
const currentFn = editorRef.current.getValue();
|
||||
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||
}, [lastSavedFn]);
|
||||
|
||||
const matchUpdatedSavedFn = useCallback(() => {
|
||||
if (!editorRef.current) return;
|
||||
editorRef.current.setValue(lastSavedFn);
|
||||
setIsChanged(false);
|
||||
}, [lastSavedFn]);
|
||||
|
||||
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
|
||||
|
||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||
const utils = api.useContext();
|
||||
const toast = useToast();
|
||||
|
||||
const [onSave] = useHandledAsyncCallback(async () => {
|
||||
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!editorRef.current) return;
|
||||
|
||||
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
||||
@@ -39,18 +47,6 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
const model = editorRef.current.getModel();
|
||||
if (!model) return;
|
||||
|
||||
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
|
||||
const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
|
||||
|
||||
if (hasErrors) {
|
||||
toast({
|
||||
title: "Invalid TypeScript",
|
||||
description: "Please fix the TypeScript errors before saving.",
|
||||
status: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
|
||||
const promptRegex = /prompt\s*=/;
|
||||
if (!promptRegex.test(currentFn)) {
|
||||
@@ -64,14 +60,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
return;
|
||||
}
|
||||
|
||||
await replaceVariant.mutateAsync({
|
||||
const resp = await replaceVariant.mutateAsync({
|
||||
id: props.variant.id,
|
||||
constructFn: currentFn,
|
||||
});
|
||||
if (resp.status === "error") {
|
||||
return toast({
|
||||
title: "Error saving variant",
|
||||
description: resp.message,
|
||||
status: "error",
|
||||
});
|
||||
}
|
||||
|
||||
setIsChanged(false);
|
||||
|
||||
await utils.promptVariants.list.invalidate();
|
||||
|
||||
checkForChanges();
|
||||
}, [checkForChanges]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -95,6 +98,7 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
wordWrapBreakAfterCharacters: "",
|
||||
wordWrapBreakBeforeCharacters: "",
|
||||
quickSuggestions: true,
|
||||
readOnly: !canModify,
|
||||
});
|
||||
|
||||
editorRef.current.onDidFocusEditorText(() => {
|
||||
@@ -122,21 +126,16 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||
}, [monaco, editorId]);
|
||||
|
||||
// useEffect(() => {
|
||||
// const savedConfigChanged = lastSavedFn !== savedConfig;
|
||||
|
||||
// lastSavedFn = savedConfig;
|
||||
|
||||
// if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
|
||||
// editorRef.current?.setValue(savedConfig);
|
||||
// }
|
||||
|
||||
// checkForChanges();
|
||||
// }, [savedConfig, checkForChanges]);
|
||||
useEffect(() => {
|
||||
if (!editorRef.current) return;
|
||||
editorRef.current.updateOptions({
|
||||
readOnly: !canModify,
|
||||
});
|
||||
}, [canModify]);
|
||||
|
||||
return (
|
||||
<Box w="100%" pos="relative">
|
||||
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
||||
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
||||
{isChanged && (
|
||||
<HStack pos="absolute" bottom={2} right={2}>
|
||||
<Button
|
||||
@@ -150,8 +149,8 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
Reset
|
||||
</Button>
|
||||
<Tooltip label={`${modifierKey} + Enter`}>
|
||||
<Button size="sm" onClick={onSave} colorScheme="blue">
|
||||
Save
|
||||
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
import { useState, type DragEvent } from "react";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
|
||||
import { BsX } from "react-icons/bs";
|
||||
import { RiDraggable } from "react-icons/ri";
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
export default function VariantHeader(props: { variant: PromptVariant }) {
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||
const [label, setLabel] = useState(props.variant.label);
|
||||
|
||||
const updateMutation = api.promptVariants.update.useMutation();
|
||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||
if (label && label !== props.variant.label) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: props.variant.id,
|
||||
updates: { label: label },
|
||||
});
|
||||
}
|
||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
||||
|
||||
const hideMutation = api.promptVariants.hide.useMutation();
|
||||
const [onHide] = useHandledAsyncCallback(async () => {
|
||||
await hideMutation.mutateAsync({
|
||||
id: props.variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [hideMutation, props.variant.id]);
|
||||
|
||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||
const [onReorder] = useHandledAsyncCallback(
|
||||
async (e: DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(false);
|
||||
const draggedId = e.dataTransfer.getData("text/plain");
|
||||
const droppedId = props.variant.id;
|
||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||
await reorderMutation.mutateAsync({
|
||||
draggedId,
|
||||
droppedId,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
},
|
||||
[reorderMutation, props.variant.id],
|
||||
);
|
||||
|
||||
return (
|
||||
<HStack
|
||||
spacing={4}
|
||||
alignItems="center"
|
||||
minH={headerMinHeight}
|
||||
draggable={!isInputHovered}
|
||||
onDragStart={(e) => {
|
||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
||||
e.currentTarget.style.opacity = "0.4";
|
||||
}}
|
||||
onDragEnd={(e) => {
|
||||
e.currentTarget.style.opacity = "1";
|
||||
}}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(true);
|
||||
}}
|
||||
onDragLeave={() => {
|
||||
setIsDragTarget(false);
|
||||
}}
|
||||
onDrop={onReorder}
|
||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||
>
|
||||
<Icon
|
||||
as={RiDraggable}
|
||||
boxSize={6}
|
||||
color="gray.400"
|
||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||
/>
|
||||
<AutoResizeTextArea // Changed to Input
|
||||
size="sm"
|
||||
value={label}
|
||||
onChange={(e) => setLabel(e.target.value)}
|
||||
onBlur={onSaveLabel}
|
||||
placeholder="Variant Name"
|
||||
borderWidth={1}
|
||||
borderColor="transparent"
|
||||
fontWeight="bold"
|
||||
fontSize={16}
|
||||
_hover={{ borderColor: "gray.300" }}
|
||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||
flex={1}
|
||||
px={cellPadding.x}
|
||||
onMouseEnter={() => setIsInputHovered(true)}
|
||||
onMouseLeave={() => setIsInputHovered(false)}
|
||||
/>
|
||||
<Tooltip label="Hide Variant" hasArrow>
|
||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
||||
<Icon as={BsX} boxSize={6} />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
);
|
||||
}
|
||||
@@ -5,8 +5,10 @@ import { api } from "~/utils/api";
|
||||
import chroma from "chroma-js";
|
||||
import { BsCurrencyDollar } from "react-icons/bs";
|
||||
import { CostTooltip } from "../tooltip/CostTooltip";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||
const { data } = api.promptVariants.stats.useQuery(
|
||||
{
|
||||
variantId: props.variant.id,
|
||||
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
completionTokens: 0,
|
||||
scenarioCount: 0,
|
||||
outputCount: 0,
|
||||
awaitingRetrievals: false,
|
||||
},
|
||||
refetchInterval,
|
||||
},
|
||||
);
|
||||
|
||||
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||
useEffect(
|
||||
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
|
||||
[data.awaitingRetrievals],
|
||||
);
|
||||
|
||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||
"green.500",
|
||||
"gray.500",
|
||||
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
|
||||
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
||||
|
||||
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
|
||||
|
||||
return (
|
||||
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
|
||||
<HStack
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
mx="2"
|
||||
fontSize="xs"
|
||||
py={cellPadding.y}
|
||||
>
|
||||
{showNumFinished && (
|
||||
<Text>
|
||||
{data.outputCount} / {data.scenarioCount}
|
||||
</Text>
|
||||
)}
|
||||
<HStack px={cellPadding.x} py={cellPadding.y}>
|
||||
<HStack px={cellPadding.x}>
|
||||
{data.evalResults.map((result) => {
|
||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
||||
const passedFrac = result.passCount / result.totalCount;
|
||||
return (
|
||||
<HStack key={result.id}>
|
||||
<Text>{result.evaluation.name}</Text>
|
||||
<Text>{result.label}</Text>
|
||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||
{(passedFrac * 100).toFixed(1)}%
|
||||
</Text>
|
||||
@@ -55,13 +69,13 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{data.overallCost && (
|
||||
{data.overallCost && !data.awaitingRetrievals && (
|
||||
<CostTooltip
|
||||
promptTokens={data.promptTokens}
|
||||
completionTokens={data.completionTokens}
|
||||
cost={data.overallCost}
|
||||
>
|
||||
<HStack spacing={0} align="center" color="gray.500" my="2">
|
||||
<HStack spacing={0} align="center" color="gray.500">
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
|
||||
import { Grid, GridItem } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import NewScenarioButton from "./NewScenarioButton";
|
||||
import NewVariantButton from "./NewVariantButton";
|
||||
import ScenarioRow from "./ScenarioRow";
|
||||
import VariantConfigEditor from "./VariantEditor";
|
||||
import VariantHeader from "./VariantHeader";
|
||||
import { cellPadding } from "../constants";
|
||||
import { BsPencil } from "react-icons/bs";
|
||||
import VariantEditor from "./VariantEditor";
|
||||
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||
import VariantStats from "./VariantStats";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
top: "-1px",
|
||||
backgroundColor: "#fff",
|
||||
zIndex: 1,
|
||||
};
|
||||
import { ScenariosHeader } from "./ScenariosHeader";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
|
||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||
const variants = api.promptVariants.list.useQuery(
|
||||
{ experimentId: experimentId as string },
|
||||
{ enabled: !!experimentId },
|
||||
);
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
|
||||
const scenarios = api.scenarios.list.useQuery(
|
||||
{ experimentId: experimentId as string },
|
||||
@@ -49,37 +40,10 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
}}
|
||||
fontSize="sm"
|
||||
>
|
||||
<GridItem
|
||||
display="flex"
|
||||
alignItems="flex-end"
|
||||
rowSpan={headerRows}
|
||||
px={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
|
||||
// so if the header height changes, this will need to be updated.
|
||||
sx={{ ...stickyHeaderStyle, top: "-337px" }}
|
||||
>
|
||||
<HStack w="100%">
|
||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||
Scenarios ({scenarios.data.length})
|
||||
</Heading>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
color="gray.500"
|
||||
aria-label="Edit"
|
||||
leftIcon={<BsPencil />}
|
||||
onClick={openDrawer}
|
||||
>
|
||||
Edit Vars
|
||||
</Button>
|
||||
</HStack>
|
||||
</GridItem>
|
||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||
<VariantHeader variant={variant} />
|
||||
</GridItem>
|
||||
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
|
||||
))}
|
||||
<GridItem
|
||||
rowSpan={scenarios.data.length + headerRows}
|
||||
@@ -94,7 +58,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId}>
|
||||
<VariantConfigEditor variant={variant} />
|
||||
<VariantEditor variant={variant} />
|
||||
</GridItem>
|
||||
))}
|
||||
{variants.data.map((variant) => (
|
||||
@@ -103,7 +67,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
</GridItem>
|
||||
))}
|
||||
{scenarios.data.map((scenario) => (
|
||||
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
|
||||
<ScenarioRow
|
||||
key={scenario.uiId}
|
||||
scenario={scenario}
|
||||
variants={variants.data}
|
||||
canHide={scenarios.data.length > 1}
|
||||
/>
|
||||
))}
|
||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
||||
<NewScenarioButton />
|
||||
|
||||
8
src/components/OutputsTable/styles.ts
Normal file
8
src/components/OutputsTable/styles.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { type SystemStyleObject } from "@chakra-ui/react";
|
||||
|
||||
export const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
top: "0",
|
||||
backgroundColor: "#fff",
|
||||
zIndex: 1,
|
||||
};
|
||||
@@ -1,21 +0,0 @@
|
||||
import { Flex, Icon, Link, Text } from "@chakra-ui/react";
|
||||
import { BsExclamationTriangleFill } from "react-icons/bs";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
export default function PublicPlaygroundWarning() {
|
||||
if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
|
||||
|
||||
return (
|
||||
<Flex bgColor="red.600" color="whiteAlpha.900" p={2} align="center">
|
||||
<Icon boxSize={4} mr={2} as={BsExclamationTriangleFill} />
|
||||
<Text>
|
||||
Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
|
||||
private use,{" "}
|
||||
<Link textDecor="underline" href="https://github.com/openpipe/openpipe" target="_blank">
|
||||
run a local copy
|
||||
</Link>
|
||||
.
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
45
src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
45
src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
@@ -0,0 +1,45 @@
|
||||
import { HStack, VStack } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||
import Prism from "prismjs";
|
||||
import "prismjs/components/prism-javascript";
|
||||
import "prismjs/themes/prism.css"; // choose a theme you like
|
||||
|
||||
const CompareFunctions = ({
|
||||
originalFunction,
|
||||
newFunction = "",
|
||||
}: {
|
||||
originalFunction: string;
|
||||
newFunction?: string;
|
||||
}) => {
|
||||
console.log("newFunction", newFunction);
|
||||
const highlightSyntax = (str: string) => {
|
||||
let highlighted;
|
||||
try {
|
||||
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
|
||||
} catch (e) {
|
||||
console.error("Error highlighting:", e);
|
||||
highlighted = str;
|
||||
}
|
||||
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
|
||||
};
|
||||
return (
|
||||
<HStack w="full" spacing={5}>
|
||||
<VStack w="full" spacing={4} maxH="65vh" fontSize={12} lineHeight={1} overflowY="auto">
|
||||
<DiffViewer
|
||||
oldValue={originalFunction}
|
||||
newValue={newFunction || originalFunction}
|
||||
splitView={true}
|
||||
hideLineNumbers={true}
|
||||
leftTitle="Original"
|
||||
rightTitle={newFunction ? "Modified" : "Unmodified"}
|
||||
disableWordDiff={true}
|
||||
compareMethod={DiffMethod.CHARS}
|
||||
renderContent={highlightSyntax}
|
||||
/>
|
||||
</VStack>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default CompareFunctions;
|
||||
103
src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
103
src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
@@ -0,0 +1,103 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
VStack,
|
||||
Text,
|
||||
Spinner,
|
||||
HStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { useState } from "react";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
import CompareFunctions from "./CompareFunctions";
|
||||
|
||||
export const RefinePromptModal = ({
|
||||
variant,
|
||||
onClose,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
|
||||
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getRefinedPromptFn.useMutation();
|
||||
const [instructions, setInstructions] = useState<string>("");
|
||||
|
||||
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!variant.experimentId) return;
|
||||
await getRefinedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
instructions,
|
||||
});
|
||||
}, [getRefinedPromptMutateAsync, onClose, variant, instructions]);
|
||||
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
|
||||
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!variant.experimentId || !refinedPromptFn) return;
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
constructFn: refinedPromptFn,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||
|
||||
return (
|
||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>Refine Your Prompt</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<HStack w="full">
|
||||
<AutoResizeTextArea
|
||||
value={instructions}
|
||||
onChange={(e) => setInstructions(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
e.currentTarget.blur();
|
||||
getRefinedPromptFn();
|
||||
}
|
||||
}}
|
||||
placeholder="Use chain of thought"
|
||||
/>
|
||||
<Button onClick={getRefinedPromptFn}>
|
||||
{refiningInProgress ? <Spinner boxSize={4} /> : <Text>Submit</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
<CompareFunctions
|
||||
originalFunction={variant.constructFn}
|
||||
newFunction={refinedPromptFn}
|
||||
/>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<HStack spacing={4}>
|
||||
<Button onClick={onClose}>Cancel</Button>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
disabled={!refinedPromptFn}
|
||||
>
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
89
src/components/SelectModelModal/ModelStatsCard.tsx
Normal file
89
src/components/SelectModelModal/ModelStatsCard.tsx
Normal file
@@ -0,0 +1,89 @@
|
||||
import {
|
||||
VStack,
|
||||
Text,
|
||||
HStack,
|
||||
type StackProps,
|
||||
GridItem,
|
||||
SimpleGrid,
|
||||
Link,
|
||||
} from "@chakra-ui/react";
|
||||
import { modelStats } from "~/server/modelStats";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
|
||||
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
|
||||
const stats = modelStats[model];
|
||||
return (
|
||||
<VStack w="full" align="start">
|
||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||
{label}
|
||||
</Text>
|
||||
|
||||
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
|
||||
<HStack w="full" align="flex-start">
|
||||
<Text flex={1} fontSize="lg">
|
||||
<Text as="span" color="gray.600">
|
||||
{stats.provider} /{" "}
|
||||
</Text>
|
||||
<Text as="span" fontWeight="bold" color="gray.900">
|
||||
{model}
|
||||
</Text>
|
||||
</Text>
|
||||
<Link
|
||||
href={stats.learnMoreUrl}
|
||||
isExternal
|
||||
color="blue.500"
|
||||
fontWeight="bold"
|
||||
fontSize="sm"
|
||||
ml={2}
|
||||
>
|
||||
Learn More
|
||||
</Link>
|
||||
</HStack>
|
||||
<SimpleGrid
|
||||
w="full"
|
||||
justifyContent="space-between"
|
||||
alignItems="flex-start"
|
||||
fontSize="sm"
|
||||
columns={{ base: 2, md: 4 }}
|
||||
>
|
||||
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
|
||||
<SelectedModelLabeledInfo
|
||||
label="Input"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo
|
||||
label="Output"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const SelectedModelLabeledInfo = ({
|
||||
label,
|
||||
info,
|
||||
...props
|
||||
}: {
|
||||
label: string;
|
||||
info: string | number | React.ReactElement;
|
||||
} & StackProps) => (
|
||||
<GridItem>
|
||||
<VStack alignItems="flex-start" {...props}>
|
||||
<Text fontWeight="bold">{label}</Text>
|
||||
<Text>{info}</Text>
|
||||
</VStack>
|
||||
</GridItem>
|
||||
);
|
||||
77
src/components/SelectModelModal/SelectModelModal.tsx
Normal file
77
src/components/SelectModelModal/SelectModelModal.tsx
Normal file
@@ -0,0 +1,77 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
VStack,
|
||||
Text,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { ModelStatsCard } from "./ModelStatsCard";
|
||||
import { SelectModelSearch } from "./SelectModelSearch";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const SelectModelModal = ({
|
||||
originalModel,
|
||||
variantId,
|
||||
onClose,
|
||||
}: {
|
||||
originalModel: SupportedModel;
|
||||
variantId: string;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
|
||||
const createMutation = api.promptVariants.create.useMutation();
|
||||
|
||||
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment?.data?.id) return;
|
||||
await createMutation.mutateAsync({
|
||||
experimentId: experiment?.data?.id,
|
||||
variantId,
|
||||
newModel: selectedModel,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [createMutation, experiment?.data?.id, variantId, onClose]);
|
||||
|
||||
return (
|
||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>Select a New Model</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||
{originalModel !== selectedModel && (
|
||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||
)}
|
||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={createNewVariant}
|
||||
minW={24}
|
||||
disabled={originalModel === selectedModel}
|
||||
>
|
||||
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
@@ -0,0 +1,47 @@
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { type LegacyRef, useCallback } from "react";
|
||||
import Select, { type SingleValue } from "react-select";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
|
||||
const modelOptions: { value: SupportedModel; label: string }[] = [
|
||||
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
|
||||
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
|
||||
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
|
||||
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
|
||||
{ value: "gpt-4", label: "gpt-4" },
|
||||
{ value: "gpt-4-0613", label: "gpt-4-0613" },
|
||||
{ value: "gpt-4-32k", label: "gpt-4-32k" },
|
||||
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
|
||||
];
|
||||
|
||||
export const SelectModelSearch = ({
|
||||
selectedModel,
|
||||
setSelectedModel,
|
||||
}: {
|
||||
selectedModel: SupportedModel;
|
||||
setSelectedModel: (model: SupportedModel) => void;
|
||||
}) => {
|
||||
const handleSelection = useCallback(
|
||||
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
|
||||
if (!option) return;
|
||||
setSelectedModel(option.value);
|
||||
},
|
||||
[setSelectedModel],
|
||||
);
|
||||
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
|
||||
|
||||
const [containerRef, containerDimensions] = useElementDimensions();
|
||||
|
||||
return (
|
||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||
<Text>Browse Models</Text>
|
||||
<Select
|
||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||
value={selectedOption}
|
||||
options={modelOptions}
|
||||
onChange={handleSelection}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
123
src/components/VariantHeader/VariantHeader.tsx
Normal file
123
src/components/VariantHeader/VariantHeader.tsx
Normal file
@@ -0,0 +1,123 @@
|
||||
import { useState, type DragEvent } from "react";
|
||||
import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { RiDraggable } from "react-icons/ri";
|
||||
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||
|
||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||
const [label, setLabel] = useState(props.variant.label);
|
||||
|
||||
const updateMutation = api.promptVariants.update.useMutation();
|
||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||
if (label && label !== props.variant.label) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: props.variant.id,
|
||||
updates: { label: label },
|
||||
});
|
||||
}
|
||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
||||
|
||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||
const [onReorder] = useHandledAsyncCallback(
|
||||
async (e: DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(false);
|
||||
const draggedId = e.dataTransfer.getData("text/plain");
|
||||
const droppedId = props.variant.id;
|
||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||
await reorderMutation.mutateAsync({
|
||||
draggedId,
|
||||
droppedId,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
},
|
||||
[reorderMutation, props.variant.id],
|
||||
);
|
||||
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
|
||||
if (!canModify) {
|
||||
return (
|
||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||
{props.variant.label}
|
||||
</Text>
|
||||
</GridItem>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<GridItem
|
||||
padding={0}
|
||||
sx={{
|
||||
...stickyHeaderStyle,
|
||||
// Ensure that the menu always appears above the sticky header of other variants
|
||||
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||
}}
|
||||
borderTopWidth={1}
|
||||
>
|
||||
<HStack
|
||||
spacing={4}
|
||||
alignItems="flex-start"
|
||||
minH={headerMinHeight}
|
||||
draggable={!isInputHovered}
|
||||
onDragStart={(e) => {
|
||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
||||
e.currentTarget.style.opacity = "0.4";
|
||||
}}
|
||||
onDragEnd={(e) => {
|
||||
e.currentTarget.style.opacity = "1";
|
||||
}}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(true);
|
||||
}}
|
||||
onDragLeave={() => {
|
||||
setIsDragTarget(false);
|
||||
}}
|
||||
onDrop={onReorder}
|
||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||
>
|
||||
<Icon
|
||||
as={RiDraggable}
|
||||
boxSize={6}
|
||||
mt={2}
|
||||
color="gray.400"
|
||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||
/>
|
||||
<AutoResizeTextArea
|
||||
size="sm"
|
||||
value={label}
|
||||
onChange={(e) => setLabel(e.target.value)}
|
||||
onBlur={onSaveLabel}
|
||||
placeholder="Variant Name"
|
||||
borderWidth={1}
|
||||
borderColor="transparent"
|
||||
fontWeight="bold"
|
||||
fontSize={16}
|
||||
_hover={{ borderColor: "gray.300" }}
|
||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||
flex={1}
|
||||
px={cellPadding.x}
|
||||
onMouseEnter={() => setIsInputHovered(true)}
|
||||
onMouseLeave={() => setIsInputHovered(false)}
|
||||
/>
|
||||
<VariantHeaderMenuButton
|
||||
variant={props.variant}
|
||||
canHide={props.canHide}
|
||||
menuOpen={menuOpen}
|
||||
setMenuOpen={setMenuOpen}
|
||||
/>
|
||||
</HStack>
|
||||
</GridItem>
|
||||
);
|
||||
}
|
||||
114
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
114
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
@@ -0,0 +1,114 @@
|
||||
import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import {
|
||||
Button,
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
MenuDivider,
|
||||
Text,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsFillTrashFill, BsGear } from "react-icons/bs";
|
||||
import { FaRegClone } from "react-icons/fa";
|
||||
import { AiOutlineDiff } from "react-icons/ai";
|
||||
import { useState } from "react";
|
||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
|
||||
export default function VariantHeaderMenuButton({
|
||||
variant,
|
||||
canHide,
|
||||
menuOpen,
|
||||
setMenuOpen,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
canHide: boolean;
|
||||
menuOpen: boolean;
|
||||
setMenuOpen: (open: boolean) => void;
|
||||
}) {
|
||||
const utils = api.useContext();
|
||||
|
||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||
|
||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||
await duplicateMutation.mutateAsync({
|
||||
experimentId: variant.experimentId,
|
||||
variantId: variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||
|
||||
const hideMutation = api.promptVariants.hide.useMutation();
|
||||
const [onHide] = useHandledAsyncCallback(async () => {
|
||||
await hideMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [hideMutation, variant.id]);
|
||||
|
||||
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
|
||||
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||
{duplicationInProgress ? (
|
||||
<Spinner boxSize={4} mx={3} my={3} />
|
||||
) : (
|
||||
<MenuButton>
|
||||
<Button variant="ghost">
|
||||
<Icon as={BsGear} />
|
||||
</Button>
|
||||
</MenuButton>
|
||||
)}
|
||||
|
||||
<MenuList mt={-3} fontSize="md">
|
||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||
Duplicate
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||
onClick={() => setSelectModelModalOpen(true)}
|
||||
>
|
||||
Change Model
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={AiOutlineDiff} boxSize={5} />}
|
||||
onClick={() => setRefinePromptModalOpen(true)}
|
||||
>
|
||||
Refine
|
||||
</MenuItem>
|
||||
{canHide && (
|
||||
<>
|
||||
<MenuDivider />
|
||||
<MenuItem
|
||||
onClick={onHide}
|
||||
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
|
||||
color="red.600"
|
||||
_hover={{ backgroundColor: "red.50" }}
|
||||
>
|
||||
<Text>Hide</Text>
|
||||
</MenuItem>
|
||||
</>
|
||||
)}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{selectModelModalOpen && (
|
||||
<SelectModelModal
|
||||
originalModel={variant.model as SupportedModel}
|
||||
variantId={variant.id}
|
||||
onClose={() => setSelectModelModalOpen(false)}
|
||||
/>
|
||||
)}
|
||||
{refinePromptModalOpen && (
|
||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,17 +1,11 @@
|
||||
import {
|
||||
Card,
|
||||
CardBody,
|
||||
HStack,
|
||||
Icon,
|
||||
VStack,
|
||||
Text,
|
||||
CardHeader,
|
||||
Divider,
|
||||
Box,
|
||||
} from "@chakra-ui/react";
|
||||
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import { formatTimePast } from "~/utils/dayjs";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsPlusSquare } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
type ExperimentData = {
|
||||
testScenarioCount: number;
|
||||
@@ -24,47 +18,42 @@ type ExperimentData = {
|
||||
};
|
||||
|
||||
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
|
||||
const router = useRouter();
|
||||
return (
|
||||
<Box
|
||||
as={Card}
|
||||
variant="elevated"
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
as={Link}
|
||||
href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
|
||||
bg="gray.50"
|
||||
_hover={{ bg: "gray.100" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
void router.push({ pathname: "/experiments/[id]", query: { id: exp.id } }, undefined, {
|
||||
shallow: true,
|
||||
});
|
||||
}}
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
justify="space-between"
|
||||
>
|
||||
<CardHeader>
|
||||
<HStack w="full" color="gray.700">
|
||||
<HStack w="full" color="gray.700" justify="center">
|
||||
<Icon as={RiFlaskLine} boxSize={4} />
|
||||
<Text fontWeight="bold">{exp.label}</Text>
|
||||
</HStack>
|
||||
</CardHeader>
|
||||
<CardBody>
|
||||
<HStack w="full" mb={8} spacing={4}>
|
||||
<HStack h="full" spacing={4} flex={1} align="center">
|
||||
<CountLabel label="Variants" count={exp.promptVariantCount} />
|
||||
<Divider h={12} orientation="vertical" />
|
||||
<CountLabel label="Scenarios" count={exp.testScenarioCount} />
|
||||
</HStack>
|
||||
<HStack w="full" color="gray.500" fontSize="xs">
|
||||
<Text>Created {formatTimePast(exp.createdAt)}</Text>
|
||||
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
|
||||
<Text flex={1}>Created {formatTimePast(exp.createdAt)}</Text>
|
||||
<Divider h={4} orientation="vertical" />
|
||||
<Text>Updated {formatTimePast(exp.updatedAt)}</Text>
|
||||
<Text flex={1}>Updated {formatTimePast(exp.updatedAt)}</Text>
|
||||
</HStack>
|
||||
</CardBody>
|
||||
</Box>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||
return (
|
||||
<VStack alignItems="flex-start">
|
||||
<VStack alignItems="center" flex={1}>
|
||||
<Text color="gray.500" fontWeight="bold">
|
||||
{label}
|
||||
</Text>
|
||||
@@ -74,3 +63,33 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export const NewExperimentCard = () => {
|
||||
const router = useRouter();
|
||||
const createMutation = api.experiments.create.useMutation();
|
||||
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
|
||||
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
|
||||
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
|
||||
}, [createMutation, router]);
|
||||
|
||||
return (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
align="center"
|
||||
justify="center"
|
||||
_hover={{ cursor: "pointer", bg: "gray.50" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
onClick={createExperiment}
|
||||
>
|
||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||
New Experiment
|
||||
</Text>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import { Icon, Button, Spinner, Text, type ButtonProps } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsPlusSquare } from "react-icons/bs";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const NewExperimentButton = (props: ButtonProps) => {
|
||||
const router = useRouter();
|
||||
const utils = api.useContext();
|
||||
const createMutation = api.experiments.create.useMutation();
|
||||
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
|
||||
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
|
||||
await utils.experiments.list.invalidate();
|
||||
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
|
||||
}, [createMutation, router]);
|
||||
|
||||
return (
|
||||
<Button
|
||||
onClick={createExperiment}
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
variant={{ base: "solid", md: "ghost" }}
|
||||
{...props}
|
||||
>
|
||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={4} />
|
||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||
New Experiment
|
||||
</Text>
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
@@ -1,84 +1,100 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import {
|
||||
Heading,
|
||||
VStack,
|
||||
Icon,
|
||||
HStack,
|
||||
Image,
|
||||
Grid,
|
||||
GridItem,
|
||||
Divider,
|
||||
Text,
|
||||
Box,
|
||||
type BoxProps,
|
||||
type LinkProps,
|
||||
Link,
|
||||
Flex,
|
||||
} from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { BsGithub, BsTwitter } from "react-icons/bs";
|
||||
import { BsGithub, BsPersonCircle } from "react-icons/bs";
|
||||
import { useRouter } from "next/router";
|
||||
import PublicPlaygroundWarning from "../PublicPlaygroundWarning";
|
||||
import { type IconType } from "react-icons";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import { useState, useEffect } from "react";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import UserMenu from "./UserMenu";
|
||||
|
||||
type IconLinkProps = BoxProps & LinkProps & { label: string; icon: IconType; href: string };
|
||||
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
|
||||
|
||||
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
|
||||
const isActive = useRouter().pathname.startsWith(href);
|
||||
const router = useRouter();
|
||||
const isActive = href && router.pathname.startsWith(href);
|
||||
return (
|
||||
<Box
|
||||
<HStack
|
||||
w="full"
|
||||
p={4}
|
||||
color={color}
|
||||
as={Link}
|
||||
href={href}
|
||||
target={target}
|
||||
w="full"
|
||||
bgColor={isActive ? "gray.300" : "transparent"}
|
||||
_hover={{ bgColor: "gray.300" }}
|
||||
py={4}
|
||||
bgColor={isActive ? "gray.200" : "transparent"}
|
||||
_hover={{ bgColor: "gray.200", textDecoration: "none" }}
|
||||
justifyContent="start"
|
||||
cursor="pointer"
|
||||
{...props}
|
||||
>
|
||||
<HStack w="full" px={4} color={color}>
|
||||
<Icon as={icon} boxSize={6} mr={2} />
|
||||
<Text fontWeight="bold">{label}</Text>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{label}
|
||||
</Text>
|
||||
</HStack>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
const Divider = () => <Box h="1px" bgColor="gray.200" />;
|
||||
|
||||
const NavSidebar = () => {
|
||||
const user = useSession().data;
|
||||
|
||||
return (
|
||||
<VStack align="stretch" bgColor="gray.100" py={2} pb={0} height="100%">
|
||||
<Link href="/" w="full" _hover={{ textDecoration: "none" }}>
|
||||
<HStack spacing={0} pl="3">
|
||||
<Image src="/logo.svg" alt="" w={8} h={8} />
|
||||
<Heading size="md" p={2} pl={{ base: 16, md: 2 }}>
|
||||
<VStack
|
||||
align="stretch"
|
||||
bgColor="gray.100"
|
||||
py={2}
|
||||
pb={0}
|
||||
height="100%"
|
||||
w={{ base: "56px", md: "200px" }}
|
||||
overflow="hidden"
|
||||
>
|
||||
<HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={4} py={2}>
|
||||
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
|
||||
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||
OpenPipe
|
||||
</Heading>
|
||||
</HStack>
|
||||
</Link>
|
||||
<Divider />
|
||||
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
|
||||
{user != null && (
|
||||
<>
|
||||
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
||||
</VStack>
|
||||
<Divider />
|
||||
<VStack w="full" spacing={0} pb={2}>
|
||||
</>
|
||||
)}
|
||||
{user === null && (
|
||||
<IconLink
|
||||
icon={BsGithub}
|
||||
label="GitHub"
|
||||
icon={BsPersonCircle}
|
||||
label="Sign In"
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</VStack>
|
||||
{user ? <UserMenu user={user} /> : <Divider />}
|
||||
<VStack spacing={0} align="center">
|
||||
<Link
|
||||
href="https://github.com/openpipe/openpipe"
|
||||
target="_blank"
|
||||
color="gray.500"
|
||||
_hover={{ color: "gray.800" }}
|
||||
/>
|
||||
<IconLink
|
||||
icon={BsTwitter}
|
||||
label="Twitter"
|
||||
href="https://twitter.com/corbtt"
|
||||
target="_blank"
|
||||
color="gray.500"
|
||||
_hover={{ color: "gray.800" }}
|
||||
/>
|
||||
p={2}
|
||||
>
|
||||
<Icon as={BsGithub} boxSize={6} />
|
||||
</Link>
|
||||
</VStack>
|
||||
</VStack>
|
||||
);
|
||||
@@ -105,25 +121,14 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Grid
|
||||
h={vh}
|
||||
w="100vw"
|
||||
templateColumns={{ base: "56px minmax(0, 1fr)", md: "200px minmax(0, 1fr)" }}
|
||||
templateRows="max-content 1fr"
|
||||
templateAreas={'"warning warning"\n"sidebar main"'}
|
||||
>
|
||||
<Flex h={vh} w="100vw">
|
||||
<Head>
|
||||
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
|
||||
</Head>
|
||||
<GridItem area="warning">
|
||||
<PublicPlaygroundWarning />
|
||||
</GridItem>
|
||||
<GridItem area="sidebar" overflow="hidden">
|
||||
<NavSidebar />
|
||||
</GridItem>
|
||||
<GridItem area="main" overflowY="auto">
|
||||
<Box h="100%" flex={1} overflowY="auto">
|
||||
{props.children}
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
74
src/components/nav/UserMenu.tsx
Normal file
74
src/components/nav/UserMenu.tsx
Normal file
@@ -0,0 +1,74 @@
|
||||
import {
|
||||
HStack,
|
||||
Icon,
|
||||
Image,
|
||||
VStack,
|
||||
Text,
|
||||
Popover,
|
||||
PopoverTrigger,
|
||||
PopoverContent,
|
||||
Link,
|
||||
} from "@chakra-ui/react";
|
||||
import { type Session } from "next-auth";
|
||||
import { signOut } from "next-auth/react";
|
||||
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
|
||||
|
||||
export default function UserMenu({ user }: { user: Session }) {
|
||||
const profileImage = user.user.image ? (
|
||||
<Image src={user.user.image} alt="profile picture" boxSize={8} borderRadius="50%" />
|
||||
) : (
|
||||
<Icon as={BsPersonCircle} boxSize={6} />
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Popover placement="right">
|
||||
<PopoverTrigger>
|
||||
<HStack
|
||||
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
|
||||
px={3}
|
||||
spacing={3}
|
||||
py={2}
|
||||
borderColor={"gray.200"}
|
||||
borderTopWidth={1}
|
||||
borderBottomWidth={1}
|
||||
cursor="pointer"
|
||||
_hover={{
|
||||
bgColor: "gray.200",
|
||||
}}
|
||||
>
|
||||
{profileImage}
|
||||
<VStack spacing={0} align="start" flex={1} flexShrink={1}>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{user.user.name}
|
||||
</Text>
|
||||
<Text color="gray.500" fontSize="xs">
|
||||
{user.user.email}
|
||||
</Text>
|
||||
</VStack>
|
||||
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
|
||||
</HStack>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
|
||||
<VStack align="stretch" spacing={0}>
|
||||
{/* sign out */}
|
||||
<HStack
|
||||
as={Link}
|
||||
onClick={() => {
|
||||
signOut().catch(console.error);
|
||||
}}
|
||||
px={4}
|
||||
py={2}
|
||||
spacing={4}
|
||||
color="gray.500"
|
||||
fontSize="sm"
|
||||
>
|
||||
<Icon as={BsBoxArrowRight} boxSize={6} />
|
||||
<Text>Sign out</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -20,7 +20,6 @@ export const CostTooltip = ({
|
||||
color="gray.800"
|
||||
bgColor="gray.50"
|
||||
borderWidth={1}
|
||||
py={2}
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
label={
|
||||
|
||||
16
src/env.mjs
16
src/env.mjs
@@ -10,6 +10,13 @@ export const env = createEnv({
|
||||
DATABASE_URL: z.string().url(),
|
||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||
OPENAI_API_KEY: z.string().min(1),
|
||||
RESTRICT_PRISMA_LOGS: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("false")
|
||||
.transform((val) => val.toLowerCase() === "true"),
|
||||
GITHUB_CLIENT_ID: z.string().min(1),
|
||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||
},
|
||||
|
||||
/**
|
||||
@@ -19,11 +26,6 @@ export const env = createEnv({
|
||||
*/
|
||||
client: {
|
||||
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
|
||||
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("false")
|
||||
.transform((val) => val.toLowerCase() === "true"),
|
||||
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
||||
},
|
||||
|
||||
@@ -35,9 +37,11 @@ export const env = createEnv({
|
||||
DATABASE_URL: process.env.DATABASE_URL,
|
||||
NODE_ENV: process.env.NODE_ENV,
|
||||
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
||||
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
|
||||
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
||||
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
|
||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||
},
|
||||
/**
|
||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||
|
||||
@@ -6,18 +6,27 @@ import { ChakraProvider } from "@chakra-ui/react";
|
||||
import theme from "~/utils/theme";
|
||||
import Favicon from "~/components/Favicon";
|
||||
import "~/utils/analytics";
|
||||
import Head from "next/head";
|
||||
|
||||
const MyApp: AppType<{ session: Session | null }> = ({
|
||||
Component,
|
||||
pageProps: { session, ...pageProps },
|
||||
}) => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
|
||||
/>
|
||||
</Head>
|
||||
<SessionProvider session={session}>
|
||||
<Favicon />
|
||||
<ChakraProvider theme={theme}>
|
||||
<Component {...pageProps} />
|
||||
</ChakraProvider>
|
||||
</SessionProvider>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
23
src/pages/account/signin.tsx
Normal file
23
src/pages/account/signin.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import { useRouter } from "next/router";
|
||||
import { useEffect } from "react";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
|
||||
export default function SignIn() {
|
||||
const session = useSession().data;
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
if (session) {
|
||||
router.push("/experiments").catch(console.error);
|
||||
} else if (session === null) {
|
||||
signIn("github").catch(console.error);
|
||||
}
|
||||
}, [session, router]);
|
||||
|
||||
return (
|
||||
<AppShell>
|
||||
<div />
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
@@ -124,6 +124,8 @@ export default function Experiment() {
|
||||
);
|
||||
}
|
||||
|
||||
const canModify = experiment.data?.access.canModify ?? false;
|
||||
|
||||
return (
|
||||
<AppShell title={experiment.data?.label}>
|
||||
<VStack h="full">
|
||||
@@ -143,6 +145,7 @@ export default function Experiment() {
|
||||
</Link>
|
||||
</BreadcrumbItem>
|
||||
<BreadcrumbItem isCurrentPage>
|
||||
{canModify ? (
|
||||
<Input
|
||||
size="sm"
|
||||
value={label}
|
||||
@@ -157,8 +160,14 @@ export default function Experiment() {
|
||||
_hover={{ borderColor: "gray.300" }}
|
||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||
/>
|
||||
) : (
|
||||
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
|
||||
{experiment.data?.label}
|
||||
</Text>
|
||||
)}
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
{canModify && (
|
||||
<HStack>
|
||||
<Button
|
||||
size="sm"
|
||||
@@ -174,6 +183,7 @@ export default function Experiment() {
|
||||
</Button>
|
||||
<DeleteButton />
|
||||
</HStack>
|
||||
)}
|
||||
</Flex>
|
||||
<SettingsDrawer />
|
||||
<Box w="100%" overflowX="auto" flex={1}>
|
||||
|
||||
@@ -1,25 +1,50 @@
|
||||
import {
|
||||
SimpleGrid,
|
||||
HStack,
|
||||
Icon,
|
||||
VStack,
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
Flex,
|
||||
Center,
|
||||
Text,
|
||||
Link,
|
||||
HStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton";
|
||||
import { ExperimentCard } from "~/components/experiments/ExperimentCard";
|
||||
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
|
||||
export default function ExperimentsPage() {
|
||||
const experiments = api.experiments.list.useQuery();
|
||||
|
||||
const user = useSession().data;
|
||||
|
||||
if (user === null) {
|
||||
return (
|
||||
<AppShell>
|
||||
<VStack alignItems={"flex-start"} m={4} mt={1}>
|
||||
<HStack w="full" justifyContent="space-between" mb={4}>
|
||||
<AppShell title="Experiments">
|
||||
<Center h="100%">
|
||||
<Text>
|
||||
<Link
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
textDecor="underline"
|
||||
>
|
||||
Sign in
|
||||
</Link>{" "}
|
||||
to view or create new experiments!
|
||||
</Text>
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<AppShell title="Experiments">
|
||||
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||
<HStack minH={8} align="center">
|
||||
<Breadcrumb flex={1}>
|
||||
<BreadcrumbItem>
|
||||
<Flex alignItems="center">
|
||||
@@ -27,9 +52,9 @@ export default function ExperimentsPage() {
|
||||
</Flex>
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
<NewExperimentButton mr={4} borderRadius={8} />
|
||||
</HStack>
|
||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||
<NewExperimentCard />
|
||||
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { type GetServerSideProps } from "next";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/require-await
|
||||
export const getServerSideProps: GetServerSideProps = async (context) => {
|
||||
export const getServerSideProps: GetServerSideProps = async () => {
|
||||
return {
|
||||
redirect: {
|
||||
destination: "/experiments",
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { prisma } from "../db";
|
||||
import { openai } from "../utils/openai";
|
||||
import { pick } from "lodash";
|
||||
|
||||
function promptHasVariable(prompt: string, variableName: string) {
|
||||
return prompt.includes(`{{${variableName}}}`);
|
||||
}
|
||||
import { pick } from "lodash-es";
|
||||
|
||||
type AxiosError = {
|
||||
response?: {
|
||||
|
||||
@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
|
||||
import { createTRPCRouter } from "~/server/api/trpc";
|
||||
import { experimentsRouter } from "./routers/experiments.router";
|
||||
import { scenariosRouter } from "./routers/scenarios.router";
|
||||
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
||||
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||
|
||||
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
|
||||
promptVariants: promptVariantsRouter,
|
||||
experiments: experimentsRouter,
|
||||
scenarios: scenariosRouter,
|
||||
outputs: modelOutputsRouter,
|
||||
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||
templateVars: templateVarsRouter,
|
||||
evaluations: evaluationsRouter,
|
||||
});
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import { EvaluationMatchType } from "@prisma/client";
|
||||
import { EvalType } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { reevaluateEvaluation } from "~/server/utils/evaluations";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const evaluationsRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
@@ -14,53 +19,76 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
});
|
||||
}),
|
||||
|
||||
create: publicProcedure
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
name: z.string(),
|
||||
matchString: z.string(),
|
||||
matchType: z.nativeEnum(EvaluationMatchType),
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
evalType: z.nativeEnum(EvalType),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const evaluation = await prisma.evaluation.create({
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
name: input.name,
|
||||
matchString: input.matchString,
|
||||
matchType: input.matchType,
|
||||
label: input.label,
|
||||
value: input.value,
|
||||
evalType: input.evalType,
|
||||
},
|
||||
});
|
||||
await reevaluateEvaluation(evaluation);
|
||||
|
||||
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
||||
// to kick off a background job or something instead
|
||||
await runAllEvals(input.experimentId);
|
||||
}),
|
||||
|
||||
update: publicProcedure
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
name: z.string().optional(),
|
||||
matchString: z.string().optional(),
|
||||
matchType: z.nativeEnum(EvaluationMatchType).optional(),
|
||||
label: z.string().optional(),
|
||||
value: z.string().optional(),
|
||||
evalType: z.nativeEnum(EvalType).optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
await prisma.evaluation.update({
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const evaluation = await prisma.evaluation.update({
|
||||
where: { id: input.id },
|
||||
data: {
|
||||
name: input.updates.name,
|
||||
matchString: input.updates.matchString,
|
||||
matchType: input.updates.matchType,
|
||||
label: input.updates.label,
|
||||
value: input.updates.value,
|
||||
evalType: input.updates.evalType,
|
||||
},
|
||||
});
|
||||
await reevaluateEvaluation(
|
||||
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
|
||||
);
|
||||
|
||||
await prisma.outputEvaluation.deleteMany({
|
||||
where: {
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
});
|
||||
// Re-run all evals. Other eval results will already be cached, so this
|
||||
// should only re-run the updated one.
|
||||
await runAllEvals(evaluation.experimentId);
|
||||
}),
|
||||
|
||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.delete({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
@@ -1,13 +1,31 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import {
|
||||
canModifyExperiment,
|
||||
requireCanModifyExperiment,
|
||||
requireCanViewExperiment,
|
||||
requireNothing,
|
||||
} from "~/utils/accessControl";
|
||||
import userOrg from "~/server/utils/userOrg";
|
||||
|
||||
export const experimentsRouter = createTRPCRouter({
|
||||
list: publicProcedure.query(async () => {
|
||||
list: protectedProcedure.query(async ({ ctx }) => {
|
||||
// Anyone can list experiments
|
||||
requireNothing(ctx);
|
||||
|
||||
const experiments = await prisma.experiment.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
OrganizationUser: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -39,15 +57,29 @@ export const experimentsRouter = createTRPCRouter({
|
||||
return experimentsWithCounts;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => {
|
||||
return await prisma.experiment.findFirst({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
const experiment = await prisma.experiment.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
const canModify = ctx.session?.user.id
|
||||
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
|
||||
: false;
|
||||
|
||||
return {
|
||||
...experiment,
|
||||
access: {
|
||||
canView: true,
|
||||
canModify,
|
||||
},
|
||||
};
|
||||
}),
|
||||
|
||||
create: publicProcedure.input(z.object({})).mutation(async () => {
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
@@ -61,36 +93,66 @@ export const experimentsRouter = createTRPCRouter({
|
||||
data: {
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `Experiment ${maxSortIndex + 1}`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([
|
||||
const [variant, _, scenario] = await prisma.$transaction([
|
||||
prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "Prompt Variant 1",
|
||||
sortIndex: 0,
|
||||
constructFn: dedent`prompt = {
|
||||
// The interpolated $ is necessary until dedent incorporates
|
||||
// https://github.com/dmnd/dedent/pull/46
|
||||
constructFn: dedent`
|
||||
/**
|
||||
* Use Javascript to define an OpenAI chat completion
|
||||
* (https://platform.openai.com/docs/api-reference/chat/create) and
|
||||
* assign it to the \`prompt\` variable.
|
||||
*
|
||||
* You have access to the current scenario in the \`scenario\`
|
||||
* variable.
|
||||
*/
|
||||
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
stream: true,
|
||||
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
||||
}`,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
||||
},
|
||||
],
|
||||
};`,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "text",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {},
|
||||
variableValues: {
|
||||
text: "This is a test scenario.",
|
||||
},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
await generateNewCell(variant.id, scenario.id);
|
||||
|
||||
return exp;
|
||||
}),
|
||||
|
||||
update: publicProcedure
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
||||
.mutation(async ({ input }) => {
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
return await prisma.experiment.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
@@ -101,7 +163,11 @@ export const experimentsRouter = createTRPCRouter({
|
||||
});
|
||||
}),
|
||||
|
||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
|
||||
await prisma.experiment.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import crypto from "crypto";
|
||||
import type { Prisma } from "@prisma/client";
|
||||
import { reevaluateVariant } from "~/server/utils/evaluations";
|
||||
import { getCompletion } from "~/server/utils/getCompletion";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
|
||||
export const modelOutputsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
channel: z.string().optional(),
|
||||
forceRefetch: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const existing = await prisma.modelOutput.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (existing && !input.forceRefetch) return existing;
|
||||
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.scenarioId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
// TODO: we should probably only use this if temperature=0
|
||||
const existingResponse = await prisma.modelOutput.findFirst({
|
||||
where: { inputHash, errorMessage: null },
|
||||
});
|
||||
|
||||
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
|
||||
|
||||
if (existingResponse) {
|
||||
modelResponse = {
|
||||
output: existingResponse.output as Prisma.InputJsonValue,
|
||||
statusCode: existingResponse.statusCode,
|
||||
errorMessage: existingResponse.errorMessage,
|
||||
timeToComplete: existingResponse.timeToComplete,
|
||||
promptTokens: existingResponse.promptTokens ?? undefined,
|
||||
completionTokens: existingResponse.completionTokens ?? undefined,
|
||||
};
|
||||
} else {
|
||||
try {
|
||||
modelResponse = await getCompletion(prompt, input.channel);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
const modelOutput = await prisma.modelOutput.upsert({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
inputHash,
|
||||
...modelResponse,
|
||||
},
|
||||
update: {
|
||||
...modelResponse,
|
||||
},
|
||||
});
|
||||
|
||||
await reevaluateVariant(input.variantId);
|
||||
|
||||
return modelOutput;
|
||||
}),
|
||||
});
|
||||
@@ -1,11 +1,23 @@
|
||||
import { isObject } from "lodash-es";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { OpenAIChatModel, type SupportedModel } from "~/server/types";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const promptVariantsRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
@@ -15,7 +27,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
});
|
||||
}),
|
||||
|
||||
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
|
||||
stats: publicProcedure
|
||||
.input(z.object({ variantId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
@@ -26,11 +40,47 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||
}
|
||||
|
||||
const evalResults = await prisma.evaluationResult.findMany({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
await requireCanViewExperiment(variant.experimentId, ctx);
|
||||
|
||||
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||
by: ["evaluationId"],
|
||||
_sum: {
|
||||
result: true,
|
||||
},
|
||||
include: { evaluation: true },
|
||||
_count: {
|
||||
id: true,
|
||||
},
|
||||
where: {
|
||||
modelOutput: {
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
id: input.variantId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
const evalResults = evals.map((evalItem) => {
|
||||
const evalResult = outputEvals.find(
|
||||
(outputEval) => outputEval.evaluationId === evalItem.id,
|
||||
);
|
||||
return {
|
||||
id: evalItem.id,
|
||||
label: evalItem.label,
|
||||
passCount: evalResult?._sum?.result ?? 0,
|
||||
totalCount: evalResult?._count?.id ?? 1,
|
||||
};
|
||||
});
|
||||
|
||||
const scenarioCount = await prisma.testScenario.count({
|
||||
@@ -39,46 +89,77 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
const outputCount = await prisma.modelOutput.count({
|
||||
const outputCount = await prisma.scenarioVariantCell.count({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelOutput: {
|
||||
is: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const overallTokens = await prisma.modelOutput.aggregate({
|
||||
where: {
|
||||
scenarioVariantCell: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
_sum: {
|
||||
cost: true,
|
||||
promptTokens: true,
|
||||
completionTokens: true,
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: fix this
|
||||
const model = "gpt-3.5-turbo-0613";
|
||||
// const model = getModelName(variant.config);
|
||||
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const overallPromptCost = calculateTokenCost(model, promptTokens);
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
|
||||
|
||||
const overallCost = overallPromptCost + overallCompletionCost;
|
||||
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
// Check if is PENDING or IN_PROGRESS
|
||||
retrievalStatus: {
|
||||
in: ["PENDING", "IN_PROGRESS"],
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
|
||||
return {
|
||||
evalResults,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
overallCost: overallTokens._sum?.cost ?? 0,
|
||||
scenarioCount,
|
||||
outputCount,
|
||||
awaitingRetrievals,
|
||||
};
|
||||
}),
|
||||
|
||||
create: publicProcedure
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
variantId: z.string().optional(),
|
||||
newModel: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const lastVariant = await prisma.promptVariant.findFirst({
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
let originalVariant: PromptVariant | null = null;
|
||||
if (input.variantId) {
|
||||
originalVariant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
originalVariant = await prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
@@ -87,6 +168,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const largestSortIndex =
|
||||
(
|
||||
@@ -100,12 +182,23 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
const newVariantLabel =
|
||||
input.variantId && originalVariant
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(
|
||||
originalVariant,
|
||||
input.newModel as SupportedModel,
|
||||
);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
||||
constructFn: lastVariant?.constructFn ?? "",
|
||||
label: newVariantLabel,
|
||||
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||
constructFn: newConstructFn,
|
||||
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -114,10 +207,26 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
if (originalVariant) {
|
||||
// Insert new variant to right of original variant
|
||||
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
|
||||
}
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id);
|
||||
}
|
||||
|
||||
return newVariant;
|
||||
}),
|
||||
|
||||
update: publicProcedure
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
@@ -126,7 +235,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
@@ -137,6 +246,8 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const updatePromptVariantAction = prisma.promptVariant.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
@@ -152,13 +263,18 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
hide: publicProcedure
|
||||
hide: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const updatedPromptVariant = await prisma.promptVariant.update({
|
||||
where: { id: input.id },
|
||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||
@@ -167,24 +283,76 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
replaceVariant: publicProcedure
|
||||
getRefinedPromptFn: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
instructions: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const constructedPrompt = await constructPrompt({ constructFn: existing.constructFn }, null);
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(
|
||||
existing,
|
||||
// @ts-expect-error TODO clean this up
|
||||
constructedPrompt?.model as SupportedModel,
|
||||
input.instructions,
|
||||
);
|
||||
|
||||
// TODO: Validate promptConstructionFn
|
||||
// TODO: Record in some sort of history
|
||||
|
||||
return promptConstructionFn;
|
||||
}),
|
||||
|
||||
replaceVariant: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
constructFn: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const existing = await prisma.promptVariant.findUnique({
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
let model = existing.model;
|
||||
try {
|
||||
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
|
||||
|
||||
if (!isObject(contructedPrompt)) {
|
||||
return userError("Prompt is not an object");
|
||||
}
|
||||
if (!("model" in contructedPrompt)) {
|
||||
return userError("Prompt does not define a model");
|
||||
}
|
||||
if (
|
||||
typeof contructedPrompt.model !== "string" ||
|
||||
!(contructedPrompt.model in OpenAIChatModel)
|
||||
) {
|
||||
return userError("Prompt defines an invalid model");
|
||||
}
|
||||
model = contructedPrompt.model;
|
||||
} catch (e) {
|
||||
return userError((e as Error).message);
|
||||
}
|
||||
|
||||
// Create a duplicate with only the config changed
|
||||
const newVariant = await prisma.promptVariant.create({
|
||||
data: {
|
||||
@@ -193,11 +361,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
sortIndex: existing.sortIndex,
|
||||
uiId: existing.uiId,
|
||||
constructFn: input.constructFn,
|
||||
model,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide anything with the same uiId besides the new one
|
||||
const hideOldVariantsAction = prisma.promptVariant.updateMany({
|
||||
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
@@ -209,80 +378,35 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([
|
||||
hideOldVariantsAction,
|
||||
recordExperimentUpdated(existing.experimentId),
|
||||
]);
|
||||
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||
|
||||
return newVariant;
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: newVariant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id);
|
||||
}
|
||||
|
||||
return { status: "ok" } as const;
|
||||
}),
|
||||
|
||||
reorder: publicProcedure
|
||||
reorder: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
draggedId: z.string(),
|
||||
droppedId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const dragged = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.draggedId,
|
||||
},
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.draggedId },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const dropped = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.droppedId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
|
||||
throw new Error(
|
||||
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
|
||||
);
|
||||
}
|
||||
|
||||
const visibleItems = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: dragged.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
});
|
||||
|
||||
// Remove the dragged item from its current position
|
||||
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
|
||||
|
||||
// Find the index of the dragged item and the dropped item
|
||||
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
|
||||
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
|
||||
|
||||
// Determine the new index for the dragged item
|
||||
let newIndex;
|
||||
if (dragIndex < dropIndex) {
|
||||
newIndex = dropIndex + 1; // Insert after the dropped item
|
||||
} else {
|
||||
newIndex = dropIndex; // Insert before the dropped item
|
||||
}
|
||||
|
||||
// Insert the dragged item at the new position
|
||||
orderedItems.splice(newIndex, 0, dragged);
|
||||
|
||||
// Now, we need to update all the items with their new sortIndex
|
||||
await prisma.$transaction(
|
||||
orderedItems.map((item, index) => {
|
||||
return prisma.promptVariant.update({
|
||||
where: {
|
||||
id: item.id,
|
||||
},
|
||||
data: {
|
||||
sortIndex: index,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
await reorderPromptVariants(input.draggedId, input.droppedId);
|
||||
}),
|
||||
});
|
||||
|
||||
90
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
90
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
await requireCanViewExperiment(experimentId, ctx);
|
||||
|
||||
return await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: {
|
||||
include: {
|
||||
outputEvaluation: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
forceRefetch: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (!cell) {
|
||||
await generateNewCell(input.variantId, input.scenarioId);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cell.modelOutput) {
|
||||
// TODO: Maybe keep these around to show previous generations?
|
||||
await prisma.modelOutput.delete({
|
||||
where: { id: cell.modelOutput.id },
|
||||
});
|
||||
}
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "PENDING" },
|
||||
});
|
||||
|
||||
await queueLLMRetrievalTask(cell.id);
|
||||
return true;
|
||||
}),
|
||||
});
|
||||
@@ -1,12 +1,18 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { autogenerateScenarioValues } from "../autogen";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
@@ -18,14 +24,16 @@ export const scenariosRouter = createTRPCRouter({
|
||||
});
|
||||
}),
|
||||
|
||||
create: publicProcedure
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
autogenerate: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.testScenario.aggregate({
|
||||
@@ -48,32 +56,50 @@ export const scenariosRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([
|
||||
const [scenario] = await prisma.$transaction([
|
||||
createNewScenarioAction,
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, scenario.id);
|
||||
}
|
||||
}),
|
||||
|
||||
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||
hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
const experimentId = (
|
||||
await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
})
|
||||
).experimentId;
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
const hiddenScenario = await prisma.testScenario.update({
|
||||
where: { id: input.id },
|
||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||
});
|
||||
|
||||
// Reevaluate all evaluations now that this scenario is hidden
|
||||
await reevaluateAll(hiddenScenario.experimentId);
|
||||
await runAllEvals(hiddenScenario.experimentId);
|
||||
|
||||
return hiddenScenario;
|
||||
}),
|
||||
|
||||
reorder: publicProcedure
|
||||
reorder: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
draggedId: z.string(),
|
||||
droppedId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const dragged = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.draggedId,
|
||||
@@ -92,6 +118,8 @@ export const scenariosRouter = createTRPCRouter({
|
||||
);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(dragged.experimentId, ctx);
|
||||
|
||||
const visibleItems = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: dragged.experimentId,
|
||||
@@ -135,14 +163,14 @@ export const scenariosRouter = createTRPCRouter({
|
||||
);
|
||||
}),
|
||||
|
||||
replaceWithValues: publicProcedure
|
||||
replaceWithValues: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
values: z.record(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
@@ -153,6 +181,8 @@ export const scenariosRouter = createTRPCRouter({
|
||||
throw new Error(`Scenario with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const newScenario = await prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: existing.experimentId,
|
||||
@@ -175,6 +205,17 @@ export const scenariosRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: newScenario.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, newScenario.id);
|
||||
}
|
||||
|
||||
return newScenario;
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const templateVarsRouter = createTRPCRouter({
|
||||
create: publicProcedure
|
||||
create: protectedProcedure
|
||||
.input(z.object({ experimentId: z.string(), label: z.string() }))
|
||||
.mutation(async ({ input }) => {
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
@@ -14,11 +17,22 @@ export const templateVarsRouter = createTRPCRouter({
|
||||
});
|
||||
}),
|
||||
|
||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await prisma.templateVariable.delete({ where: { id: input.id } });
|
||||
}),
|
||||
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
return await prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
|
||||
@@ -27,6 +27,9 @@ type CreateContextOptions = {
|
||||
session: Session | null;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
||||
const noOp = () => {};
|
||||
|
||||
/**
|
||||
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
|
||||
* it from here.
|
||||
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
|
||||
return {
|
||||
session: opts.session,
|
||||
prisma,
|
||||
markAccessControlRun: noOp,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
|
||||
* errors on the backend.
|
||||
*/
|
||||
|
||||
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
|
||||
|
||||
const t = initTRPC.context<typeof createTRPCContext>().create({
|
||||
transformer: superjson,
|
||||
errorFormatter({ shape, error }) {
|
||||
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
|
||||
export const publicProcedure = t.procedure;
|
||||
|
||||
/** Reusable middleware that enforces users are logged in before running the procedure. */
|
||||
const enforceUserIsAuthed = t.middleware(({ ctx, next }) => {
|
||||
const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
|
||||
if (!ctx.session || !ctx.session.user) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
return next({
|
||||
|
||||
let accessControlRun = false;
|
||||
const resp = await next({
|
||||
ctx: {
|
||||
// infers the `session` as non-nullable
|
||||
session: { ...ctx.session, user: ctx.session.user },
|
||||
markAccessControlRun: () => {
|
||||
accessControlRun = true;
|
||||
},
|
||||
},
|
||||
});
|
||||
if (!accessControlRun)
|
||||
throw new TRPCError({
|
||||
code: "INTERNAL_SERVER_ERROR",
|
||||
message:
|
||||
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
|
||||
});
|
||||
|
||||
return resp;
|
||||
});
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
|
||||
import { type GetServerSidePropsContext } from "next";
|
||||
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
|
||||
import { prisma } from "~/server/db";
|
||||
import GitHubProvider from "next-auth/providers/github";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
/**
|
||||
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
|
||||
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
|
||||
},
|
||||
adapter: PrismaAdapter(prisma),
|
||||
providers: [
|
||||
// DiscordProvider({
|
||||
// clientId: env.DISCORD_CLIENT_ID,
|
||||
// clientSecret: env.DISCORD_CLIENT_SECRET,
|
||||
// }),
|
||||
/**
|
||||
* ...add more providers here.
|
||||
*
|
||||
* Most other providers require a bit more work than the Discord provider. For example, the
|
||||
* GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
|
||||
* model. Refer to the NextAuth.js docs for the provider you want to use. Example:
|
||||
*
|
||||
* @see https://next-auth.js.org/providers/github
|
||||
*/
|
||||
GitHubProvider({
|
||||
clientId: env.GITHUB_CLIENT_ID,
|
||||
clientSecret: env.GITHUB_CLIENT_SECRET,
|
||||
}),
|
||||
],
|
||||
theme: {
|
||||
logo: "/logo.svg",
|
||||
brandColor: "#ff5733",
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
|
||||
export const prisma =
|
||||
globalForPrisma.prisma ??
|
||||
new PrismaClient({
|
||||
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"],
|
||||
log:
|
||||
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
|
||||
? ["query", "error", "warn"]
|
||||
: ["error"],
|
||||
});
|
||||
|
||||
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
||||
|
||||
77
src/server/modelStats.ts
Normal file
77
src/server/modelStats.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { type SupportedModel } from "./types";
|
||||
|
||||
interface ModelStats {
|
||||
contextLength: number;
|
||||
promptTokenPrice: number;
|
||||
completionTokenPrice: number;
|
||||
speed: "fast" | "medium" | "slow";
|
||||
provider: "OpenAI";
|
||||
learnMoreUrl: string;
|
||||
}
|
||||
|
||||
export const modelStats: Record<SupportedModel, ModelStats> = {
|
||||
"gpt-4": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-0613": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
};
|
||||
26
src/server/scripts/replicate-test.ts
Normal file
26
src/server/scripts/replicate-test.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
// /* eslint-disable */
|
||||
|
||||
// import "dotenv/config";
|
||||
// import Replicate from "replicate";
|
||||
|
||||
// const replicate = new Replicate({
|
||||
// auth: process.env.REPLICATE_API_TOKEN || "",
|
||||
// });
|
||||
|
||||
// console.log("going to run");
|
||||
// const prediction = await replicate.predictions.create({
|
||||
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||
// input: {
|
||||
// prompt: "...",
|
||||
// },
|
||||
// });
|
||||
|
||||
// console.log("waiting");
|
||||
// setInterval(() => {
|
||||
// replicate.predictions.get(prediction.id).then((prediction) => {
|
||||
// console.log(prediction.output);
|
||||
// });
|
||||
// }, 500);
|
||||
// // const output = await replicate.wait(prediction, {});
|
||||
|
||||
// // console.log(output);
|
||||
31
src/server/tasks/defineTask.ts
Normal file
31
src/server/tasks/defineTask.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
// Import necessary dependencies
|
||||
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
// Define the defineTask function
|
||||
function defineTask<TPayload>(
|
||||
taskIdentifier: string,
|
||||
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||
) {
|
||||
const enqueue = async (payload: TPayload) => {
|
||||
console.log("Enqueuing task", taskIdentifier, payload);
|
||||
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
||||
};
|
||||
|
||||
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
|
||||
return taskHandler(payload, helpers);
|
||||
};
|
||||
|
||||
const task = {
|
||||
identifier: taskIdentifier,
|
||||
handler: handler as Task,
|
||||
};
|
||||
|
||||
return {
|
||||
enqueue,
|
||||
task,
|
||||
};
|
||||
}
|
||||
|
||||
export default defineTask;
|
||||
182
src/server/tasks/queryLLM.task.ts
Normal file
182
src/server/tasks/queryLLM.task.ts
Normal file
@@ -0,0 +1,182 @@
|
||||
import crypto from "crypto";
|
||||
import { prisma } from "~/server/db";
|
||||
import defineTask from "./defineTask";
|
||||
import { type CompletionResponse, getOpenAIChatCompletion } from "../utils/getCompletion";
|
||||
import { type JSONSerializable } from "../types";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import { shouldStream } from "../utils/shouldStream";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import { constructPrompt } from "../utils/constructPrompt";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
|
||||
const MAX_AUTO_RETRIES = 10;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
const getCompletionWithRetries = async (
|
||||
cellId: string,
|
||||
payload: JSONSerializable,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> => {
|
||||
let modelResponse: CompletionResponse | null = null;
|
||||
try {
|
||||
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
|
||||
modelResponse = await getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
channel,
|
||||
);
|
||||
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
|
||||
return modelResponse;
|
||||
}
|
||||
const delay = calculateDelay(i);
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
errorMessage: "Rate limit exceeded",
|
||||
statusCode: 429,
|
||||
retryTime: new Date(Date.now() + delay),
|
||||
},
|
||||
});
|
||||
// TODO: Maybe requeue the job so other jobs can run in the future?
|
||||
await sleep(delay);
|
||||
}
|
||||
throw new Error("Max retries limit reached");
|
||||
} catch (error: unknown) {
|
||||
return {
|
||||
statusCode: modelResponse?.statusCode ?? 500,
|
||||
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
|
||||
output: null,
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export type queryLLMJob = {
|
||||
scenarioVariantCellId: string;
|
||||
};
|
||||
|
||||
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
const { scenarioVariantCellId } = task;
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: { id: scenarioVariantCellId },
|
||||
include: { modelOutput: true },
|
||||
});
|
||||
if (!cell) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Cell not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// If cell is not pending, then some other job is already processing it
|
||||
if (cell.retrievalStatus !== "PENDING") {
|
||||
return;
|
||||
}
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
retrievalStatus: "IN_PROGRESS",
|
||||
},
|
||||
});
|
||||
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: { id: cell.promptVariantId },
|
||||
});
|
||||
if (!variant) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Prompt Variant not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: { id: cell.testScenarioId },
|
||||
});
|
||||
if (!scenario) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Scenario not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
|
||||
const streamingEnabled = shouldStream(prompt);
|
||||
let streamingChannel;
|
||||
|
||||
if (streamingEnabled) {
|
||||
streamingChannel = generateChannel();
|
||||
// Save streaming channel so that UI can connect to it
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
streamingChannel,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const modelResponse = await getCompletionWithRetries(
|
||||
scenarioVariantCellId,
|
||||
prompt,
|
||||
streamingChannel,
|
||||
);
|
||||
|
||||
let modelOutput = null;
|
||||
if (modelResponse.statusCode === 200) {
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
modelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
scenarioVariantCellId,
|
||||
inputHash,
|
||||
output: modelResponse.output as unknown as Prisma.InputJsonObject,
|
||||
timeToComplete: modelResponse.timeToComplete,
|
||||
promptTokens: modelResponse.promptTokens,
|
||||
completionTokens: modelResponse.completionTokens,
|
||||
cost: modelResponse.cost,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: modelResponse.statusCode,
|
||||
errorMessage: modelResponse.errorMessage,
|
||||
streamingChannel: null,
|
||||
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
|
||||
modelOutput: {
|
||||
connect: {
|
||||
id: modelOutput?.id,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (modelOutput) {
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||
}
|
||||
});
|
||||
40
src/server/tasks/worker.ts
Normal file
40
src/server/tasks/worker.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { type TaskList, run } from "graphile-worker";
|
||||
import "dotenv/config";
|
||||
|
||||
import { env } from "~/env.mjs";
|
||||
import { queryLLM } from "./queryLLM.task";
|
||||
|
||||
const registeredTasks = [queryLLM];
|
||||
|
||||
const taskList = registeredTasks.reduce((acc, task) => {
|
||||
acc[task.task.identifier] = task.task.handler;
|
||||
return acc;
|
||||
}, {} as TaskList);
|
||||
|
||||
async function main() {
|
||||
// Run a worker to execute jobs:
|
||||
const runner = await run({
|
||||
connectionString: env.DATABASE_URL,
|
||||
concurrency: 20,
|
||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||
noHandleSignals: false,
|
||||
pollInterval: 1000,
|
||||
// you can set the taskList or taskDirectory but not both
|
||||
taskList,
|
||||
// or:
|
||||
// taskDirectory: `${__dirname}/tasks`,
|
||||
});
|
||||
|
||||
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
||||
// "unhandled rejection" errors causing a process crash in the event of
|
||||
// something going wrong.
|
||||
await runner.promise;
|
||||
|
||||
// If the worker exits (whether through fatal error or otherwise), the above
|
||||
// promise will resolve/reject.
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Unhandled error occurred running worker: ", err);
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -7,10 +7,8 @@ test.skip("constructPrompt", async () => {
|
||||
constructFn: `prompt = { "fooz": "bar" }`,
|
||||
},
|
||||
{
|
||||
variableValues: {
|
||||
foo: "bar",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
console.log(constructed);
|
||||
|
||||
@@ -6,12 +6,10 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function constructPrompt(
|
||||
variant: Pick<PromptVariant, "constructFn">,
|
||||
testScenario: Pick<TestScenario, "variableValues">,
|
||||
scenario: TestScenario["variableValues"],
|
||||
): Promise<JSONSerializable> {
|
||||
const scenario = testScenario.variableValues as JSONSerializable;
|
||||
|
||||
const code = `
|
||||
const scenario = ${JSON.stringify(scenario, null, 2)};
|
||||
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||
let prompt
|
||||
|
||||
${variant.constructFn}
|
||||
|
||||
123
src/server/utils/deriveNewContructFn.ts
Normal file
123
src/server/utils/deriveNewContructFn.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { type SupportedModel } from "../types";
|
||||
import ivm from "isolated-vm";
|
||||
import dedent from "dedent";
|
||||
import { openai } from "./openai";
|
||||
import { getApiShapeForModel } from "./getTypesForModel";
|
||||
import { isObject } from "lodash-es";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function deriveNewConstructFn(
|
||||
originalVariant: PromptVariant | null,
|
||||
newModel?: SupportedModel,
|
||||
instructions?: string,
|
||||
) {
|
||||
if (originalVariant && !newModel && !instructions) {
|
||||
return originalVariant.constructFn;
|
||||
}
|
||||
if (originalVariant && (newModel || instructions)) {
|
||||
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
|
||||
}
|
||||
return dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Return 'Hello, world!'",
|
||||
}
|
||||
]
|
||||
}`;
|
||||
}
|
||||
|
||||
const NUM_RETRIES = 5;
|
||||
const requestUpdatedPromptFunction = async (
|
||||
originalVariant: PromptVariant,
|
||||
newModel?: SupportedModel,
|
||||
instructions?: string,
|
||||
) => {
|
||||
const originalModel = originalVariant.model as SupportedModel;
|
||||
let newContructionFn = "";
|
||||
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||
try {
|
||||
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming.Message[] = [
|
||||
{
|
||||
role: "system",
|
||||
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
||||
getApiShapeForModel(originalModel),
|
||||
null,
|
||||
2,
|
||||
)}`,
|
||||
},
|
||||
];
|
||||
if (newModel) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
|
||||
});
|
||||
}
|
||||
if (instructions) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Follow these instructions: ${instructions}`,
|
||||
});
|
||||
}
|
||||
messages.push({
|
||||
role: "user",
|
||||
content:
|
||||
"The prompt variable has already been declared, so do not declare it again. Rewrite the entire prompt constructor function.",
|
||||
});
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages,
|
||||
functions: [
|
||||
{
|
||||
name: "update_prompt_constructor_function",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
new_prompt_function: {
|
||||
type: "string",
|
||||
description: "The new prompt function, runnable in typescript",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "update_prompt_constructor_function",
|
||||
},
|
||||
});
|
||||
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
|
||||
|
||||
const code = `
|
||||
global.contructPromptFunctionArgs = ${argString};
|
||||
`;
|
||||
|
||||
const context = await isolate.createContext();
|
||||
|
||||
const jail = context.global;
|
||||
await jail.set("global", jail.derefInto());
|
||||
|
||||
const script = await isolate.compileScript(code);
|
||||
|
||||
await script.run(context);
|
||||
const contructPromptFunctionArgs = (await context.global.get(
|
||||
"contructPromptFunctionArgs",
|
||||
)) as ivm.Reference;
|
||||
|
||||
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
|
||||
|
||||
if (args && isObject(args) && "new_prompt_function" in args) {
|
||||
newContructionFn = args.new_prompt_function as string;
|
||||
break;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
return newContructionFn;
|
||||
};
|
||||
6
src/server/utils/error.ts
Normal file
6
src/server/utils/error.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export default function userError(message: string): { status: "error"; message: string } {
|
||||
return {
|
||||
status: "error",
|
||||
message,
|
||||
};
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate } from "./fillTemplate";
|
||||
|
||||
export const evaluateOutput = (
|
||||
modelOutput: ModelOutput,
|
||||
scenario: TestScenario,
|
||||
evaluation: Evaluation,
|
||||
): boolean => {
|
||||
const output = modelOutput.output as unknown as ChatCompletion;
|
||||
const message = output?.choices?.[0]?.message;
|
||||
|
||||
if (!message) return false;
|
||||
|
||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||
|
||||
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
|
||||
|
||||
let match;
|
||||
|
||||
switch (evaluation.matchType) {
|
||||
case "CONTAINS":
|
||||
match = stringifiedMessage.match(matchRegex) !== null;
|
||||
break;
|
||||
case "DOES_NOT_CONTAIN":
|
||||
match = stringifiedMessage.match(matchRegex) === null;
|
||||
break;
|
||||
}
|
||||
|
||||
return match;
|
||||
};
|
||||
@@ -1,103 +1,79 @@
|
||||
import { type Evaluation } from "@prisma/client";
|
||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { evaluateOutput } from "./evaluateOutput";
|
||||
import { runOneEval } from "./runOneEval";
|
||||
import { type Scenario } from "~/components/OutputsTable/types";
|
||||
|
||||
export const reevaluateVariant = async (variantId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: { id: variantId },
|
||||
});
|
||||
if (!variant) return;
|
||||
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId: variant.experimentId },
|
||||
});
|
||||
|
||||
const modelOutputs = await prisma.modelOutput.findMany({
|
||||
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelOutput);
|
||||
return await prisma.outputEvaluation.upsert({
|
||||
where: {
|
||||
promptVariantId: variantId,
|
||||
statusCode: { notIn: [429] },
|
||||
testScenario: { visible: true },
|
||||
},
|
||||
include: { testScenario: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => {
|
||||
const passCount = modelOutputs.filter((output) =>
|
||||
evaluateOutput(output, output.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = modelOutputs.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
evaluationId_promptVariantId: {
|
||||
modelOutputId_evaluationId: {
|
||||
modelOutputId: modelOutput.id,
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variantId,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
modelOutputId: modelOutput.id,
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variantId,
|
||||
passCount,
|
||||
failCount,
|
||||
...result,
|
||||
},
|
||||
update: {
|
||||
passCount,
|
||||
failCount,
|
||||
...result,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
||||
const variants = await prisma.promptVariant.findMany({
|
||||
where: { experimentId: evaluation.experimentId, visible: true },
|
||||
});
|
||||
|
||||
const modelOutputs = await prisma.modelOutput.findMany({
|
||||
where: {
|
||||
promptVariantId: { in: variants.map((v) => v.id) },
|
||||
testScenario: { visible: true },
|
||||
statusCode: { notIn: [429] },
|
||||
},
|
||||
include: { testScenario: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
variants.map(async (variant) => {
|
||||
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
|
||||
const passCount = outputs.filter((output) =>
|
||||
evaluateOutput(output, output.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = outputs.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
evaluationId_promptVariantId: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variant.id,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variant.id,
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
update: {
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
export const reevaluateAll = async (experimentId: string) => {
|
||||
export const runEvalsForOutput = async (
|
||||
experimentId: string,
|
||||
scenario: Scenario,
|
||||
modelOutput: ModelOutput,
|
||||
) => {
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(evaluations.map(reevaluateEvaluation));
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
||||
);
|
||||
};
|
||||
|
||||
export const runAllEvals = async (experimentId: string) => {
|
||||
const outputs = await prisma.modelOutput.findMany({
|
||||
where: {
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
include: {
|
||||
scenarioVariantCell: {
|
||||
include: {
|
||||
testScenario: true,
|
||||
},
|
||||
},
|
||||
outputEvaluation: true,
|
||||
},
|
||||
});
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
outputs.map(async (output) => {
|
||||
const unrunEvals = evals.filter(
|
||||
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
unrunEvals.map(async (evaluation) => {
|
||||
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
||||
}),
|
||||
);
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,6 +2,16 @@ import { type JSONSerializable } from "../types";
|
||||
|
||||
export type VariableMap = Record<string, string>;
|
||||
|
||||
// Escape quotes to match the way we encode JSON
|
||||
export function escapeQuotes(str: string) {
|
||||
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
|
||||
}
|
||||
|
||||
// Escape regex special characters
|
||||
export function escapeRegExp(str: string) {
|
||||
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
|
||||
}
|
||||
|
||||
export function fillTemplate(template: string, variables: VariableMap): string {
|
||||
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
||||
}
|
||||
|
||||
77
src/server/utils/generateNewCell.ts
Normal file
77
src/server/utils/generateNewCell.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import crypto from "crypto";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||
import { constructPrompt } from "./constructPrompt";
|
||||
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: variantId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: scenarioId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (cell) return cell;
|
||||
|
||||
cell = await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
},
|
||||
});
|
||||
|
||||
let newModelOutput;
|
||||
|
||||
if (matchingModelOutput) {
|
||||
newModelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
scenarioVariantCellId: cell.id,
|
||||
inputHash,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
timeToComplete: matchingModelOutput.timeToComplete,
|
||||
cost: matchingModelOutput.cost,
|
||||
promptTokens: matchingModelOutput.promptTokens,
|
||||
completionTokens: matchingModelOutput.completionTokens,
|
||||
createdAt: matchingModelOutput.createdAt,
|
||||
updatedAt: matchingModelOutput.updatedAt,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
cell = await queueLLMRetrievalTask(cell.id);
|
||||
}
|
||||
|
||||
return { ...cell, modelOutput: newModelOutput };
|
||||
};
|
||||
@@ -1,56 +1,26 @@
|
||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||
import { isObject } from "lodash";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { isObject } from "lodash-es";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type JSONSerializable, OpenAIChatModel } from "../types";
|
||||
import { type SupportedModel, type OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { getModelName } from "./getModelName";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import { modelStats } from "../modelStats";
|
||||
|
||||
env;
|
||||
|
||||
type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
export type CompletionResponse = {
|
||||
output: ChatCompletion | null;
|
||||
statusCode: number;
|
||||
errorMessage: string | null;
|
||||
timeToComplete: number;
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
cost?: number;
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
payload: JSONSerializable,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
const modelName = getModelName(payload);
|
||||
if (!modelName)
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid payload provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
if (modelName in OpenAIChatModel) {
|
||||
return getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
env.OPENAI_API_KEY,
|
||||
channel,
|
||||
);
|
||||
}
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid model provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
apiKey: string,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||
@@ -60,13 +30,13 @@ export async function getOpenAIChatCompletion(
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
const resp: CompletionResponse = {
|
||||
output: Prisma.JsonNull,
|
||||
output: null,
|
||||
errorMessage: null,
|
||||
statusCode: response.status,
|
||||
timeToComplete: 0,
|
||||
@@ -83,7 +53,7 @@ export async function getOpenAIChatCompletion(
|
||||
}
|
||||
})().catch((err) => console.error(err));
|
||||
if (finalOutput) {
|
||||
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
||||
resp.output = finalOutput;
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
}
|
||||
} else {
|
||||
@@ -119,6 +89,13 @@ export async function getOpenAIChatCompletion(
|
||||
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
||||
}
|
||||
}
|
||||
|
||||
const stats = modelStats[resp.output?.model as SupportedModel];
|
||||
if (stats && resp.promptTokens && resp.completionTokens) {
|
||||
resp.cost =
|
||||
resp.promptTokens * stats.promptTokenPrice +
|
||||
resp.completionTokens * stats.completionTokenPrice;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (response.ok) {
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
import { isObject } from "lodash";
|
||||
import { type JSONSerializable, type SupportedModel } from "../types";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
|
||||
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
|
||||
if (!isObject(config)) return null;
|
||||
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
||||
return null;
|
||||
}
|
||||
7
src/server/utils/getTypesForModel.ts
Normal file
7
src/server/utils/getTypesForModel.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { OpenAIChatModel, type SupportedModel } from "../types";
|
||||
import openAIChatApiShape from "~/codegen/openai.types.ts.txt";
|
||||
|
||||
export const getApiShapeForModel = (model: SupportedModel) => {
|
||||
if (model in OpenAIChatModel) return openAIChatApiShape;
|
||||
return "";
|
||||
};
|
||||
@@ -1,4 +1,4 @@
|
||||
import { omit } from "lodash";
|
||||
import { omit } from "lodash-es";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
import OpenAI from "openai";
|
||||
|
||||
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { prisma } from "../db";
|
||||
import { queryLLM } from "../tasks/queryLLM.task";
|
||||
|
||||
export const queueLLMRetrievalTask = async (cellId: string) => {
|
||||
const updatedCell = await prisma.scenarioVariantCell.update({
|
||||
where: {
|
||||
id: cellId,
|
||||
},
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
// @ts-expect-error we aren't passing the helpers but that's ok
|
||||
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
||||
|
||||
return updatedCell;
|
||||
};
|
||||
65
src/server/utils/reorderPromptVariants.ts
Normal file
65
src/server/utils/reorderPromptVariants.ts
Normal file
@@ -0,0 +1,65 @@
|
||||
import { prisma } from "~/server/db";
|
||||
|
||||
export const reorderPromptVariants = async (
|
||||
movedId: string,
|
||||
stationaryTargetId: string,
|
||||
alwaysInsertRight?: boolean,
|
||||
) => {
|
||||
const moved = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: movedId,
|
||||
},
|
||||
});
|
||||
|
||||
const target = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: stationaryTargetId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!moved || !target || moved.experimentId !== target.experimentId) {
|
||||
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
|
||||
}
|
||||
|
||||
const visibleItems = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: moved.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
});
|
||||
|
||||
// Remove the moved item from its current position
|
||||
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
|
||||
|
||||
// Find the index of the moved item and the target item
|
||||
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
|
||||
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
|
||||
|
||||
// Determine the new index for the moved item
|
||||
let newIndex;
|
||||
if (movedIndex < targetIndex || alwaysInsertRight) {
|
||||
newIndex = targetIndex + 1; // Insert after the target item
|
||||
} else {
|
||||
newIndex = targetIndex; // Insert before the target item
|
||||
}
|
||||
|
||||
// Insert the moved item at the new position
|
||||
orderedItems.splice(newIndex, 0, moved);
|
||||
|
||||
// Now, we need to update all the items with their new sortIndex
|
||||
await prisma.$transaction(
|
||||
orderedItems.map((item, index) => {
|
||||
return prisma.promptVariant.update({
|
||||
where: {
|
||||
id: item.id,
|
||||
},
|
||||
data: {
|
||||
sortIndex: index,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
95
src/server/utils/runOneEval.ts
Normal file
95
src/server/utils/runOneEval.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||
import { openai } from "./openai";
|
||||
import dedent from "dedent";
|
||||
|
||||
export const runGpt4Eval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
message: ChatCompletion.Choice.Message,
|
||||
): Promise<{ result: number; details: string }> => {
|
||||
const output = await openai.chat.completions.create({
|
||||
model: "gpt-4-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: dedent`
|
||||
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
|
||||
---
|
||||
${evaluation.value}
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `The full output of the simpler message:\n---\n${JSON.stringify(
|
||||
message.content ?? message.function_call,
|
||||
null,
|
||||
2,
|
||||
)}`,
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "report_success",
|
||||
},
|
||||
functions: [
|
||||
{
|
||||
name: "report_success",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["thoughts", "success"],
|
||||
properties: {
|
||||
thoughts: {
|
||||
type: "string",
|
||||
description: "Explain your reasoning for considering this a pass or fail",
|
||||
},
|
||||
success: {
|
||||
type: "boolean",
|
||||
description:
|
||||
"Whether the simpler model successfully completed the task for this scenario",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
try {
|
||||
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
|
||||
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return { result: 0, details: "Error parsing GPT-4 output" };
|
||||
}
|
||||
};
|
||||
|
||||
export const runOneEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
modelOutput: ModelOutput,
|
||||
): Promise<{ result: number; details?: string }> => {
|
||||
const output = modelOutput.output as unknown as ChatCompletion;
|
||||
|
||||
const message = output?.choices?.[0]?.message;
|
||||
|
||||
if (!message) return { result: 0 };
|
||||
|
||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||
|
||||
const matchRegex = escapeRegExp(
|
||||
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
||||
);
|
||||
|
||||
switch (evaluation.evalType) {
|
||||
case "CONTAINS":
|
||||
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
|
||||
case "DOES_NOT_CONTAIN":
|
||||
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
|
||||
case "GPT4_EVAL":
|
||||
return await runGpt4Eval(evaluation, scenario, message);
|
||||
}
|
||||
};
|
||||
7
src/server/utils/shouldStream.ts
Normal file
7
src/server/utils/shouldStream.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { isObject } from "lodash-es";
|
||||
import { type JSONSerializable } from "../types";
|
||||
|
||||
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
|
||||
return shouldStream;
|
||||
};
|
||||
1
src/server/utils/sleep.ts
Normal file
1
src/server/utils/sleep.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
19
src/server/utils/userOrg.ts
Normal file
19
src/server/utils/userOrg.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { prisma } from "~/server/db";
|
||||
|
||||
export default async function userOrg(userId: string) {
|
||||
return await prisma.organization.upsert({
|
||||
where: {
|
||||
personalOrgUserId: userId,
|
||||
},
|
||||
update: {},
|
||||
create: {
|
||||
personalOrgUserId: userId,
|
||||
OrganizationUser: {
|
||||
create: {
|
||||
userId: userId,
|
||||
role: "ADMIN",
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -2,12 +2,9 @@ import { type RouterOutputs } from "~/utils/api";
|
||||
import { type SliceCreator } from "./store";
|
||||
import loader from "@monaco-editor/loader";
|
||||
import openAITypes from "~/codegen/openai.types.ts.txt";
|
||||
import prettier from "prettier/standalone";
|
||||
import parserTypescript from "prettier/plugins/typescript";
|
||||
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||
|
||||
// @ts-expect-error for some reason missing from types
|
||||
import parserEstree from "prettier/plugins/estree";
|
||||
import { type languages } from "monaco-editor/esm/vs/editor/editor.api";
|
||||
export const editorBackground = "#fafafa";
|
||||
|
||||
export type SharedVariantEditorSlice = {
|
||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||
@@ -17,29 +14,12 @@ export type SharedVariantEditorSlice = {
|
||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
||||
};
|
||||
|
||||
const customFormatter: languages.DocumentFormattingEditProvider = {
|
||||
provideDocumentFormattingEdits: async (model) => {
|
||||
const val = model.getValue();
|
||||
console.log("going to format!", val);
|
||||
const text = await prettier.format(val, {
|
||||
parser: "typescript",
|
||||
plugins: [parserTypescript, parserEstree],
|
||||
// We're showing these in pretty narrow panes so let's keep the print width low
|
||||
printWidth: 60,
|
||||
});
|
||||
|
||||
return [
|
||||
{
|
||||
range: model.getFullModelRange(),
|
||||
text,
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||
monaco: loader.__getMonacoInstance(),
|
||||
loadMonaco: async () => {
|
||||
// We only want to run this client-side
|
||||
if (typeof window === "undefined") return;
|
||||
|
||||
const monaco = await loader.init();
|
||||
|
||||
monaco.editor.defineTheme("customTheme", {
|
||||
@@ -47,12 +27,13 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
"editor.background": "#fafafa",
|
||||
"editor.background": editorBackground,
|
||||
},
|
||||
});
|
||||
|
||||
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
|
||||
allowNonTsExtensions: true,
|
||||
strictNullChecks: true,
|
||||
lib: ["esnext"],
|
||||
});
|
||||
|
||||
@@ -66,7 +47,16 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
monaco.Uri.parse("file:///openai.types.ts"),
|
||||
);
|
||||
|
||||
monaco.languages.registerDocumentFormattingEditProvider("typescript", customFormatter);
|
||||
monaco.languages.registerDocumentFormattingEditProvider("typescript", {
|
||||
provideDocumentFormattingEdits: async (model) => {
|
||||
return [
|
||||
{
|
||||
range: model.getFullModelRange(),
|
||||
text: await formatPromptConstructor(model.getValue()),
|
||||
},
|
||||
];
|
||||
},
|
||||
});
|
||||
|
||||
set((state) => {
|
||||
state.sharedVariantEditor.monaco = monaco;
|
||||
@@ -95,7 +85,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
)} as const;
|
||||
|
||||
type Scenario = typeof scenarios[number];
|
||||
declare var scenario: Scenario | null;
|
||||
declare var scenario: Scenario | { [key: string]: string };
|
||||
`;
|
||||
|
||||
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));
|
||||
|
||||
49
src/utils/accessControl.ts
Normal file
49
src/utils/accessControl.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { OrganizationUserRole } from "@prisma/client";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { type TRPCContext } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
|
||||
// No-op method for protected routes that really should be accessible to anyone.
|
||||
export const requireNothing = (ctx: TRPCContext) => {
|
||||
ctx.markAccessControlRun();
|
||||
};
|
||||
|
||||
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
|
||||
await prisma.experiment.findFirst({
|
||||
where: { id: experimentId },
|
||||
});
|
||||
|
||||
// Right now all experiments are publicly viewable, so this is a no-op.
|
||||
ctx.markAccessControlRun();
|
||||
};
|
||||
|
||||
export const canModifyExperiment = async (experimentId: string, userId: string) => {
|
||||
const experiment = await prisma.experiment.findFirst({
|
||||
where: {
|
||||
id: experimentId,
|
||||
organization: {
|
||||
OrganizationUser: {
|
||||
some: {
|
||||
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
||||
userId,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return !!experiment;
|
||||
};
|
||||
|
||||
export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
|
||||
const userId = ctx.session?.user.id;
|
||||
if (!userId) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
|
||||
if (!(await canModifyExperiment(experimentId, userId))) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
|
||||
ctx.markAccessControlRun();
|
||||
};
|
||||
@@ -1,46 +0,0 @@
|
||||
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
|
||||
|
||||
const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||
"gpt-4": 0.00003,
|
||||
"gpt-4-0613": 0.00003,
|
||||
"gpt-4-32k": 0.00006,
|
||||
"gpt-4-32k-0613": 0.00006,
|
||||
"gpt-3.5-turbo": 0.0000015,
|
||||
"gpt-3.5-turbo-0613": 0.0000015,
|
||||
"gpt-3.5-turbo-16k": 0.000003,
|
||||
"gpt-3.5-turbo-16k-0613": 0.000003,
|
||||
};
|
||||
|
||||
const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||
"gpt-4": 0.00006,
|
||||
"gpt-4-0613": 0.00006,
|
||||
"gpt-4-32k": 0.00012,
|
||||
"gpt-4-32k-0613": 0.00012,
|
||||
"gpt-3.5-turbo": 0.000002,
|
||||
"gpt-3.5-turbo-0613": 0.000002,
|
||||
"gpt-3.5-turbo-16k": 0.000004,
|
||||
"gpt-3.5-turbo-16k-0613": 0.000004,
|
||||
};
|
||||
|
||||
export const calculateTokenCost = (
|
||||
model: SupportedModel | null,
|
||||
numTokens: number,
|
||||
isCompletion = false,
|
||||
) => {
|
||||
if (!model) return 0;
|
||||
if (model in OpenAIChatModel) {
|
||||
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
const calculateOpenAIChatTokenCost = (
|
||||
model: OpenAIChatModel,
|
||||
numTokens: number,
|
||||
isCompletion: boolean,
|
||||
) => {
|
||||
const tokensToDollars = isCompletion
|
||||
? openAICompletionTokensToDollars[model]
|
||||
: openAIPromptTokensToDollars[model];
|
||||
return tokensToDollars * numTokens;
|
||||
};
|
||||
@@ -5,16 +5,5 @@ import relativeTime from "dayjs/plugin/relativeTime";
|
||||
dayjs.extend(duration);
|
||||
dayjs.extend(relativeTime);
|
||||
|
||||
export const formatTimePast = (date: Date) => {
|
||||
const now = dayjs();
|
||||
const dayDiff = Math.floor(now.diff(date, "day"));
|
||||
if (dayDiff > 0) return dayjs.duration(-dayDiff, "days").humanize(true);
|
||||
|
||||
const hourDiff = Math.floor(now.diff(date, "hour"));
|
||||
if (hourDiff > 0) return dayjs.duration(-hourDiff, "hours").humanize(true);
|
||||
|
||||
const minuteDiff = Math.floor(now.diff(date, "minute"));
|
||||
if (minuteDiff > 0) return dayjs.duration(-minuteDiff, "minutes").humanize(true);
|
||||
|
||||
return "a few seconds ago";
|
||||
};
|
||||
export const formatTimePast = (date: Date) =>
|
||||
dayjs.duration(dayjs(date).diff(dayjs())).humanize(true);
|
||||
|
||||
10
src/utils/formatPromptConstructor.test.ts
Normal file
10
src/utils/formatPromptConstructor.test.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { expect, test } from "vitest";
|
||||
import { stripTypes } from "./formatPromptConstructor";
|
||||
|
||||
test("stripTypes", () => {
|
||||
expect(stripTypes(`const foo: string = "bar";`)).toBe(`const foo = "bar";`);
|
||||
});
|
||||
|
||||
test("stripTypes with invalid syntax", () => {
|
||||
expect(stripTypes(`asdf foo: string = "bar"`)).toBe(`asdf foo: string = "bar"`);
|
||||
});
|
||||
31
src/utils/formatPromptConstructor.ts
Normal file
31
src/utils/formatPromptConstructor.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
import prettier from "prettier/standalone";
|
||||
import parserTypescript from "prettier/plugins/typescript";
|
||||
|
||||
// @ts-expect-error for some reason missing from types
|
||||
import parserEstree from "prettier/plugins/estree";
|
||||
|
||||
import * as babel from "@babel/standalone";
|
||||
|
||||
export function stripTypes(tsCode: string): string {
|
||||
const options = {
|
||||
presets: ["typescript"],
|
||||
filename: "file.ts",
|
||||
};
|
||||
|
||||
try {
|
||||
const result = babel.transform(tsCode, options);
|
||||
return result.code ?? tsCode;
|
||||
} catch (error) {
|
||||
console.error("Error stripping types", error);
|
||||
return tsCode;
|
||||
}
|
||||
}
|
||||
|
||||
export default async function formatPromptConstructor(code: string): Promise<string> {
|
||||
return await prettier.format(stripTypes(code), {
|
||||
parser: "typescript",
|
||||
plugins: [parserTypescript, parserEstree],
|
||||
// We're showing these in pretty narrow panes so let's keep the print width low
|
||||
printWidth: 60,
|
||||
});
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useRouter } from "next/router";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
|
||||
export const useExperiment = () => {
|
||||
@@ -12,6 +12,10 @@ export const useExperiment = () => {
|
||||
return experiment;
|
||||
};
|
||||
|
||||
export const useExperimentAccess = () => {
|
||||
return useExperiment().data?.access ?? { canView: false, canModify: false };
|
||||
};
|
||||
|
||||
type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>;
|
||||
|
||||
export function useHandledAsyncCallback<T extends unknown[], U>(
|
||||
@@ -49,3 +53,43 @@ export const useModifierKeyLabel = () => {
|
||||
}, []);
|
||||
return label;
|
||||
};
|
||||
|
||||
interface Dimensions {
|
||||
left: number;
|
||||
top: number;
|
||||
right: number;
|
||||
bottom: number;
|
||||
width: number;
|
||||
height: number;
|
||||
x: number;
|
||||
y: number;
|
||||
}
|
||||
|
||||
// get dimensions of an element
|
||||
export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | undefined] => {
|
||||
const ref = useRef<HTMLElement>(null);
|
||||
const [dimensions, setDimensions] = useState<Dimensions | undefined>();
|
||||
|
||||
useEffect(() => {
|
||||
if (ref.current) {
|
||||
const observer = new ResizeObserver((entries) => {
|
||||
entries.forEach((entry) => {
|
||||
setDimensions(entry.contentRect);
|
||||
});
|
||||
});
|
||||
|
||||
const observedRef = ref.current;
|
||||
|
||||
observer.observe(observedRef);
|
||||
|
||||
// Cleanup the observer on component unmount
|
||||
return () => {
|
||||
if (observedRef) {
|
||||
observer.unobserve(observedRef);
|
||||
}
|
||||
};
|
||||
}
|
||||
}, []);
|
||||
|
||||
return [ref, dimensions];
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { extendTheme } from "@chakra-ui/react";
|
||||
import "@fontsource/inconsolata";
|
||||
|
||||
const systemFont =
|
||||
'ui-sans-serif, -apple-system, "system-ui", "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol"';
|
||||
|
||||
@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
|
||||
|
||||
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
||||
|
||||
export default function useSocket(channel?: string) {
|
||||
export default function useSocket(channel?: string | null) {
|
||||
const socketRef = useRef<Socket>();
|
||||
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user