Compare commits
34 Commits
autoformat
...
fix-pretti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
999a4c08fa | ||
|
|
374d0237ee | ||
|
|
b1f873623d | ||
|
|
4131aa67d0 | ||
|
|
8e7a6d3ae2 | ||
|
|
7d41e94ca2 | ||
|
|
011b12abb9 | ||
|
|
1ba18015bc | ||
|
|
54369dba54 | ||
|
|
6b84a59372 | ||
|
|
8db8aeacd3 | ||
|
|
64bd71e370 | ||
|
|
ca21a7af06 | ||
|
|
3b99b7bd2b | ||
|
|
0c3bdbe4f2 | ||
|
|
74c201d3a8 | ||
|
|
ab9c721d09 | ||
|
|
0a2578a1d8 | ||
|
|
1bebaff386 | ||
|
|
3bf5eaf4a2 | ||
|
|
ded97f8bb9 | ||
|
|
26ee8698be | ||
|
|
b98eb9b729 | ||
|
|
032c07ec65 | ||
|
|
80c0d13bb9 | ||
|
|
f7c94be3f6 | ||
|
|
c3e85607e0 | ||
|
|
cd5927b8f5 | ||
|
|
731406d1f4 | ||
|
|
3c59e4b774 | ||
|
|
972b1f2333 | ||
|
|
7321f3deda | ||
|
|
2bd41fdfbf | ||
|
|
a5378b106b |
53
.github/workflows/ci.yaml
vendored
Normal file
53
.github/workflows/ci.yaml
vendored
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
name: CI checks
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run-checks:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v2
|
||||||
|
with:
|
||||||
|
node-version: "20"
|
||||||
|
|
||||||
|
- uses: pnpm/action-setup@v2
|
||||||
|
name: Install pnpm
|
||||||
|
id: pnpm-install
|
||||||
|
with:
|
||||||
|
version: 8.6.1
|
||||||
|
run_install: false
|
||||||
|
|
||||||
|
- name: Get pnpm store directory
|
||||||
|
id: pnpm-cache
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- uses: actions/cache@v3
|
||||||
|
name: Setup pnpm cache
|
||||||
|
with:
|
||||||
|
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||||
|
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-store-
|
||||||
|
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: pnpm install
|
||||||
|
|
||||||
|
- name: Check types
|
||||||
|
run: pnpm tsc
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: SKIP_ENV_VALIDATION=1 pnpm lint
|
||||||
|
|
||||||
|
- name: Check prettier
|
||||||
|
run: pnpm prettier . --check
|
||||||
1
.tool-versions
Normal file
1
.tool-versions
Normal file
@@ -0,0 +1 @@
|
|||||||
|
nodejs 20.2.0
|
||||||
19
package.json
19
package.json
@@ -3,17 +3,25 @@
|
|||||||
"type": "module",
|
"type": "module",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=20.0.0",
|
||||||
|
"pnpm": ">=8.6.1"
|
||||||
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"build": "next build",
|
"build": "next build",
|
||||||
"dev:next": "next dev",
|
"dev:next": "next dev",
|
||||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||||
|
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
||||||
"postinstall": "prisma generate",
|
"postinstall": "prisma generate",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"codegen": "tsx src/codegen/export-openai-types.ts"
|
"codegen": "tsx src/codegen/export-openai-types.ts",
|
||||||
|
"seed": "tsx prisma/seed.ts"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@babel/preset-typescript": "^7.22.5",
|
||||||
|
"@babel/standalone": "^7.22.9",
|
||||||
"@chakra-ui/next-js": "^2.1.4",
|
"@chakra-ui/next-js": "^2.1.4",
|
||||||
"@chakra-ui/react": "^2.7.1",
|
"@chakra-ui/react": "^2.7.1",
|
||||||
"@emotion/react": "^11.11.1",
|
"@emotion/react": "^11.11.1",
|
||||||
@@ -38,10 +46,11 @@
|
|||||||
"express": "^4.18.2",
|
"express": "^4.18.2",
|
||||||
"framer-motion": "^10.12.17",
|
"framer-motion": "^10.12.17",
|
||||||
"gpt-tokens": "^1.0.10",
|
"gpt-tokens": "^1.0.10",
|
||||||
|
"graphile-worker": "^0.13.0",
|
||||||
"immer": "^10.0.2",
|
"immer": "^10.0.2",
|
||||||
"isolated-vm": "^4.5.0",
|
"isolated-vm": "^4.5.0",
|
||||||
"json-stringify-pretty-compact": "^4.0.0",
|
"json-stringify-pretty-compact": "^4.0.0",
|
||||||
"lodash": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
"nextjs-routes": "^2.0.1",
|
"nextjs-routes": "^2.0.1",
|
||||||
@@ -63,11 +72,13 @@
|
|||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
||||||
|
"@types/babel__core": "^7.20.1",
|
||||||
|
"@types/babel__standalone": "^7.1.4",
|
||||||
"@types/chroma-js": "^2.4.0",
|
"@types/chroma-js": "^2.4.0",
|
||||||
"@types/cors": "^2.8.13",
|
"@types/cors": "^2.8.13",
|
||||||
"@types/eslint": "^8.37.0",
|
"@types/eslint": "^8.37.0",
|
||||||
"@types/express": "^4.17.17",
|
"@types/express": "^4.17.17",
|
||||||
"@types/lodash": "^4.14.195",
|
"@types/lodash-es": "^4.17.8",
|
||||||
"@types/node": "^18.16.0",
|
"@types/node": "^18.16.0",
|
||||||
"@types/pluralize": "^0.0.30",
|
"@types/pluralize": "^0.0.30",
|
||||||
"@types/react": "^18.2.6",
|
"@types/react": "^18.2.6",
|
||||||
@@ -90,6 +101,6 @@
|
|||||||
"initVersion": "7.14.0"
|
"initVersion": "7.14.0"
|
||||||
},
|
},
|
||||||
"prisma": {
|
"prisma": {
|
||||||
"seed": "tsx prisma/seed.ts"
|
"seed": "pnpm seed"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
912
pnpm-lock.yaml
generated
912
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
|||||||
|
-- Drop the foreign key constraints on the original ModelOutput
|
||||||
|
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
|
||||||
|
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
|
||||||
|
|
||||||
|
-- Rename the old table
|
||||||
|
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
|
||||||
|
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
|
||||||
|
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
|
||||||
|
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
|
||||||
|
|
||||||
|
-- Add the new fields to the renamed table
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
|
||||||
|
ALTER COLUMN "statusCode" DROP NOT NULL,
|
||||||
|
ALTER COLUMN "timeToComplete" DROP NOT NULL;
|
||||||
|
|
||||||
|
-- Create the new table
|
||||||
|
CREATE TABLE "ModelOutput" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"inputHash" TEXT NOT NULL,
|
||||||
|
"output" JSONB NOT NULL,
|
||||||
|
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
|
||||||
|
"promptTokens" INTEGER,
|
||||||
|
"completionTokens" INTEGER,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"scenarioVariantCellId" UUID
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Move inputHash index
|
||||||
|
DROP INDEX "ScenarioVariantCell_inputHash_idx";
|
||||||
|
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
|
||||||
|
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
|
||||||
|
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
|
||||||
|
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;
|
||||||
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- RenameEnum
|
||||||
|
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
|
||||||
|
|
||||||
|
-- AlterColumnType
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
|
||||||
|
|
||||||
|
-- SetNotNullConstraint
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterEnum
|
||||||
|
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
|
||||||
|
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
|
||||||
|
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
|
||||||
|
|
||||||
|
-- DropTable
|
||||||
|
DROP TABLE "EvaluationResult";
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "OutputEvaluation" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"result" DOUBLE PRECISION NOT NULL,
|
||||||
|
"details" TEXT,
|
||||||
|
"modelOutputId" UUID NOT NULL,
|
||||||
|
"evaluationId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -2,8 +2,7 @@
|
|||||||
// learn more about it in the docs: https://pris.ly/d/prisma-schema
|
// learn more about it in the docs: https://pris.ly/d/prisma-schema
|
||||||
|
|
||||||
generator client {
|
generator client {
|
||||||
provider = "prisma-client-js"
|
provider = "prisma-client-js"
|
||||||
previewFeatures = ["jsonProtocol"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
datasource db {
|
datasource db {
|
||||||
@@ -26,10 +25,11 @@ model Experiment {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model PromptVariant {
|
model PromptVariant {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
label String
|
|
||||||
|
|
||||||
|
label String
|
||||||
constructFn String
|
constructFn String
|
||||||
|
model String
|
||||||
|
|
||||||
uiId String @default(uuid()) @db.Uuid
|
uiId String @default(uuid()) @db.Uuid
|
||||||
visible Boolean @default(true)
|
visible Boolean @default(true)
|
||||||
@@ -38,10 +38,9 @@ model PromptVariant {
|
|||||||
experimentId String @db.Uuid
|
experimentId String @db.Uuid
|
||||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
ModelOutput ModelOutput[]
|
scenarioVariantCells ScenarioVariantCell[]
|
||||||
EvaluationResult EvaluationResult[]
|
|
||||||
|
|
||||||
@@index([uiId])
|
@@index([uiId])
|
||||||
}
|
}
|
||||||
@@ -58,9 +57,9 @@ model TestScenario {
|
|||||||
experimentId String @db.Uuid
|
experimentId String @db.Uuid
|
||||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
ModelOutput ModelOutput[]
|
scenarioVariantCells ScenarioVariantCell[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model TemplateVariable {
|
model TemplateVariable {
|
||||||
@@ -75,17 +74,28 @@ model TemplateVariable {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
model ModelOutput {
|
enum CellRetrievalStatus {
|
||||||
|
PENDING
|
||||||
|
IN_PROGRESS
|
||||||
|
COMPLETE
|
||||||
|
ERROR
|
||||||
|
}
|
||||||
|
|
||||||
|
model ScenarioVariantCell {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
inputHash String
|
inputHash String? // TODO: Remove once migration is complete
|
||||||
output Json
|
output Json? // TODO: Remove once migration is complete
|
||||||
statusCode Int
|
statusCode Int?
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
timeToComplete Int @default(0)
|
timeToComplete Int? @default(0) // TODO: Remove once migration is complete
|
||||||
|
retryTime DateTime?
|
||||||
|
streamingChannel String?
|
||||||
|
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||||
|
|
||||||
promptTokens Int? // Added promptTokens field
|
promptTokens Int? // TODO: Remove once migration is complete
|
||||||
completionTokens Int? // Added completionTokens field
|
completionTokens Int? // TODO: Remove once migration is complete
|
||||||
|
modelOutput ModelOutput?
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||||
@@ -97,45 +107,66 @@ model ModelOutput {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([promptVariantId, testScenarioId])
|
@@unique([promptVariantId, testScenarioId])
|
||||||
|
}
|
||||||
|
|
||||||
|
model ModelOutput {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
inputHash String
|
||||||
|
output Json
|
||||||
|
timeToComplete Int @default(0)
|
||||||
|
promptTokens Int?
|
||||||
|
completionTokens Int?
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
scenarioVariantCellId String @db.Uuid
|
||||||
|
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||||
|
outputEvaluation OutputEvaluation[]
|
||||||
|
|
||||||
|
@@unique([scenarioVariantCellId])
|
||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
}
|
}
|
||||||
|
|
||||||
enum EvaluationMatchType {
|
enum EvalType {
|
||||||
CONTAINS
|
CONTAINS
|
||||||
DOES_NOT_CONTAIN
|
DOES_NOT_CONTAIN
|
||||||
|
GPT4_EVAL
|
||||||
}
|
}
|
||||||
|
|
||||||
model Evaluation {
|
model Evaluation {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
name String
|
label String
|
||||||
matchString String
|
evalType EvalType
|
||||||
matchType EvaluationMatchType
|
value String
|
||||||
|
|
||||||
experimentId String @db.Uuid
|
experimentId String @db.Uuid
|
||||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
EvaluationResult EvaluationResult[]
|
OutputEvaluation OutputEvaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model EvaluationResult {
|
model OutputEvaluation {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
passCount Int
|
// Number between 0 (fail) and 1 (pass)
|
||||||
failCount Int
|
result Float
|
||||||
|
details String?
|
||||||
|
|
||||||
|
modelOutputId String @db.Uuid
|
||||||
|
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
evaluationId String @db.Uuid
|
evaluationId String @db.Uuid
|
||||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([evaluationId, promptVariantId])
|
@@unique([modelOutputId, evaluationId])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Necessary for Next auth
|
// Necessary for Next auth
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import dedent from "dedent";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
const experimentId = "11111111-1111-1111-1111-111111111111";
|
const experimentId = "11111111-1111-1111-1111-111111111111";
|
||||||
|
|
||||||
@@ -9,14 +11,14 @@ await prisma.experiment.deleteMany({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const experiment = await prisma.experiment.create({
|
await prisma.experiment.create({
|
||||||
data: {
|
data: {
|
||||||
id: experimentId,
|
id: experimentId,
|
||||||
label: "Country Capitals Example",
|
label: "Country Capitals Example",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.modelOutput.deleteMany({
|
await prisma.scenarioVariantCell.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
promptVariant: {
|
promptVariant: {
|
||||||
experimentId,
|
experimentId,
|
||||||
@@ -36,27 +38,35 @@ await prisma.promptVariant.createMany({
|
|||||||
experimentId,
|
experimentId,
|
||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
constructFn: `prompt = {
|
model: "gpt-3.5-turbo-0613",
|
||||||
model: "gpt-3.5-turbo-0613",
|
constructFn: dedent`
|
||||||
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
|
prompt = {
|
||||||
temperature: 0,
|
model: "gpt-3.5-turbo-0613",
|
||||||
}`,
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`What is the capital of ${"$"}{scenario.country}?\`
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId,
|
||||||
label: "Prompt Variant 2",
|
label: "Prompt Variant 2",
|
||||||
sortIndex: 1,
|
sortIndex: 1,
|
||||||
constructFn: `prompt = {
|
model: "gpt-3.5-turbo-0613",
|
||||||
model: "gpt-3.5-turbo-0613",
|
constructFn: dedent`
|
||||||
messages: [
|
prompt = {
|
||||||
{
|
model: "gpt-3.5-turbo-0613",
|
||||||
role: "user",
|
messages: [
|
||||||
content:
|
{
|
||||||
"What is the capital of {{country}}? Return just the city name and nothing else.",
|
role: "user",
|
||||||
},
|
content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
|
||||||
],
|
}
|
||||||
temperature: 0,
|
],
|
||||||
}`,
|
temperature: 0,
|
||||||
|
}`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
@@ -107,3 +117,26 @@ await prisma.testScenario.createMany({
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const variants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
variants
|
||||||
|
.flatMap((variant) =>
|
||||||
|
scenarios.map((scenario) => ({
|
||||||
|
promptVariantId: variant.id,
|
||||||
|
testScenarioId: scenario.id,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
||||||
|
);
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ await prisma.promptVariant.createMany({
|
|||||||
{
|
{
|
||||||
experimentId: functionCallsExperiment.id,
|
experimentId: functionCallsExperiment.id,
|
||||||
label: "No Fn Calls",
|
label: "No Fn Calls",
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
constructFn: `prompt = {
|
constructFn: `prompt = {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
@@ -30,6 +31,7 @@ await prisma.promptVariant.createMany({
|
|||||||
{
|
{
|
||||||
experimentId: functionCallsExperiment.id,
|
experimentId: functionCallsExperiment.id,
|
||||||
label: "Fn Calls",
|
label: "Fn Calls",
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
constructFn: `prompt = {
|
constructFn: `prompt = {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
@@ -92,6 +94,7 @@ await prisma.promptVariant.createMany({
|
|||||||
experimentId: redditExperiment.id,
|
experimentId: redditExperiment.id,
|
||||||
label: "3.5 Base",
|
label: "3.5 Base",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
constructFn: `prompt = {
|
constructFn: `prompt = {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
@@ -107,6 +110,7 @@ await prisma.promptVariant.createMany({
|
|||||||
experimentId: redditExperiment.id,
|
experimentId: redditExperiment.id,
|
||||||
label: "4 Base",
|
label: "4 Base",
|
||||||
sortIndex: 1,
|
sortIndex: 1,
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
constructFn: `prompt = {
|
constructFn: `prompt = {
|
||||||
model: "gpt-4-0613",
|
model: "gpt-4-0613",
|
||||||
messages: [
|
messages: [
|
||||||
@@ -122,6 +126,7 @@ await prisma.promptVariant.createMany({
|
|||||||
experimentId: redditExperiment.id,
|
experimentId: redditExperiment.id,
|
||||||
label: "3.5 CoT + Functions",
|
label: "3.5 CoT + Functions",
|
||||||
sortIndex: 2,
|
sortIndex: 2,
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
constructFn: `prompt = {
|
constructFn: `prompt = {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
@@ -178,9 +183,9 @@ await prisma.templateVariable.createMany({
|
|||||||
await prisma.evaluation.create({
|
await prisma.evaluation.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: redditExperiment.id,
|
experimentId: redditExperiment.id,
|
||||||
name: "Relevance Accuracy",
|
label: "Relevance Accuracy",
|
||||||
matchType: "CONTAINS",
|
evalType: "CONTAINS",
|
||||||
matchString: '"{{relevance}}"',
|
value: '"{{relevance}}"',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -1119,12 +1124,3 @@ await prisma.testScenario.createMany({
|
|||||||
variableValues: vars,
|
variableValues: vars,
|
||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
|
|
||||||
// await prisma.evaluation.create({
|
|
||||||
// data: {
|
|
||||||
// experimentId: redditExperiment.id,
|
|
||||||
// name: "Scores Match",
|
|
||||||
// matchType: "CONTAINS",
|
|
||||||
// matchString: "{{score}}",
|
|
||||||
// },
|
|
||||||
// });
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import fs from "fs";
|
|||||||
import path from "path";
|
import path from "path";
|
||||||
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
||||||
import YAML from "yaml";
|
import YAML from "yaml";
|
||||||
import _ from "lodash";
|
import { pick } from "lodash-es";
|
||||||
import assert from "assert";
|
import assert from "assert";
|
||||||
|
|
||||||
const OPENAPI_URL =
|
const OPENAPI_URL =
|
||||||
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
|
|||||||
|
|
||||||
delete schema["paths"];
|
delete schema["paths"];
|
||||||
assert(schema.components?.schemas);
|
assert(schema.components?.schemas);
|
||||||
schema.components.schemas = _.pick(schema.components?.schemas, [
|
schema.components.schemas = pick(schema.components?.schemas, [
|
||||||
"CreateChatCompletionRequest",
|
"CreateChatCompletionRequest",
|
||||||
"ChatCompletionRequestMessage",
|
"ChatCompletionRequestMessage",
|
||||||
"ChatCompletionFunctions",
|
"ChatCompletionFunctions",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import React from "react";
|
|||||||
|
|
||||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||||
HTMLTextAreaElement,
|
HTMLTextAreaElement,
|
||||||
TextareaProps
|
TextareaProps & { minRows?: number }
|
||||||
> = (props, ref) => {
|
> = (props, ref) => {
|
||||||
return (
|
return (
|
||||||
<Textarea
|
<Textarea
|
||||||
|
|||||||
@@ -11,14 +11,16 @@ import {
|
|||||||
FormLabel,
|
FormLabel,
|
||||||
Select,
|
Select,
|
||||||
FormHelperText,
|
FormHelperText,
|
||||||
|
Code,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
|
import { type Evaluation, EvalType } from "@prisma/client";
|
||||||
import { useCallback, useState } from "react";
|
import { useCallback, useState } from "react";
|
||||||
import { BsPencil, BsX } from "react-icons/bs";
|
import { BsPencil, BsX } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
|
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
|
||||||
|
|
||||||
export function EvaluationEditor(props: {
|
export function EvaluationEditor(props: {
|
||||||
evaluation: Evaluation | null;
|
evaluation: Evaluation | null;
|
||||||
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
|
|||||||
onCancel: () => void;
|
onCancel: () => void;
|
||||||
}) {
|
}) {
|
||||||
const [values, setValues] = useState<EvalValues>({
|
const [values, setValues] = useState<EvalValues>({
|
||||||
name: props.evaluation?.name ?? props.defaultName ?? "",
|
label: props.evaluation?.label ?? props.defaultName ?? "",
|
||||||
matchString: props.evaluation?.matchString ?? "",
|
value: props.evaluation?.value ?? "",
|
||||||
matchType: props.evaluation?.matchType ?? "CONTAINS",
|
evalType: props.evaluation?.evalType ?? "CONTAINS",
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
||||||
<HStack w="100%">
|
<HStack w="100%">
|
||||||
<FormControl flex={1}>
|
<FormControl flex={1}>
|
||||||
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
|
<FormLabel fontSize="sm">Eval Name</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.name}
|
value={values.label}
|
||||||
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl flex={1}>
|
<FormControl flex={1}>
|
||||||
<FormLabel fontSize="sm">Match Type</FormLabel>
|
<FormLabel fontSize="sm">Eval Type</FormLabel>
|
||||||
<Select
|
<Select
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.matchType}
|
value={values.evalType}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
setValues((values) => ({
|
setValues((values) => ({
|
||||||
...values,
|
...values,
|
||||||
matchType: e.target.value as EvaluationMatchType,
|
evalType: e.target.value as EvalType,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
{Object.values(EvaluationMatchType).map((type) => (
|
{Object.values(EvalType).map((type) => (
|
||||||
<option key={type} value={type}>
|
<option key={type} value={type}>
|
||||||
{type}
|
{type}
|
||||||
</option>
|
</option>
|
||||||
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
|
|||||||
</Select>
|
</Select>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</HStack>
|
</HStack>
|
||||||
<FormControl>
|
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
|
||||||
<FormLabel fontSize="sm">Match String</FormLabel>
|
<FormControl>
|
||||||
<Input
|
<FormLabel fontSize="sm">Match String</FormLabel>
|
||||||
size="sm"
|
<Input
|
||||||
value={values.matchString}
|
size="sm"
|
||||||
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
|
value={values.value}
|
||||||
/>
|
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||||
<FormHelperText>
|
/>
|
||||||
This string will be interpreted as a regex and checked against each model output.
|
<FormHelperText>
|
||||||
</FormHelperText>
|
This string will be interpreted as a regex and checked against each model output. You
|
||||||
</FormControl>
|
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
|
||||||
|
</FormHelperText>
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
{values.evalType === "GPT4_EVAL" && (
|
||||||
|
<FormControl pt={2}>
|
||||||
|
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
size="sm"
|
||||||
|
value={values.value}
|
||||||
|
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||||
|
minRows={3}
|
||||||
|
/>
|
||||||
|
<FormHelperText>
|
||||||
|
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
|
||||||
|
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
|
||||||
|
access to the specific prompt variant, so be sure to be clear about the task you want it
|
||||||
|
to perform.
|
||||||
|
</FormHelperText>
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
<HStack alignSelf="flex-end">
|
<HStack alignSelf="flex-end">
|
||||||
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
||||||
Cancel
|
Cancel
|
||||||
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
|
|||||||
}
|
}
|
||||||
await utils.evaluations.list.invalidate();
|
await utils.evaluations.list.invalidate();
|
||||||
await utils.promptVariants.stats.invalidate();
|
await utils.promptVariants.stats.invalidate();
|
||||||
|
await utils.scenarioVariantCells.get.invalidate();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onCancel = useCallback(() => {
|
const onCancel = useCallback(() => {
|
||||||
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
|
|||||||
align="center"
|
align="center"
|
||||||
key={evaluation.id}
|
key={evaluation.id}
|
||||||
>
|
>
|
||||||
<Text fontWeight="bold">{evaluation.name}</Text>
|
<Text fontWeight="bold">{evaluation.label}</Text>
|
||||||
<Text flex={1}>
|
<Text flex={1}>
|
||||||
{evaluation.matchType}: "{evaluation.matchString}"
|
{evaluation.evalType}: "{evaluation.value}"
|
||||||
</Text>
|
</Text>
|
||||||
<Button
|
<Button
|
||||||
variant="unstyled"
|
variant="unstyled"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react";
|
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { BsCheck, BsX } from "react-icons/bs";
|
import { BsCheck, BsX } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
|
|||||||
<Heading size="sm">Scenario Variables</Heading>
|
<Heading size="sm">Scenario Variables</Heading>
|
||||||
<Stack spacing={2}>
|
<Stack spacing={2}>
|
||||||
<Text fontSize="sm">
|
<Text fontSize="sm">
|
||||||
Scenario variables can be used in your prompt variants as well as evaluations. Reference
|
Scenario variables can be used in your prompt variants as well as evaluations.
|
||||||
them using <Code>{"{{curly_braces}}"}</Code>.
|
|
||||||
</Text>
|
</Text>
|
||||||
<HStack spacing={0}>
|
<HStack spacing={0}>
|
||||||
<Input
|
<Input
|
||||||
|
|||||||
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import { Button, HStack, Icon } from "@chakra-ui/react";
|
||||||
|
import { BsArrowClockwise } from "react-icons/bs";
|
||||||
|
|
||||||
|
export const CellOptions = ({
|
||||||
|
refetchingOutput,
|
||||||
|
refetchOutput,
|
||||||
|
}: {
|
||||||
|
refetchingOutput: boolean;
|
||||||
|
refetchOutput: () => void;
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<HStack justifyContent="flex-end" w="full">
|
||||||
|
{!refetchingOutput && (
|
||||||
|
<Button
|
||||||
|
size="xs"
|
||||||
|
w={4}
|
||||||
|
h={4}
|
||||||
|
py={4}
|
||||||
|
px={4}
|
||||||
|
minW={0}
|
||||||
|
borderRadius={8}
|
||||||
|
color="gray.500"
|
||||||
|
variant="ghost"
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={refetchOutput}
|
||||||
|
aria-label="refetch output"
|
||||||
|
>
|
||||||
|
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,29 +1,21 @@
|
|||||||
import { type ModelOutput } from "@prisma/client";
|
import { type ScenarioVariantCell } from "@prisma/client";
|
||||||
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
import { VStack, Text } from "@chakra-ui/react";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { BsArrowClockwise } from "react-icons/bs";
|
|
||||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
|
||||||
import pluralize from "pluralize";
|
import pluralize from "pluralize";
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 3;
|
|
||||||
|
|
||||||
export const ErrorHandler = ({
|
export const ErrorHandler = ({
|
||||||
output,
|
cell,
|
||||||
refetchOutput,
|
refetchOutput,
|
||||||
numPreviousTries,
|
|
||||||
}: {
|
}: {
|
||||||
output: ModelOutput;
|
cell: ScenarioVariantCell;
|
||||||
refetchOutput: () => void;
|
refetchOutput: () => void;
|
||||||
numPreviousTries: number;
|
|
||||||
}) => {
|
}) => {
|
||||||
const [msToWait, setMsToWait] = useState(0);
|
const [msToWait, setMsToWait] = useState(0);
|
||||||
const shouldAutoRetry =
|
|
||||||
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!shouldAutoRetry) return;
|
if (!cell.retryTime) return;
|
||||||
|
|
||||||
const initialWaitTime = calculateDelay(numPreviousTries);
|
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
||||||
const msModuloOneSecond = initialWaitTime % 1000;
|
const msModuloOneSecond = initialWaitTime % 1000;
|
||||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||||
setMsToWait(remainingTime);
|
setMsToWait(remainingTime);
|
||||||
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
|
|||||||
setMsToWait(remainingTime);
|
setMsToWait(remainingTime);
|
||||||
|
|
||||||
if (remainingTime <= 0) {
|
if (remainingTime <= 0) {
|
||||||
refetchOutput();
|
|
||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
}
|
}
|
||||||
}, 1000);
|
}, 1000);
|
||||||
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
|
|||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
clearTimeout(timeout);
|
clearTimeout(timeout);
|
||||||
};
|
};
|
||||||
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack w="full">
|
<VStack w="full">
|
||||||
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
|
||||||
<Text color="red.600" fontWeight="bold">
|
|
||||||
Error
|
|
||||||
</Text>
|
|
||||||
<Button
|
|
||||||
size="xs"
|
|
||||||
w={4}
|
|
||||||
h={4}
|
|
||||||
px={4}
|
|
||||||
py={4}
|
|
||||||
minW={0}
|
|
||||||
borderRadius={8}
|
|
||||||
variant="ghost"
|
|
||||||
cursor="pointer"
|
|
||||||
onClick={refetchOutput}
|
|
||||||
aria-label="refetch output"
|
|
||||||
>
|
|
||||||
<Icon as={BsArrowClockwise} boxSize={6} />
|
|
||||||
</Button>
|
|
||||||
</HStack>
|
|
||||||
<Text color="red.600" wordBreak="break-word">
|
<Text color="red.600" wordBreak="break-word">
|
||||||
{output.errorMessage}
|
{cell.errorMessage}
|
||||||
</Text>
|
</Text>
|
||||||
{msToWait > 0 && (
|
{msToWait > 0 && (
|
||||||
<Text color="red.600" fontSize="sm">
|
<Text color="red.600" fontSize="sm">
|
||||||
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
|
|||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const MIN_DELAY = 500; // milliseconds
|
|
||||||
const MAX_DELAY = 5000; // milliseconds
|
|
||||||
|
|
||||||
function calculateDelay(numPreviousTries: number): number {
|
|
||||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
|
||||||
const jitter = Math.random() * baseDelay;
|
|
||||||
return baseDelay + jitter;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
import { type RouterOutputs, api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { type PromptVariant, type Scenario } from "../types";
|
import { type PromptVariant, type Scenario } from "../types";
|
||||||
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
|
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
import { type ReactElement, useState, useEffect } from "react";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
|
||||||
import { isObject } from "lodash";
|
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
import { OutputStats } from "./OutputStats";
|
import { OutputStats } from "./OutputStats";
|
||||||
import { ErrorHandler } from "./ErrorHandler";
|
import { ErrorHandler } from "./ErrorHandler";
|
||||||
|
import { CellOptions } from "./CellOptions";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -37,120 +36,116 @@ export default function OutputCell({
|
|||||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||||
// disabledReason = "Save your prompt variant to see output";
|
// disabledReason = "Save your prompt variant to see output";
|
||||||
|
|
||||||
// const model = getModelName(variant.config as JSONSerializable);
|
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||||
// TODO: Temporarily hardcoding this while we get other stuff working
|
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||||
const model = "gpt-3.5-turbo";
|
{ scenarioId: scenario.id, variantId: variant.id },
|
||||||
|
{ refetchInterval },
|
||||||
const outputMutation = api.outputs.get.useMutation();
|
|
||||||
|
|
||||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
|
||||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
|
||||||
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
|
||||||
|
|
||||||
const fetchMutex = useRef(false);
|
|
||||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
|
||||||
async (forceRefetch?: boolean) => {
|
|
||||||
if (fetchMutex.current) return;
|
|
||||||
setNumPreviousTries((prev) => prev + 1);
|
|
||||||
|
|
||||||
fetchMutex.current = true;
|
|
||||||
setOutput(null);
|
|
||||||
|
|
||||||
const shouldStream =
|
|
||||||
isObject(variant) &&
|
|
||||||
"config" in variant &&
|
|
||||||
isObject(variant.config) &&
|
|
||||||
"stream" in variant.config &&
|
|
||||||
variant.config.stream === true;
|
|
||||||
|
|
||||||
const channel = shouldStream ? generateChannel() : undefined;
|
|
||||||
setChannel(channel);
|
|
||||||
|
|
||||||
const output = await outputMutation.mutateAsync({
|
|
||||||
scenarioId: scenario.id,
|
|
||||||
variantId: variant.id,
|
|
||||||
channel,
|
|
||||||
forceRefetch,
|
|
||||||
});
|
|
||||||
setOutput(output);
|
|
||||||
await utils.promptVariants.stats.invalidate();
|
|
||||||
fetchMutex.current = false;
|
|
||||||
},
|
|
||||||
[outputMutation, scenario.id, variant.id],
|
|
||||||
);
|
);
|
||||||
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
|
||||||
|
|
||||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
|
||||||
|
api.scenarioVariantCells.forceRefetch.useMutation();
|
||||||
|
const [hardRefetch] = useHandledAsyncCallback(async () => {
|
||||||
|
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||||
|
await utils.scenarioVariantCells.get.invalidate({
|
||||||
|
scenarioId: scenario.id,
|
||||||
|
variantId: variant.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.stats.invalidate({
|
||||||
|
variantId: variant.id,
|
||||||
|
});
|
||||||
|
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||||
|
|
||||||
|
const fetchingOutput = queryLoading || refetchingOutput;
|
||||||
|
|
||||||
|
const awaitingOutput =
|
||||||
|
!cell ||
|
||||||
|
cell.retrievalStatus === "PENDING" ||
|
||||||
|
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||||
|
refetchingOutput;
|
||||||
|
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||||
|
|
||||||
|
const modelOutput = cell?.modelOutput;
|
||||||
|
|
||||||
// Disconnect from socket if we're not streaming anymore
|
// Disconnect from socket if we're not streaming anymore
|
||||||
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
|
const streamedMessage = useSocket(cell?.streamingChannel);
|
||||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|
||||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||||
|
|
||||||
if (fetchingOutput && !streamedMessage)
|
if (awaitingOutput && !streamedMessage)
|
||||||
return (
|
return (
|
||||||
<Center h="100%" w="100%">
|
<Center h="100%" w="100%">
|
||||||
<Spinner />
|
<Spinner />
|
||||||
</Center>
|
</Center>
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||||
|
|
||||||
if (output && output.errorMessage) {
|
if (cell && cell.errorMessage) {
|
||||||
return (
|
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||||
<ErrorHandler
|
|
||||||
output={output}
|
|
||||||
refetchOutput={hardRefetch}
|
|
||||||
numPreviousTries={numPreviousTries}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = output?.output as unknown as ChatCompletion;
|
const response = modelOutput?.output as unknown as ChatCompletion;
|
||||||
const message = response?.choices?.[0]?.message;
|
const message = response?.choices?.[0]?.message;
|
||||||
|
|
||||||
if (output && message?.function_call) {
|
if (modelOutput && message?.function_call) {
|
||||||
const rawArgs = message.function_call.arguments ?? "null";
|
const rawArgs = message.function_call.arguments ?? "null";
|
||||||
let parsedArgs: string;
|
let parsedArgs: string;
|
||||||
try {
|
try {
|
||||||
parsedArgs = JSON.parse(rawArgs);
|
parsedArgs = JSON.parse(rawArgs);
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
<VStack
|
||||||
<SyntaxHighlighter
|
w="100%"
|
||||||
customStyle={{ overflowX: "unset" }}
|
h="100%"
|
||||||
language="json"
|
fontSize="xs"
|
||||||
style={docco}
|
flexWrap="wrap"
|
||||||
lineProps={{
|
overflowX="auto"
|
||||||
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
justifyContent="space-between"
|
||||||
}}
|
>
|
||||||
wrapLines
|
<VStack w="full" flex={1} spacing={0}>
|
||||||
>
|
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||||
{stringify(
|
<SyntaxHighlighter
|
||||||
{
|
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||||
function: message.function_call.name,
|
language="json"
|
||||||
args: parsedArgs,
|
style={docco}
|
||||||
},
|
lineProps={{
|
||||||
{ maxLength: 40 },
|
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
||||||
)}
|
}}
|
||||||
</SyntaxHighlighter>
|
wrapLines
|
||||||
<OutputStats model={model} modelOutput={output} scenario={scenario} />
|
>
|
||||||
</Box>
|
{stringify(
|
||||||
|
{
|
||||||
|
function: message.function_call.name,
|
||||||
|
args: parsedArgs,
|
||||||
|
},
|
||||||
|
{ maxLength: 40 },
|
||||||
|
)}
|
||||||
|
</SyntaxHighlighter>
|
||||||
|
</VStack>
|
||||||
|
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||||
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
|
const contentToDisplay =
|
||||||
|
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
{contentToDisplay}
|
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||||
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
|
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||||
</Flex>
|
<Text>{contentToDisplay}</Text>
|
||||||
|
</VStack>
|
||||||
|
{modelOutput && (
|
||||||
|
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,30 +1,25 @@
|
|||||||
import { type ModelOutput } from "@prisma/client";
|
|
||||||
import { type SupportedModel } from "~/server/types";
|
import { type SupportedModel } from "~/server/types";
|
||||||
import { type Scenario } from "../types";
|
import { type Scenario } from "../types";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { type RouterOutputs } from "~/utils/api";
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
||||||
import { HStack, Icon, Text } from "@chakra-ui/react";
|
|
||||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||||
|
|
||||||
const SHOW_COST = false;
|
const SHOW_COST = true;
|
||||||
const SHOW_TIME = false;
|
const SHOW_TIME = true;
|
||||||
|
|
||||||
export const OutputStats = ({
|
export const OutputStats = ({
|
||||||
model,
|
model,
|
||||||
modelOutput,
|
modelOutput,
|
||||||
scenario,
|
|
||||||
}: {
|
}: {
|
||||||
model: SupportedModel | null;
|
model: SupportedModel | string | null;
|
||||||
modelOutput: ModelOutput;
|
modelOutput: NonNullable<
|
||||||
|
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||||
|
>;
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
}) => {
|
}) => {
|
||||||
const timeToComplete = modelOutput.timeToComplete;
|
const timeToComplete = modelOutput.timeToComplete;
|
||||||
const experiment = useExperiment();
|
|
||||||
const evals =
|
|
||||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
|
||||||
|
|
||||||
const promptTokens = modelOutput.promptTokens;
|
const promptTokens = modelOutput.promptTokens;
|
||||||
const completionTokens = modelOutput.completionTokens;
|
const completionTokens = modelOutput.completionTokens;
|
||||||
@@ -35,22 +30,26 @@ export const OutputStats = ({
|
|||||||
|
|
||||||
const cost = promptCost + completionCost;
|
const cost = promptCost + completionCost;
|
||||||
|
|
||||||
if (!evals.length) return null;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
{evals.map((evaluation) => {
|
{modelOutput.outputEvaluation.map((evaluation) => {
|
||||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
const passed = evaluation.result > 0.5;
|
||||||
return (
|
return (
|
||||||
<HStack spacing={0} key={evaluation.id}>
|
<Tooltip
|
||||||
<Text>{evaluation.name}</Text>
|
isDisabled={!evaluation.details}
|
||||||
<Icon
|
label={evaluation.details}
|
||||||
as={passed ? BsCheck : BsX}
|
key={evaluation.id}
|
||||||
color={passed ? "green.500" : "red.500"}
|
>
|
||||||
boxSize={6}
|
<HStack spacing={0}>
|
||||||
/>
|
<Text>{evaluation.evaluation.label}</Text>
|
||||||
</HStack>
|
<Icon
|
||||||
|
as={passed ? BsCheck : BsX}
|
||||||
|
color={passed ? "green.500" : "red.500"}
|
||||||
|
boxSize={6}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</Tooltip>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { type DragEvent } from "react";
|
import { type DragEvent } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { isEqual } from "lodash";
|
import { isEqual } from "lodash-es";
|
||||||
import { type Scenario } from "./types";
|
import { type Scenario } from "./types";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
@@ -13,10 +13,11 @@ import AutoResizeTextArea from "../AutoResizeTextArea";
|
|||||||
|
|
||||||
export default function ScenarioEditor({
|
export default function ScenarioEditor({
|
||||||
scenario,
|
scenario,
|
||||||
hovered,
|
...props
|
||||||
}: {
|
}: {
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
hovered: boolean;
|
hovered: boolean;
|
||||||
|
canHide: boolean;
|
||||||
}) {
|
}) {
|
||||||
const savedValues = scenario.variableValues as Record<string, string>;
|
const savedValues = scenario.variableValues as Record<string, string>;
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
@@ -92,30 +93,34 @@ export default function ScenarioEditor({
|
|||||||
onDrop={onReorder}
|
onDrop={onReorder}
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||||
>
|
>
|
||||||
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}>
|
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
|
||||||
<Tooltip label="Hide scenario" hasArrow>
|
{props.canHide && (
|
||||||
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
<>
|
||||||
<Button
|
<Tooltip label="Hide scenario" hasArrow>
|
||||||
variant="unstyled"
|
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
||||||
color="gray.400"
|
<Button
|
||||||
height="unset"
|
variant="unstyled"
|
||||||
width="unset"
|
color="gray.400"
|
||||||
minW="unset"
|
height="unset"
|
||||||
onClick={onHide}
|
width="unset"
|
||||||
_hover={{
|
minW="unset"
|
||||||
color: "gray.800",
|
onClick={onHide}
|
||||||
cursor: "pointer",
|
_hover={{
|
||||||
}}
|
color: "gray.800",
|
||||||
>
|
cursor: "pointer",
|
||||||
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
}}
|
||||||
</Button>
|
>
|
||||||
</Tooltip>
|
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
||||||
<Icon
|
</Button>
|
||||||
as={RiDraggable}
|
</Tooltip>
|
||||||
boxSize={6}
|
<Icon
|
||||||
color="gray.400"
|
as={RiDraggable}
|
||||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
boxSize={6}
|
||||||
/>
|
color="gray.400"
|
||||||
|
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Stack>
|
</Stack>
|
||||||
{variableLabels.length === 0 ? (
|
{variableLabels.length === 0 ? (
|
||||||
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
||||||
|
|||||||
@@ -5,7 +5,11 @@ import OutputCell from "./OutputCell/OutputCell";
|
|||||||
import ScenarioEditor from "./ScenarioEditor";
|
import ScenarioEditor from "./ScenarioEditor";
|
||||||
import type { PromptVariant, Scenario } from "./types";
|
import type { PromptVariant, Scenario } from "./types";
|
||||||
|
|
||||||
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
|
const ScenarioRow = (props: {
|
||||||
|
scenario: Scenario;
|
||||||
|
variants: PromptVariant[];
|
||||||
|
canHide: boolean;
|
||||||
|
}) => {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const highlightStyle = { backgroundColor: "gray.50" };
|
const highlightStyle = { backgroundColor: "gray.50" };
|
||||||
@@ -18,7 +22,7 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
|
|||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
borderLeftWidth={1}
|
borderLeftWidth={1}
|
||||||
>
|
>
|
||||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} />
|
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
{props.variants.map((variant) => (
|
{props.variants.map((variant) => (
|
||||||
<GridItem
|
<GridItem
|
||||||
|
|||||||
49
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
49
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
||||||
|
import { cellPadding } from "../constants";
|
||||||
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
import { stickyHeaderStyle } from "./styles";
|
||||||
|
import { BsPencil } from "react-icons/bs";
|
||||||
|
import { useAppStore } from "~/state/store";
|
||||||
|
|
||||||
|
export const ScenariosHeader = ({
|
||||||
|
headerRows,
|
||||||
|
numScenarios,
|
||||||
|
}: {
|
||||||
|
headerRows: number;
|
||||||
|
numScenarios: number;
|
||||||
|
}) => {
|
||||||
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
|
|
||||||
|
const [ref, dimensions] = useElementDimensions();
|
||||||
|
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<GridItem
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
ref={ref as any}
|
||||||
|
display="flex"
|
||||||
|
alignItems="flex-end"
|
||||||
|
rowSpan={headerRows}
|
||||||
|
px={cellPadding.x}
|
||||||
|
py={cellPadding.y}
|
||||||
|
// Only display the part of the grid item that has content
|
||||||
|
sx={{ ...stickyHeaderStyle, top: topValue }}
|
||||||
|
>
|
||||||
|
<HStack w="100%">
|
||||||
|
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||||
|
Scenarios ({numScenarios})
|
||||||
|
</Heading>
|
||||||
|
<Button
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
color="gray.500"
|
||||||
|
aria-label="Edit"
|
||||||
|
leftIcon={<BsPencil />}
|
||||||
|
onClick={openDrawer}
|
||||||
|
>
|
||||||
|
Edit Vars
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
|
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useRef, useEffect, useState, useCallback } from "react";
|
||||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
// import openAITypes from "~/codegen/openai.types.ts.txt";
|
|
||||||
|
|
||||||
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
@@ -18,15 +17,17 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
|
|
||||||
const checkForChanges = useCallback(() => {
|
const checkForChanges = useCallback(() => {
|
||||||
if (!editorRef.current) return;
|
if (!editorRef.current) return;
|
||||||
const currentConfig = editorRef.current.getValue();
|
const currentFn = editorRef.current.getValue();
|
||||||
setIsChanged(currentConfig !== lastSavedFn);
|
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||||
}, [lastSavedFn]);
|
}, [lastSavedFn]);
|
||||||
|
|
||||||
|
useEffect(checkForChanges, [checkForChanges, lastSavedFn]);
|
||||||
|
|
||||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
|
|
||||||
const [onSave] = useHandledAsyncCallback(async () => {
|
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||||
if (!editorRef.current) return;
|
if (!editorRef.current) return;
|
||||||
|
|
||||||
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
||||||
@@ -64,14 +65,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
await replaceVariant.mutateAsync({
|
const resp = await replaceVariant.mutateAsync({
|
||||||
id: props.variant.id,
|
id: props.variant.id,
|
||||||
constructFn: currentFn,
|
constructFn: currentFn,
|
||||||
});
|
});
|
||||||
|
if (resp.status === "error") {
|
||||||
|
return toast({
|
||||||
|
title: "Error saving variant",
|
||||||
|
description: resp.message,
|
||||||
|
status: "error",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsChanged(false);
|
||||||
|
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
|
|
||||||
checkForChanges();
|
|
||||||
}, [checkForChanges]);
|
}, [checkForChanges]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -122,21 +130,9 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||||
}, [monaco, editorId]);
|
}, [monaco, editorId]);
|
||||||
|
|
||||||
// useEffect(() => {
|
|
||||||
// const savedConfigChanged = lastSavedFn !== savedConfig;
|
|
||||||
|
|
||||||
// lastSavedFn = savedConfig;
|
|
||||||
|
|
||||||
// if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
|
|
||||||
// editorRef.current?.setValue(savedConfig);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// checkForChanges();
|
|
||||||
// }, [savedConfig, checkForChanges]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="100%" pos="relative">
|
<Box w="100%" pos="relative">
|
||||||
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
||||||
{isChanged && (
|
{isChanged && (
|
||||||
<HStack pos="absolute" bottom={2} right={2}>
|
<HStack pos="absolute" bottom={2} right={2}>
|
||||||
<Button
|
<Button
|
||||||
@@ -150,8 +146,8 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
Reset
|
Reset
|
||||||
</Button>
|
</Button>
|
||||||
<Tooltip label={`${modifierKey} + Enter`}>
|
<Tooltip label={`${modifierKey} + Enter`}>
|
||||||
<Button size="sm" onClick={onSave} colorScheme="blue">
|
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||||
Save
|
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import { RiDraggable } from "react-icons/ri";
|
|||||||
import { cellPadding, headerMinHeight } from "../constants";
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
export default function VariantHeader(props: { variant: PromptVariant }) {
|
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||||
@@ -95,11 +95,13 @@ export default function VariantHeader(props: { variant: PromptVariant }) {
|
|||||||
onMouseEnter={() => setIsInputHovered(true)}
|
onMouseEnter={() => setIsInputHovered(true)}
|
||||||
onMouseLeave={() => setIsInputHovered(false)}
|
onMouseLeave={() => setIsInputHovered(false)}
|
||||||
/>
|
/>
|
||||||
<Tooltip label="Hide Variant" hasArrow>
|
{props.canHide && (
|
||||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
<Tooltip label="Remove Variant" hasArrow>
|
||||||
<Icon as={BsX} boxSize={6} />
|
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
||||||
</Button>
|
<Icon as={BsX} boxSize={6} />
|
||||||
</Tooltip>
|
</Button>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
|
import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import chroma from "chroma-js";
|
import chroma from "chroma-js";
|
||||||
import { BsCurrencyDollar } from "react-icons/bs";
|
import { BsCurrencyDollar } from "react-icons/bs";
|
||||||
import { CostTooltip } from "../tooltip/CostTooltip";
|
import { CostTooltip } from "../tooltip/CostTooltip";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
export default function VariantStats(props: { variant: PromptVariant }) {
|
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||||
|
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||||
const { data } = api.promptVariants.stats.useQuery(
|
const { data } = api.promptVariants.stats.useQuery(
|
||||||
{
|
{
|
||||||
variantId: props.variant.id,
|
variantId: props.variant.id,
|
||||||
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
completionTokens: 0,
|
completionTokens: 0,
|
||||||
scenarioCount: 0,
|
scenarioCount: 0,
|
||||||
outputCount: 0,
|
outputCount: 0,
|
||||||
|
awaitingRetrievals: false,
|
||||||
},
|
},
|
||||||
|
refetchInterval,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||||
|
useEffect(
|
||||||
|
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
|
||||||
|
[data.awaitingRetrievals],
|
||||||
|
);
|
||||||
|
|
||||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||||
"green.500",
|
"green.500",
|
||||||
"gray.500",
|
"gray.500",
|
||||||
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
|
|
||||||
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
||||||
|
|
||||||
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
|
<HStack
|
||||||
|
justifyContent="space-between"
|
||||||
|
alignItems="center"
|
||||||
|
mx="2"
|
||||||
|
fontSize="xs"
|
||||||
|
py={cellPadding.y}
|
||||||
|
>
|
||||||
{showNumFinished && (
|
{showNumFinished && (
|
||||||
<Text>
|
<Text>
|
||||||
{data.outputCount} / {data.scenarioCount}
|
{data.outputCount} / {data.scenarioCount}
|
||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
<HStack px={cellPadding.x} py={cellPadding.y}>
|
<HStack px={cellPadding.x}>
|
||||||
{data.evalResults.map((result) => {
|
{data.evalResults.map((result) => {
|
||||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
const passedFrac = result.passCount / result.totalCount;
|
||||||
return (
|
return (
|
||||||
<HStack key={result.id}>
|
<HStack key={result.id}>
|
||||||
<Text>{result.evaluation.name}</Text>
|
<Text>{result.label}</Text>
|
||||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||||
{(passedFrac * 100).toFixed(1)}%
|
{(passedFrac * 100).toFixed(1)}%
|
||||||
</Text>
|
</Text>
|
||||||
@@ -55,17 +69,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{data.overallCost && (
|
{data.overallCost && !data.awaitingRetrievals ? (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={data.promptTokens}
|
promptTokens={data.promptTokens}
|
||||||
completionTokens={data.completionTokens}
|
completionTokens={data.completionTokens}
|
||||||
cost={data.overallCost}
|
cost={data.overallCost}
|
||||||
>
|
>
|
||||||
<HStack spacing={0} align="center" color="gray.500" my="2">
|
<HStack spacing={0} align="center" color="gray.500">
|
||||||
<Icon as={BsCurrencyDollar} />
|
<Icon as={BsCurrencyDollar} />
|
||||||
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CostTooltip>
|
</CostTooltip>
|
||||||
|
) : (
|
||||||
|
<Skeleton height={4} width={12} mr={1} />
|
||||||
)}
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,28 +1,19 @@
|
|||||||
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
|
import { Grid, GridItem } from "@chakra-ui/react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import NewScenarioButton from "./NewScenarioButton";
|
import NewScenarioButton from "./NewScenarioButton";
|
||||||
import NewVariantButton from "./NewVariantButton";
|
import NewVariantButton from "./NewVariantButton";
|
||||||
import ScenarioRow from "./ScenarioRow";
|
import ScenarioRow from "./ScenarioRow";
|
||||||
import VariantConfigEditor from "./VariantEditor";
|
import VariantEditor from "./VariantEditor";
|
||||||
import VariantHeader from "./VariantHeader";
|
import VariantHeader from "./VariantHeader";
|
||||||
import { cellPadding } from "../constants";
|
|
||||||
import { BsPencil } from "react-icons/bs";
|
|
||||||
import VariantStats from "./VariantStats";
|
import VariantStats from "./VariantStats";
|
||||||
import { useAppStore } from "~/state/store";
|
import { ScenariosHeader } from "./ScenariosHeader";
|
||||||
|
import { stickyHeaderStyle } from "./styles";
|
||||||
const stickyHeaderStyle: SystemStyleObject = {
|
|
||||||
position: "sticky",
|
|
||||||
top: "-1px",
|
|
||||||
backgroundColor: "#fff",
|
|
||||||
zIndex: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||||
const variants = api.promptVariants.list.useQuery(
|
const variants = api.promptVariants.list.useQuery(
|
||||||
{ experimentId: experimentId as string },
|
{ experimentId: experimentId as string },
|
||||||
{ enabled: !!experimentId },
|
{ enabled: !!experimentId },
|
||||||
);
|
);
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
|
||||||
|
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
const scenarios = api.scenarios.list.useQuery(
|
||||||
{ experimentId: experimentId as string },
|
{ experimentId: experimentId as string },
|
||||||
@@ -49,36 +40,11 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
}}
|
}}
|
||||||
fontSize="sm"
|
fontSize="sm"
|
||||||
>
|
>
|
||||||
<GridItem
|
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
||||||
display="flex"
|
|
||||||
alignItems="flex-end"
|
|
||||||
rowSpan={headerRows}
|
|
||||||
px={cellPadding.x}
|
|
||||||
py={cellPadding.y}
|
|
||||||
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
|
|
||||||
// so if the header height changes, this will need to be updated.
|
|
||||||
sx={{ ...stickyHeaderStyle, top: "-337px" }}
|
|
||||||
>
|
|
||||||
<HStack w="100%">
|
|
||||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
|
||||||
Scenarios ({scenarios.data.length})
|
|
||||||
</Heading>
|
|
||||||
<Button
|
|
||||||
size="xs"
|
|
||||||
variant="ghost"
|
|
||||||
color="gray.500"
|
|
||||||
aria-label="Edit"
|
|
||||||
leftIcon={<BsPencil />}
|
|
||||||
onClick={openDrawer}
|
|
||||||
>
|
|
||||||
Edit Vars
|
|
||||||
</Button>
|
|
||||||
</HStack>
|
|
||||||
</GridItem>
|
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||||
<VariantHeader variant={variant} />
|
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
<GridItem
|
<GridItem
|
||||||
@@ -94,7 +60,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId}>
|
<GridItem key={variant.uiId}>
|
||||||
<VariantConfigEditor variant={variant} />
|
<VariantEditor variant={variant} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
@@ -103,7 +69,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
{scenarios.data.map((scenario) => (
|
{scenarios.data.map((scenario) => (
|
||||||
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
|
<ScenarioRow
|
||||||
|
key={scenario.uiId}
|
||||||
|
scenario={scenario}
|
||||||
|
variants={variants.data}
|
||||||
|
canHide={scenarios.data.length > 1}
|
||||||
|
/>
|
||||||
))}
|
))}
|
||||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
||||||
<NewScenarioButton />
|
<NewScenarioButton />
|
||||||
|
|||||||
8
src/components/OutputsTable/styles.ts
Normal file
8
src/components/OutputsTable/styles.ts
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import { type SystemStyleObject } from "@chakra-ui/react";
|
||||||
|
|
||||||
|
export const stickyHeaderStyle: SystemStyleObject = {
|
||||||
|
position: "sticky",
|
||||||
|
top: "-1px",
|
||||||
|
backgroundColor: "#fff",
|
||||||
|
zIndex: 1,
|
||||||
|
};
|
||||||
@@ -20,7 +20,6 @@ export const CostTooltip = ({
|
|||||||
color="gray.800"
|
color="gray.800"
|
||||||
bgColor="gray.50"
|
bgColor="gray.50"
|
||||||
borderWidth={1}
|
borderWidth={1}
|
||||||
py={2}
|
|
||||||
hasArrow
|
hasArrow
|
||||||
shouldWrapChildren
|
shouldWrapChildren
|
||||||
label={
|
label={
|
||||||
|
|||||||
@@ -6,18 +6,27 @@ import { ChakraProvider } from "@chakra-ui/react";
|
|||||||
import theme from "~/utils/theme";
|
import theme from "~/utils/theme";
|
||||||
import Favicon from "~/components/Favicon";
|
import Favicon from "~/components/Favicon";
|
||||||
import "~/utils/analytics";
|
import "~/utils/analytics";
|
||||||
|
import Head from "next/head";
|
||||||
|
|
||||||
const MyApp: AppType<{ session: Session | null }> = ({
|
const MyApp: AppType<{ session: Session | null }> = ({
|
||||||
Component,
|
Component,
|
||||||
pageProps: { session, ...pageProps },
|
pageProps: { session, ...pageProps },
|
||||||
}) => {
|
}) => {
|
||||||
return (
|
return (
|
||||||
<SessionProvider session={session}>
|
<>
|
||||||
<Favicon />
|
<Head>
|
||||||
<ChakraProvider theme={theme}>
|
<meta
|
||||||
<Component {...pageProps} />
|
name="viewport"
|
||||||
</ChakraProvider>
|
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
|
||||||
</SessionProvider>
|
/>
|
||||||
|
</Head>
|
||||||
|
<SessionProvider session={session}>
|
||||||
|
<Favicon />
|
||||||
|
<ChakraProvider theme={theme}>
|
||||||
|
<Component {...pageProps} />
|
||||||
|
</ChakraProvider>
|
||||||
|
</SessionProvider>
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { type GetServerSideProps } from "next";
|
import { type GetServerSideProps } from "next";
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/require-await
|
// eslint-disable-next-line @typescript-eslint/require-await
|
||||||
export const getServerSideProps: GetServerSideProps = async (context) => {
|
export const getServerSideProps: GetServerSideProps = async () => {
|
||||||
return {
|
return {
|
||||||
redirect: {
|
redirect: {
|
||||||
destination: "/experiments",
|
destination: "/experiments",
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { openai } from "../utils/openai";
|
import { openai } from "../utils/openai";
|
||||||
import { pick } from "lodash";
|
import { pick } from "lodash-es";
|
||||||
|
|
||||||
function promptHasVariable(prompt: string, variableName: string) {
|
|
||||||
return prompt.includes(`{{${variableName}}}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
type AxiosError = {
|
type AxiosError = {
|
||||||
response?: {
|
response?: {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
|
|||||||
import { createTRPCRouter } from "~/server/api/trpc";
|
import { createTRPCRouter } from "~/server/api/trpc";
|
||||||
import { experimentsRouter } from "./routers/experiments.router";
|
import { experimentsRouter } from "./routers/experiments.router";
|
||||||
import { scenariosRouter } from "./routers/scenarios.router";
|
import { scenariosRouter } from "./routers/scenarios.router";
|
||||||
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
|
|||||||
promptVariants: promptVariantsRouter,
|
promptVariants: promptVariantsRouter,
|
||||||
experiments: experimentsRouter,
|
experiments: experimentsRouter,
|
||||||
scenarios: scenariosRouter,
|
scenarios: scenariosRouter,
|
||||||
outputs: modelOutputsRouter,
|
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||||
templateVars: templateVarsRouter,
|
templateVars: templateVarsRouter,
|
||||||
evaluations: evaluationsRouter,
|
evaluations: evaluationsRouter,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import { EvaluationMatchType } from "@prisma/client";
|
import { EvalType } from "@prisma/client";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { reevaluateEvaluation } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
|
|
||||||
export const evaluationsRouter = createTRPCRouter({
|
export const evaluationsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||||
@@ -18,21 +18,24 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
name: z.string(),
|
label: z.string(),
|
||||||
matchString: z.string(),
|
value: z.string(),
|
||||||
matchType: z.nativeEnum(EvaluationMatchType),
|
evalType: z.nativeEnum(EvalType),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
const evaluation = await prisma.evaluation.create({
|
await prisma.evaluation.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
name: input.name,
|
label: input.label,
|
||||||
matchString: input.matchString,
|
value: input.value,
|
||||||
matchType: input.matchType,
|
evalType: input.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(evaluation);
|
|
||||||
|
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
||||||
|
// to kick off a background job or something instead
|
||||||
|
await runAllEvals(input.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: publicProcedure
|
||||||
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
updates: z.object({
|
updates: z.object({
|
||||||
name: z.string().optional(),
|
label: z.string().optional(),
|
||||||
matchString: z.string().optional(),
|
value: z.string().optional(),
|
||||||
matchType: z.nativeEnum(EvaluationMatchType).optional(),
|
evalType: z.nativeEnum(EvalType).optional(),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
await prisma.evaluation.update({
|
const evaluation = await prisma.evaluation.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: {
|
data: {
|
||||||
name: input.updates.name,
|
label: input.updates.label,
|
||||||
matchString: input.updates.matchString,
|
value: input.updates.value,
|
||||||
matchType: input.updates.matchType,
|
evalType: input.updates.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(
|
|
||||||
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
|
await prisma.outputEvaluation.deleteMany({
|
||||||
);
|
where: {
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// Re-run all evals. Other eval results will already be cached, so this
|
||||||
|
// should only re-run the updated one.
|
||||||
|
await runAllEvals(evaluation.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { z } from "zod";
|
|||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
export const experimentsRouter = createTRPCRouter({
|
export const experimentsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.query(async () => {
|
list: publicProcedure.query(async () => {
|
||||||
@@ -64,27 +65,55 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
const [variant, _, scenario] = await prisma.$transaction([
|
||||||
prisma.promptVariant.create({
|
prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
constructFn: dedent`prompt = {
|
// The interpolated $ is necessary until dedent incorporates
|
||||||
|
// https://github.com/dmnd/dedent/pull/46
|
||||||
|
constructFn: dedent`
|
||||||
|
/**
|
||||||
|
* Use Javascript to define an OpenAI chat completion
|
||||||
|
* (https://platform.openai.com/docs/api-reference/chat/create) and
|
||||||
|
* assign it to the \`prompt\` variable.
|
||||||
|
*
|
||||||
|
* You have access to the current scenario in the \`scenario\`
|
||||||
|
* variable.
|
||||||
|
*/
|
||||||
|
|
||||||
|
prompt = {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
messages: [
|
||||||
}`,
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};`,
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.templateVariable.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
label: "text",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
prisma.testScenario.create({
|
prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
variableValues: {},
|
variableValues: {
|
||||||
|
text: "This is a test scenario.",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
await generateNewCell(variant.id, scenario.id);
|
||||||
|
|
||||||
return exp;
|
return exp;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
|||||||
@@ -1,97 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
|
||||||
import { prisma } from "~/server/db";
|
|
||||||
import crypto from "crypto";
|
|
||||||
import type { Prisma } from "@prisma/client";
|
|
||||||
import { reevaluateVariant } from "~/server/utils/evaluations";
|
|
||||||
import { getCompletion } from "~/server/utils/getCompletion";
|
|
||||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
|
||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
|
||||||
get: publicProcedure
|
|
||||||
.input(
|
|
||||||
z.object({
|
|
||||||
scenarioId: z.string(),
|
|
||||||
variantId: z.string(),
|
|
||||||
channel: z.string().optional(),
|
|
||||||
forceRefetch: z.boolean().optional(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.mutation(async ({ input }) => {
|
|
||||||
const existing = await prisma.modelOutput.findUnique({
|
|
||||||
where: {
|
|
||||||
promptVariantId_testScenarioId: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (existing && !input.forceRefetch) return existing;
|
|
||||||
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
|
||||||
where: {
|
|
||||||
id: input.variantId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const scenario = await prisma.testScenario.findUnique({
|
|
||||||
where: {
|
|
||||||
id: input.scenarioId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!variant || !scenario) return null;
|
|
||||||
|
|
||||||
const prompt = await constructPrompt(variant, scenario);
|
|
||||||
|
|
||||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
|
||||||
|
|
||||||
// TODO: we should probably only use this if temperature=0
|
|
||||||
const existingResponse = await prisma.modelOutput.findFirst({
|
|
||||||
where: { inputHash, errorMessage: null },
|
|
||||||
});
|
|
||||||
|
|
||||||
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
|
|
||||||
|
|
||||||
if (existingResponse) {
|
|
||||||
modelResponse = {
|
|
||||||
output: existingResponse.output as Prisma.InputJsonValue,
|
|
||||||
statusCode: existingResponse.statusCode,
|
|
||||||
errorMessage: existingResponse.errorMessage,
|
|
||||||
timeToComplete: existingResponse.timeToComplete,
|
|
||||||
promptTokens: existingResponse.promptTokens ?? undefined,
|
|
||||||
completionTokens: existingResponse.completionTokens ?? undefined,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
modelResponse = await getCompletion(prompt, input.channel);
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.upsert({
|
|
||||||
where: {
|
|
||||||
promptVariantId_testScenarioId: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
create: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
inputHash,
|
|
||||||
...modelResponse,
|
|
||||||
},
|
|
||||||
update: {
|
|
||||||
...modelResponse,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
await reevaluateVariant(input.variantId);
|
|
||||||
|
|
||||||
return modelOutput;
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
@@ -1,6 +1,12 @@
|
|||||||
|
import dedent from "dedent";
|
||||||
|
import { isObject } from "lodash-es";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import { OpenAIChatModel } from "~/server/types";
|
||||||
|
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||||
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
|
|
||||||
@@ -26,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const evalResults = await prisma.evaluationResult.findMany({
|
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||||
where: {
|
by: ["evaluationId"],
|
||||||
promptVariantId: input.variantId,
|
_sum: {
|
||||||
|
result: true,
|
||||||
},
|
},
|
||||||
include: { evaluation: true },
|
_count: {
|
||||||
|
id: true,
|
||||||
|
},
|
||||||
|
where: {
|
||||||
|
modelOutput: {
|
||||||
|
scenarioVariantCell: {
|
||||||
|
promptVariant: {
|
||||||
|
id: input.variantId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const evals = await prisma.evaluation.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: variant.experimentId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const evalResults = evals.map((evalItem) => {
|
||||||
|
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
|
||||||
|
return {
|
||||||
|
id: evalItem.id,
|
||||||
|
label: evalItem.label,
|
||||||
|
passCount: evalResult?._sum?.result ?? 0,
|
||||||
|
totalCount: evalResult?._count?.id ?? 1,
|
||||||
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
const scenarioCount = await prisma.testScenario.count({
|
const scenarioCount = await prisma.testScenario.count({
|
||||||
@@ -39,17 +77,24 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
const outputCount = await prisma.modelOutput.count({
|
const outputCount = await prisma.scenarioVariantCell.count({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenario: { visible: true },
|
testScenario: { visible: true },
|
||||||
|
modelOutput: {
|
||||||
|
is: {},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const overallTokens = await prisma.modelOutput.aggregate({
|
const overallTokens = await prisma.modelOutput.aggregate({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
scenarioVariantCell: {
|
||||||
testScenario: { visible: true },
|
promptVariantId: input.variantId,
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
_sum: {
|
_sum: {
|
||||||
promptTokens: true,
|
promptTokens: true,
|
||||||
@@ -57,18 +102,33 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: fix this
|
|
||||||
const model = "gpt-3.5-turbo-0613";
|
|
||||||
// const model = getModelName(variant.config);
|
|
||||||
|
|
||||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||||
const overallPromptCost = calculateTokenCost(model, promptTokens);
|
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
|
||||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||||
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
|
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
|
||||||
|
|
||||||
const overallCost = overallPromptCost + overallCompletionCost;
|
const overallCost = overallPromptCost + overallCompletionCost;
|
||||||
|
|
||||||
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
|
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
||||||
|
where: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenario: { visible: true },
|
||||||
|
// Check if is PENDING or IN_PROGRESS
|
||||||
|
retrievalStatus: {
|
||||||
|
in: ["PENDING", "IN_PROGRESS"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
return {
|
||||||
|
evalResults,
|
||||||
|
promptTokens,
|
||||||
|
completionTokens,
|
||||||
|
overallCost,
|
||||||
|
scenarioCount,
|
||||||
|
outputCount,
|
||||||
|
awaitingRetrievals,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: publicProcedure
|
||||||
@@ -105,7 +165,19 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
label: `Prompt Variant ${largestSortIndex + 2}`,
|
||||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
||||||
constructFn: lastVariant?.constructFn ?? "",
|
constructFn:
|
||||||
|
lastVariant?.constructFn ??
|
||||||
|
dedent`
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "Return 'Hello, world!'",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -114,6 +186,17 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const scenario of scenarios) {
|
||||||
|
await generateNewCell(newVariant.id, scenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
return newVariant;
|
return newVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -185,6 +268,27 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let model = existing.model;
|
||||||
|
try {
|
||||||
|
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
|
||||||
|
|
||||||
|
if (!isObject(contructedPrompt)) {
|
||||||
|
return userError("Prompt is not an object");
|
||||||
|
}
|
||||||
|
if (!("model" in contructedPrompt)) {
|
||||||
|
return userError("Prompt does not define a model");
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
typeof contructedPrompt.model !== "string" ||
|
||||||
|
!(contructedPrompt.model in OpenAIChatModel)
|
||||||
|
) {
|
||||||
|
return userError("Prompt defines an invalid model");
|
||||||
|
}
|
||||||
|
model = contructedPrompt.model;
|
||||||
|
} catch (e) {
|
||||||
|
return userError((e as Error).message);
|
||||||
|
}
|
||||||
|
|
||||||
// Create a duplicate with only the config changed
|
// Create a duplicate with only the config changed
|
||||||
const newVariant = await prisma.promptVariant.create({
|
const newVariant = await prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -193,11 +297,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
sortIndex: existing.sortIndex,
|
sortIndex: existing.sortIndex,
|
||||||
uiId: existing.uiId,
|
uiId: existing.uiId,
|
||||||
constructFn: input.constructFn,
|
constructFn: input.constructFn,
|
||||||
|
model,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Hide anything with the same uiId besides the new one
|
// Hide anything with the same uiId besides the new one
|
||||||
const hideOldVariantsAction = prisma.promptVariant.updateMany({
|
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||||
where: {
|
where: {
|
||||||
uiId: existing.uiId,
|
uiId: existing.uiId,
|
||||||
id: {
|
id: {
|
||||||
@@ -209,12 +314,20 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||||
hideOldVariantsAction,
|
|
||||||
recordExperimentUpdated(existing.experimentId),
|
|
||||||
]);
|
|
||||||
|
|
||||||
return newVariant;
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: newVariant.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const scenario of scenarios) {
|
||||||
|
await generateNewCell(newVariant.id, scenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { status: "ok" } as const;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
reorder: publicProcedure
|
reorder: publicProcedure
|
||||||
|
|||||||
78
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
78
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
||||||
|
|
||||||
|
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||||
|
get: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
scenarioId: z.string(),
|
||||||
|
variantId: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.query(async ({ input }) => {
|
||||||
|
return await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: {
|
||||||
|
include: {
|
||||||
|
outputEvaluation: {
|
||||||
|
include: {
|
||||||
|
evaluation: {
|
||||||
|
select: { label: true },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
forceRefetch: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
scenarioId: z.string(),
|
||||||
|
variantId: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input }) => {
|
||||||
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!cell) {
|
||||||
|
await generateNewCell(input.variantId, input.scenarioId);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cell.modelOutput) {
|
||||||
|
// TODO: Maybe keep these around to show previous generations?
|
||||||
|
await prisma.modelOutput.delete({
|
||||||
|
where: { id: cell.modelOutput.id },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cell.id },
|
||||||
|
data: { retrievalStatus: "PENDING" },
|
||||||
|
});
|
||||||
|
|
||||||
|
await queueLLMRetrievalTask(cell.id);
|
||||||
|
return true;
|
||||||
|
}),
|
||||||
|
});
|
||||||
@@ -3,7 +3,8 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
|||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { autogenerateScenarioValues } from "../autogen";
|
import { autogenerateScenarioValues } from "../autogen";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||||
@@ -48,10 +49,21 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
const [scenario] = await prisma.$transaction([
|
||||||
createNewScenarioAction,
|
createNewScenarioAction,
|
||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const promptVariants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const variant of promptVariants) {
|
||||||
|
await generateNewCell(variant.id, scenario.id);
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||||
@@ -61,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Reevaluate all evaluations now that this scenario is hidden
|
// Reevaluate all evaluations now that this scenario is hidden
|
||||||
await reevaluateAll(hiddenScenario.experimentId);
|
await runAllEvals(hiddenScenario.experimentId);
|
||||||
|
|
||||||
return hiddenScenario;
|
return hiddenScenario;
|
||||||
}),
|
}),
|
||||||
@@ -175,6 +187,17 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const promptVariants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: newScenario.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const variant of promptVariants) {
|
||||||
|
await generateNewCell(variant.id, newScenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
return newScenario;
|
return newScenario;
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
47
src/server/scripts/migrateScenarioVariantOutputData.ts
Normal file
47
src/server/scripts/migrateScenarioVariantOutputData.ts
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
import { prisma } from "../db";
|
||||||
|
|
||||||
|
async function migrateScenarioVariantOutputData() {
|
||||||
|
// Get all ScenarioVariantCells
|
||||||
|
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
|
||||||
|
console.log(`Found ${cells.length} records`);
|
||||||
|
|
||||||
|
let updatedCount = 0;
|
||||||
|
|
||||||
|
// Loop through all scenarioVariants
|
||||||
|
for (const cell of cells) {
|
||||||
|
// Create a new ModelOutput for each ScenarioVariant with an existing output
|
||||||
|
if (cell.output && !cell.modelOutput) {
|
||||||
|
updatedCount++;
|
||||||
|
await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
scenarioVariantCellId: cell.id,
|
||||||
|
inputHash: cell.inputHash || "",
|
||||||
|
output: cell.output as Prisma.InputJsonValue,
|
||||||
|
timeToComplete: cell.timeToComplete ?? undefined,
|
||||||
|
promptTokens: cell.promptTokens,
|
||||||
|
completionTokens: cell.completionTokens,
|
||||||
|
createdAt: cell.createdAt,
|
||||||
|
updatedAt: cell.updatedAt,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
|
||||||
|
updatedCount++;
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cell.id },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("Data migration completed");
|
||||||
|
console.log(`Updated ${updatedCount} records`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the function
|
||||||
|
migrateScenarioVariantOutputData().catch((error) => {
|
||||||
|
console.error("An error occurred while migrating data: ", error);
|
||||||
|
process.exit(1);
|
||||||
|
});
|
||||||
31
src/server/tasks/defineTask.ts
Normal file
31
src/server/tasks/defineTask.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Import necessary dependencies
|
||||||
|
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
|
// Define the defineTask function
|
||||||
|
function defineTask<TPayload>(
|
||||||
|
taskIdentifier: string,
|
||||||
|
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||||
|
) {
|
||||||
|
const enqueue = async (payload: TPayload) => {
|
||||||
|
console.log("Enqueuing task", taskIdentifier, payload);
|
||||||
|
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||||
|
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
|
||||||
|
return taskHandler(payload, helpers);
|
||||||
|
};
|
||||||
|
|
||||||
|
const task = {
|
||||||
|
identifier: taskIdentifier,
|
||||||
|
handler: handler as Task,
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
enqueue,
|
||||||
|
task,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export default defineTask;
|
||||||
154
src/server/tasks/queryLLM.task.ts
Normal file
154
src/server/tasks/queryLLM.task.ts
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
|
||||||
|
import { type JSONSerializable } from "../types";
|
||||||
|
import { sleep } from "../utils/sleep";
|
||||||
|
import { shouldStream } from "../utils/shouldStream";
|
||||||
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
|
import { constructPrompt } from "../utils/constructPrompt";
|
||||||
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
|
||||||
|
const MAX_AUTO_RETRIES = 10;
|
||||||
|
const MIN_DELAY = 500; // milliseconds
|
||||||
|
const MAX_DELAY = 15000; // milliseconds
|
||||||
|
|
||||||
|
function calculateDelay(numPreviousTries: number): number {
|
||||||
|
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||||
|
const jitter = Math.random() * baseDelay;
|
||||||
|
return baseDelay + jitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
const getCompletionWithRetries = async (
|
||||||
|
cellId: string,
|
||||||
|
payload: JSONSerializable,
|
||||||
|
channel?: string,
|
||||||
|
): Promise<CompletionResponse> => {
|
||||||
|
let modelResponse: CompletionResponse | null = null;
|
||||||
|
try {
|
||||||
|
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
|
||||||
|
modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel);
|
||||||
|
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
|
||||||
|
return modelResponse;
|
||||||
|
}
|
||||||
|
const delay = calculateDelay(i);
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
errorMessage: "Rate limit exceeded",
|
||||||
|
statusCode: 429,
|
||||||
|
retryTime: new Date(Date.now() + delay),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// TODO: Maybe requeue the job so other jobs can run in the future?
|
||||||
|
await sleep(delay);
|
||||||
|
}
|
||||||
|
throw new Error("Max retries limit reached");
|
||||||
|
} catch (error: unknown) {
|
||||||
|
return {
|
||||||
|
statusCode: modelResponse?.statusCode ?? 500,
|
||||||
|
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
|
||||||
|
output: null as unknown as Prisma.InputJsonValue,
|
||||||
|
timeToComplete: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export type queryLLMJob = {
|
||||||
|
scenarioVariantCellId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||||
|
const { scenarioVariantCellId } = task;
|
||||||
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
include: { modelOutput: true },
|
||||||
|
});
|
||||||
|
if (!cell) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cell is not pending, then some other job is already processing it
|
||||||
|
if (cell.retrievalStatus !== "PENDING") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "IN_PROGRESS",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: { id: cell.promptVariantId },
|
||||||
|
});
|
||||||
|
if (!variant) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scenario = await prisma.testScenario.findUnique({
|
||||||
|
where: { id: cell.testScenarioId },
|
||||||
|
});
|
||||||
|
if (!scenario) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||||
|
|
||||||
|
const streamingEnabled = shouldStream(prompt);
|
||||||
|
let streamingChannel;
|
||||||
|
|
||||||
|
if (streamingEnabled) {
|
||||||
|
streamingChannel = generateChannel();
|
||||||
|
// Save streaming channel so that UI can connect to it
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
streamingChannel,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelResponse = await getCompletionWithRetries(
|
||||||
|
scenarioVariantCellId,
|
||||||
|
prompt,
|
||||||
|
streamingChannel,
|
||||||
|
);
|
||||||
|
|
||||||
|
let modelOutput = null;
|
||||||
|
if (modelResponse.statusCode === 200) {
|
||||||
|
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||||
|
|
||||||
|
modelOutput = await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
scenarioVariantCellId,
|
||||||
|
inputHash,
|
||||||
|
output: modelResponse.output,
|
||||||
|
timeToComplete: modelResponse.timeToComplete,
|
||||||
|
promptTokens: modelResponse.promptTokens,
|
||||||
|
completionTokens: modelResponse.completionTokens,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: modelResponse.statusCode,
|
||||||
|
errorMessage: modelResponse.errorMessage,
|
||||||
|
streamingChannel: null,
|
||||||
|
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
|
||||||
|
modelOutput: {
|
||||||
|
connect: {
|
||||||
|
id: modelOutput?.id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (modelOutput) {
|
||||||
|
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||||
|
}
|
||||||
|
});
|
||||||
40
src/server/tasks/worker.ts
Normal file
40
src/server/tasks/worker.ts
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import { type TaskList, run } from "graphile-worker";
|
||||||
|
import "dotenv/config";
|
||||||
|
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
import { queryLLM } from "./queryLLM.task";
|
||||||
|
|
||||||
|
const registeredTasks = [queryLLM];
|
||||||
|
|
||||||
|
const taskList = registeredTasks.reduce((acc, task) => {
|
||||||
|
acc[task.task.identifier] = task.task.handler;
|
||||||
|
return acc;
|
||||||
|
}, {} as TaskList);
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
// Run a worker to execute jobs:
|
||||||
|
const runner = await run({
|
||||||
|
connectionString: env.DATABASE_URL,
|
||||||
|
concurrency: 20,
|
||||||
|
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||||
|
noHandleSignals: false,
|
||||||
|
pollInterval: 1000,
|
||||||
|
// you can set the taskList or taskDirectory but not both
|
||||||
|
taskList,
|
||||||
|
// or:
|
||||||
|
// taskDirectory: `${__dirname}/tasks`,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
||||||
|
// "unhandled rejection" errors causing a process crash in the event of
|
||||||
|
// something going wrong.
|
||||||
|
await runner.promise;
|
||||||
|
|
||||||
|
// If the worker exits (whether through fatal error or otherwise), the above
|
||||||
|
// promise will resolve/reject.
|
||||||
|
}
|
||||||
|
|
||||||
|
main().catch((err) => {
|
||||||
|
console.error("Unhandled error occurred running worker: ", err);
|
||||||
|
process.exit(1);
|
||||||
|
});
|
||||||
@@ -7,9 +7,7 @@ test.skip("constructPrompt", async () => {
|
|||||||
constructFn: `prompt = { "fooz": "bar" }`,
|
constructFn: `prompt = { "fooz": "bar" }`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
variableValues: {
|
foo: "bar",
|
||||||
foo: "bar",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,10 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
|||||||
|
|
||||||
export async function constructPrompt(
|
export async function constructPrompt(
|
||||||
variant: Pick<PromptVariant, "constructFn">,
|
variant: Pick<PromptVariant, "constructFn">,
|
||||||
testScenario: Pick<TestScenario, "variableValues">,
|
scenario: TestScenario["variableValues"],
|
||||||
): Promise<JSONSerializable> {
|
): Promise<JSONSerializable> {
|
||||||
const scenario = testScenario.variableValues as JSONSerializable;
|
|
||||||
|
|
||||||
const code = `
|
const code = `
|
||||||
const scenario = ${JSON.stringify(scenario, null, 2)};
|
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||||
let prompt
|
let prompt
|
||||||
|
|
||||||
${variant.constructFn}
|
${variant.constructFn}
|
||||||
|
|||||||
6
src/server/utils/error.ts
Normal file
6
src/server/utils/error.ts
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export default function userError(message: string): { status: "error"; message: string } {
|
||||||
|
return {
|
||||||
|
status: "error",
|
||||||
|
message,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
|
||||||
import { type VariableMap, fillTemplate } from "./fillTemplate";
|
|
||||||
|
|
||||||
export const evaluateOutput = (
|
|
||||||
modelOutput: ModelOutput,
|
|
||||||
scenario: TestScenario,
|
|
||||||
evaluation: Evaluation,
|
|
||||||
): boolean => {
|
|
||||||
const output = modelOutput.output as unknown as ChatCompletion;
|
|
||||||
const message = output?.choices?.[0]?.message;
|
|
||||||
|
|
||||||
if (!message) return false;
|
|
||||||
|
|
||||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
|
||||||
|
|
||||||
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
|
|
||||||
|
|
||||||
let match;
|
|
||||||
|
|
||||||
switch (evaluation.matchType) {
|
|
||||||
case "CONTAINS":
|
|
||||||
match = stringifiedMessage.match(matchRegex) !== null;
|
|
||||||
break;
|
|
||||||
case "DOES_NOT_CONTAIN":
|
|
||||||
match = stringifiedMessage.match(matchRegex) === null;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return match;
|
|
||||||
};
|
|
||||||
@@ -1,103 +1,79 @@
|
|||||||
import { type Evaluation } from "@prisma/client";
|
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { evaluateOutput } from "./evaluateOutput";
|
import { runOneEval } from "./runOneEval";
|
||||||
|
import { type Scenario } from "~/components/OutputsTable/types";
|
||||||
|
|
||||||
export const reevaluateVariant = async (variantId: string) => {
|
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const result = await runOneEval(evaluation, scenario, modelOutput);
|
||||||
where: { id: variantId },
|
return await prisma.outputEvaluation.upsert({
|
||||||
});
|
|
||||||
if (!variant) return;
|
|
||||||
|
|
||||||
const evaluations = await prisma.evaluation.findMany({
|
|
||||||
where: { experimentId: variant.experimentId },
|
|
||||||
});
|
|
||||||
|
|
||||||
const modelOutputs = await prisma.modelOutput.findMany({
|
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: variantId,
|
modelOutputId_evaluationId: {
|
||||||
statusCode: { notIn: [429] },
|
modelOutputId: modelOutput.id,
|
||||||
testScenario: { visible: true },
|
evaluationId: evaluation.id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
create: {
|
||||||
|
modelOutputId: modelOutput.id,
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
...result,
|
||||||
|
},
|
||||||
|
update: {
|
||||||
|
...result,
|
||||||
},
|
},
|
||||||
include: { testScenario: true },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
await Promise.all(
|
|
||||||
evaluations.map(async (evaluation) => {
|
|
||||||
const passCount = modelOutputs.filter((output) =>
|
|
||||||
evaluateOutput(output, output.testScenario, evaluation),
|
|
||||||
).length;
|
|
||||||
const failCount = modelOutputs.length - passCount;
|
|
||||||
|
|
||||||
await prisma.evaluationResult.upsert({
|
|
||||||
where: {
|
|
||||||
evaluationId_promptVariantId: {
|
|
||||||
evaluationId: evaluation.id,
|
|
||||||
promptVariantId: variantId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
create: {
|
|
||||||
evaluationId: evaluation.id,
|
|
||||||
promptVariantId: variantId,
|
|
||||||
passCount,
|
|
||||||
failCount,
|
|
||||||
},
|
|
||||||
update: {
|
|
||||||
passCount,
|
|
||||||
failCount,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
export const runEvalsForOutput = async (
|
||||||
const variants = await prisma.promptVariant.findMany({
|
experimentId: string,
|
||||||
where: { experimentId: evaluation.experimentId, visible: true },
|
scenario: Scenario,
|
||||||
});
|
modelOutput: ModelOutput,
|
||||||
|
) => {
|
||||||
const modelOutputs = await prisma.modelOutput.findMany({
|
|
||||||
where: {
|
|
||||||
promptVariantId: { in: variants.map((v) => v.id) },
|
|
||||||
testScenario: { visible: true },
|
|
||||||
statusCode: { notIn: [429] },
|
|
||||||
},
|
|
||||||
include: { testScenario: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
await Promise.all(
|
|
||||||
variants.map(async (variant) => {
|
|
||||||
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
|
|
||||||
const passCount = outputs.filter((output) =>
|
|
||||||
evaluateOutput(output, output.testScenario, evaluation),
|
|
||||||
).length;
|
|
||||||
const failCount = outputs.length - passCount;
|
|
||||||
|
|
||||||
await prisma.evaluationResult.upsert({
|
|
||||||
where: {
|
|
||||||
evaluationId_promptVariantId: {
|
|
||||||
evaluationId: evaluation.id,
|
|
||||||
promptVariantId: variant.id,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
create: {
|
|
||||||
evaluationId: evaluation.id,
|
|
||||||
promptVariantId: variant.id,
|
|
||||||
passCount,
|
|
||||||
failCount,
|
|
||||||
},
|
|
||||||
update: {
|
|
||||||
passCount,
|
|
||||||
failCount,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const reevaluateAll = async (experimentId: string) => {
|
|
||||||
const evaluations = await prisma.evaluation.findMany({
|
const evaluations = await prisma.evaluation.findMany({
|
||||||
where: { experimentId },
|
where: { experimentId },
|
||||||
});
|
});
|
||||||
|
|
||||||
await Promise.all(evaluations.map(reevaluateEvaluation));
|
await Promise.all(
|
||||||
|
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const runAllEvals = async (experimentId: string) => {
|
||||||
|
const outputs = await prisma.modelOutput.findMany({
|
||||||
|
where: {
|
||||||
|
scenarioVariantCell: {
|
||||||
|
promptVariant: {
|
||||||
|
experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
scenarioVariantCell: {
|
||||||
|
include: {
|
||||||
|
testScenario: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
outputEvaluation: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const evals = await prisma.evaluation.findMany({
|
||||||
|
where: { experimentId },
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
outputs.map(async (output) => {
|
||||||
|
const unrunEvals = evals.filter(
|
||||||
|
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
||||||
|
);
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
unrunEvals.map(async (evaluation) => {
|
||||||
|
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -2,6 +2,16 @@ import { type JSONSerializable } from "../types";
|
|||||||
|
|
||||||
export type VariableMap = Record<string, string>;
|
export type VariableMap = Record<string, string>;
|
||||||
|
|
||||||
|
// Escape quotes to match the way we encode JSON
|
||||||
|
export function escapeQuotes(str: string) {
|
||||||
|
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape regex special characters
|
||||||
|
export function escapeRegExp(str: string) {
|
||||||
|
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
|
||||||
|
}
|
||||||
|
|
||||||
export function fillTemplate(template: string, variables: VariableMap): string {
|
export function fillTemplate(template: string, variables: VariableMap): string {
|
||||||
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
||||||
}
|
}
|
||||||
|
|||||||
76
src/server/utils/generateNewCell.ts
Normal file
76
src/server/utils/generateNewCell.ts
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
import { prisma } from "../db";
|
||||||
|
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||||
|
import { constructPrompt } from "./constructPrompt";
|
||||||
|
|
||||||
|
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: variantId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const scenario = await prisma.testScenario.findUnique({
|
||||||
|
where: {
|
||||||
|
id: scenarioId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!variant || !scenario) return null;
|
||||||
|
|
||||||
|
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||||
|
|
||||||
|
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||||
|
|
||||||
|
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: variantId,
|
||||||
|
testScenarioId: scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (cell) return cell;
|
||||||
|
|
||||||
|
cell = await prisma.scenarioVariantCell.create({
|
||||||
|
data: {
|
||||||
|
promptVariantId: variantId,
|
||||||
|
testScenarioId: scenarioId,
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||||
|
where: {
|
||||||
|
inputHash,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
let newModelOutput;
|
||||||
|
|
||||||
|
if (matchingModelOutput) {
|
||||||
|
newModelOutput = await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
scenarioVariantCellId: cell.id,
|
||||||
|
inputHash,
|
||||||
|
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||||
|
timeToComplete: matchingModelOutput.timeToComplete,
|
||||||
|
promptTokens: matchingModelOutput.promptTokens,
|
||||||
|
completionTokens: matchingModelOutput.completionTokens,
|
||||||
|
createdAt: matchingModelOutput.createdAt,
|
||||||
|
updatedAt: matchingModelOutput.updatedAt,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
cell = await queueLLMRetrievalTask(cell.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { ...cell, modelOutput: newModelOutput };
|
||||||
|
};
|
||||||
@@ -1,18 +1,15 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash-es";
|
||||||
import { Prisma } from "@prisma/client";
|
import { Prisma } from "@prisma/client";
|
||||||
import { streamChatCompletion } from "./openai";
|
import { streamChatCompletion } from "./openai";
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { type JSONSerializable, OpenAIChatModel } from "../types";
|
import { type OpenAIChatModel } from "../types";
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
import { getModelName } from "./getModelName";
|
|
||||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||||
|
|
||||||
env;
|
export type CompletionResponse = {
|
||||||
|
|
||||||
type CompletionResponse = {
|
|
||||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||||
statusCode: number;
|
statusCode: number;
|
||||||
errorMessage: string | null;
|
errorMessage: string | null;
|
||||||
@@ -22,35 +19,7 @@ type CompletionResponse = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export async function getCompletion(
|
export async function getCompletion(
|
||||||
payload: JSONSerializable,
|
|
||||||
channel?: string,
|
|
||||||
): Promise<CompletionResponse> {
|
|
||||||
const modelName = getModelName(payload);
|
|
||||||
if (!modelName)
|
|
||||||
return {
|
|
||||||
output: Prisma.JsonNull,
|
|
||||||
statusCode: 400,
|
|
||||||
errorMessage: "Invalid payload provided",
|
|
||||||
timeToComplete: 0,
|
|
||||||
};
|
|
||||||
if (modelName in OpenAIChatModel) {
|
|
||||||
return getOpenAIChatCompletion(
|
|
||||||
payload as unknown as CompletionCreateParams,
|
|
||||||
env.OPENAI_API_KEY,
|
|
||||||
channel,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
output: Prisma.JsonNull,
|
|
||||||
statusCode: 400,
|
|
||||||
errorMessage: "Invalid model provided",
|
|
||||||
timeToComplete: 0,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function getOpenAIChatCompletion(
|
|
||||||
payload: CompletionCreateParams,
|
payload: CompletionCreateParams,
|
||||||
apiKey: string,
|
|
||||||
channel?: string,
|
channel?: string,
|
||||||
): Promise<CompletionResponse> {
|
): Promise<CompletionResponse> {
|
||||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||||
@@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion(
|
|||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
Authorization: `Bearer ${apiKey}`,
|
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
||||||
},
|
},
|
||||||
body: JSON.stringify(payload),
|
body: JSON.stringify(payload),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
import { isObject } from "lodash";
|
|
||||||
import { type JSONSerializable, type SupportedModel } from "../types";
|
|
||||||
import { type Prisma } from "@prisma/client";
|
|
||||||
|
|
||||||
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
|
|
||||||
if (!isObject(config)) return null;
|
|
||||||
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
import { omit } from "lodash";
|
import { omit } from "lodash-es";
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
import OpenAI from "openai";
|
import OpenAI from "openai";
|
||||||
|
|||||||
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import { prisma } from "../db";
|
||||||
|
import { queryLLM } from "../tasks/queryLLM.task";
|
||||||
|
|
||||||
|
export const queueLLMRetrievalTask = async (cellId: string) => {
|
||||||
|
const updatedCell = await prisma.scenarioVariantCell.update({
|
||||||
|
where: {
|
||||||
|
id: cellId,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "PENDING",
|
||||||
|
errorMessage: null,
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// @ts-expect-error we aren't passing the helpers but that's ok
|
||||||
|
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
||||||
|
|
||||||
|
return updatedCell;
|
||||||
|
};
|
||||||
95
src/server/utils/runOneEval.ts
Normal file
95
src/server/utils/runOneEval.ts
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||||
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||||
|
import { openai } from "./openai";
|
||||||
|
import dedent from "dedent";
|
||||||
|
|
||||||
|
export const runGpt4Eval = async (
|
||||||
|
evaluation: Evaluation,
|
||||||
|
scenario: TestScenario,
|
||||||
|
message: ChatCompletion.Choice.Message,
|
||||||
|
): Promise<{ result: number; details: string }> => {
|
||||||
|
const output = await openai.chat.completions.create({
|
||||||
|
model: "gpt-4-0613",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: dedent`
|
||||||
|
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
|
||||||
|
---
|
||||||
|
${evaluation.value}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: `The full output of the simpler message:\n---\n${JSON.stringify(
|
||||||
|
message.content ?? message.function_call,
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
)}`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "report_success",
|
||||||
|
},
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "report_success",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
required: ["thoughts", "success"],
|
||||||
|
properties: {
|
||||||
|
thoughts: {
|
||||||
|
type: "string",
|
||||||
|
description: "Explain your reasoning for considering this a pass or fail",
|
||||||
|
},
|
||||||
|
success: {
|
||||||
|
type: "boolean",
|
||||||
|
description:
|
||||||
|
"Whether the simpler model successfully completed the task for this scenario",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
|
||||||
|
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
return { result: 0, details: "Error parsing GPT-4 output" };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export const runOneEval = async (
|
||||||
|
evaluation: Evaluation,
|
||||||
|
scenario: TestScenario,
|
||||||
|
modelOutput: ModelOutput,
|
||||||
|
): Promise<{ result: number; details?: string }> => {
|
||||||
|
const output = modelOutput.output as unknown as ChatCompletion;
|
||||||
|
|
||||||
|
const message = output?.choices?.[0]?.message;
|
||||||
|
|
||||||
|
if (!message) return { result: 0 };
|
||||||
|
|
||||||
|
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||||
|
|
||||||
|
const matchRegex = escapeRegExp(
|
||||||
|
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
||||||
|
);
|
||||||
|
|
||||||
|
switch (evaluation.evalType) {
|
||||||
|
case "CONTAINS":
|
||||||
|
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
|
||||||
|
case "DOES_NOT_CONTAIN":
|
||||||
|
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
|
||||||
|
case "GPT4_EVAL":
|
||||||
|
return await runGpt4Eval(evaluation, scenario, message);
|
||||||
|
}
|
||||||
|
};
|
||||||
7
src/server/utils/shouldStream.ts
Normal file
7
src/server/utils/shouldStream.ts
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import { isObject } from "lodash-es";
|
||||||
|
import { type JSONSerializable } from "../types";
|
||||||
|
|
||||||
|
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||||
|
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
|
||||||
|
return shouldStream;
|
||||||
|
};
|
||||||
1
src/server/utils/sleep.ts
Normal file
1
src/server/utils/sleep.ts
Normal file
@@ -0,0 +1 @@
|
|||||||
|
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
@@ -2,12 +2,9 @@ import { type RouterOutputs } from "~/utils/api";
|
|||||||
import { type SliceCreator } from "./store";
|
import { type SliceCreator } from "./store";
|
||||||
import loader from "@monaco-editor/loader";
|
import loader from "@monaco-editor/loader";
|
||||||
import openAITypes from "~/codegen/openai.types.ts.txt";
|
import openAITypes from "~/codegen/openai.types.ts.txt";
|
||||||
import prettier from "prettier/standalone";
|
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||||
import parserTypescript from "prettier/plugins/typescript";
|
|
||||||
|
|
||||||
// @ts-expect-error for some reason missing from types
|
export const editorBackground = "#fafafa";
|
||||||
import parserEstree from "prettier/plugins/estree";
|
|
||||||
import { type languages } from "monaco-editor/esm/vs/editor/editor.api";
|
|
||||||
|
|
||||||
export type SharedVariantEditorSlice = {
|
export type SharedVariantEditorSlice = {
|
||||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||||
@@ -17,29 +14,12 @@ export type SharedVariantEditorSlice = {
|
|||||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
const customFormatter: languages.DocumentFormattingEditProvider = {
|
|
||||||
provideDocumentFormattingEdits: async (model) => {
|
|
||||||
const val = model.getValue();
|
|
||||||
console.log("going to format!", val);
|
|
||||||
const text = await prettier.format(val, {
|
|
||||||
parser: "typescript",
|
|
||||||
plugins: [parserTypescript, parserEstree],
|
|
||||||
// We're showing these in pretty narrow panes so let's keep the print width low
|
|
||||||
printWidth: 60,
|
|
||||||
});
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
range: model.getFullModelRange(),
|
|
||||||
text,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||||
monaco: loader.__getMonacoInstance(),
|
monaco: loader.__getMonacoInstance(),
|
||||||
loadMonaco: async () => {
|
loadMonaco: async () => {
|
||||||
|
// We only want to run this client-side
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
|
||||||
const monaco = await loader.init();
|
const monaco = await loader.init();
|
||||||
|
|
||||||
monaco.editor.defineTheme("customTheme", {
|
monaco.editor.defineTheme("customTheme", {
|
||||||
@@ -47,12 +27,13 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
inherit: true,
|
inherit: true,
|
||||||
rules: [],
|
rules: [],
|
||||||
colors: {
|
colors: {
|
||||||
"editor.background": "#fafafa",
|
"editor.background": editorBackground,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
|
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
|
||||||
allowNonTsExtensions: true,
|
allowNonTsExtensions: true,
|
||||||
|
strictNullChecks: true,
|
||||||
lib: ["esnext"],
|
lib: ["esnext"],
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -66,7 +47,16 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
monaco.Uri.parse("file:///openai.types.ts"),
|
monaco.Uri.parse("file:///openai.types.ts"),
|
||||||
);
|
);
|
||||||
|
|
||||||
monaco.languages.registerDocumentFormattingEditProvider("typescript", customFormatter);
|
monaco.languages.registerDocumentFormattingEditProvider("typescript", {
|
||||||
|
provideDocumentFormattingEdits: async (model) => {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
range: model.getFullModelRange(),
|
||||||
|
text: await formatPromptConstructor(model.getValue()),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.sharedVariantEditor.monaco = monaco;
|
state.sharedVariantEditor.monaco = monaco;
|
||||||
@@ -95,7 +85,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
)} as const;
|
)} as const;
|
||||||
|
|
||||||
type Scenario = typeof scenarios[number];
|
type Scenario = typeof scenarios[number];
|
||||||
declare var scenario: Scenario | null;
|
declare var scenario: Scenario | { [key: string]: string };
|
||||||
`;
|
`;
|
||||||
|
|
||||||
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));
|
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const calculateTokenCost = (
|
export const calculateTokenCost = (
|
||||||
model: SupportedModel | null,
|
model: SupportedModel | string | null,
|
||||||
numTokens: number,
|
numTokens: number,
|
||||||
isCompletion = false,
|
isCompletion = false,
|
||||||
) => {
|
) => {
|
||||||
|
|||||||
10
src/utils/formatPromptConstructor.test.ts
Normal file
10
src/utils/formatPromptConstructor.test.ts
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import { expect, test } from "vitest";
|
||||||
|
import { stripTypes } from "./formatPromptConstructor";
|
||||||
|
|
||||||
|
test("stripTypes", () => {
|
||||||
|
expect(stripTypes(`const foo: string = "bar";`)).toBe(`const foo = "bar";`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("stripTypes with invalid syntax", () => {
|
||||||
|
expect(stripTypes(`asdf foo: string = "bar"`)).toBe(`asdf foo: string = "bar"`);
|
||||||
|
});
|
||||||
31
src/utils/formatPromptConstructor.ts
Normal file
31
src/utils/formatPromptConstructor.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import prettier from "prettier/standalone";
|
||||||
|
import parserTypescript from "prettier/plugins/typescript";
|
||||||
|
|
||||||
|
// @ts-expect-error for some reason missing from types
|
||||||
|
import parserEstree from "prettier/plugins/estree";
|
||||||
|
|
||||||
|
import * as babel from "@babel/standalone";
|
||||||
|
|
||||||
|
export function stripTypes(tsCode: string): string {
|
||||||
|
const options = {
|
||||||
|
presets: ["typescript"],
|
||||||
|
filename: "file.ts",
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = babel.transform(tsCode, options);
|
||||||
|
return result.code ?? tsCode;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error stripping types", error);
|
||||||
|
return tsCode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default async function formatPromptConstructor(code: string): Promise<string> {
|
||||||
|
return await prettier.format(stripTypes(code), {
|
||||||
|
parser: "typescript",
|
||||||
|
plugins: [parserTypescript, parserEstree],
|
||||||
|
// We're showing these in pretty narrow panes so let's keep the print width low
|
||||||
|
printWidth: 60,
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import { useCallback, useEffect, useState } from "react";
|
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
|
|
||||||
export const useExperiment = () => {
|
export const useExperiment = () => {
|
||||||
@@ -49,3 +49,43 @@ export const useModifierKeyLabel = () => {
|
|||||||
}, []);
|
}, []);
|
||||||
return label;
|
return label;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
interface Dimensions {
|
||||||
|
left: number;
|
||||||
|
top: number;
|
||||||
|
right: number;
|
||||||
|
bottom: number;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
x: number;
|
||||||
|
y: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get dimensions of an element
|
||||||
|
export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | undefined] => {
|
||||||
|
const ref = useRef<HTMLElement>(null);
|
||||||
|
const [dimensions, setDimensions] = useState<Dimensions | undefined>();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (ref.current) {
|
||||||
|
const observer = new ResizeObserver((entries) => {
|
||||||
|
entries.forEach((entry) => {
|
||||||
|
setDimensions(entry.contentRect);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const observedRef = ref.current;
|
||||||
|
|
||||||
|
observer.observe(observedRef);
|
||||||
|
|
||||||
|
// Cleanup the observer on component unmount
|
||||||
|
return () => {
|
||||||
|
if (observedRef) {
|
||||||
|
observer.unobserve(observedRef);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return [ref, dimensions];
|
||||||
|
};
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
|
|||||||
|
|
||||||
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
||||||
|
|
||||||
export default function useSocket(channel?: string) {
|
export default function useSocket(channel?: string | null) {
|
||||||
const socketRef = useRef<Socket>();
|
const socketRef = useRef<Socket>();
|
||||||
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user