Compare commits
28 Commits
fix-pretti
...
model-prov
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
847753c32b | ||
|
|
332a2101c0 | ||
|
|
ded6678e97 | ||
|
|
2c8c8d07cf | ||
|
|
e885bdd365 | ||
|
|
86dc36a656 | ||
|
|
55c077d604 | ||
|
|
e598e454d0 | ||
|
|
6e3f90cd2f | ||
|
|
eec894e101 | ||
|
|
f797fc3fa4 | ||
|
|
335dc0357f | ||
|
|
e6e2c706c2 | ||
|
|
7d2166b305 | ||
|
|
60765e51ac | ||
|
|
2c4ba6eb9b | ||
|
|
4c97b9f147 | ||
|
|
58892d8b63 | ||
|
|
4fa2dffbcb | ||
|
|
654f8c7cf2 | ||
|
|
d02482468d | ||
|
|
5c6ed22f1d | ||
|
|
2cb623f332 | ||
|
|
1c1cefe286 | ||
|
|
b4aa95edca | ||
|
|
1dcdba04a6 | ||
|
|
e0e64c4207 | ||
|
|
fa5b1ab1c5 |
@@ -18,3 +18,11 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
|
|||||||
OPENAI_API_KEY=""
|
OPENAI_API_KEY=""
|
||||||
|
|
||||||
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||||
|
|
||||||
|
# Next Auth
|
||||||
|
NEXTAUTH_SECRET="your_secret"
|
||||||
|
NEXTAUTH_URL="http://localhost:3000"
|
||||||
|
|
||||||
|
# Next Auth Github Provider
|
||||||
|
GITHUB_CLIENT_ID="your_client_id"
|
||||||
|
GITHUB_CLIENT_SECRET="your_secret"
|
||||||
|
|||||||
1
@types/nextjs-routes.d.ts
vendored
1
@types/nextjs-routes.d.ts
vendored
@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
|
|||||||
} from "next";
|
} from "next";
|
||||||
|
|
||||||
export type Route =
|
export type Route =
|
||||||
|
| StaticRoute<"/account/signin">
|
||||||
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
|
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
|
||||||
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
|
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
|
||||||
| DynamicRoute<"/experiments/[id]", { "id": string }>
|
| DynamicRoute<"/experiments/[id]", { "id": string }>
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ FROM base as builder
|
|||||||
|
|
||||||
# Include all NEXT_PUBLIC_* env vars here
|
# Include all NEXT_PUBLIC_* env vars here
|
||||||
ARG NEXT_PUBLIC_POSTHOG_KEY
|
ARG NEXT_PUBLIC_POSTHOG_KEY
|
||||||
ARG NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND
|
|
||||||
ARG NEXT_PUBLIC_SOCKET_URL
|
ARG NEXT_PUBLIC_SOCKET_URL
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -4,11 +4,18 @@
|
|||||||
|
|
||||||
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
|
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
|
||||||
|
|
||||||
**Live Demo:** https://openpipe.ai
|
## Sample Experiments
|
||||||
|
|
||||||
|
These are simple experiments users have created that show how OpenPipe works.
|
||||||
|
|
||||||
|
- [Country Capitals](https://app.openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
|
||||||
|
- [Reddit User Needs](https://app.openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
|
||||||
|
- [OpenAI Function Calls](https://app.openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
|
||||||
|
- [Activity Classification](https://app.openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
|
||||||
|
|
||||||
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
|
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
|
||||||
|
|
||||||
Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally).
|
You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
|
||||||
|
|
||||||
## High-Level Features
|
## High-Level Features
|
||||||
|
|
||||||
@@ -47,5 +54,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
|
|||||||
5. Install the dependencies: `cd openpipe && pnpm install`
|
5. Install the dependencies: `cd openpipe && pnpm install`
|
||||||
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
|
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
|
||||||
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
|
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
|
||||||
8. Start the app: `pnpm dev`.
|
8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
|
||||||
9. Navigate to [http://localhost:3000](http://localhost:3000)
|
9. Start the app: `pnpm dev`.
|
||||||
|
10. Navigate to [http://localhost:3000](http://localhost:3000)
|
||||||
|
|||||||
17
package.json
17
package.json
@@ -17,9 +17,11 @@
|
|||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"codegen": "tsx src/codegen/export-openai-types.ts",
|
"codegen": "tsx src/codegen/export-openai-types.ts",
|
||||||
"seed": "tsx prisma/seed.ts"
|
"seed": "tsx prisma/seed.ts",
|
||||||
|
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@apidevtools/json-schema-ref-parser": "^10.1.0",
|
||||||
"@babel/preset-typescript": "^7.22.5",
|
"@babel/preset-typescript": "^7.22.5",
|
||||||
"@babel/standalone": "^7.22.9",
|
"@babel/standalone": "^7.22.9",
|
||||||
"@chakra-ui/next-js": "^2.1.4",
|
"@chakra-ui/next-js": "^2.1.4",
|
||||||
@@ -27,6 +29,7 @@
|
|||||||
"@emotion/react": "^11.11.1",
|
"@emotion/react": "^11.11.1",
|
||||||
"@emotion/server": "^11.11.0",
|
"@emotion/server": "^11.11.0",
|
||||||
"@emotion/styled": "^11.11.0",
|
"@emotion/styled": "^11.11.0",
|
||||||
|
"@fontsource/inconsolata": "^5.0.5",
|
||||||
"@monaco-editor/loader": "^1.3.3",
|
"@monaco-editor/loader": "^1.3.3",
|
||||||
"@next-auth/prisma-adapter": "^1.0.5",
|
"@next-auth/prisma-adapter": "^1.0.5",
|
||||||
"@prisma/client": "^4.14.0",
|
"@prisma/client": "^4.14.0",
|
||||||
@@ -37,6 +40,7 @@
|
|||||||
"@trpc/next": "^10.26.0",
|
"@trpc/next": "^10.26.0",
|
||||||
"@trpc/react-query": "^10.26.0",
|
"@trpc/react-query": "^10.26.0",
|
||||||
"@trpc/server": "^10.26.0",
|
"@trpc/server": "^10.26.0",
|
||||||
|
"ast-types": "^0.14.2",
|
||||||
"chroma-js": "^2.4.2",
|
"chroma-js": "^2.4.2",
|
||||||
"concurrently": "^8.2.0",
|
"concurrently": "^8.2.0",
|
||||||
"cors": "^2.8.5",
|
"cors": "^2.8.5",
|
||||||
@@ -49,7 +53,9 @@
|
|||||||
"graphile-worker": "^0.13.0",
|
"graphile-worker": "^0.13.0",
|
||||||
"immer": "^10.0.2",
|
"immer": "^10.0.2",
|
||||||
"isolated-vm": "^4.5.0",
|
"isolated-vm": "^4.5.0",
|
||||||
|
"json-schema-to-typescript": "^13.0.2",
|
||||||
"json-stringify-pretty-compact": "^4.0.0",
|
"json-stringify-pretty-compact": "^4.0.0",
|
||||||
|
"jsonschema": "^1.4.1",
|
||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
@@ -58,15 +64,22 @@
|
|||||||
"pluralize": "^8.0.0",
|
"pluralize": "^8.0.0",
|
||||||
"posthog-js": "^1.68.4",
|
"posthog-js": "^1.68.4",
|
||||||
"prettier": "^3.0.0",
|
"prettier": "^3.0.0",
|
||||||
|
"prismjs": "^1.29.0",
|
||||||
"react": "18.2.0",
|
"react": "18.2.0",
|
||||||
|
"react-diff-viewer": "^3.1.1",
|
||||||
"react-dom": "18.2.0",
|
"react-dom": "18.2.0",
|
||||||
"react-icons": "^4.10.1",
|
"react-icons": "^4.10.1",
|
||||||
|
"react-select": "^5.7.4",
|
||||||
"react-syntax-highlighter": "^15.5.0",
|
"react-syntax-highlighter": "^15.5.0",
|
||||||
"react-textarea-autosize": "^8.5.0",
|
"react-textarea-autosize": "^8.5.0",
|
||||||
|
"recast": "^0.23.3",
|
||||||
|
"replicate": "^0.12.3",
|
||||||
"socket.io": "^4.7.1",
|
"socket.io": "^4.7.1",
|
||||||
"socket.io-client": "^4.7.1",
|
"socket.io-client": "^4.7.1",
|
||||||
"superjson": "1.12.2",
|
"superjson": "1.12.2",
|
||||||
"tsx": "^3.12.7",
|
"tsx": "^3.12.7",
|
||||||
|
"type-fest": "^4.0.0",
|
||||||
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
"zod": "^3.21.4",
|
"zod": "^3.21.4",
|
||||||
"zustand": "^4.3.9"
|
"zustand": "^4.3.9"
|
||||||
},
|
},
|
||||||
@@ -78,9 +91,11 @@
|
|||||||
"@types/cors": "^2.8.13",
|
"@types/cors": "^2.8.13",
|
||||||
"@types/eslint": "^8.37.0",
|
"@types/eslint": "^8.37.0",
|
||||||
"@types/express": "^4.17.17",
|
"@types/express": "^4.17.17",
|
||||||
|
"@types/json-schema": "^7.0.12",
|
||||||
"@types/lodash-es": "^4.17.8",
|
"@types/lodash-es": "^4.17.8",
|
||||||
"@types/node": "^18.16.0",
|
"@types/node": "^18.16.0",
|
||||||
"@types/pluralize": "^0.0.30",
|
"@types/pluralize": "^0.0.30",
|
||||||
|
"@types/prismjs": "^1.26.0",
|
||||||
"@types/react": "^18.2.6",
|
"@types/react": "^18.2.6",
|
||||||
"@types/react-dom": "^18.2.4",
|
"@types/react-dom": "^18.2.4",
|
||||||
"@types/react-syntax-highlighter": "^15.5.7",
|
"@types/react-syntax-highlighter": "^15.5.7",
|
||||||
|
|||||||
909
pnpm-lock.yaml
generated
909
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,124 @@
|
|||||||
|
DROP TABLE "Account";
|
||||||
|
DROP TABLE "Session";
|
||||||
|
DROP TABLE "User";
|
||||||
|
DROP TABLE "VerificationToken";
|
||||||
|
|
||||||
|
CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Organization" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"personalOrgUserId" UUID,
|
||||||
|
|
||||||
|
CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "OrganizationUser" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"role" "OrganizationUserRole" NOT NULL,
|
||||||
|
"organizationId" UUID NOT NULL,
|
||||||
|
"userId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Account" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"userId" UUID NOT NULL,
|
||||||
|
"type" TEXT NOT NULL,
|
||||||
|
"provider" TEXT NOT NULL,
|
||||||
|
"providerAccountId" TEXT NOT NULL,
|
||||||
|
"refresh_token" TEXT,
|
||||||
|
"refresh_token_expires_in" INTEGER,
|
||||||
|
"access_token" TEXT,
|
||||||
|
"expires_at" INTEGER,
|
||||||
|
"token_type" TEXT,
|
||||||
|
"scope" TEXT,
|
||||||
|
"id_token" TEXT,
|
||||||
|
"session_state" TEXT,
|
||||||
|
|
||||||
|
CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Session" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"sessionToken" TEXT NOT NULL,
|
||||||
|
"userId" UUID NOT NULL,
|
||||||
|
"expires" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "User" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"name" TEXT,
|
||||||
|
"email" TEXT,
|
||||||
|
"emailVerified" TIMESTAMP(3),
|
||||||
|
"image" TEXT,
|
||||||
|
|
||||||
|
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "VerificationToken" (
|
||||||
|
"identifier" TEXT NOT NULL,
|
||||||
|
"token" TEXT NOT NULL,
|
||||||
|
"expires" TIMESTAMP(3) NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
|
||||||
|
|
||||||
|
-- AlterTable add organizationId as NULLABLE
|
||||||
|
ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
|
||||||
|
|
||||||
|
-- Set default organization for existing experiments
|
||||||
|
UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
|
||||||
|
|
||||||
|
-- AlterTable set organizationId as NOT NULL
|
||||||
|
ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
|
||||||
|
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
|
||||||
|
|
||||||
|
ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `completionTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `inputHash` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `output` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `promptTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `timeToComplete` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "completionTokens",
|
||||||
|
DROP COLUMN "inputHash",
|
||||||
|
DROP COLUMN "output",
|
||||||
|
DROP COLUMN "promptTokens",
|
||||||
|
DROP COLUMN "timeToComplete";
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
-- Add new columns allowing NULL values
|
||||||
|
ALTER TABLE "PromptVariant"
|
||||||
|
ADD COLUMN "constructFnVersion" INTEGER,
|
||||||
|
ADD COLUMN "modelProvider" TEXT;
|
||||||
|
|
||||||
|
-- Update existing records to have the default values
|
||||||
|
UPDATE "PromptVariant"
|
||||||
|
SET "constructFnVersion" = 1,
|
||||||
|
"modelProvider" = 'openai/ChatCompletion'
|
||||||
|
WHERE "constructFnVersion" IS NULL OR "modelProvider" IS NULL;
|
||||||
|
|
||||||
|
-- Alter table to set NOT NULL constraint
|
||||||
|
ALTER TABLE "PromptVariant"
|
||||||
|
ALTER COLUMN "constructFnVersion" SET NOT NULL,
|
||||||
|
ALTER COLUMN "modelProvider" SET NOT NULL;
|
||||||
|
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "prompt" JSONB;
|
||||||
@@ -16,8 +16,12 @@ model Experiment {
|
|||||||
|
|
||||||
sortIndex Int @default(0)
|
sortIndex Int @default(0)
|
||||||
|
|
||||||
|
organizationId String @db.Uuid
|
||||||
|
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
TemplateVariable TemplateVariable[]
|
TemplateVariable TemplateVariable[]
|
||||||
PromptVariant PromptVariant[]
|
PromptVariant PromptVariant[]
|
||||||
TestScenario TestScenario[]
|
TestScenario TestScenario[]
|
||||||
@@ -29,7 +33,9 @@ model PromptVariant {
|
|||||||
|
|
||||||
label String
|
label String
|
||||||
constructFn String
|
constructFn String
|
||||||
|
constructFnVersion Int
|
||||||
model String
|
model String
|
||||||
|
modelProvider String
|
||||||
|
|
||||||
uiId String @default(uuid()) @db.Uuid
|
uiId String @default(uuid()) @db.Uuid
|
||||||
visible Boolean @default(true)
|
visible Boolean @default(true)
|
||||||
@@ -84,21 +90,17 @@ enum CellRetrievalStatus {
|
|||||||
model ScenarioVariantCell {
|
model ScenarioVariantCell {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
inputHash String? // TODO: Remove once migration is complete
|
|
||||||
output Json? // TODO: Remove once migration is complete
|
|
||||||
statusCode Int?
|
statusCode Int?
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
timeToComplete Int? @default(0) // TODO: Remove once migration is complete
|
|
||||||
retryTime DateTime?
|
retryTime DateTime?
|
||||||
streamingChannel String?
|
streamingChannel String?
|
||||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||||
|
|
||||||
promptTokens Int? // TODO: Remove once migration is complete
|
|
||||||
completionTokens Int? // TODO: Remove once migration is complete
|
|
||||||
modelOutput ModelOutput?
|
modelOutput ModelOutput?
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||||
|
prompt Json?
|
||||||
|
|
||||||
testScenarioId String @db.Uuid
|
testScenarioId String @db.Uuid
|
||||||
testScenario TestScenario @relation(fields: [testScenarioId], references: [id], onDelete: Cascade)
|
testScenario TestScenario @relation(fields: [testScenarioId], references: [id], onDelete: Cascade)
|
||||||
@@ -115,6 +117,7 @@ model ModelOutput {
|
|||||||
inputHash String
|
inputHash String
|
||||||
output Json
|
output Json
|
||||||
timeToComplete Int @default(0)
|
timeToComplete Int @default(0)
|
||||||
|
cost Float?
|
||||||
promptTokens Int?
|
promptTokens Int?
|
||||||
completionTokens Int?
|
completionTokens Int?
|
||||||
|
|
||||||
@@ -169,19 +172,53 @@ model OutputEvaluation {
|
|||||||
@@unique([modelOutputId, evaluationId])
|
@@unique([modelOutputId, evaluationId])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Necessary for Next auth
|
model Organization {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
personalOrgUserId String? @unique @db.Uuid
|
||||||
|
PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
OrganizationUser OrganizationUser[]
|
||||||
|
Experiment Experiment[]
|
||||||
|
}
|
||||||
|
|
||||||
|
enum OrganizationUserRole {
|
||||||
|
ADMIN
|
||||||
|
MEMBER
|
||||||
|
VIEWER
|
||||||
|
}
|
||||||
|
|
||||||
|
model OrganizationUser {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
role OrganizationUserRole
|
||||||
|
|
||||||
|
organizationId String @db.Uuid
|
||||||
|
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
userId String @db.Uuid
|
||||||
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
@@unique([organizationId, userId])
|
||||||
|
}
|
||||||
|
|
||||||
model Account {
|
model Account {
|
||||||
id String @id @default(cuid())
|
id String @id @default(uuid()) @db.Uuid
|
||||||
userId String
|
userId String @db.Uuid
|
||||||
type String
|
type String
|
||||||
provider String
|
provider String
|
||||||
providerAccountId String
|
providerAccountId String
|
||||||
refresh_token String? // @db.Text
|
refresh_token String? @db.Text
|
||||||
access_token String? // @db.Text
|
refresh_token_expires_in Int?
|
||||||
|
access_token String? @db.Text
|
||||||
expires_at Int?
|
expires_at Int?
|
||||||
token_type String?
|
token_type String?
|
||||||
scope String?
|
scope String?
|
||||||
id_token String? // @db.Text
|
id_token String? @db.Text
|
||||||
session_state String?
|
session_state String?
|
||||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
@@ -189,21 +226,23 @@ model Account {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model Session {
|
model Session {
|
||||||
id String @id @default(cuid())
|
id String @id @default(uuid()) @db.Uuid
|
||||||
sessionToken String @unique
|
sessionToken String @unique
|
||||||
userId String
|
userId String @db.Uuid
|
||||||
expires DateTime
|
expires DateTime
|
||||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
}
|
}
|
||||||
|
|
||||||
model User {
|
model User {
|
||||||
id String @id @default(cuid())
|
id String @id @default(uuid()) @db.Uuid
|
||||||
name String?
|
name String?
|
||||||
email String? @unique
|
email String? @unique
|
||||||
emailVerified DateTime?
|
emailVerified DateTime?
|
||||||
image String?
|
image String?
|
||||||
accounts Account[]
|
accounts Account[]
|
||||||
sessions Session[]
|
sessions Session[]
|
||||||
|
OrganizationUser OrganizationUser[]
|
||||||
|
Organization Organization[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model VerificationToken {
|
model VerificationToken {
|
||||||
|
|||||||
@@ -2,45 +2,54 @@ import { prisma } from "~/server/db";
|
|||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
const experimentId = "11111111-1111-1111-1111-111111111111";
|
const defaultId = "11111111-1111-1111-1111-111111111111";
|
||||||
|
|
||||||
|
await prisma.organization.deleteMany({
|
||||||
|
where: { id: defaultId },
|
||||||
|
});
|
||||||
|
await prisma.organization.create({
|
||||||
|
data: { id: defaultId },
|
||||||
|
});
|
||||||
|
|
||||||
// Delete the existing experiment
|
|
||||||
await prisma.experiment.deleteMany({
|
await prisma.experiment.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
id: experimentId,
|
id: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.experiment.create({
|
await prisma.experiment.create({
|
||||||
data: {
|
data: {
|
||||||
id: experimentId,
|
id: defaultId,
|
||||||
label: "Country Capitals Example",
|
label: "Country Capitals Example",
|
||||||
|
organizationId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.deleteMany({
|
await prisma.scenarioVariantCell.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
promptVariant: {
|
promptVariant: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.promptVariant.deleteMany({
|
await prisma.promptVariant.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.promptVariant.createMany({
|
await prisma.promptVariant.createMany({
|
||||||
data: [
|
data: [
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
modelProvider: "openai/ChatCompletion",
|
||||||
|
constructFnVersion: 1,
|
||||||
constructFn: dedent`
|
constructFn: dedent`
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -49,15 +58,17 @@ await prisma.promptVariant.createMany({
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
}`,
|
})`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
label: "Prompt Variant 2",
|
label: "Prompt Variant 2",
|
||||||
sortIndex: 1,
|
sortIndex: 1,
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
modelProvider: "openai/ChatCompletion",
|
||||||
|
constructFnVersion: 1,
|
||||||
constructFn: dedent`
|
constructFn: dedent`
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -66,21 +77,21 @@ await prisma.promptVariant.createMany({
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
}`,
|
})`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.templateVariable.deleteMany({
|
await prisma.templateVariable.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.templateVariable.createMany({
|
await prisma.templateVariable.createMany({
|
||||||
data: [
|
data: [
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
label: "country",
|
label: "country",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -88,28 +99,28 @@ await prisma.templateVariable.createMany({
|
|||||||
|
|
||||||
await prisma.testScenario.deleteMany({
|
await prisma.testScenario.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.testScenario.createMany({
|
await prisma.testScenario.createMany({
|
||||||
data: [
|
data: [
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
variableValues: {
|
variableValues: {
|
||||||
country: "Spain",
|
country: "Spain",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
sortIndex: 1,
|
sortIndex: 1,
|
||||||
variableValues: {
|
variableValues: {
|
||||||
country: "USA",
|
country: "USA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
sortIndex: 2,
|
sortIndex: 2,
|
||||||
variableValues: {
|
variableValues: {
|
||||||
country: "Chile",
|
country: "Chile",
|
||||||
@@ -120,13 +131,13 @@ await prisma.testScenario.createMany({
|
|||||||
|
|
||||||
const variants = await prisma.promptVariant.findMany({
|
const variants = await prisma.promptVariant.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const scenarios = await prisma.testScenario.findMany({
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
1126
prisma/seedDemo.ts
1126
prisma/seedDemo.ts
File diff suppressed because one or more lines are too long
@@ -12,7 +12,7 @@ services:
|
|||||||
dockerContext: .
|
dockerContext: .
|
||||||
plan: standard
|
plan: standard
|
||||||
domains:
|
domains:
|
||||||
- openpipe.ai
|
- app.openpipe.ai
|
||||||
envVars:
|
envVars:
|
||||||
- key: NODE_ENV
|
- key: NODE_ENV
|
||||||
value: production
|
value: production
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-var-requires */
|
|
||||||
|
|
||||||
import YAML from "yaml";
|
|
||||||
import fs from "fs";
|
|
||||||
import path from "path";
|
|
||||||
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
|
|
||||||
import assert from "assert";
|
|
||||||
import { type AcceptibleInputSchema } from "@openapi-contrib/openapi-schema-to-json-schema/dist/mjs/openapi-schema-types";
|
|
||||||
|
|
||||||
const OPENAPI_URL =
|
|
||||||
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
|
|
||||||
|
|
||||||
const convertOpenApiToJsonSchema = async (url: string) => {
|
|
||||||
// Fetch the openapi document
|
|
||||||
const response = await fetch(url);
|
|
||||||
const openApiYaml = await response.text();
|
|
||||||
|
|
||||||
// Parse the yaml document
|
|
||||||
const openApiDocument = YAML.parse(openApiYaml) as AcceptibleInputSchema;
|
|
||||||
|
|
||||||
// Convert the openapi schema to json schema
|
|
||||||
const jsonSchema = openapiSchemaToJsonSchema(openApiDocument);
|
|
||||||
|
|
||||||
const modelProperty = jsonSchema.components.schemas.CreateChatCompletionRequest.properties.model;
|
|
||||||
|
|
||||||
assert(modelProperty.oneOf.length === 2, "Expected model to have oneOf length of 2");
|
|
||||||
|
|
||||||
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
|
||||||
// the fact that the schema says `model` can be either a string or an enum,
|
|
||||||
// and displays a warning in the editor. Let's stick with just an enum for
|
|
||||||
// now and drop the string option.
|
|
||||||
modelProperty.type = "string";
|
|
||||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
|
||||||
modelProperty.oneOf = undefined;
|
|
||||||
|
|
||||||
// Get the directory of the current script
|
|
||||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
|
||||||
|
|
||||||
// Write the JSON schema to a file in the current directory
|
|
||||||
fs.writeFileSync(
|
|
||||||
path.join(currentDirectory, "openai.schema.json"),
|
|
||||||
JSON.stringify(jsonSchema, null, 2),
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
convertOpenApiToJsonSchema(OPENAPI_URL)
|
|
||||||
.then(() => console.log("JSON schema has been written successfully."))
|
|
||||||
.catch((err) => console.error(err));
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
import fs from "fs";
|
|
||||||
import path from "path";
|
|
||||||
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
|
||||||
import YAML from "yaml";
|
|
||||||
import { pick } from "lodash-es";
|
|
||||||
import assert from "assert";
|
|
||||||
|
|
||||||
const OPENAPI_URL =
|
|
||||||
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
|
|
||||||
|
|
||||||
// Generate TypeScript types from OpenAPI
|
|
||||||
|
|
||||||
const schema = await fetch(OPENAPI_URL)
|
|
||||||
.then((res) => res.text())
|
|
||||||
.then((txt) => YAML.parse(txt) as OpenAPI3);
|
|
||||||
|
|
||||||
console.log(schema.components?.schemas?.CreateChatCompletionRequest);
|
|
||||||
|
|
||||||
// @ts-expect-error just assume this works, the assert will catch it if it doesn't
|
|
||||||
const modelProperty = schema.components?.schemas?.CreateChatCompletionRequest?.properties?.model;
|
|
||||||
|
|
||||||
assert(modelProperty.oneOf.length === 2, "Expected model to have oneOf length of 2");
|
|
||||||
|
|
||||||
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
|
||||||
// the fact that the schema says `model` can be either a string or an enum,
|
|
||||||
// and displays a warning in the editor. Let's stick with just an enum for
|
|
||||||
// now and drop the string option.
|
|
||||||
modelProperty.type = "string";
|
|
||||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
|
||||||
modelProperty.oneOf = undefined;
|
|
||||||
|
|
||||||
delete schema["paths"];
|
|
||||||
assert(schema.components?.schemas);
|
|
||||||
schema.components.schemas = pick(schema.components?.schemas, [
|
|
||||||
"CreateChatCompletionRequest",
|
|
||||||
"ChatCompletionRequestMessage",
|
|
||||||
"ChatCompletionFunctions",
|
|
||||||
"ChatCompletionFunctionParameters",
|
|
||||||
]);
|
|
||||||
console.log(schema);
|
|
||||||
|
|
||||||
let openApiTypes = await openapiTS(schema);
|
|
||||||
|
|
||||||
// Remove the `export` from any line that starts with `export`
|
|
||||||
openApiTypes = openApiTypes.replaceAll("\nexport ", "\n");
|
|
||||||
|
|
||||||
// Get the directory of the current script
|
|
||||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
|
||||||
|
|
||||||
// Write the TypeScript types. We only want to use this in our in-app editor, so
|
|
||||||
// save as a .txt so VS Code doesn't try to auto-import definitions from it.
|
|
||||||
fs.writeFileSync(path.join(currentDirectory, "openai.types.ts.txt"), openApiTypes);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
|||||||
/**
|
|
||||||
* This file was auto-generated by openapi-typescript.
|
|
||||||
* Do not make direct changes to the file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
/** OneOf type helpers */
|
|
||||||
type Without<T, U> = { [P in Exclude<keyof T, keyof U>]?: never };
|
|
||||||
type XOR<T, U> = (T | U) extends object ? (Without<T, U> & U) | (Without<U, T> & T) : T | U;
|
|
||||||
type OneOf<T extends any[]> = T extends [infer Only] ? Only : T extends [infer A, infer B, ...infer Rest] ? OneOf<[XOR<A, B>, ...Rest]> : never;
|
|
||||||
|
|
||||||
type paths = Record<string, never>;
|
|
||||||
|
|
||||||
type webhooks = Record<string, never>;
|
|
||||||
|
|
||||||
interface components {
|
|
||||||
schemas: {
|
|
||||||
CreateChatCompletionRequest: {
|
|
||||||
/**
|
|
||||||
* @description ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
|
|
||||||
* @example gpt-3.5-turbo
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
model: "gpt-4" | "gpt-4-0613" | "gpt-4-32k" | "gpt-4-32k-0613" | "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-3.5-turbo-0613" | "gpt-3.5-turbo-16k-0613";
|
|
||||||
/** @description A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb). */
|
|
||||||
messages: (components["schemas"]["ChatCompletionRequestMessage"])[];
|
|
||||||
/** @description A list of functions the model may generate JSON inputs for. */
|
|
||||||
functions?: (components["schemas"]["ChatCompletionFunctions"])[];
|
|
||||||
/** @description Controls how the model responds to function calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a function. Specifying a particular function via `{"name":\ "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present. */
|
|
||||||
function_call?: OneOf<["none" | "auto", {
|
|
||||||
/** @description The name of the function to call. */
|
|
||||||
name: string;
|
|
||||||
}]>;
|
|
||||||
/**
|
|
||||||
* @description What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
|
|
||||||
*
|
|
||||||
* We generally recommend altering this or `top_p` but not both.
|
|
||||||
*
|
|
||||||
* @default 1
|
|
||||||
* @example 1
|
|
||||||
*/
|
|
||||||
temperature?: number | null;
|
|
||||||
/**
|
|
||||||
* @description An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
|
||||||
*
|
|
||||||
* We generally recommend altering this or `temperature` but not both.
|
|
||||||
*
|
|
||||||
* @default 1
|
|
||||||
* @example 1
|
|
||||||
*/
|
|
||||||
top_p?: number | null;
|
|
||||||
/**
|
|
||||||
* @description How many chat completion choices to generate for each input message.
|
|
||||||
* @default 1
|
|
||||||
* @example 1
|
|
||||||
*/
|
|
||||||
n?: number | null;
|
|
||||||
/**
|
|
||||||
* @description If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).
|
|
||||||
*
|
|
||||||
* @default false
|
|
||||||
*/
|
|
||||||
stream?: boolean | null;
|
|
||||||
/**
|
|
||||||
* @description Up to 4 sequences where the API will stop generating further tokens.
|
|
||||||
*
|
|
||||||
* @default null
|
|
||||||
*/
|
|
||||||
stop?: (string | null) | (string)[];
|
|
||||||
/**
|
|
||||||
* @description The maximum number of [tokens](/tokenizer) to generate in the chat completion.
|
|
||||||
*
|
|
||||||
* The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.
|
|
||||||
*
|
|
||||||
* @default inf
|
|
||||||
*/
|
|
||||||
max_tokens?: number;
|
|
||||||
/**
|
|
||||||
* @description Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
|
||||||
*
|
|
||||||
* [See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)
|
|
||||||
*
|
|
||||||
* @default 0
|
|
||||||
*/
|
|
||||||
presence_penalty?: number | null;
|
|
||||||
/**
|
|
||||||
* @description Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
|
||||||
*
|
|
||||||
* [See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)
|
|
||||||
*
|
|
||||||
* @default 0
|
|
||||||
*/
|
|
||||||
frequency_penalty?: number | null;
|
|
||||||
/**
|
|
||||||
* @description Modify the likelihood of specified tokens appearing in the completion.
|
|
||||||
*
|
|
||||||
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
|
|
||||||
*
|
|
||||||
* @default null
|
|
||||||
*/
|
|
||||||
logit_bias?: Record<string, unknown> | null;
|
|
||||||
/**
|
|
||||||
* @description A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
|
|
||||||
*
|
|
||||||
* @example user-1234
|
|
||||||
*/
|
|
||||||
user?: string;
|
|
||||||
};
|
|
||||||
ChatCompletionRequestMessage: {
|
|
||||||
/**
|
|
||||||
* @description The role of the messages author. One of `system`, `user`, `assistant`, or `function`.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
role: "system" | "user" | "assistant" | "function";
|
|
||||||
/** @description The contents of the message. `content` is required for all messages except assistant messages with function calls. */
|
|
||||||
content?: string;
|
|
||||||
/** @description The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. */
|
|
||||||
name?: string;
|
|
||||||
/** @description The name and arguments of a function that should be called, as generated by the model. */
|
|
||||||
function_call?: {
|
|
||||||
/** @description The name of the function to call. */
|
|
||||||
name?: string;
|
|
||||||
/** @description The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. */
|
|
||||||
arguments?: string;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
ChatCompletionFunctions: {
|
|
||||||
/** @description The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. */
|
|
||||||
name: string;
|
|
||||||
/** @description The description of what the function does. */
|
|
||||||
description?: string;
|
|
||||||
parameters?: components["schemas"]["ChatCompletionFunctionParameters"];
|
|
||||||
};
|
|
||||||
/** @description The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. */
|
|
||||||
ChatCompletionFunctionParameters: {
|
|
||||||
[key: string]: unknown;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
responses: never;
|
|
||||||
parameters: never;
|
|
||||||
requestBodies: never;
|
|
||||||
headers: never;
|
|
||||||
pathItems: never;
|
|
||||||
}
|
|
||||||
|
|
||||||
type external = Record<string, never>;
|
|
||||||
|
|
||||||
type operations = Record<string, never>;
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
{
|
|
||||||
"compilerOptions": {
|
|
||||||
"target": "esnext",
|
|
||||||
"moduleResolution": "nodenext"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
49
src/components/OutputsTable/FloatingLabelInput.tsx
Normal file
49
src/components/OutputsTable/FloatingLabelInput.tsx
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import { FormLabel, FormControl, type TextareaProps } from "@chakra-ui/react";
|
||||||
|
import { useState } from "react";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
|
export const FloatingLabelInput = ({
|
||||||
|
label,
|
||||||
|
value,
|
||||||
|
...props
|
||||||
|
}: { label: string; value: string } & TextareaProps) => {
|
||||||
|
const [isFocused, setIsFocused] = useState(false);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl position="relative">
|
||||||
|
<FormLabel
|
||||||
|
position="absolute"
|
||||||
|
left="10px"
|
||||||
|
top={isFocused || !!value ? 0 : 3}
|
||||||
|
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
||||||
|
fontSize={isFocused || !!value ? "12px" : "16px"}
|
||||||
|
transition="all 0.15s"
|
||||||
|
zIndex="100"
|
||||||
|
bg="white"
|
||||||
|
px={1}
|
||||||
|
mt={0}
|
||||||
|
mb={2}
|
||||||
|
lineHeight="1"
|
||||||
|
pointerEvents="none"
|
||||||
|
color={isFocused ? "blue.500" : "gray.500"}
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</FormLabel>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
px={3}
|
||||||
|
pt={3}
|
||||||
|
pb={2}
|
||||||
|
onFocus={() => setIsFocused(true)}
|
||||||
|
onBlur={() => setIsFocused(false)}
|
||||||
|
borderRadius="md"
|
||||||
|
borderColor={isFocused ? "blue.500" : "gray.400"}
|
||||||
|
autoComplete="off"
|
||||||
|
value={value}
|
||||||
|
maxHeight={32}
|
||||||
|
overflowY="auto"
|
||||||
|
overflowX="hidden"
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
||||||
import { BsPlus } from "react-icons/bs";
|
import { BsPlus } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
// Extracted Button styling into reusable component
|
// Extracted Button styling into reusable component
|
||||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
||||||
@@ -17,6 +17,8 @@ const StyledButton = ({ children, onClick }: ButtonProps) => (
|
|||||||
);
|
);
|
||||||
|
|
||||||
export default function NewScenarioButton() {
|
export default function NewScenarioButton() {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
const mutation = api.scenarios.create.useMutation();
|
const mutation = api.scenarios.create.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
@@ -38,6 +40,8 @@ export default function NewScenarioButton() {
|
|||||||
await utils.scenarios.list.invalidate();
|
await utils.scenarios.list.invalidate();
|
||||||
}, [mutation]);
|
}, [mutation]);
|
||||||
|
|
||||||
|
if (!canModify) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack spacing={2}>
|
<HStack spacing={2}>
|
||||||
<StyledButton onClick={onClick}>
|
<StyledButton onClick={onClick}>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Button, Icon, Spinner } from "@chakra-ui/react";
|
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||||
import { BsPlus } from "react-icons/bs";
|
import { BsPlus } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
|
|
||||||
export default function NewVariantButton() {
|
export default function NewVariantButton() {
|
||||||
@@ -17,6 +17,9 @@ export default function NewVariantButton() {
|
|||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [mutation]);
|
}, [mutation]);
|
||||||
|
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
if (!canModify) return <Box w={cellPadding.x} />;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Button
|
<Button
|
||||||
w="100%"
|
w="100%"
|
||||||
@@ -31,7 +34,7 @@ export default function NewVariantButton() {
|
|||||||
minH={headerMinHeight}
|
minH={headerMinHeight}
|
||||||
>
|
>
|
||||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||||
Add Variant
|
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||||
</Button>
|
</Button>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import { Button, HStack, Icon } from "@chakra-ui/react";
|
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
|
||||||
import { BsArrowClockwise } from "react-icons/bs";
|
import { BsArrowClockwise } from "react-icons/bs";
|
||||||
|
import { useExperimentAccess } from "~/utils/hooks";
|
||||||
|
|
||||||
export const CellOptions = ({
|
export const CellOptions = ({
|
||||||
refetchingOutput,
|
refetchingOutput,
|
||||||
@@ -8,9 +9,11 @@ export const CellOptions = ({
|
|||||||
refetchingOutput: boolean;
|
refetchingOutput: boolean;
|
||||||
refetchOutput: () => void;
|
refetchOutput: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
return (
|
return (
|
||||||
<HStack justifyContent="flex-end" w="full">
|
<HStack justifyContent="flex-end" w="full">
|
||||||
{!refetchingOutput && (
|
{!refetchingOutput && canModify && (
|
||||||
|
<Tooltip label="Refetch output" aria-label="refetch output">
|
||||||
<Button
|
<Button
|
||||||
size="xs"
|
size="xs"
|
||||||
w={4}
|
w={4}
|
||||||
@@ -27,6 +30,7 @@ export const CellOptions = ({
|
|||||||
>
|
>
|
||||||
<Icon as={BsArrowClockwise} boxSize={4} />
|
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||||
</Button>
|
</Button>
|
||||||
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ import SyntaxHighlighter from "react-syntax-highlighter";
|
|||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { type ReactElement, useState, useEffect } from "react";
|
import { type ReactElement, useState, useEffect } from "react";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
import { OutputStats } from "./OutputStats";
|
import { OutputStats } from "./OutputStats";
|
||||||
import { ErrorHandler } from "./ErrorHandler";
|
import { ErrorHandler } from "./ErrorHandler";
|
||||||
import { CellOptions } from "./CellOptions";
|
import { CellOptions } from "./CellOptions";
|
||||||
|
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -33,18 +33,19 @@ export default function OutputCell({
|
|||||||
|
|
||||||
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
|
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
|
||||||
|
|
||||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
|
||||||
// disabledReason = "Save your prompt variant to see output";
|
|
||||||
|
|
||||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||||
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||||
{ scenarioId: scenario.id, variantId: variant.id },
|
{ scenarioId: scenario.id, variantId: variant.id },
|
||||||
{ refetchInterval },
|
{ refetchInterval },
|
||||||
);
|
);
|
||||||
|
|
||||||
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
|
const provider =
|
||||||
api.scenarioVariantCells.forceRefetch.useMutation();
|
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
|
||||||
const [hardRefetch] = useHandledAsyncCallback(async () => {
|
|
||||||
|
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
||||||
|
|
||||||
|
const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation();
|
||||||
|
const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => {
|
||||||
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||||
await utils.scenarioVariantCells.get.invalidate({
|
await utils.scenarioVariantCells.get.invalidate({
|
||||||
scenarioId: scenario.id,
|
scenarioId: scenario.id,
|
||||||
@@ -55,20 +56,19 @@ export default function OutputCell({
|
|||||||
});
|
});
|
||||||
}, [hardRefetchMutate, scenario.id, variant.id]);
|
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||||
|
|
||||||
const fetchingOutput = queryLoading || refetchingOutput;
|
const fetchingOutput = queryLoading || hardRefetching;
|
||||||
|
|
||||||
const awaitingOutput =
|
const awaitingOutput =
|
||||||
!cell ||
|
!cell ||
|
||||||
cell.retrievalStatus === "PENDING" ||
|
cell.retrievalStatus === "PENDING" ||
|
||||||
cell.retrievalStatus === "IN_PROGRESS" ||
|
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||||
refetchingOutput;
|
hardRefetching;
|
||||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||||
|
|
||||||
const modelOutput = cell?.modelOutput;
|
const modelOutput = cell?.modelOutput;
|
||||||
|
|
||||||
// Disconnect from socket if we're not streaming anymore
|
// Disconnect from socket if we're not streaming anymore
|
||||||
const streamedMessage = useSocket(cell?.streamingChannel);
|
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
|
||||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|
||||||
@@ -87,30 +87,26 @@ export default function OutputCell({
|
|||||||
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = modelOutput?.output as unknown as ChatCompletion;
|
const normalizedOutput = modelOutput
|
||||||
const message = response?.choices?.[0]?.message;
|
? // @ts-expect-error TODO FIX ASAP
|
||||||
|
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
||||||
if (modelOutput && message?.function_call) {
|
: streamedMessage
|
||||||
const rawArgs = message.function_call.arguments ?? "null";
|
? // @ts-expect-error TODO FIX ASAP
|
||||||
let parsedArgs: string;
|
provider.normalizeOutput(streamedMessage)
|
||||||
try {
|
: null;
|
||||||
parsedArgs = JSON.parse(rawArgs);
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
} catch (e: any) {
|
|
||||||
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if (modelOutput && normalizedOutput?.type === "json") {
|
||||||
return (
|
return (
|
||||||
<VStack
|
<VStack
|
||||||
w="100%"
|
w="100%"
|
||||||
h="100%"
|
h="100%"
|
||||||
fontSize="xs"
|
fontSize="xs"
|
||||||
flexWrap="wrap"
|
flexWrap="wrap"
|
||||||
overflowX="auto"
|
overflowX="hidden"
|
||||||
justifyContent="space-between"
|
justifyContent="space-between"
|
||||||
>
|
>
|
||||||
<VStack w="full" flex={1} spacing={0}>
|
<VStack w="full" flex={1} spacing={0}>
|
||||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||||
<SyntaxHighlighter
|
<SyntaxHighlighter
|
||||||
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||||
language="json"
|
language="json"
|
||||||
@@ -120,32 +116,23 @@ export default function OutputCell({
|
|||||||
}}
|
}}
|
||||||
wrapLines
|
wrapLines
|
||||||
>
|
>
|
||||||
{stringify(
|
{stringify(normalizedOutput.value, { maxLength: 40 })}
|
||||||
{
|
|
||||||
function: message.function_call.name,
|
|
||||||
args: parsedArgs,
|
|
||||||
},
|
|
||||||
{ maxLength: 40 },
|
|
||||||
)}
|
|
||||||
</SyntaxHighlighter>
|
</SyntaxHighlighter>
|
||||||
</VStack>
|
</VStack>
|
||||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
<OutputStats modelOutput={modelOutput} scenario={scenario} />
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const contentToDisplay =
|
const contentToDisplay = (normalizedOutput?.type === "text" && normalizedOutput.value) || "";
|
||||||
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
<VStack w="full" alignItems="flex-start" spacing={0}>
|
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||||
<Text>{contentToDisplay}</Text>
|
<Text>{contentToDisplay}</Text>
|
||||||
</VStack>
|
</VStack>
|
||||||
{modelOutput && (
|
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
|
||||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
|
||||||
)}
|
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,14 @@
|
|||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import { type Scenario } from "../types";
|
import { type Scenario } from "../types";
|
||||||
import { type RouterOutputs } from "~/utils/api";
|
import { type RouterOutputs } from "~/utils/api";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
|
||||||
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
||||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||||
|
|
||||||
const SHOW_COST = true;
|
|
||||||
const SHOW_TIME = true;
|
const SHOW_TIME = true;
|
||||||
|
|
||||||
export const OutputStats = ({
|
export const OutputStats = ({
|
||||||
model,
|
|
||||||
modelOutput,
|
modelOutput,
|
||||||
}: {
|
}: {
|
||||||
model: SupportedModel | string | null;
|
|
||||||
modelOutput: NonNullable<
|
modelOutput: NonNullable<
|
||||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||||
>;
|
>;
|
||||||
@@ -24,12 +19,6 @@ export const OutputStats = ({
|
|||||||
const promptTokens = modelOutput.promptTokens;
|
const promptTokens = modelOutput.promptTokens;
|
||||||
const completionTokens = modelOutput.completionTokens;
|
const completionTokens = modelOutput.completionTokens;
|
||||||
|
|
||||||
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
|
||||||
const completionCost =
|
|
||||||
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
|
||||||
|
|
||||||
const cost = promptCost + completionCost;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
@@ -53,11 +42,15 @@ export const OutputStats = ({
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{SHOW_COST && (
|
{modelOutput.cost && (
|
||||||
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}>
|
<CostTooltip
|
||||||
|
promptTokens={promptTokens}
|
||||||
|
completionTokens={completionTokens}
|
||||||
|
cost={modelOutput.cost}
|
||||||
|
>
|
||||||
<HStack spacing={0}>
|
<HStack spacing={0}>
|
||||||
<Icon as={BsCurrencyDollar} />
|
<Icon as={BsCurrencyDollar} />
|
||||||
<Text mr={1}>{cost.toFixed(3)}</Text>
|
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CostTooltip>
|
</CostTooltip>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ import { type DragEvent } from "react";
|
|||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { isEqual } from "lodash-es";
|
import { isEqual } from "lodash-es";
|
||||||
import { type Scenario } from "./types";
|
import { type Scenario } from "./types";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
|
||||||
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
|
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { BsX } from "react-icons/bs";
|
import { BsX } from "react-icons/bs";
|
||||||
import { RiDraggable } from "react-icons/ri";
|
import { RiDraggable } from "react-icons/ri";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import { FloatingLabelInput } from "./FloatingLabelInput";
|
||||||
|
|
||||||
export default function ScenarioEditor({
|
export default function ScenarioEditor({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -19,6 +19,8 @@ export default function ScenarioEditor({
|
|||||||
hovered: boolean;
|
hovered: boolean;
|
||||||
canHide: boolean;
|
canHide: boolean;
|
||||||
}) {
|
}) {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
|
||||||
const savedValues = scenario.variableValues as Record<string, string>;
|
const savedValues = scenario.variableValues as Record<string, string>;
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
@@ -72,8 +74,9 @@ export default function ScenarioEditor({
|
|||||||
return (
|
return (
|
||||||
<HStack
|
<HStack
|
||||||
alignItems="flex-start"
|
alignItems="flex-start"
|
||||||
pr={cellPadding.x}
|
px={cellPadding.x}
|
||||||
py={cellPadding.y}
|
py={cellPadding.y}
|
||||||
|
spacing={0}
|
||||||
height="100%"
|
height="100%"
|
||||||
draggable={!variableInputHovered}
|
draggable={!variableInputHovered}
|
||||||
onDragStart={(e) => {
|
onDragStart={(e) => {
|
||||||
@@ -93,9 +96,13 @@ export default function ScenarioEditor({
|
|||||||
onDrop={onReorder}
|
onDrop={onReorder}
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||||
>
|
>
|
||||||
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
|
{canModify && props.canHide && (
|
||||||
{props.canHide && (
|
<Stack
|
||||||
<>
|
alignSelf="flex-start"
|
||||||
|
opacity={props.hovered ? 1 : 0}
|
||||||
|
spacing={0}
|
||||||
|
ml={-cellPadding.x}
|
||||||
|
>
|
||||||
<Tooltip label="Hide scenario" hasArrow>
|
<Tooltip label="Hide scenario" hasArrow>
|
||||||
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
||||||
<Button
|
<Button
|
||||||
@@ -110,7 +117,7 @@ export default function ScenarioEditor({
|
|||||||
cursor: "pointer",
|
cursor: "pointer",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={hidingInProgress ? 4 : 6} />
|
||||||
</Button>
|
</Button>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<Icon
|
<Icon
|
||||||
@@ -119,13 +126,13 @@ export default function ScenarioEditor({
|
|||||||
color="gray.400"
|
color="gray.400"
|
||||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||||
/>
|
/>
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</Stack>
|
</Stack>
|
||||||
|
)}
|
||||||
|
|
||||||
{variableLabels.length === 0 ? (
|
{variableLabels.length === 0 ? (
|
||||||
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
||||||
) : (
|
) : (
|
||||||
<VStack spacing={1}>
|
<VStack spacing={4} flex={1} py={2}>
|
||||||
{variableLabels.map((key) => {
|
{variableLabels.map((key) => {
|
||||||
const value = values[key] ?? "";
|
const value = values[key] ?? "";
|
||||||
const layoutDirection = value.length > 20 ? "column" : "row";
|
const layoutDirection = value.length > 20 ? "column" : "row";
|
||||||
@@ -137,29 +144,14 @@ export default function ScenarioEditor({
|
|||||||
flexWrap="wrap"
|
flexWrap="wrap"
|
||||||
width="full"
|
width="full"
|
||||||
>
|
>
|
||||||
<Box
|
<FloatingLabelInput
|
||||||
bgColor="blue.100"
|
label={key}
|
||||||
color="blue.600"
|
isDisabled={!canModify}
|
||||||
px={1}
|
style={{ width: "100%" }}
|
||||||
my="3px"
|
|
||||||
fontSize="xs"
|
|
||||||
fontWeight="bold"
|
|
||||||
>
|
|
||||||
{key}
|
|
||||||
</Box>
|
|
||||||
<AutoResizeTextArea
|
|
||||||
px={2}
|
|
||||||
py={1}
|
|
||||||
placeholder="empty"
|
|
||||||
borderRadius="sm"
|
|
||||||
fontSize="sm"
|
|
||||||
lineHeight={1.2}
|
|
||||||
value={value}
|
value={value}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
setValues((prev) => ({ ...prev, [key]: e.target.value }));
|
setValues((prev) => ({ ...prev, [key]: e.target.value }));
|
||||||
}}
|
}}
|
||||||
maxH="32"
|
|
||||||
overflowY="auto"
|
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
|
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
@@ -167,12 +159,6 @@ export default function ScenarioEditor({
|
|||||||
onSave();
|
onSave();
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
resize="none"
|
|
||||||
overflow="hidden"
|
|
||||||
flex={layoutDirection === "row" ? 1 : undefined}
|
|
||||||
borderColor={hasChanged ? "blue.300" : "transparent"}
|
|
||||||
_hover={{ borderColor: "gray.300" }}
|
|
||||||
_focus={{ borderColor: "blue.500", outline: "none", bg: "white" }}
|
|
||||||
onMouseEnter={() => setVariableInputHovered(true)}
|
onMouseEnter={() => setVariableInputHovered(true)}
|
||||||
onMouseLeave={() => setVariableInputHovered(false)}
|
onMouseLeave={() => setVariableInputHovered(false)}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { useElementDimensions } from "~/utils/hooks";
|
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
|
||||||
import { stickyHeaderStyle } from "./styles";
|
import { stickyHeaderStyle } from "./styles";
|
||||||
import { BsPencil } from "react-icons/bs";
|
import { BsPencil } from "react-icons/bs";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
@@ -13,6 +13,7 @@ export const ScenariosHeader = ({
|
|||||||
numScenarios: number;
|
numScenarios: number;
|
||||||
}) => {
|
}) => {
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
|
||||||
const [ref, dimensions] = useElementDimensions();
|
const [ref, dimensions] = useElementDimensions();
|
||||||
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
||||||
@@ -33,6 +34,7 @@ export const ScenariosHeader = ({
|
|||||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||||
Scenarios ({numScenarios})
|
Scenarios ({numScenarios})
|
||||||
</Heading>
|
</Heading>
|
||||||
|
{canModify && (
|
||||||
<Button
|
<Button
|
||||||
size="xs"
|
size="xs"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
@@ -43,6 +45,7 @@ export const ScenariosHeader = ({
|
|||||||
>
|
>
|
||||||
Edit Vars
|
Edit Vars
|
||||||
</Button>
|
</Button>
|
||||||
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
</GridItem>
|
</GridItem>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useRef, useEffect, useState, useCallback } from "react";
|
||||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
|
||||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
@@ -21,7 +22,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||||
}, [lastSavedFn]);
|
}, [lastSavedFn]);
|
||||||
|
|
||||||
useEffect(checkForChanges, [checkForChanges, lastSavedFn]);
|
const matchUpdatedSavedFn = useCallback(() => {
|
||||||
|
if (!editorRef.current) return;
|
||||||
|
editorRef.current.setValue(lastSavedFn);
|
||||||
|
setIsChanged(false);
|
||||||
|
}, [lastSavedFn]);
|
||||||
|
|
||||||
|
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
|
||||||
|
|
||||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
@@ -40,26 +47,12 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
const model = editorRef.current.getModel();
|
const model = editorRef.current.getModel();
|
||||||
if (!model) return;
|
if (!model) return;
|
||||||
|
|
||||||
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
|
|
||||||
const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
|
|
||||||
|
|
||||||
if (hasErrors) {
|
|
||||||
toast({
|
|
||||||
title: "Invalid TypeScript",
|
|
||||||
description: "Please fix the TypeScript errors before saving.",
|
|
||||||
status: "error",
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
|
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
|
||||||
const promptRegex = /prompt\s*=/;
|
const promptRegex = /definePrompt\(/;
|
||||||
if (!promptRegex.test(currentFn)) {
|
if (!promptRegex.test(currentFn)) {
|
||||||
console.log("no prompt");
|
|
||||||
console.log(currentFn);
|
|
||||||
toast({
|
toast({
|
||||||
title: "Missing prompt",
|
title: "Missing prompt",
|
||||||
description: "Please define the prompt (eg. `prompt = { ...`).",
|
description: "Please define the prompt (eg. `definePrompt(...`",
|
||||||
status: "error",
|
status: "error",
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
@@ -103,6 +96,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
wordWrapBreakAfterCharacters: "",
|
wordWrapBreakAfterCharacters: "",
|
||||||
wordWrapBreakBeforeCharacters: "",
|
wordWrapBreakBeforeCharacters: "",
|
||||||
quickSuggestions: true,
|
quickSuggestions: true,
|
||||||
|
readOnly: !canModify,
|
||||||
});
|
});
|
||||||
|
|
||||||
editorRef.current.onDidFocusEditorText(() => {
|
editorRef.current.onDidFocusEditorText(() => {
|
||||||
@@ -130,6 +124,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||||
}, [monaco, editorId]);
|
}, [monaco, editorId]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!editorRef.current) return;
|
||||||
|
editorRef.current.updateOptions({
|
||||||
|
readOnly: !canModify,
|
||||||
|
});
|
||||||
|
}, [canModify]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="100%" pos="relative">
|
<Box w="100%" pos="relative">
|
||||||
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
import { useState, type DragEvent } from "react";
|
|
||||||
import { type PromptVariant } from "./types";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
|
|
||||||
import { BsX } from "react-icons/bs";
|
|
||||||
import { RiDraggable } from "react-icons/ri";
|
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
|
||||||
|
|
||||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
|
||||||
const utils = api.useContext();
|
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
|
||||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
|
||||||
const [label, setLabel] = useState(props.variant.label);
|
|
||||||
|
|
||||||
const updateMutation = api.promptVariants.update.useMutation();
|
|
||||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
|
||||||
if (label && label !== props.variant.label) {
|
|
||||||
await updateMutation.mutateAsync({
|
|
||||||
id: props.variant.id,
|
|
||||||
updates: { label: label },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
|
||||||
|
|
||||||
const hideMutation = api.promptVariants.hide.useMutation();
|
|
||||||
const [onHide] = useHandledAsyncCallback(async () => {
|
|
||||||
await hideMutation.mutateAsync({
|
|
||||||
id: props.variant.id,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
}, [hideMutation, props.variant.id]);
|
|
||||||
|
|
||||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
|
||||||
const [onReorder] = useHandledAsyncCallback(
|
|
||||||
async (e: DragEvent<HTMLDivElement>) => {
|
|
||||||
e.preventDefault();
|
|
||||||
setIsDragTarget(false);
|
|
||||||
const draggedId = e.dataTransfer.getData("text/plain");
|
|
||||||
const droppedId = props.variant.id;
|
|
||||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
|
||||||
await reorderMutation.mutateAsync({
|
|
||||||
draggedId,
|
|
||||||
droppedId,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
},
|
|
||||||
[reorderMutation, props.variant.id],
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<HStack
|
|
||||||
spacing={4}
|
|
||||||
alignItems="center"
|
|
||||||
minH={headerMinHeight}
|
|
||||||
draggable={!isInputHovered}
|
|
||||||
onDragStart={(e) => {
|
|
||||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
|
||||||
e.currentTarget.style.opacity = "0.4";
|
|
||||||
}}
|
|
||||||
onDragEnd={(e) => {
|
|
||||||
e.currentTarget.style.opacity = "1";
|
|
||||||
}}
|
|
||||||
onDragOver={(e) => {
|
|
||||||
e.preventDefault();
|
|
||||||
setIsDragTarget(true);
|
|
||||||
}}
|
|
||||||
onDragLeave={() => {
|
|
||||||
setIsDragTarget(false);
|
|
||||||
}}
|
|
||||||
onDrop={onReorder}
|
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
|
||||||
>
|
|
||||||
<Icon
|
|
||||||
as={RiDraggable}
|
|
||||||
boxSize={6}
|
|
||||||
color="gray.400"
|
|
||||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
|
||||||
/>
|
|
||||||
<AutoResizeTextArea // Changed to Input
|
|
||||||
size="sm"
|
|
||||||
value={label}
|
|
||||||
onChange={(e) => setLabel(e.target.value)}
|
|
||||||
onBlur={onSaveLabel}
|
|
||||||
placeholder="Variant Name"
|
|
||||||
borderWidth={1}
|
|
||||||
borderColor="transparent"
|
|
||||||
fontWeight="bold"
|
|
||||||
fontSize={16}
|
|
||||||
_hover={{ borderColor: "gray.300" }}
|
|
||||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
|
||||||
flex={1}
|
|
||||||
px={cellPadding.x}
|
|
||||||
onMouseEnter={() => setIsInputHovered(true)}
|
|
||||||
onMouseLeave={() => setIsInputHovered(false)}
|
|
||||||
/>
|
|
||||||
{props.canHide && (
|
|
||||||
<Tooltip label="Remove Variant" hasArrow>
|
|
||||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
|
||||||
<Icon as={BsX} boxSize={6} />
|
|
||||||
</Button>
|
|
||||||
</Tooltip>
|
|
||||||
)}
|
|
||||||
</HStack>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
|
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
@@ -69,7 +69,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{data.overallCost && !data.awaitingRetrievals ? (
|
{data.overallCost && !data.awaitingRetrievals && (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={data.promptTokens}
|
promptTokens={data.promptTokens}
|
||||||
completionTokens={data.completionTokens}
|
completionTokens={data.completionTokens}
|
||||||
@@ -80,8 +80,6 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CostTooltip>
|
</CostTooltip>
|
||||||
) : (
|
|
||||||
<Skeleton height={4} width={12} mr={1} />
|
|
||||||
)}
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import NewScenarioButton from "./NewScenarioButton";
|
|||||||
import NewVariantButton from "./NewVariantButton";
|
import NewVariantButton from "./NewVariantButton";
|
||||||
import ScenarioRow from "./ScenarioRow";
|
import ScenarioRow from "./ScenarioRow";
|
||||||
import VariantEditor from "./VariantEditor";
|
import VariantEditor from "./VariantEditor";
|
||||||
import VariantHeader from "./VariantHeader";
|
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||||
import VariantStats from "./VariantStats";
|
import VariantStats from "./VariantStats";
|
||||||
import { ScenariosHeader } from "./ScenariosHeader";
|
import { ScenariosHeader } from "./ScenariosHeader";
|
||||||
import { stickyHeaderStyle } from "./styles";
|
import { stickyHeaderStyle } from "./styles";
|
||||||
@@ -43,9 +43,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
|
||||||
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
|
|
||||||
</GridItem>
|
|
||||||
))}
|
))}
|
||||||
<GridItem
|
<GridItem
|
||||||
rowSpan={scenarios.data.length + headerRows}
|
rowSpan={scenarios.data.length + headerRows}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { type SystemStyleObject } from "@chakra-ui/react";
|
|||||||
|
|
||||||
export const stickyHeaderStyle: SystemStyleObject = {
|
export const stickyHeaderStyle: SystemStyleObject = {
|
||||||
position: "sticky",
|
position: "sticky",
|
||||||
top: "-1px",
|
top: "0",
|
||||||
backgroundColor: "#fff",
|
backgroundColor: "#fff",
|
||||||
zIndex: 1,
|
zIndex: 1,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
import { Flex, Icon, Link, Text } from "@chakra-ui/react";
|
|
||||||
import { BsExclamationTriangleFill } from "react-icons/bs";
|
|
||||||
import { env } from "~/env.mjs";
|
|
||||||
|
|
||||||
export default function PublicPlaygroundWarning() {
|
|
||||||
if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex bgColor="red.600" color="whiteAlpha.900" p={2} align="center">
|
|
||||||
<Icon boxSize={4} mr={2} as={BsExclamationTriangleFill} />
|
|
||||||
<Text>
|
|
||||||
Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
|
|
||||||
private use,{" "}
|
|
||||||
<Link textDecor="underline" href="https://github.com/openpipe/openpipe" target="_blank">
|
|
||||||
run a local copy
|
|
||||||
</Link>
|
|
||||||
.
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
56
src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
56
src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||||
|
import React from "react";
|
||||||
|
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||||
|
import Prism from "prismjs";
|
||||||
|
import "prismjs/components/prism-javascript";
|
||||||
|
import "prismjs/themes/prism.css"; // choose a theme you like
|
||||||
|
|
||||||
|
const highlightSyntax = (str: string) => {
|
||||||
|
let highlighted;
|
||||||
|
try {
|
||||||
|
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error highlighting:", e);
|
||||||
|
highlighted = str;
|
||||||
|
}
|
||||||
|
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
const CompareFunctions = ({
|
||||||
|
originalFunction,
|
||||||
|
newFunction = "",
|
||||||
|
}: {
|
||||||
|
originalFunction: string;
|
||||||
|
newFunction?: string;
|
||||||
|
}) => {
|
||||||
|
const showSplitView = useBreakpointValue(
|
||||||
|
{
|
||||||
|
base: false,
|
||||||
|
md: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
fallback: "base",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack w="full" spacing={5}>
|
||||||
|
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
|
||||||
|
<DiffViewer
|
||||||
|
oldValue={originalFunction}
|
||||||
|
newValue={newFunction || originalFunction}
|
||||||
|
splitView={showSplitView}
|
||||||
|
hideLineNumbers={!showSplitView}
|
||||||
|
leftTitle="Original"
|
||||||
|
rightTitle={newFunction ? "Modified" : "Unmodified"}
|
||||||
|
disableWordDiff={true}
|
||||||
|
compareMethod={DiffMethod.CHARS}
|
||||||
|
renderContent={highlightSyntax}
|
||||||
|
showDiffOnly={false}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default CompareFunctions;
|
||||||
75
src/components/RefinePromptModal/CustomInstructionsInput.tsx
Normal file
75
src/components/RefinePromptModal/CustomInstructionsInput.tsx
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import { Button, Spinner, InputGroup, InputRightElement, Icon, HStack } from "@chakra-ui/react";
|
||||||
|
import { IoMdSend } from "react-icons/io";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
|
export const CustomInstructionsInput = ({
|
||||||
|
instructions,
|
||||||
|
setInstructions,
|
||||||
|
loading,
|
||||||
|
onSubmit,
|
||||||
|
}: {
|
||||||
|
instructions: string;
|
||||||
|
setInstructions: (instructions: string) => void;
|
||||||
|
loading: boolean;
|
||||||
|
onSubmit: () => void;
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<InputGroup
|
||||||
|
size="md"
|
||||||
|
w="full"
|
||||||
|
maxW="600"
|
||||||
|
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
|
||||||
|
borderRadius={8}
|
||||||
|
alignItems="center"
|
||||||
|
colorScheme="orange"
|
||||||
|
>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
value={instructions}
|
||||||
|
onChange={(e) => setInstructions(e.target.value)}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
e.currentTarget.blur();
|
||||||
|
onSubmit();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
placeholder="Send custom instructions"
|
||||||
|
py={4}
|
||||||
|
pl={4}
|
||||||
|
pr={12}
|
||||||
|
colorScheme="orange"
|
||||||
|
borderColor="gray.300"
|
||||||
|
borderWidth={1}
|
||||||
|
_hover={{
|
||||||
|
borderColor: "gray.300",
|
||||||
|
}}
|
||||||
|
_focus={{
|
||||||
|
borderColor: "gray.300",
|
||||||
|
}}
|
||||||
|
isDisabled={loading}
|
||||||
|
/>
|
||||||
|
<HStack></HStack>
|
||||||
|
<InputRightElement width="8" height="full">
|
||||||
|
<Button
|
||||||
|
h="8"
|
||||||
|
w="8"
|
||||||
|
minW="unset"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => onSubmit()}
|
||||||
|
disabled={!instructions}
|
||||||
|
variant={instructions ? "solid" : "ghost"}
|
||||||
|
mr={4}
|
||||||
|
borderRadius="8"
|
||||||
|
bgColor={instructions ? "orange.400" : "transparent"}
|
||||||
|
colorScheme="orange"
|
||||||
|
>
|
||||||
|
{loading ? (
|
||||||
|
<Spinner boxSize={4} />
|
||||||
|
) : (
|
||||||
|
<Icon as={IoMdSend} color={instructions ? "white" : "gray.500"} boxSize={5} />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</InputRightElement>
|
||||||
|
</InputGroup>
|
||||||
|
);
|
||||||
|
};
|
||||||
64
src/components/RefinePromptModal/RefineOption.tsx
Normal file
64
src/components/RefinePromptModal/RefineOption.tsx
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||||
|
import { type IconType } from "react-icons";
|
||||||
|
import { refineOptions, type RefineOptionLabel } from "./refineOptions";
|
||||||
|
|
||||||
|
export const RefineOption = ({
|
||||||
|
label,
|
||||||
|
activeLabel,
|
||||||
|
icon,
|
||||||
|
onClick,
|
||||||
|
loading,
|
||||||
|
}: {
|
||||||
|
label: RefineOptionLabel;
|
||||||
|
activeLabel: RefineOptionLabel | undefined;
|
||||||
|
icon: IconType;
|
||||||
|
onClick: (label: RefineOptionLabel) => void;
|
||||||
|
loading: boolean;
|
||||||
|
}) => {
|
||||||
|
const isActive = activeLabel === label;
|
||||||
|
const desciption = refineOptions[label].description;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<GridItem w="80" h="44">
|
||||||
|
<VStack
|
||||||
|
w="full"
|
||||||
|
h="full"
|
||||||
|
onClick={() => {
|
||||||
|
!loading && onClick(label);
|
||||||
|
}}
|
||||||
|
borderColor={isActive ? "blue.500" : "gray.200"}
|
||||||
|
borderWidth={2}
|
||||||
|
borderRadius={16}
|
||||||
|
padding={6}
|
||||||
|
backgroundColor="gray.50"
|
||||||
|
_hover={
|
||||||
|
loading
|
||||||
|
? undefined
|
||||||
|
: {
|
||||||
|
backgroundColor: "gray.100",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
spacing={8}
|
||||||
|
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
|
||||||
|
cursor="pointer"
|
||||||
|
opacity={loading ? 0.5 : 1}
|
||||||
|
>
|
||||||
|
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
||||||
|
<Icon as={icon} boxSize={12} />
|
||||||
|
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||||
|
{label}
|
||||||
|
</Heading>
|
||||||
|
</HStack>
|
||||||
|
<Text
|
||||||
|
fontSize="sm"
|
||||||
|
color="gray.500"
|
||||||
|
flexWrap="wrap"
|
||||||
|
wordBreak="break-word"
|
||||||
|
overflowWrap="break-word"
|
||||||
|
>
|
||||||
|
{desciption}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
|
};
|
||||||
141
src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
141
src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Spinner,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
SimpleGrid,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { BsStars } from "react-icons/bs";
|
||||||
|
import { VscJson } from "react-icons/vsc";
|
||||||
|
import { TfiThought } from "react-icons/tfi";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { useState } from "react";
|
||||||
|
import CompareFunctions from "./CompareFunctions";
|
||||||
|
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||||
|
import { type RefineOptionLabel, refineOptions } from "./refineOptions";
|
||||||
|
import { RefineOption } from "./RefineOption";
|
||||||
|
import { isObject, isString } from "lodash-es";
|
||||||
|
|
||||||
|
export const RefinePromptModal = ({
|
||||||
|
variant,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
onClose: () => void;
|
||||||
|
}) => {
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
|
||||||
|
api.promptVariants.getRefinedPromptFn.useMutation();
|
||||||
|
const [instructions, setInstructions] = useState<string>("");
|
||||||
|
|
||||||
|
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
||||||
|
RefineOptionLabel | undefined
|
||||||
|
>(undefined);
|
||||||
|
|
||||||
|
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(
|
||||||
|
async (label?: RefineOptionLabel) => {
|
||||||
|
if (!variant.experimentId) return;
|
||||||
|
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
||||||
|
setActiveRefineOptionLabel(label);
|
||||||
|
await getRefinedPromptMutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
instructions: updatedInstructions,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
||||||
|
);
|
||||||
|
|
||||||
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
|
|
||||||
|
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (
|
||||||
|
!variant.experimentId ||
|
||||||
|
!refinedPromptFn ||
|
||||||
|
(isObject(refinedPromptFn) && "status" in refinedPromptFn)
|
||||||
|
)
|
||||||
|
return;
|
||||||
|
await replaceVariantMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
constructFn: refinedPromptFn,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
onClose();
|
||||||
|
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w={1200}>
|
||||||
|
<ModalHeader>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={BsStars} />
|
||||||
|
<Text>Refine with GPT-4</Text>
|
||||||
|
</HStack>
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody maxW="unset">
|
||||||
|
<VStack spacing={8}>
|
||||||
|
<VStack spacing={4}>
|
||||||
|
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||||
|
<RefineOption
|
||||||
|
label="Convert to function call"
|
||||||
|
activeLabel={activeRefineOptionLabel}
|
||||||
|
icon={VscJson}
|
||||||
|
onClick={getRefinedPromptFn}
|
||||||
|
loading={refiningInProgress}
|
||||||
|
/>
|
||||||
|
<RefineOption
|
||||||
|
label="Add chain of thought"
|
||||||
|
activeLabel={activeRefineOptionLabel}
|
||||||
|
icon={TfiThought}
|
||||||
|
onClick={getRefinedPromptFn}
|
||||||
|
loading={refiningInProgress}
|
||||||
|
/>
|
||||||
|
</SimpleGrid>
|
||||||
|
<HStack>
|
||||||
|
<Text color="gray.500">or</Text>
|
||||||
|
</HStack>
|
||||||
|
<CustomInstructionsInput
|
||||||
|
instructions={instructions}
|
||||||
|
setInstructions={setInstructions}
|
||||||
|
loading={refiningInProgress}
|
||||||
|
onSubmit={getRefinedPromptFn}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
<CompareFunctions
|
||||||
|
originalFunction={variant.constructFn}
|
||||||
|
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
</ModalBody>
|
||||||
|
|
||||||
|
<ModalFooter>
|
||||||
|
<HStack spacing={4}>
|
||||||
|
<Button
|
||||||
|
onClick={replaceVariant}
|
||||||
|
minW={24}
|
||||||
|
disabled={replacementInProgress || !refinedPromptFn}
|
||||||
|
_disabled={{
|
||||||
|
bgColor: "blue.500",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
237
src/components/RefinePromptModal/refineOptions.ts
Normal file
237
src/components/RefinePromptModal/refineOptions.ts
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
// Super hacky, but we'll redo the organization when we have more models
|
||||||
|
|
||||||
|
export type RefineOptionLabel = "Add chain of thought" | "Convert to function call";
|
||||||
|
|
||||||
|
export const refineOptions: Record<
|
||||||
|
RefineOptionLabel,
|
||||||
|
{ description: string; instructions: string }
|
||||||
|
> = {
|
||||||
|
"Add chain of thought": {
|
||||||
|
description: "Asking the model to plan its answer can increase accuracy.",
|
||||||
|
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding chain of thought:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
This is what one looks like after adding chain of thought:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
Here's another example:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
explanation: {
|
||||||
|
type: "string",
|
||||||
|
}
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Add chain of thought to the original prompt.`,
|
||||||
|
},
|
||||||
|
"Convert to function call": {
|
||||||
|
description: "Use function calls to get output from the model in a more structured way.",
|
||||||
|
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding a function:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
This is what one looks like after adding a function:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "Evaluate sentiment.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: scenario.user_message,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "extract_sentiment",
|
||||||
|
parameters: {
|
||||||
|
type: "object", // parameters must always be an object with a properties key
|
||||||
|
properties: { // properties key is required
|
||||||
|
sentiment: {
|
||||||
|
type: "string",
|
||||||
|
description: "one of positive/negative/neutral",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "extract_sentiment",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Here's another example of adding a function:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Here is the title and body of a reddit post I am interested in:
|
||||||
|
|
||||||
|
title: \${scenario.title}
|
||||||
|
body: \${scenario.body}
|
||||||
|
|
||||||
|
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Answer one integer between 1 and 3.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||||
|
},
|
||||||
|
};
|
||||||
89
src/components/SelectModelModal/ModelStatsCard.tsx
Normal file
89
src/components/SelectModelModal/ModelStatsCard.tsx
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import {
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
HStack,
|
||||||
|
type StackProps,
|
||||||
|
GridItem,
|
||||||
|
SimpleGrid,
|
||||||
|
Link,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { modelStats } from "~/modelProviders/modelStats";
|
||||||
|
import { type SupportedModel } from "~/server/types";
|
||||||
|
|
||||||
|
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
|
||||||
|
const stats = modelStats[model];
|
||||||
|
return (
|
||||||
|
<VStack w="full" align="start">
|
||||||
|
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
|
||||||
|
<HStack w="full" align="flex-start">
|
||||||
|
<Text flex={1} fontSize="lg">
|
||||||
|
<Text as="span" color="gray.600">
|
||||||
|
{stats.provider} /{" "}
|
||||||
|
</Text>
|
||||||
|
<Text as="span" fontWeight="bold" color="gray.900">
|
||||||
|
{model}
|
||||||
|
</Text>
|
||||||
|
</Text>
|
||||||
|
<Link
|
||||||
|
href={stats.learnMoreUrl}
|
||||||
|
isExternal
|
||||||
|
color="blue.500"
|
||||||
|
fontWeight="bold"
|
||||||
|
fontSize="sm"
|
||||||
|
ml={2}
|
||||||
|
>
|
||||||
|
Learn More
|
||||||
|
</Link>
|
||||||
|
</HStack>
|
||||||
|
<SimpleGrid
|
||||||
|
w="full"
|
||||||
|
justifyContent="space-between"
|
||||||
|
alignItems="flex-start"
|
||||||
|
fontSize="sm"
|
||||||
|
columns={{ base: 2, md: 4 }}
|
||||||
|
>
|
||||||
|
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
|
||||||
|
<SelectedModelLabeledInfo
|
||||||
|
label="Input"
|
||||||
|
info={
|
||||||
|
<Text>
|
||||||
|
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||||
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<SelectedModelLabeledInfo
|
||||||
|
label="Output"
|
||||||
|
info={
|
||||||
|
<Text>
|
||||||
|
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||||
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
|
||||||
|
</SimpleGrid>
|
||||||
|
</VStack>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const SelectedModelLabeledInfo = ({
|
||||||
|
label,
|
||||||
|
info,
|
||||||
|
...props
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
info: string | number | React.ReactElement;
|
||||||
|
} & StackProps) => (
|
||||||
|
<GridItem>
|
||||||
|
<VStack alignItems="flex-start" {...props}>
|
||||||
|
<Text fontWeight="bold">{label}</Text>
|
||||||
|
<Text>{info}</Text>
|
||||||
|
</VStack>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
85
src/components/SelectModelModal/SelectModelModal.tsx
Normal file
85
src/components/SelectModelModal/SelectModelModal.tsx
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Spinner,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { type SupportedModel } from "~/server/types";
|
||||||
|
import { ModelStatsCard } from "./ModelStatsCard";
|
||||||
|
import { SelectModelSearch } from "./SelectModelSearch";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
|
export const SelectModelModal = ({
|
||||||
|
originalModel,
|
||||||
|
variantId,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
originalModel: SupportedModel;
|
||||||
|
variantId: string;
|
||||||
|
onClose: () => void;
|
||||||
|
}) => {
|
||||||
|
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const experiment = useExperiment();
|
||||||
|
|
||||||
|
const createMutation = api.promptVariants.create.useMutation();
|
||||||
|
|
||||||
|
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment?.data?.id) return;
|
||||||
|
await createMutation.mutateAsync({
|
||||||
|
experimentId: experiment?.data?.id,
|
||||||
|
variantId,
|
||||||
|
newModel: selectedModel,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
onClose();
|
||||||
|
}, [createMutation, experiment?.data?.id, variantId, onClose]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w={1200}>
|
||||||
|
<ModalHeader>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={RiExchangeFundsFill} />
|
||||||
|
<Text>Change Model</Text>
|
||||||
|
</HStack>
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody maxW="unset">
|
||||||
|
<VStack spacing={8}>
|
||||||
|
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||||
|
{originalModel !== selectedModel && (
|
||||||
|
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||||
|
)}
|
||||||
|
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
|
</VStack>
|
||||||
|
</ModalBody>
|
||||||
|
|
||||||
|
<ModalFooter>
|
||||||
|
<Button
|
||||||
|
colorScheme="blue"
|
||||||
|
onClick={createNewVariant}
|
||||||
|
minW={24}
|
||||||
|
disabled={originalModel === selectedModel}
|
||||||
|
>
|
||||||
|
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
|
||||||
|
</Button>
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import { VStack, Text } from "@chakra-ui/react";
|
||||||
|
import { type LegacyRef, useCallback } from "react";
|
||||||
|
import Select, { type SingleValue } from "react-select";
|
||||||
|
import { type SupportedModel } from "~/server/types";
|
||||||
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
|
||||||
|
const modelOptions: { value: SupportedModel; label: string }[] = [
|
||||||
|
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
|
||||||
|
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
|
||||||
|
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
|
||||||
|
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
|
||||||
|
{ value: "gpt-4", label: "gpt-4" },
|
||||||
|
{ value: "gpt-4-0613", label: "gpt-4-0613" },
|
||||||
|
{ value: "gpt-4-32k", label: "gpt-4-32k" },
|
||||||
|
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
|
||||||
|
];
|
||||||
|
|
||||||
|
export const SelectModelSearch = ({
|
||||||
|
selectedModel,
|
||||||
|
setSelectedModel,
|
||||||
|
}: {
|
||||||
|
selectedModel: SupportedModel;
|
||||||
|
setSelectedModel: (model: SupportedModel) => void;
|
||||||
|
}) => {
|
||||||
|
const handleSelection = useCallback(
|
||||||
|
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
|
||||||
|
if (!option) return;
|
||||||
|
setSelectedModel(option.value);
|
||||||
|
},
|
||||||
|
[setSelectedModel],
|
||||||
|
);
|
||||||
|
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
|
||||||
|
|
||||||
|
const [containerRef, containerDimensions] = useElementDimensions();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||||
|
<Text>Browse Models</Text>
|
||||||
|
<Select
|
||||||
|
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||||
|
value={selectedOption}
|
||||||
|
options={modelOptions}
|
||||||
|
onChange={handleSelection}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
123
src/components/VariantHeader/VariantHeader.tsx
Normal file
123
src/components/VariantHeader/VariantHeader.tsx
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import { useState, type DragEvent } from "react";
|
||||||
|
import { type PromptVariant } from "../OutputsTable/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { RiDraggable } from "react-icons/ri";
|
||||||
|
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
|
||||||
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||||
|
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||||
|
|
||||||
|
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
|
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||||
|
const [label, setLabel] = useState(props.variant.label);
|
||||||
|
|
||||||
|
const updateMutation = api.promptVariants.update.useMutation();
|
||||||
|
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||||
|
if (label && label !== props.variant.label) {
|
||||||
|
await updateMutation.mutateAsync({
|
||||||
|
id: props.variant.id,
|
||||||
|
updates: { label: label },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
||||||
|
|
||||||
|
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||||
|
const [onReorder] = useHandledAsyncCallback(
|
||||||
|
async (e: DragEvent<HTMLDivElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setIsDragTarget(false);
|
||||||
|
const draggedId = e.dataTransfer.getData("text/plain");
|
||||||
|
const droppedId = props.variant.id;
|
||||||
|
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||||
|
await reorderMutation.mutateAsync({
|
||||||
|
draggedId,
|
||||||
|
droppedId,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
},
|
||||||
|
[reorderMutation, props.variant.id],
|
||||||
|
);
|
||||||
|
|
||||||
|
const [menuOpen, setMenuOpen] = useState(false);
|
||||||
|
|
||||||
|
if (!canModify) {
|
||||||
|
return (
|
||||||
|
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||||
|
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||||
|
{props.variant.label}
|
||||||
|
</Text>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<GridItem
|
||||||
|
padding={0}
|
||||||
|
sx={{
|
||||||
|
...stickyHeaderStyle,
|
||||||
|
// Ensure that the menu always appears above the sticky header of other variants
|
||||||
|
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||||
|
}}
|
||||||
|
borderTopWidth={1}
|
||||||
|
>
|
||||||
|
<HStack
|
||||||
|
spacing={4}
|
||||||
|
alignItems="flex-start"
|
||||||
|
minH={headerMinHeight}
|
||||||
|
draggable={!isInputHovered}
|
||||||
|
onDragStart={(e) => {
|
||||||
|
e.dataTransfer.setData("text/plain", props.variant.id);
|
||||||
|
e.currentTarget.style.opacity = "0.4";
|
||||||
|
}}
|
||||||
|
onDragEnd={(e) => {
|
||||||
|
e.currentTarget.style.opacity = "1";
|
||||||
|
}}
|
||||||
|
onDragOver={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setIsDragTarget(true);
|
||||||
|
}}
|
||||||
|
onDragLeave={() => {
|
||||||
|
setIsDragTarget(false);
|
||||||
|
}}
|
||||||
|
onDrop={onReorder}
|
||||||
|
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||||
|
>
|
||||||
|
<Icon
|
||||||
|
as={RiDraggable}
|
||||||
|
boxSize={6}
|
||||||
|
mt={2}
|
||||||
|
color="gray.400"
|
||||||
|
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||||
|
/>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
size="sm"
|
||||||
|
value={label}
|
||||||
|
onChange={(e) => setLabel(e.target.value)}
|
||||||
|
onBlur={onSaveLabel}
|
||||||
|
placeholder="Variant Name"
|
||||||
|
borderWidth={1}
|
||||||
|
borderColor="transparent"
|
||||||
|
fontWeight="bold"
|
||||||
|
fontSize={16}
|
||||||
|
_hover={{ borderColor: "gray.300" }}
|
||||||
|
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||||
|
flex={1}
|
||||||
|
px={cellPadding.x}
|
||||||
|
onMouseEnter={() => setIsInputHovered(true)}
|
||||||
|
onMouseLeave={() => setIsInputHovered(false)}
|
||||||
|
/>
|
||||||
|
<VariantHeaderMenuButton
|
||||||
|
variant={props.variant}
|
||||||
|
canHide={props.canHide}
|
||||||
|
menuOpen={menuOpen}
|
||||||
|
setMenuOpen={setMenuOpen}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
|
}
|
||||||
113
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
113
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import { type PromptVariant } from "../OutputsTable/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Icon,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuItem,
|
||||||
|
MenuList,
|
||||||
|
MenuDivider,
|
||||||
|
Text,
|
||||||
|
Spinner,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
||||||
|
import { FaRegClone } from "react-icons/fa";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||||
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
|
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||||
|
import { type SupportedModel } from "~/server/types";
|
||||||
|
|
||||||
|
export default function VariantHeaderMenuButton({
|
||||||
|
variant,
|
||||||
|
canHide,
|
||||||
|
menuOpen,
|
||||||
|
setMenuOpen,
|
||||||
|
}: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
canHide: boolean;
|
||||||
|
menuOpen: boolean;
|
||||||
|
setMenuOpen: (open: boolean) => void;
|
||||||
|
}) {
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||||
|
|
||||||
|
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
await duplicateMutation.mutateAsync({
|
||||||
|
experimentId: variant.experimentId,
|
||||||
|
variantId: variant.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||||
|
|
||||||
|
const hideMutation = api.promptVariants.hide.useMutation();
|
||||||
|
const [onHide] = useHandledAsyncCallback(async () => {
|
||||||
|
await hideMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [hideMutation, variant.id]);
|
||||||
|
|
||||||
|
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
|
||||||
|
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||||
|
{duplicationInProgress ? (
|
||||||
|
<Spinner boxSize={4} mx={3} my={3} />
|
||||||
|
) : (
|
||||||
|
<MenuButton>
|
||||||
|
<Button variant="ghost">
|
||||||
|
<Icon as={BsGear} />
|
||||||
|
</Button>
|
||||||
|
</MenuButton>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<MenuList mt={-3} fontSize="md">
|
||||||
|
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||||
|
Duplicate
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem
|
||||||
|
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||||
|
onClick={() => setSelectModelModalOpen(true)}
|
||||||
|
>
|
||||||
|
Change Model
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem
|
||||||
|
icon={<Icon as={BsStars} boxSize={5} />}
|
||||||
|
onClick={() => setRefinePromptModalOpen(true)}
|
||||||
|
>
|
||||||
|
Refine
|
||||||
|
</MenuItem>
|
||||||
|
{canHide && (
|
||||||
|
<>
|
||||||
|
<MenuDivider />
|
||||||
|
<MenuItem
|
||||||
|
onClick={onHide}
|
||||||
|
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
|
||||||
|
color="red.600"
|
||||||
|
_hover={{ backgroundColor: "red.50" }}
|
||||||
|
>
|
||||||
|
<Text>Hide</Text>
|
||||||
|
</MenuItem>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
{selectModelModalOpen && (
|
||||||
|
<SelectModelModal
|
||||||
|
originalModel={variant.model as SupportedModel}
|
||||||
|
variantId={variant.id}
|
||||||
|
onClose={() => setSelectModelModalOpen(false)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{refinePromptModalOpen && (
|
||||||
|
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,17 +1,11 @@
|
|||||||
import {
|
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
|
||||||
Card,
|
|
||||||
CardBody,
|
|
||||||
HStack,
|
|
||||||
Icon,
|
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
CardHeader,
|
|
||||||
Divider,
|
|
||||||
Box,
|
|
||||||
} from "@chakra-ui/react";
|
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import { formatTimePast } from "~/utils/dayjs";
|
import { formatTimePast } from "~/utils/dayjs";
|
||||||
|
import Link from "next/link";
|
||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
|
import { BsPlusSquare } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
type ExperimentData = {
|
type ExperimentData = {
|
||||||
testScenarioCount: number;
|
testScenarioCount: number;
|
||||||
@@ -24,47 +18,42 @@ type ExperimentData = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
|
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
|
||||||
const router = useRouter();
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<AspectRatio ratio={1.2} w="full">
|
||||||
as={Card}
|
<VStack
|
||||||
variant="elevated"
|
as={Link}
|
||||||
|
href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
|
||||||
bg="gray.50"
|
bg="gray.50"
|
||||||
_hover={{ bg: "gray.100" }}
|
_hover={{ bg: "gray.100" }}
|
||||||
transition="background 0.2s"
|
transition="background 0.2s"
|
||||||
cursor="pointer"
|
cursor="pointer"
|
||||||
onClick={(e) => {
|
borderColor="gray.200"
|
||||||
e.preventDefault();
|
borderWidth={1}
|
||||||
void router.push({ pathname: "/experiments/[id]", query: { id: exp.id } }, undefined, {
|
p={4}
|
||||||
shallow: true,
|
justify="space-between"
|
||||||
});
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
<CardHeader>
|
<HStack w="full" color="gray.700" justify="center">
|
||||||
<HStack w="full" color="gray.700">
|
|
||||||
<Icon as={RiFlaskLine} boxSize={4} />
|
<Icon as={RiFlaskLine} boxSize={4} />
|
||||||
<Text fontWeight="bold">{exp.label}</Text>
|
<Text fontWeight="bold">{exp.label}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CardHeader>
|
<HStack h="full" spacing={4} flex={1} align="center">
|
||||||
<CardBody>
|
|
||||||
<HStack w="full" mb={8} spacing={4}>
|
|
||||||
<CountLabel label="Variants" count={exp.promptVariantCount} />
|
<CountLabel label="Variants" count={exp.promptVariantCount} />
|
||||||
<Divider h={12} orientation="vertical" />
|
<Divider h={12} orientation="vertical" />
|
||||||
<CountLabel label="Scenarios" count={exp.testScenarioCount} />
|
<CountLabel label="Scenarios" count={exp.testScenarioCount} />
|
||||||
</HStack>
|
</HStack>
|
||||||
<HStack w="full" color="gray.500" fontSize="xs">
|
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
|
||||||
<Text>Created {formatTimePast(exp.createdAt)}</Text>
|
<Text flex={1}>Created {formatTimePast(exp.createdAt)}</Text>
|
||||||
<Divider h={4} orientation="vertical" />
|
<Divider h={4} orientation="vertical" />
|
||||||
<Text>Updated {formatTimePast(exp.updatedAt)}</Text>
|
<Text flex={1}>Updated {formatTimePast(exp.updatedAt)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CardBody>
|
</VStack>
|
||||||
</Box>
|
</AspectRatio>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||||
return (
|
return (
|
||||||
<VStack alignItems="flex-start">
|
<VStack alignItems="center" flex={1}>
|
||||||
<Text color="gray.500" fontWeight="bold">
|
<Text color="gray.500" fontWeight="bold">
|
||||||
{label}
|
{label}
|
||||||
</Text>
|
</Text>
|
||||||
@@ -74,3 +63,33 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
|||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const NewExperimentCard = () => {
|
||||||
|
const router = useRouter();
|
||||||
|
const createMutation = api.experiments.create.useMutation();
|
||||||
|
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
|
||||||
|
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
|
||||||
|
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
|
||||||
|
}, [createMutation, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AspectRatio ratio={1.2} w="full">
|
||||||
|
<VStack
|
||||||
|
align="center"
|
||||||
|
justify="center"
|
||||||
|
_hover={{ cursor: "pointer", bg: "gray.50" }}
|
||||||
|
transition="background 0.2s"
|
||||||
|
cursor="pointer"
|
||||||
|
borderColor="gray.200"
|
||||||
|
borderWidth={1}
|
||||||
|
p={4}
|
||||||
|
onClick={createExperiment}
|
||||||
|
>
|
||||||
|
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||||
|
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||||
|
New Experiment
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
</AspectRatio>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
import { Icon, Button, Spinner, Text, type ButtonProps } from "@chakra-ui/react";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useRouter } from "next/router";
|
|
||||||
import { BsPlusSquare } from "react-icons/bs";
|
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
|
|
||||||
export const NewExperimentButton = (props: ButtonProps) => {
|
|
||||||
const router = useRouter();
|
|
||||||
const utils = api.useContext();
|
|
||||||
const createMutation = api.experiments.create.useMutation();
|
|
||||||
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
|
|
||||||
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
|
|
||||||
await utils.experiments.list.invalidate();
|
|
||||||
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
|
|
||||||
}, [createMutation, router]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
onClick={createExperiment}
|
|
||||||
display="flex"
|
|
||||||
alignItems="center"
|
|
||||||
variant={{ base: "solid", md: "ghost" }}
|
|
||||||
{...props}
|
|
||||||
>
|
|
||||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={4} />
|
|
||||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
|
||||||
New Experiment
|
|
||||||
</Text>
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -1,84 +1,100 @@
|
|||||||
|
import { useState, useEffect } from "react";
|
||||||
import {
|
import {
|
||||||
Heading,
|
Heading,
|
||||||
VStack,
|
VStack,
|
||||||
Icon,
|
Icon,
|
||||||
HStack,
|
HStack,
|
||||||
Image,
|
Image,
|
||||||
Grid,
|
|
||||||
GridItem,
|
|
||||||
Divider,
|
|
||||||
Text,
|
Text,
|
||||||
Box,
|
Box,
|
||||||
type BoxProps,
|
type BoxProps,
|
||||||
type LinkProps,
|
type LinkProps,
|
||||||
Link,
|
Link,
|
||||||
|
Flex,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
import { BsGithub, BsTwitter } from "react-icons/bs";
|
import { BsGithub, BsPersonCircle } from "react-icons/bs";
|
||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import PublicPlaygroundWarning from "../PublicPlaygroundWarning";
|
|
||||||
import { type IconType } from "react-icons";
|
import { type IconType } from "react-icons";
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import { useState, useEffect } from "react";
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
import UserMenu from "./UserMenu";
|
||||||
|
|
||||||
type IconLinkProps = BoxProps & LinkProps & { label: string; icon: IconType; href: string };
|
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
|
||||||
|
|
||||||
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
|
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
|
||||||
const isActive = useRouter().pathname.startsWith(href);
|
const router = useRouter();
|
||||||
|
const isActive = href && router.pathname.startsWith(href);
|
||||||
return (
|
return (
|
||||||
<Box
|
<HStack
|
||||||
|
w="full"
|
||||||
|
p={4}
|
||||||
|
color={color}
|
||||||
as={Link}
|
as={Link}
|
||||||
href={href}
|
href={href}
|
||||||
target={target}
|
target={target}
|
||||||
w="full"
|
bgColor={isActive ? "gray.200" : "transparent"}
|
||||||
bgColor={isActive ? "gray.300" : "transparent"}
|
_hover={{ bgColor: "gray.200", textDecoration: "none" }}
|
||||||
_hover={{ bgColor: "gray.300" }}
|
|
||||||
py={4}
|
|
||||||
justifyContent="start"
|
justifyContent="start"
|
||||||
cursor="pointer"
|
cursor="pointer"
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
<HStack w="full" px={4} color={color}>
|
|
||||||
<Icon as={icon} boxSize={6} mr={2} />
|
<Icon as={icon} boxSize={6} mr={2} />
|
||||||
<Text fontWeight="bold">{label}</Text>
|
<Text fontWeight="bold" fontSize="sm">
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</Box>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const Divider = () => <Box h="1px" bgColor="gray.200" />;
|
||||||
|
|
||||||
const NavSidebar = () => {
|
const NavSidebar = () => {
|
||||||
|
const user = useSession().data;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack align="stretch" bgColor="gray.100" py={2} pb={0} height="100%">
|
<VStack
|
||||||
<Link href="/" w="full" _hover={{ textDecoration: "none" }}>
|
align="stretch"
|
||||||
<HStack spacing={0} pl="3">
|
bgColor="gray.100"
|
||||||
<Image src="/logo.svg" alt="" w={8} h={8} />
|
py={2}
|
||||||
<Heading size="md" p={2} pl={{ base: 16, md: 2 }}>
|
pb={0}
|
||||||
|
height="100%"
|
||||||
|
w={{ base: "56px", md: "200px" }}
|
||||||
|
overflow="hidden"
|
||||||
|
>
|
||||||
|
<HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={4} py={2}>
|
||||||
|
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
|
||||||
|
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||||
OpenPipe
|
OpenPipe
|
||||||
</Heading>
|
</Heading>
|
||||||
</HStack>
|
</HStack>
|
||||||
</Link>
|
|
||||||
<Divider />
|
|
||||||
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
|
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
|
||||||
|
{user != null && (
|
||||||
|
<>
|
||||||
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
||||||
</VStack>
|
</>
|
||||||
<Divider />
|
)}
|
||||||
<VStack w="full" spacing={0} pb={2}>
|
{user === null && (
|
||||||
<IconLink
|
<IconLink
|
||||||
icon={BsGithub}
|
icon={BsPersonCircle}
|
||||||
label="GitHub"
|
label="Sign In"
|
||||||
|
onClick={() => {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
{user ? <UserMenu user={user} /> : <Divider />}
|
||||||
|
<VStack spacing={0} align="center">
|
||||||
|
<Link
|
||||||
href="https://github.com/openpipe/openpipe"
|
href="https://github.com/openpipe/openpipe"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
color="gray.500"
|
color="gray.500"
|
||||||
_hover={{ color: "gray.800" }}
|
_hover={{ color: "gray.800" }}
|
||||||
/>
|
p={2}
|
||||||
<IconLink
|
>
|
||||||
icon={BsTwitter}
|
<Icon as={BsGithub} boxSize={6} />
|
||||||
label="Twitter"
|
</Link>
|
||||||
href="https://twitter.com/corbtt"
|
|
||||||
target="_blank"
|
|
||||||
color="gray.500"
|
|
||||||
_hover={{ color: "gray.800" }}
|
|
||||||
/>
|
|
||||||
</VStack>
|
</VStack>
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
@@ -105,25 +121,14 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid
|
<Flex h={vh} w="100vw">
|
||||||
h={vh}
|
|
||||||
w="100vw"
|
|
||||||
templateColumns={{ base: "56px minmax(0, 1fr)", md: "200px minmax(0, 1fr)" }}
|
|
||||||
templateRows="max-content 1fr"
|
|
||||||
templateAreas={'"warning warning"\n"sidebar main"'}
|
|
||||||
>
|
|
||||||
<Head>
|
<Head>
|
||||||
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
|
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
|
||||||
</Head>
|
</Head>
|
||||||
<GridItem area="warning">
|
|
||||||
<PublicPlaygroundWarning />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem area="sidebar" overflow="hidden">
|
|
||||||
<NavSidebar />
|
<NavSidebar />
|
||||||
</GridItem>
|
<Box h="100%" flex={1} overflowY="auto">
|
||||||
<GridItem area="main" overflowY="auto">
|
|
||||||
{props.children}
|
{props.children}
|
||||||
</GridItem>
|
</Box>
|
||||||
</Grid>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
74
src/components/nav/UserMenu.tsx
Normal file
74
src/components/nav/UserMenu.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import {
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
Image,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Popover,
|
||||||
|
PopoverTrigger,
|
||||||
|
PopoverContent,
|
||||||
|
Link,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { type Session } from "next-auth";
|
||||||
|
import { signOut } from "next-auth/react";
|
||||||
|
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
|
||||||
|
|
||||||
|
export default function UserMenu({ user }: { user: Session }) {
|
||||||
|
const profileImage = user.user.image ? (
|
||||||
|
<Image src={user.user.image} alt="profile picture" boxSize={8} borderRadius="50%" />
|
||||||
|
) : (
|
||||||
|
<Icon as={BsPersonCircle} boxSize={6} />
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Popover placement="right">
|
||||||
|
<PopoverTrigger>
|
||||||
|
<HStack
|
||||||
|
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
|
||||||
|
px={3}
|
||||||
|
spacing={3}
|
||||||
|
py={2}
|
||||||
|
borderColor={"gray.200"}
|
||||||
|
borderTopWidth={1}
|
||||||
|
borderBottomWidth={1}
|
||||||
|
cursor="pointer"
|
||||||
|
_hover={{
|
||||||
|
bgColor: "gray.200",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{profileImage}
|
||||||
|
<VStack spacing={0} align="start" flex={1} flexShrink={1}>
|
||||||
|
<Text fontWeight="bold" fontSize="sm">
|
||||||
|
{user.user.name}
|
||||||
|
</Text>
|
||||||
|
<Text color="gray.500" fontSize="xs">
|
||||||
|
{user.user.email}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
|
||||||
|
</HStack>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
|
||||||
|
<VStack align="stretch" spacing={0}>
|
||||||
|
{/* sign out */}
|
||||||
|
<HStack
|
||||||
|
as={Link}
|
||||||
|
onClick={() => {
|
||||||
|
signOut().catch(console.error);
|
||||||
|
}}
|
||||||
|
px={4}
|
||||||
|
py={2}
|
||||||
|
spacing={4}
|
||||||
|
color="gray.500"
|
||||||
|
fontSize="sm"
|
||||||
|
>
|
||||||
|
<Icon as={BsBoxArrowRight} boxSize={6} />
|
||||||
|
<Text>Sign out</Text>
|
||||||
|
</HStack>
|
||||||
|
</VStack>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
18
src/env.mjs
18
src/env.mjs
@@ -10,6 +10,14 @@ export const env = createEnv({
|
|||||||
DATABASE_URL: z.string().url(),
|
DATABASE_URL: z.string().url(),
|
||||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||||
OPENAI_API_KEY: z.string().min(1),
|
OPENAI_API_KEY: z.string().min(1),
|
||||||
|
RESTRICT_PRISMA_LOGS: z
|
||||||
|
.string()
|
||||||
|
.optional()
|
||||||
|
.default("false")
|
||||||
|
.transform((val) => val.toLowerCase() === "true"),
|
||||||
|
GITHUB_CLIENT_ID: z.string().min(1),
|
||||||
|
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||||
|
REPLICATE_API_TOKEN: z.string().min(1),
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -19,11 +27,6 @@ export const env = createEnv({
|
|||||||
*/
|
*/
|
||||||
client: {
|
client: {
|
||||||
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
|
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
|
||||||
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
|
|
||||||
.string()
|
|
||||||
.optional()
|
|
||||||
.default("false")
|
|
||||||
.transform((val) => val.toLowerCase() === "true"),
|
|
||||||
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -35,9 +38,12 @@ export const env = createEnv({
|
|||||||
DATABASE_URL: process.env.DATABASE_URL,
|
DATABASE_URL: process.env.DATABASE_URL,
|
||||||
NODE_ENV: process.env.NODE_ENV,
|
NODE_ENV: process.env.NODE_ENV,
|
||||||
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
||||||
|
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
|
||||||
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
||||||
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
|
|
||||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||||
|
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||||
|
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||||
|
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||||
|
|||||||
36
src/modelProviders/generateTypes.ts
Normal file
36
src/modelProviders/generateTypes.ts
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import { type JSONSchema4Object } from "json-schema";
|
||||||
|
import modelProviders from "./modelProviders";
|
||||||
|
import { compile } from "json-schema-to-typescript";
|
||||||
|
import dedent from "dedent";
|
||||||
|
|
||||||
|
export default async function generateTypes() {
|
||||||
|
const combinedSchema = {
|
||||||
|
type: "object",
|
||||||
|
properties: {} as Record<string, JSONSchema4Object>,
|
||||||
|
};
|
||||||
|
|
||||||
|
Object.entries(modelProviders).forEach(([id, provider]) => {
|
||||||
|
combinedSchema.properties[id] = provider.inputSchema;
|
||||||
|
});
|
||||||
|
|
||||||
|
Object.entries(modelProviders).forEach(([id, provider]) => {
|
||||||
|
combinedSchema.properties[id] = provider.inputSchema;
|
||||||
|
});
|
||||||
|
|
||||||
|
const promptTypes = (
|
||||||
|
await compile(combinedSchema as JSONSchema4Object, "PromptTypes", {
|
||||||
|
additionalProperties: false,
|
||||||
|
bannerComment: dedent`
|
||||||
|
/**
|
||||||
|
* This type map defines the input types for each model provider.
|
||||||
|
*/
|
||||||
|
`,
|
||||||
|
})
|
||||||
|
).replace(/export interface PromptTypes/g, "interface PromptTypes");
|
||||||
|
|
||||||
|
return dedent`
|
||||||
|
${promptTypes}
|
||||||
|
|
||||||
|
declare function definePrompt<T extends keyof PromptTypes>(modelProvider: T, input: PromptTypes[T])
|
||||||
|
`;
|
||||||
|
}
|
||||||
9
src/modelProviders/modelProviders.ts
Normal file
9
src/modelProviders/modelProviders.ts
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||||
|
import replicateLlama2 from "./replicate-llama2";
|
||||||
|
|
||||||
|
const modelProviders = {
|
||||||
|
"openai/ChatCompletion": openaiChatCompletion,
|
||||||
|
"replicate/llama2": replicateLlama2,
|
||||||
|
} as const;
|
||||||
|
|
||||||
|
export default modelProviders;
|
||||||
14
src/modelProviders/modelProvidersFrontend.ts
Normal file
14
src/modelProviders/modelProvidersFrontend.ts
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||||
|
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||||
|
|
||||||
|
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||||
|
|
||||||
|
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||||
|
// just include them in the default `modelProviders` object because it has some
|
||||||
|
// transient dependencies that can only be imported on the server.
|
||||||
|
const modelProvidersFrontend = {
|
||||||
|
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||||
|
"replicate/llama2": replicateLlama2Frontend,
|
||||||
|
} as const;
|
||||||
|
|
||||||
|
export default modelProvidersFrontend;
|
||||||
77
src/modelProviders/modelStats.ts
Normal file
77
src/modelProviders/modelStats.ts
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import { type SupportedModel } from "../server/types";
|
||||||
|
|
||||||
|
interface ModelStats {
|
||||||
|
contextLength: number;
|
||||||
|
promptTokenPrice: number;
|
||||||
|
completionTokenPrice: number;
|
||||||
|
speed: "fast" | "medium" | "slow";
|
||||||
|
provider: "OpenAI";
|
||||||
|
learnMoreUrl: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const modelStats: Record<SupportedModel, ModelStats> = {
|
||||||
|
"gpt-4": {
|
||||||
|
contextLength: 8192,
|
||||||
|
promptTokenPrice: 0.00003,
|
||||||
|
completionTokenPrice: 0.00006,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-4-0613": {
|
||||||
|
contextLength: 8192,
|
||||||
|
promptTokenPrice: 0.00003,
|
||||||
|
completionTokenPrice: 0.00006,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-4-32k": {
|
||||||
|
contextLength: 32768,
|
||||||
|
promptTokenPrice: 0.00006,
|
||||||
|
completionTokenPrice: 0.00012,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-4-32k-0613": {
|
||||||
|
contextLength: 32768,
|
||||||
|
promptTokenPrice: 0.00006,
|
||||||
|
completionTokenPrice: 0.00012,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo": {
|
||||||
|
contextLength: 4096,
|
||||||
|
promptTokenPrice: 0.0000015,
|
||||||
|
completionTokenPrice: 0.000002,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-0613": {
|
||||||
|
contextLength: 4096,
|
||||||
|
promptTokenPrice: 0.0000015,
|
||||||
|
completionTokenPrice: 0.000002,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-16k": {
|
||||||
|
contextLength: 16384,
|
||||||
|
promptTokenPrice: 0.000003,
|
||||||
|
completionTokenPrice: 0.000004,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-16k-0613": {
|
||||||
|
contextLength: 16384,
|
||||||
|
promptTokenPrice: 0.000003,
|
||||||
|
completionTokenPrice: 0.000004,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "OpenAI",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
};
|
||||||
69
src/modelProviders/openai-ChatCompletion/codegen/codegen.ts
Normal file
69
src/modelProviders/openai-ChatCompletion/codegen/codegen.ts
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-var-requires */
|
||||||
|
|
||||||
|
import YAML from "yaml";
|
||||||
|
import fs from "fs";
|
||||||
|
import path from "path";
|
||||||
|
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
|
||||||
|
import $RefParser from "@apidevtools/json-schema-ref-parser";
|
||||||
|
import { type JSONObject } from "superjson/dist/types";
|
||||||
|
import assert from "assert";
|
||||||
|
import { type JSONSchema4Object } from "json-schema";
|
||||||
|
import { isObject } from "lodash-es";
|
||||||
|
|
||||||
|
// @ts-expect-error for some reason missing from types
|
||||||
|
import parserEstree from "prettier/plugins/estree";
|
||||||
|
import parserBabel from "prettier/plugins/babel";
|
||||||
|
import prettier from "prettier/standalone";
|
||||||
|
|
||||||
|
const OPENAPI_URL =
|
||||||
|
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
|
||||||
|
|
||||||
|
// Fetch the openapi document
|
||||||
|
const response = await fetch(OPENAPI_URL);
|
||||||
|
const openApiYaml = await response.text();
|
||||||
|
|
||||||
|
// Parse the yaml document
|
||||||
|
let schema = YAML.parse(openApiYaml) as JSONObject;
|
||||||
|
schema = openapiSchemaToJsonSchema(schema);
|
||||||
|
|
||||||
|
const jsonSchema = await $RefParser.dereference(schema);
|
||||||
|
|
||||||
|
assert("components" in jsonSchema);
|
||||||
|
const completionRequestSchema = jsonSchema.components.schemas
|
||||||
|
.CreateChatCompletionRequest as JSONSchema4Object;
|
||||||
|
|
||||||
|
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
||||||
|
// the fact that the schema says `model` can be either a string or an enum,
|
||||||
|
// and displays a warning in the editor. Let's stick with just an enum for
|
||||||
|
// now and drop the string option.
|
||||||
|
assert(
|
||||||
|
"properties" in completionRequestSchema &&
|
||||||
|
isObject(completionRequestSchema.properties) &&
|
||||||
|
"model" in completionRequestSchema.properties &&
|
||||||
|
isObject(completionRequestSchema.properties.model),
|
||||||
|
);
|
||||||
|
|
||||||
|
const modelProperty = completionRequestSchema.properties.model;
|
||||||
|
assert(
|
||||||
|
"oneOf" in modelProperty &&
|
||||||
|
Array.isArray(modelProperty.oneOf) &&
|
||||||
|
modelProperty.oneOf.length === 2 &&
|
||||||
|
isObject(modelProperty.oneOf[1]) &&
|
||||||
|
"enum" in modelProperty.oneOf[1],
|
||||||
|
"Expected model to have oneOf length of 2",
|
||||||
|
);
|
||||||
|
modelProperty.type = "string";
|
||||||
|
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||||
|
delete modelProperty["oneOf"];
|
||||||
|
|
||||||
|
// Get the directory of the current script
|
||||||
|
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||||
|
|
||||||
|
// Write the JSON schema to a file in the current directory
|
||||||
|
fs.writeFileSync(
|
||||||
|
path.join(currentDirectory, "input.schema.json"),
|
||||||
|
await prettier.format(JSON.stringify(completionRequestSchema, null, 2), {
|
||||||
|
parser: "json",
|
||||||
|
plugins: [parserBabel, parserEstree],
|
||||||
|
}),
|
||||||
|
);
|
||||||
@@ -0,0 +1,186 @@
|
|||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"description": "ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.",
|
||||||
|
"example": "gpt-3.5-turbo",
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"gpt-4-32k-0613",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"messages": {
|
||||||
|
"description": "A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb).",
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"role": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["system", "user", "assistant", "function"],
|
||||||
|
"description": "The role of the messages author. One of `system`, `user`, `assistant`, or `function`."
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The contents of the message. `content` is required for all messages except assistant messages with function calls."
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters."
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The name and arguments of a function that should be called, as generated by the model.",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the function to call."
|
||||||
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["role"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"functions": {
|
||||||
|
"description": "A list of functions the model may generate JSON inputs for.",
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64."
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The description of what the function does."
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format.",
|
||||||
|
"additionalProperties": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
"description": "Controls how the model responds to function calls. \"none\" means the model does not call a function, and responds to the end-user. \"auto\" means the model can pick between an end-user or calling a function. Specifying a particular function via `{\"name\":\\ \"my_function\"}` forces the model to call that function. \"none\" is the default when no functions are present. \"auto\" is the default if functions are present.",
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["none", "auto"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the function to call."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 2,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n"
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n"
|
||||||
|
},
|
||||||
|
"n": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 128,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "How many chat completion choices to generate for each input message."
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"description": "If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).\n",
|
||||||
|
"type": "boolean",
|
||||||
|
"nullable": true,
|
||||||
|
"default": false
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"description": "Up to 4 sequences where the API will stop generating further tokens.\n",
|
||||||
|
"default": null,
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"maxItems": 4,
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"max_tokens": {
|
||||||
|
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
||||||
|
"default": "inf",
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"presence_penalty": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 0,
|
||||||
|
"minimum": -2,
|
||||||
|
"maximum": 2,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
|
||||||
|
},
|
||||||
|
"frequency_penalty": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 0,
|
||||||
|
"minimum": -2,
|
||||||
|
"maximum": 2,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
|
||||||
|
},
|
||||||
|
"logit_bias": {
|
||||||
|
"type": "object",
|
||||||
|
"x-oaiTypeLabel": "map",
|
||||||
|
"default": null,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "user-1234",
|
||||||
|
"description": "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["model", "messages"]
|
||||||
|
}
|
||||||
42
src/modelProviders/openai-ChatCompletion/frontend.ts
Normal file
42
src/modelProviders/openai-ChatCompletion/frontend.ts
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import { type JsonValue } from "type-fest";
|
||||||
|
import { type OpenaiChatModelProvider } from ".";
|
||||||
|
import { type ModelProviderFrontend } from "../types";
|
||||||
|
|
||||||
|
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
||||||
|
normalizeOutput: (output) => {
|
||||||
|
const message = output.choices[0]?.message;
|
||||||
|
if (!message)
|
||||||
|
return {
|
||||||
|
type: "json",
|
||||||
|
value: output as unknown as JsonValue,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (message.content) {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
value: message.content,
|
||||||
|
};
|
||||||
|
} else if (message.function_call) {
|
||||||
|
let args = message.function_call.arguments ?? "";
|
||||||
|
try {
|
||||||
|
args = JSON.parse(args);
|
||||||
|
} catch (e) {
|
||||||
|
// Ignore
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
type: "json",
|
||||||
|
value: {
|
||||||
|
...message.function_call,
|
||||||
|
arguments: args,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
type: "json",
|
||||||
|
value: message as unknown as JsonValue,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProviderFrontend;
|
||||||
142
src/modelProviders/openai-ChatCompletion/getCompletion.ts
Normal file
142
src/modelProviders/openai-ChatCompletion/getCompletion.ts
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||||
|
import {
|
||||||
|
type ChatCompletionChunk,
|
||||||
|
type ChatCompletion,
|
||||||
|
type CompletionCreateParams,
|
||||||
|
} from "openai/resources/chat";
|
||||||
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
|
import { type CompletionResponse } from "../types";
|
||||||
|
import { omit } from "lodash-es";
|
||||||
|
import { openai } from "~/server/utils/openai";
|
||||||
|
import { type OpenAIChatModel } from "~/server/types";
|
||||||
|
import { truthyFilter } from "~/utils/utils";
|
||||||
|
import { APIError } from "openai";
|
||||||
|
import { modelStats } from "../modelStats";
|
||||||
|
|
||||||
|
const mergeStreamedChunks = (
|
||||||
|
base: ChatCompletion | null,
|
||||||
|
chunk: ChatCompletionChunk,
|
||||||
|
): ChatCompletion => {
|
||||||
|
if (base === null) {
|
||||||
|
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
const choices = [...base.choices];
|
||||||
|
for (const choice of chunk.choices) {
|
||||||
|
const baseChoice = choices.find((c) => c.index === choice.index);
|
||||||
|
if (baseChoice) {
|
||||||
|
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
|
||||||
|
baseChoice.message = baseChoice.message ?? { role: "assistant" };
|
||||||
|
|
||||||
|
if (choice.delta?.content)
|
||||||
|
baseChoice.message.content =
|
||||||
|
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
|
||||||
|
if (choice.delta?.function_call) {
|
||||||
|
const fnCall = baseChoice.message.function_call ?? {};
|
||||||
|
fnCall.name =
|
||||||
|
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
|
||||||
|
fnCall.arguments =
|
||||||
|
((fnCall.arguments as string) ?? "") +
|
||||||
|
((choice.delta.function_call.arguments as string) ?? "");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const merged: ChatCompletion = {
|
||||||
|
...base,
|
||||||
|
choices,
|
||||||
|
};
|
||||||
|
|
||||||
|
return merged;
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function getCompletion(
|
||||||
|
input: CompletionCreateParams,
|
||||||
|
onStream: ((partialOutput: ChatCompletion) => void) | null,
|
||||||
|
): Promise<CompletionResponse<ChatCompletion>> {
|
||||||
|
const start = Date.now();
|
||||||
|
let finalCompletion: ChatCompletion | null = null;
|
||||||
|
let promptTokens: number | undefined = undefined;
|
||||||
|
let completionTokens: number | undefined = undefined;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (onStream) {
|
||||||
|
const resp = await openai.chat.completions.create(
|
||||||
|
{ ...input, stream: true },
|
||||||
|
{
|
||||||
|
maxRetries: 0,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
for await (const part of resp) {
|
||||||
|
finalCompletion = mergeStreamedChunks(finalCompletion, part);
|
||||||
|
onStream(finalCompletion);
|
||||||
|
}
|
||||||
|
if (!finalCompletion) {
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: "Streaming failed to return a completion",
|
||||||
|
autoRetry: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
promptTokens = countOpenAIChatTokens(
|
||||||
|
input.model as keyof typeof OpenAIChatModel,
|
||||||
|
input.messages,
|
||||||
|
);
|
||||||
|
completionTokens = countOpenAIChatTokens(
|
||||||
|
input.model as keyof typeof OpenAIChatModel,
|
||||||
|
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
||||||
|
);
|
||||||
|
} catch (err) {
|
||||||
|
// TODO handle this, library seems like maybe it doesn't work with function calls?
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const resp = await openai.chat.completions.create(
|
||||||
|
{ ...input, stream: false },
|
||||||
|
{
|
||||||
|
maxRetries: 0,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
finalCompletion = resp;
|
||||||
|
promptTokens = resp.usage?.prompt_tokens ?? 0;
|
||||||
|
completionTokens = resp.usage?.completion_tokens ?? 0;
|
||||||
|
}
|
||||||
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
|
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
|
||||||
|
let cost = undefined;
|
||||||
|
if (stats && promptTokens && completionTokens) {
|
||||||
|
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "success",
|
||||||
|
statusCode: 200,
|
||||||
|
value: finalCompletion,
|
||||||
|
timeToComplete,
|
||||||
|
promptTokens,
|
||||||
|
completionTokens,
|
||||||
|
cost,
|
||||||
|
};
|
||||||
|
} catch (error: unknown) {
|
||||||
|
console.error("ERROR IS", error);
|
||||||
|
if (error instanceof APIError) {
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: error.message,
|
||||||
|
autoRetry: error.status === 429 || error.status === 503,
|
||||||
|
statusCode: error.status,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
console.error(error);
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: (error as Error).message,
|
||||||
|
autoRetry: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
62
src/modelProviders/openai-ChatCompletion/index.ts
Normal file
62
src/modelProviders/openai-ChatCompletion/index.ts
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import { type JSONSchema4 } from "json-schema";
|
||||||
|
import { type ModelProvider } from "../types";
|
||||||
|
import inputSchema from "./codegen/input.schema.json";
|
||||||
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
|
import { getCompletion } from "./getCompletion";
|
||||||
|
|
||||||
|
const supportedModels = [
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-4-32k-0613",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
|
export type OpenaiChatModelProvider = ModelProvider<
|
||||||
|
SupportedModel,
|
||||||
|
CompletionCreateParams,
|
||||||
|
ChatCompletion
|
||||||
|
>;
|
||||||
|
|
||||||
|
const modelProvider: OpenaiChatModelProvider = {
|
||||||
|
name: "OpenAI ChatCompletion",
|
||||||
|
models: {
|
||||||
|
"gpt-4-0613": {
|
||||||
|
name: "GPT-4",
|
||||||
|
learnMore: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-4-32k-0613": {
|
||||||
|
name: "GPT-4 32k",
|
||||||
|
learnMore: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-0613": {
|
||||||
|
name: "GPT-3.5 Turbo",
|
||||||
|
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-16k-0613": {
|
||||||
|
name: "GPT-3.5 Turbo 16k",
|
||||||
|
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
getModel: (input) => {
|
||||||
|
if (supportedModels.includes(input.model as SupportedModel))
|
||||||
|
return input.model as SupportedModel;
|
||||||
|
|
||||||
|
const modelMaps: Record<string, SupportedModel> = {
|
||||||
|
"gpt-4": "gpt-4-0613",
|
||||||
|
"gpt-4-32k": "gpt-4-32k-0613",
|
||||||
|
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
|
||||||
|
};
|
||||||
|
|
||||||
|
if (input.model in modelMaps) return modelMaps[input.model] as SupportedModel;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
|
shouldStream: (input) => input.stream ?? false,
|
||||||
|
getCompletion,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProvider;
|
||||||
13
src/modelProviders/replicate-llama2/frontend.ts
Normal file
13
src/modelProviders/replicate-llama2/frontend.ts
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import { type ReplicateLlama2Provider } from ".";
|
||||||
|
import { type ModelProviderFrontend } from "../types";
|
||||||
|
|
||||||
|
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
||||||
|
normalizeOutput: (output) => {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
value: output.join(""),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProviderFrontend;
|
||||||
62
src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
62
src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import { env } from "~/env.mjs";
|
||||||
|
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
|
||||||
|
import { type CompletionResponse } from "../types";
|
||||||
|
import Replicate from "replicate";
|
||||||
|
|
||||||
|
const replicate = new Replicate({
|
||||||
|
auth: env.REPLICATE_API_TOKEN || "",
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||||
|
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||||
|
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
||||||
|
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function getCompletion(
|
||||||
|
input: ReplicateLlama2Input,
|
||||||
|
onStream: ((partialOutput: string[]) => void) | null,
|
||||||
|
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||||
|
const start = Date.now();
|
||||||
|
|
||||||
|
const { model, stream, ...rest } = input;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const prediction = await replicate.predictions.create({
|
||||||
|
version: modelIds[model],
|
||||||
|
input: rest,
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log("stream?", onStream);
|
||||||
|
|
||||||
|
const interval = onStream
|
||||||
|
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||||
|
setInterval(async () => {
|
||||||
|
const partialPrediction = await replicate.predictions.get(prediction.id);
|
||||||
|
|
||||||
|
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
|
||||||
|
}, 500)
|
||||||
|
: null;
|
||||||
|
|
||||||
|
const resp = await replicate.wait(prediction, {});
|
||||||
|
if (interval) clearInterval(interval);
|
||||||
|
|
||||||
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
|
if (resp.error) throw new Error(resp.error as string);
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "success",
|
||||||
|
statusCode: 200,
|
||||||
|
value: resp.output as ReplicateLlama2Output,
|
||||||
|
timeToComplete,
|
||||||
|
};
|
||||||
|
} catch (error: unknown) {
|
||||||
|
console.error("ERROR IS", error);
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: (error as Error).message,
|
||||||
|
autoRetry: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
74
src/modelProviders/replicate-llama2/index.ts
Normal file
74
src/modelProviders/replicate-llama2/index.ts
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { type ModelProvider } from "../types";
|
||||||
|
import { getCompletion } from "./getCompletion";
|
||||||
|
|
||||||
|
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||||
|
|
||||||
|
type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
|
export type ReplicateLlama2Input = {
|
||||||
|
model: SupportedModel;
|
||||||
|
prompt: string;
|
||||||
|
stream?: boolean;
|
||||||
|
max_length?: number;
|
||||||
|
temperature?: number;
|
||||||
|
top_p?: number;
|
||||||
|
repetition_penalty?: number;
|
||||||
|
debug?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ReplicateLlama2Output = string[];
|
||||||
|
|
||||||
|
export type ReplicateLlama2Provider = ModelProvider<
|
||||||
|
SupportedModel,
|
||||||
|
ReplicateLlama2Input,
|
||||||
|
ReplicateLlama2Output
|
||||||
|
>;
|
||||||
|
|
||||||
|
const modelProvider: ReplicateLlama2Provider = {
|
||||||
|
name: "OpenAI ChatCompletion",
|
||||||
|
models: {
|
||||||
|
"7b-chat": {},
|
||||||
|
"13b-chat": {},
|
||||||
|
"70b-chat": {},
|
||||||
|
},
|
||||||
|
getModel: (input) => {
|
||||||
|
if (supportedModels.includes(input.model)) return input.model;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
inputSchema: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
model: {
|
||||||
|
type: "string",
|
||||||
|
enum: supportedModels as unknown as string[],
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
stream: {
|
||||||
|
type: "boolean",
|
||||||
|
},
|
||||||
|
max_length: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
temperature: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
top_p: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
repetition_penalty: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
debug: {
|
||||||
|
type: "boolean",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ["model", "prompt"],
|
||||||
|
},
|
||||||
|
shouldStream: (input) => input.stream ?? false,
|
||||||
|
getCompletion,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProvider;
|
||||||
48
src/modelProviders/types.ts
Normal file
48
src/modelProviders/types.ts
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import { type JSONSchema4 } from "json-schema";
|
||||||
|
import { type JsonValue } from "type-fest";
|
||||||
|
|
||||||
|
type ModelProviderModel = {
|
||||||
|
name?: string;
|
||||||
|
learnMore?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type CompletionResponse<T> =
|
||||||
|
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
|
||||||
|
| {
|
||||||
|
type: "success";
|
||||||
|
value: T;
|
||||||
|
timeToComplete: number;
|
||||||
|
statusCode: number;
|
||||||
|
promptTokens?: number;
|
||||||
|
completionTokens?: number;
|
||||||
|
cost?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||||
|
name: string;
|
||||||
|
models: Record<SupportedModels, ModelProviderModel>;
|
||||||
|
getModel: (input: InputSchema) => SupportedModels | null;
|
||||||
|
shouldStream: (input: InputSchema) => boolean;
|
||||||
|
inputSchema: JSONSchema4;
|
||||||
|
getCompletion: (
|
||||||
|
input: InputSchema,
|
||||||
|
onStream: ((partialOutput: OutputSchema) => void) | null,
|
||||||
|
) => Promise<CompletionResponse<OutputSchema>>;
|
||||||
|
|
||||||
|
// This is just a convenience for type inference, don't use it at runtime
|
||||||
|
_outputSchema?: OutputSchema | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type NormalizedOutput =
|
||||||
|
| {
|
||||||
|
type: "text";
|
||||||
|
value: string;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "json";
|
||||||
|
value: JsonValue;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
|
||||||
|
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
|
||||||
|
};
|
||||||
@@ -2,11 +2,11 @@ import { type Session } from "next-auth";
|
|||||||
import { SessionProvider } from "next-auth/react";
|
import { SessionProvider } from "next-auth/react";
|
||||||
import { type AppType } from "next/app";
|
import { type AppType } from "next/app";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { ChakraProvider } from "@chakra-ui/react";
|
|
||||||
import theme from "~/utils/theme";
|
|
||||||
import Favicon from "~/components/Favicon";
|
import Favicon from "~/components/Favicon";
|
||||||
import "~/utils/analytics";
|
import "~/utils/analytics";
|
||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
|
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
||||||
|
import { SyncAppStore } from "~/state/sync";
|
||||||
|
|
||||||
const MyApp: AppType<{ session: Session | null }> = ({
|
const MyApp: AppType<{ session: Session | null }> = ({
|
||||||
Component,
|
Component,
|
||||||
@@ -21,10 +21,11 @@ const MyApp: AppType<{ session: Session | null }> = ({
|
|||||||
/>
|
/>
|
||||||
</Head>
|
</Head>
|
||||||
<SessionProvider session={session}>
|
<SessionProvider session={session}>
|
||||||
|
<SyncAppStore />
|
||||||
<Favicon />
|
<Favicon />
|
||||||
<ChakraProvider theme={theme}>
|
<ChakraThemeProvider>
|
||||||
<Component {...pageProps} />
|
<Component {...pageProps} />
|
||||||
</ChakraProvider>
|
</ChakraThemeProvider>
|
||||||
</SessionProvider>
|
</SessionProvider>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|||||||
23
src/pages/account/signin.tsx
Normal file
23
src/pages/account/signin.tsx
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
import AppShell from "~/components/nav/AppShell";
|
||||||
|
|
||||||
|
export default function SignIn() {
|
||||||
|
const session = useSession().data;
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (session) {
|
||||||
|
router.push("/experiments").catch(console.error);
|
||||||
|
} else if (session === null) {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
}
|
||||||
|
}, [session, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AppShell>
|
||||||
|
<div />
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -49,6 +49,10 @@ const DeleteButton = () => {
|
|||||||
onClose();
|
onClose();
|
||||||
}, [mutation, experiment.data?.id, router]);
|
}, [mutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
||||||
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Button
|
<Button
|
||||||
@@ -124,6 +128,8 @@ export default function Experiment() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const canModify = experiment.data?.access.canModify ?? false;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AppShell title={experiment.data?.label}>
|
<AppShell title={experiment.data?.label}>
|
||||||
<VStack h="full">
|
<VStack h="full">
|
||||||
@@ -143,6 +149,7 @@ export default function Experiment() {
|
|||||||
</Link>
|
</Link>
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
<BreadcrumbItem isCurrentPage>
|
<BreadcrumbItem isCurrentPage>
|
||||||
|
{canModify ? (
|
||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={label}
|
value={label}
|
||||||
@@ -157,8 +164,14 @@ export default function Experiment() {
|
|||||||
_hover={{ borderColor: "gray.300" }}
|
_hover={{ borderColor: "gray.300" }}
|
||||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||||
/>
|
/>
|
||||||
|
) : (
|
||||||
|
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
|
||||||
|
{experiment.data?.label}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
|
{canModify && (
|
||||||
<HStack>
|
<HStack>
|
||||||
<Button
|
<Button
|
||||||
size="sm"
|
size="sm"
|
||||||
@@ -174,6 +187,7 @@ export default function Experiment() {
|
|||||||
</Button>
|
</Button>
|
||||||
<DeleteButton />
|
<DeleteButton />
|
||||||
</HStack>
|
</HStack>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<SettingsDrawer />
|
<SettingsDrawer />
|
||||||
<Box w="100%" overflowX="auto" flex={1}>
|
<Box w="100%" overflowX="auto" flex={1}>
|
||||||
|
|||||||
@@ -1,25 +1,50 @@
|
|||||||
import {
|
import {
|
||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
HStack,
|
|
||||||
Icon,
|
Icon,
|
||||||
VStack,
|
VStack,
|
||||||
Breadcrumb,
|
Breadcrumb,
|
||||||
BreadcrumbItem,
|
BreadcrumbItem,
|
||||||
Flex,
|
Flex,
|
||||||
|
Center,
|
||||||
|
Text,
|
||||||
|
Link,
|
||||||
|
HStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import AppShell from "~/components/nav/AppShell";
|
import AppShell from "~/components/nav/AppShell";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton";
|
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
|
||||||
import { ExperimentCard } from "~/components/experiments/ExperimentCard";
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
|
||||||
export default function ExperimentsPage() {
|
export default function ExperimentsPage() {
|
||||||
const experiments = api.experiments.list.useQuery();
|
const experiments = api.experiments.list.useQuery();
|
||||||
|
|
||||||
|
const user = useSession().data;
|
||||||
|
|
||||||
|
if (user === null) {
|
||||||
return (
|
return (
|
||||||
<AppShell>
|
<AppShell title="Experiments">
|
||||||
<VStack alignItems={"flex-start"} m={4} mt={1}>
|
<Center h="100%">
|
||||||
<HStack w="full" justifyContent="space-between" mb={4}>
|
<Text>
|
||||||
|
<Link
|
||||||
|
onClick={() => {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
}}
|
||||||
|
textDecor="underline"
|
||||||
|
>
|
||||||
|
Sign in
|
||||||
|
</Link>{" "}
|
||||||
|
to view or create new experiments!
|
||||||
|
</Text>
|
||||||
|
</Center>
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AppShell title="Experiments">
|
||||||
|
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||||
|
<HStack minH={8} align="center">
|
||||||
<Breadcrumb flex={1}>
|
<Breadcrumb flex={1}>
|
||||||
<BreadcrumbItem>
|
<BreadcrumbItem>
|
||||||
<Flex alignItems="center">
|
<Flex alignItems="center">
|
||||||
@@ -27,9 +52,9 @@ export default function ExperimentsPage() {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
<NewExperimentButton mr={4} borderRadius={8} />
|
|
||||||
</HStack>
|
</HStack>
|
||||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||||
|
<NewExperimentCard />
|
||||||
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</VStack>
|
</VStack>
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
import { EvalType } from "@prisma/client";
|
import { EvalType } from "@prisma/client";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { runAllEvals } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const evaluationsRouter = createTRPCRouter({
|
export const evaluationsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
|
.input(z.object({ experimentId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.evaluation.findMany({
|
return await prisma.evaluation.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -14,7 +19,7 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
@@ -23,7 +28,9 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
evalType: z.nativeEnum(EvalType),
|
evalType: z.nativeEnum(EvalType),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
await prisma.evaluation.create({
|
await prisma.evaluation.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -38,7 +45,7 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
await runAllEvals(input.experimentId);
|
await runAllEvals(input.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
@@ -49,7 +56,12 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
const evaluation = await prisma.evaluation.update({
|
const evaluation = await prisma.evaluation.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: {
|
data: {
|
||||||
@@ -69,7 +81,14 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
await runAllEvals(evaluation.experimentId);
|
await runAllEvals(evaluation.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
await prisma.evaluation.delete({
|
await prisma.evaluation.delete({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,14 +1,32 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import {
|
||||||
|
canModifyExperiment,
|
||||||
|
requireCanModifyExperiment,
|
||||||
|
requireCanViewExperiment,
|
||||||
|
requireNothing,
|
||||||
|
} from "~/utils/accessControl";
|
||||||
|
import userOrg from "~/server/utils/userOrg";
|
||||||
|
import generateTypes from "~/modelProviders/generateTypes";
|
||||||
|
|
||||||
export const experimentsRouter = createTRPCRouter({
|
export const experimentsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.query(async () => {
|
list: protectedProcedure.query(async ({ ctx }) => {
|
||||||
|
// Anyone can list experiments
|
||||||
|
requireNothing(ctx);
|
||||||
|
|
||||||
const experiments = await prisma.experiment.findMany({
|
const experiments = await prisma.experiment.findMany({
|
||||||
|
where: {
|
||||||
|
organization: {
|
||||||
|
OrganizationUser: {
|
||||||
|
some: { userId: ctx.session.user.id },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
orderBy: {
|
orderBy: {
|
||||||
sortIndex: "asc",
|
sortIndex: "desc",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -40,15 +58,29 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
return experimentsWithCounts;
|
return experimentsWithCounts;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => {
|
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||||
return await prisma.experiment.findFirst({
|
await requireCanViewExperiment(input.id, ctx);
|
||||||
where: {
|
const experiment = await prisma.experiment.findFirstOrThrow({
|
||||||
id: input.id,
|
where: { id: input.id },
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const canModify = ctx.session?.user.id
|
||||||
|
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
|
||||||
|
: false;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...experiment,
|
||||||
|
access: {
|
||||||
|
canView: true,
|
||||||
|
canModify,
|
||||||
|
},
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure.input(z.object({})).mutation(async () => {
|
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||||
|
// Anyone can create an experiment
|
||||||
|
requireNothing(ctx);
|
||||||
|
|
||||||
const maxSortIndex =
|
const maxSortIndex =
|
||||||
(
|
(
|
||||||
await prisma.experiment.aggregate({
|
await prisma.experiment.aggregate({
|
||||||
@@ -62,6 +94,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
data: {
|
data: {
|
||||||
sortIndex: maxSortIndex + 1,
|
sortIndex: maxSortIndex + 1,
|
||||||
label: `Experiment ${maxSortIndex + 1}`,
|
label: `Experiment ${maxSortIndex + 1}`,
|
||||||
|
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -76,14 +109,13 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
constructFn: dedent`
|
constructFn: dedent`
|
||||||
/**
|
/**
|
||||||
* Use Javascript to define an OpenAI chat completion
|
* Use Javascript to define an OpenAI chat completion
|
||||||
* (https://platform.openai.com/docs/api-reference/chat/create) and
|
* (https://platform.openai.com/docs/api-reference/chat/create).
|
||||||
* assign it to the \`prompt\` variable.
|
|
||||||
*
|
*
|
||||||
* You have access to the current scenario in the \`scenario\`
|
* You have access to the current scenario in the \`scenario\`
|
||||||
* variable.
|
* variable.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -92,8 +124,10 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};`,
|
});`,
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
modelProvider: "openai/ChatCompletion",
|
||||||
|
constructFnVersion: 2,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
prisma.templateVariable.create({
|
prisma.templateVariable.create({
|
||||||
@@ -117,9 +151,10 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
return exp;
|
return exp;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: protectedProcedure
|
||||||
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.id, ctx);
|
||||||
return await prisma.experiment.update({
|
return await prisma.experiment.update({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -130,11 +165,21 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.id, ctx);
|
||||||
|
|
||||||
await prisma.experiment.delete({
|
await prisma.experiment.delete({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
// Keeping these on `experiment` for now because we might want to limit the
|
||||||
|
// providers based on your account/experiment
|
||||||
|
promptTypes: publicProcedure.query(async () => {
|
||||||
|
return await generateTypes();
|
||||||
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,17 +1,22 @@
|
|||||||
import dedent from "dedent";
|
|
||||||
import { isObject } from "lodash-es";
|
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { OpenAIChatModel } from "~/server/types";
|
import { type SupportedModel } from "~/server/types";
|
||||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
|
||||||
import userError from "~/server/utils/error";
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
|
.input(z.object({ experimentId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.promptVariant.findMany({
|
return await prisma.promptVariant.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -21,7 +26,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
|
stats: publicProcedure
|
||||||
|
.input(z.object({ variantId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.variantId,
|
id: input.variantId,
|
||||||
@@ -32,6 +39,8 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await requireCanViewExperiment(variant.experimentId, ctx);
|
||||||
|
|
||||||
const outputEvals = await prisma.outputEvaluation.groupBy({
|
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||||
by: ["evaluationId"],
|
by: ["evaluationId"],
|
||||||
_sum: {
|
_sum: {
|
||||||
@@ -62,7 +71,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
const evalResults = evals.map((evalItem) => {
|
const evalResults = evals.map((evalItem) => {
|
||||||
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
|
const evalResult = outputEvals.find(
|
||||||
|
(outputEval) => outputEval.evaluationId === evalItem.id,
|
||||||
|
);
|
||||||
return {
|
return {
|
||||||
id: evalItem.id,
|
id: evalItem.id,
|
||||||
label: evalItem.label,
|
label: evalItem.label,
|
||||||
@@ -97,17 +108,14 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
_sum: {
|
_sum: {
|
||||||
|
cost: true,
|
||||||
promptTokens: true,
|
promptTokens: true,
|
||||||
completionTokens: true,
|
completionTokens: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||||
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
|
|
||||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||||
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
|
|
||||||
|
|
||||||
const overallCost = overallPromptCost + overallCompletionCost;
|
|
||||||
|
|
||||||
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
||||||
where: {
|
where: {
|
||||||
@@ -124,21 +132,33 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
evalResults,
|
evalResults,
|
||||||
promptTokens,
|
promptTokens,
|
||||||
completionTokens,
|
completionTokens,
|
||||||
overallCost,
|
overallCost: overallTokens._sum?.cost ?? 0,
|
||||||
scenarioCount,
|
scenarioCount,
|
||||||
outputCount,
|
outputCount,
|
||||||
awaitingRetrievals,
|
awaitingRetrievals,
|
||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
|
variantId: z.string().optional(),
|
||||||
|
newModel: z.string().optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const lastVariant = await prisma.promptVariant.findFirst({
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
|
let originalVariant: PromptVariant | null = null;
|
||||||
|
if (input.variantId) {
|
||||||
|
originalVariant = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: input.variantId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
originalVariant = await prisma.promptVariant.findFirst({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
visible: true,
|
visible: true,
|
||||||
@@ -147,6 +167,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
sortIndex: "desc",
|
sortIndex: "desc",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const largestSortIndex =
|
const largestSortIndex =
|
||||||
(
|
(
|
||||||
@@ -160,24 +181,25 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
})
|
})
|
||||||
)._max?.sortIndex ?? 0;
|
)._max?.sortIndex ?? 0;
|
||||||
|
|
||||||
|
const newVariantLabel =
|
||||||
|
input.variantId && originalVariant
|
||||||
|
? `${originalVariant?.label} Copy`
|
||||||
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
|
const newConstructFn = await deriveNewConstructFn(
|
||||||
|
originalVariant,
|
||||||
|
input.newModel as SupportedModel,
|
||||||
|
);
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
label: newVariantLabel,
|
||||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||||
constructFn:
|
constructFn: newConstructFn,
|
||||||
lastVariant?.constructFn ??
|
constructFnVersion: 2,
|
||||||
dedent`
|
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||||
prompt = {
|
modelProvider: originalVariant?.modelProvider ?? "openai/ChatCompletion",
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: "Return 'Hello, world!'",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`,
|
|
||||||
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -186,6 +208,11 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
if (originalVariant) {
|
||||||
|
// Insert new variant to right of original variant
|
||||||
|
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
|
||||||
|
}
|
||||||
|
|
||||||
const scenarios = await prisma.testScenario.findMany({
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -200,7 +227,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return newVariant;
|
return newVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
@@ -209,7 +236,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const existing = await prisma.promptVariant.findUnique({
|
const existing = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -220,6 +247,8 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
const updatePromptVariantAction = prisma.promptVariant.update({
|
const updatePromptVariantAction = prisma.promptVariant.update({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -235,13 +264,18 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return updatedPromptVariant;
|
return updatedPromptVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
hide: publicProcedure
|
hide: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
const updatedPromptVariant = await prisma.promptVariant.update({
|
const updatedPromptVariant = await prisma.promptVariant.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||||
@@ -250,43 +284,62 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return updatedPromptVariant;
|
return updatedPromptVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
replaceVariant: publicProcedure
|
getRefinedPromptFn: protectedProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
id: z.string(),
|
||||||
|
instructions: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
|
const constructedPrompt = await parseConstructFn(existing.constructFn);
|
||||||
|
|
||||||
|
if ("error" in constructedPrompt) {
|
||||||
|
return userError(constructedPrompt.error);
|
||||||
|
}
|
||||||
|
|
||||||
|
const promptConstructionFn = await deriveNewConstructFn(
|
||||||
|
existing,
|
||||||
|
constructedPrompt.model as SupportedModel,
|
||||||
|
input.instructions,
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: Validate promptConstructionFn
|
||||||
|
// TODO: Record in some sort of history
|
||||||
|
|
||||||
|
return promptConstructionFn;
|
||||||
|
}),
|
||||||
|
|
||||||
|
replaceVariant: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
constructFn: z.string(),
|
constructFn: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const existing = await prisma.promptVariant.findUnique({
|
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
if (!existing) {
|
if (!existing) {
|
||||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
let model = existing.model;
|
const parsedPrompt = await parseConstructFn(input.constructFn);
|
||||||
try {
|
|
||||||
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
|
|
||||||
|
|
||||||
if (!isObject(contructedPrompt)) {
|
if ("error" in parsedPrompt) {
|
||||||
return userError("Prompt is not an object");
|
return userError(parsedPrompt.error);
|
||||||
}
|
|
||||||
if (!("model" in contructedPrompt)) {
|
|
||||||
return userError("Prompt does not define a model");
|
|
||||||
}
|
|
||||||
if (
|
|
||||||
typeof contructedPrompt.model !== "string" ||
|
|
||||||
!(contructedPrompt.model in OpenAIChatModel)
|
|
||||||
) {
|
|
||||||
return userError("Prompt defines an invalid model");
|
|
||||||
}
|
|
||||||
model = contructedPrompt.model;
|
|
||||||
} catch (e) {
|
|
||||||
return userError((e as Error).message);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a duplicate with only the config changed
|
// Create a duplicate with only the config changed
|
||||||
@@ -297,7 +350,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
sortIndex: existing.sortIndex,
|
sortIndex: existing.sortIndex,
|
||||||
uiId: existing.uiId,
|
uiId: existing.uiId,
|
||||||
constructFn: input.constructFn,
|
constructFn: input.constructFn,
|
||||||
model,
|
constructFnVersion: 2,
|
||||||
|
modelProvider: parsedPrompt.modelProvider,
|
||||||
|
model: parsedPrompt.model,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -330,72 +385,19 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return { status: "ok" } as const;
|
return { status: "ok" } as const;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
reorder: publicProcedure
|
reorder: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
draggedId: z.string(),
|
draggedId: z.string(),
|
||||||
droppedId: z.string(),
|
droppedId: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const dragged = await prisma.promptVariant.findUnique({
|
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
where: {
|
where: { id: input.draggedId },
|
||||||
id: input.draggedId,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
const dropped = await prisma.promptVariant.findUnique({
|
await reorderPromptVariants(input.draggedId, input.droppedId);
|
||||||
where: {
|
|
||||||
id: input.droppedId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
|
|
||||||
throw new Error(
|
|
||||||
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const visibleItems = await prisma.promptVariant.findMany({
|
|
||||||
where: {
|
|
||||||
experimentId: dragged.experimentId,
|
|
||||||
visible: true,
|
|
||||||
},
|
|
||||||
orderBy: {
|
|
||||||
sortIndex: "asc",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remove the dragged item from its current position
|
|
||||||
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
|
|
||||||
|
|
||||||
// Find the index of the dragged item and the dropped item
|
|
||||||
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
|
|
||||||
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
|
|
||||||
|
|
||||||
// Determine the new index for the dragged item
|
|
||||||
let newIndex;
|
|
||||||
if (dragIndex < dropIndex) {
|
|
||||||
newIndex = dropIndex + 1; // Insert after the dropped item
|
|
||||||
} else {
|
|
||||||
newIndex = dropIndex; // Insert before the dropped item
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert the dragged item at the new position
|
|
||||||
orderedItems.splice(newIndex, 0, dragged);
|
|
||||||
|
|
||||||
// Now, we need to update all the items with their new sortIndex
|
|
||||||
await prisma.$transaction(
|
|
||||||
orderedItems.map((item, index) => {
|
|
||||||
return prisma.promptVariant.update({
|
|
||||||
where: {
|
|
||||||
id: item.id,
|
|
||||||
},
|
|
||||||
data: {
|
|
||||||
sortIndex: index,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||||
get: publicProcedure
|
get: publicProcedure
|
||||||
@@ -12,7 +13,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
variantId: z.string(),
|
variantId: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.query(async ({ input }) => {
|
.query(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||||
|
where: { id: input.scenarioId },
|
||||||
|
});
|
||||||
|
await requireCanViewExperiment(experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.scenarioVariantCell.findUnique({
|
return await prisma.scenarioVariantCell.findUnique({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId_testScenarioId: {
|
promptVariantId_testScenarioId: {
|
||||||
@@ -35,14 +41,20 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
forceRefetch: publicProcedure
|
forceRefetch: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
scenarioId: z.string(),
|
scenarioId: z.string(),
|
||||||
variantId: z.string(),
|
variantId: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||||
|
where: { id: input.scenarioId },
|
||||||
|
});
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId_testScenarioId: {
|
promptVariantId_testScenarioId: {
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { autogenerateScenarioValues } from "../autogen";
|
import { autogenerateScenarioValues } from "../autogen";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { runAllEvals } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
|
.input(z.object({ experimentId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.testScenario.findMany({
|
return await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -19,14 +24,16 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
autogenerate: z.boolean().optional(),
|
autogenerate: z.boolean().optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
const maxSortIndex =
|
const maxSortIndex =
|
||||||
(
|
(
|
||||||
await prisma.testScenario.aggregate({
|
await prisma.testScenario.aggregate({
|
||||||
@@ -66,7 +73,14 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||||
|
const experimentId = (
|
||||||
|
await prisma.testScenario.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
})
|
||||||
|
).experimentId;
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
const hiddenScenario = await prisma.testScenario.update({
|
const hiddenScenario = await prisma.testScenario.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||||
@@ -78,14 +92,14 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
return hiddenScenario;
|
return hiddenScenario;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
reorder: publicProcedure
|
reorder: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
draggedId: z.string(),
|
draggedId: z.string(),
|
||||||
droppedId: z.string(),
|
droppedId: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const dragged = await prisma.testScenario.findUnique({
|
const dragged = await prisma.testScenario.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.draggedId,
|
id: input.draggedId,
|
||||||
@@ -104,6 +118,8 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(dragged.experimentId, ctx);
|
||||||
|
|
||||||
const visibleItems = await prisma.testScenario.findMany({
|
const visibleItems = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: dragged.experimentId,
|
experimentId: dragged.experimentId,
|
||||||
@@ -147,14 +163,14 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
);
|
);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
replaceWithValues: publicProcedure
|
replaceWithValues: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
values: z.record(z.string()),
|
values: z.record(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const existing = await prisma.testScenario.findUnique({
|
const existing = await prisma.testScenario.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -165,6 +181,8 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
throw new Error(`Scenario with id ${input.id} does not exist`);
|
throw new Error(`Scenario with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
const newScenario = await prisma.testScenario.create({
|
const newScenario = await prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: existing.experimentId,
|
experimentId: existing.experimentId,
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const templateVarsRouter = createTRPCRouter({
|
export const templateVarsRouter = createTRPCRouter({
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(z.object({ experimentId: z.string(), label: z.string() }))
|
.input(z.object({ experimentId: z.string(), label: z.string() }))
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
await prisma.templateVariable.create({
|
await prisma.templateVariable.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -14,11 +17,22 @@ export const templateVarsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
await prisma.templateVariable.delete({ where: { id: input.id } });
|
await prisma.templateVariable.delete({ where: { id: input.id } });
|
||||||
}),
|
}),
|
||||||
|
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
|
.input(z.object({ experimentId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
return await prisma.templateVariable.findMany({
|
return await prisma.templateVariable.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ type CreateContextOptions = {
|
|||||||
session: Session | null;
|
session: Session | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
||||||
|
const noOp = () => {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
|
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
|
||||||
* it from here.
|
* it from here.
|
||||||
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
|
|||||||
return {
|
return {
|
||||||
session: opts.session,
|
session: opts.session,
|
||||||
prisma,
|
prisma,
|
||||||
|
markAccessControlRun: noOp,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
|
|||||||
* errors on the backend.
|
* errors on the backend.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
|
||||||
|
|
||||||
const t = initTRPC.context<typeof createTRPCContext>().create({
|
const t = initTRPC.context<typeof createTRPCContext>().create({
|
||||||
transformer: superjson,
|
transformer: superjson,
|
||||||
errorFormatter({ shape, error }) {
|
errorFormatter({ shape, error }) {
|
||||||
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
|
|||||||
export const publicProcedure = t.procedure;
|
export const publicProcedure = t.procedure;
|
||||||
|
|
||||||
/** Reusable middleware that enforces users are logged in before running the procedure. */
|
/** Reusable middleware that enforces users are logged in before running the procedure. */
|
||||||
const enforceUserIsAuthed = t.middleware(({ ctx, next }) => {
|
const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
|
||||||
if (!ctx.session || !ctx.session.user) {
|
if (!ctx.session || !ctx.session.user) {
|
||||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||||
}
|
}
|
||||||
return next({
|
|
||||||
|
let accessControlRun = false;
|
||||||
|
const resp = await next({
|
||||||
ctx: {
|
ctx: {
|
||||||
// infers the `session` as non-nullable
|
// infers the `session` as non-nullable
|
||||||
session: { ...ctx.session, user: ctx.session.user },
|
session: { ...ctx.session, user: ctx.session.user },
|
||||||
|
markAccessControlRun: () => {
|
||||||
|
accessControlRun = true;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
if (!accessControlRun)
|
||||||
|
throw new TRPCError({
|
||||||
|
code: "INTERNAL_SERVER_ERROR",
|
||||||
|
message:
|
||||||
|
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
|
||||||
|
});
|
||||||
|
|
||||||
|
return resp;
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
|
|||||||
import { type GetServerSidePropsContext } from "next";
|
import { type GetServerSidePropsContext } from "next";
|
||||||
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
|
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import GitHubProvider from "next-auth/providers/github";
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
|
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
|
||||||
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
|
|||||||
},
|
},
|
||||||
adapter: PrismaAdapter(prisma),
|
adapter: PrismaAdapter(prisma),
|
||||||
providers: [
|
providers: [
|
||||||
// DiscordProvider({
|
GitHubProvider({
|
||||||
// clientId: env.DISCORD_CLIENT_ID,
|
clientId: env.GITHUB_CLIENT_ID,
|
||||||
// clientSecret: env.DISCORD_CLIENT_SECRET,
|
clientSecret: env.GITHUB_CLIENT_SECRET,
|
||||||
// }),
|
}),
|
||||||
/**
|
|
||||||
* ...add more providers here.
|
|
||||||
*
|
|
||||||
* Most other providers require a bit more work than the Discord provider. For example, the
|
|
||||||
* GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
|
|
||||||
* model. Refer to the NextAuth.js docs for the provider you want to use. Example:
|
|
||||||
*
|
|
||||||
* @see https://next-auth.js.org/providers/github
|
|
||||||
*/
|
|
||||||
],
|
],
|
||||||
|
theme: {
|
||||||
|
logo: "/logo.svg",
|
||||||
|
brandColor: "#ff5733",
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
|
|||||||
export const prisma =
|
export const prisma =
|
||||||
globalForPrisma.prisma ??
|
globalForPrisma.prisma ??
|
||||||
new PrismaClient({
|
new PrismaClient({
|
||||||
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"],
|
log:
|
||||||
|
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
|
||||||
|
? ["query", "error", "warn"]
|
||||||
|
: ["error"],
|
||||||
});
|
});
|
||||||
|
|
||||||
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
||||||
|
|||||||
45
src/server/scripts/migrateConstructFns.test.ts
Normal file
45
src/server/scripts/migrateConstructFns.test.ts
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import "dotenv/config";
|
||||||
|
import dedent from "dedent";
|
||||||
|
import { expect, test } from "vitest";
|
||||||
|
import { migrate1to2 } from "./migrateConstructFns";
|
||||||
|
|
||||||
|
test("migrate1to2", () => {
|
||||||
|
const constructFn = dedent`
|
||||||
|
// Test comment
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "What is the capital of China?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
const migrated = migrate1to2(constructFn);
|
||||||
|
expect(migrated).toBe(dedent`
|
||||||
|
// Test comment
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "What is the capital of China?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
`);
|
||||||
|
|
||||||
|
// console.log(
|
||||||
|
// migrateConstructFn(dedent`definePrompt(
|
||||||
|
// "openai/ChatCompletion",
|
||||||
|
// {
|
||||||
|
// model: 'gpt-3.5-turbo-0613',
|
||||||
|
// messages: []
|
||||||
|
// }
|
||||||
|
// )`),
|
||||||
|
// );
|
||||||
|
});
|
||||||
58
src/server/scripts/migrateConstructFns.ts
Normal file
58
src/server/scripts/migrateConstructFns.ts
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import * as recast from "recast";
|
||||||
|
import { type ASTNode } from "ast-types";
|
||||||
|
import { prisma } from "../db";
|
||||||
|
import { fileURLToPath } from "url";
|
||||||
|
const { builders: b } = recast.types;
|
||||||
|
|
||||||
|
export const migrate1to2 = (fnBody: string): string => {
|
||||||
|
const ast: ASTNode = recast.parse(fnBody);
|
||||||
|
|
||||||
|
recast.visit(ast, {
|
||||||
|
visitAssignmentExpression(path) {
|
||||||
|
const node = path.node;
|
||||||
|
if ("name" in node.left && node.left.name === "prompt") {
|
||||||
|
const functionCall = b.callExpression(b.identifier("definePrompt"), [
|
||||||
|
b.literal("openai/ChatCompletion"),
|
||||||
|
node.right,
|
||||||
|
]);
|
||||||
|
path.replace(functionCall);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return recast.print(ast).code;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default async function migrateConstructFns() {
|
||||||
|
const v1Prompts = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
constructFnVersion: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
console.log(`Migrating ${v1Prompts.length} prompts 1->2`);
|
||||||
|
await Promise.all(
|
||||||
|
v1Prompts.map(async (variant) => {
|
||||||
|
try {
|
||||||
|
await prisma.promptVariant.update({
|
||||||
|
where: {
|
||||||
|
id: variant.id,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
constructFn: migrate1to2(variant.constructFn),
|
||||||
|
constructFnVersion: 2,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error migrating constructFn for variant", variant.id, e);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we're running this file directly, run the migration
|
||||||
|
if (process.argv.at(-1) === fileURLToPath(import.meta.url)) {
|
||||||
|
console.log("Running migration");
|
||||||
|
await migrateConstructFns();
|
||||||
|
console.log("Done");
|
||||||
|
}
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
import { type Prisma } from "@prisma/client";
|
|
||||||
import { prisma } from "../db";
|
|
||||||
|
|
||||||
async function migrateScenarioVariantOutputData() {
|
|
||||||
// Get all ScenarioVariantCells
|
|
||||||
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
|
|
||||||
console.log(`Found ${cells.length} records`);
|
|
||||||
|
|
||||||
let updatedCount = 0;
|
|
||||||
|
|
||||||
// Loop through all scenarioVariants
|
|
||||||
for (const cell of cells) {
|
|
||||||
// Create a new ModelOutput for each ScenarioVariant with an existing output
|
|
||||||
if (cell.output && !cell.modelOutput) {
|
|
||||||
updatedCount++;
|
|
||||||
await prisma.modelOutput.create({
|
|
||||||
data: {
|
|
||||||
scenarioVariantCellId: cell.id,
|
|
||||||
inputHash: cell.inputHash || "",
|
|
||||||
output: cell.output as Prisma.InputJsonValue,
|
|
||||||
timeToComplete: cell.timeToComplete ?? undefined,
|
|
||||||
promptTokens: cell.promptTokens,
|
|
||||||
completionTokens: cell.completionTokens,
|
|
||||||
createdAt: cell.createdAt,
|
|
||||||
updatedAt: cell.updatedAt,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
|
|
||||||
updatedCount++;
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: cell.id },
|
|
||||||
data: {
|
|
||||||
retrievalStatus: "ERROR",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
console.log("Data migration completed");
|
|
||||||
console.log(`Updated ${updatedCount} records`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute the function
|
|
||||||
migrateScenarioVariantOutputData().catch((error) => {
|
|
||||||
console.error("An error occurred while migrating data: ", error);
|
|
||||||
process.exit(1);
|
|
||||||
});
|
|
||||||
19
src/server/scripts/openai-test.ts
Normal file
19
src/server/scripts/openai-test.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import "dotenv/config";
|
||||||
|
import { openai } from "../utils/openai";
|
||||||
|
|
||||||
|
const resp = await openai.chat.completions.create({
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "count to 20",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const part of resp) {
|
||||||
|
console.log("part", part);
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("final resp", resp);
|
||||||
26
src/server/scripts/replicate-test.ts
Normal file
26
src/server/scripts/replicate-test.ts
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import "dotenv/config";
|
||||||
|
import Replicate from "replicate";
|
||||||
|
|
||||||
|
const replicate = new Replicate({
|
||||||
|
auth: process.env.REPLICATE_API_TOKEN || "",
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log("going to run");
|
||||||
|
const prediction = await replicate.predictions.create({
|
||||||
|
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||||
|
input: {
|
||||||
|
prompt: "...",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log("waiting");
|
||||||
|
setInterval(() => {
|
||||||
|
replicate.predictions.get(prediction.id).then((prediction) => {
|
||||||
|
console.log(prediction);
|
||||||
|
});
|
||||||
|
}, 500);
|
||||||
|
// const output = await replicate.wait(prediction, {});
|
||||||
|
|
||||||
|
// console.log(output);
|
||||||
@@ -1,15 +1,18 @@
|
|||||||
import crypto from "crypto";
|
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import defineTask from "./defineTask";
|
import defineTask from "./defineTask";
|
||||||
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
|
|
||||||
import { type JSONSerializable } from "../types";
|
|
||||||
import { sleep } from "../utils/sleep";
|
import { sleep } from "../utils/sleep";
|
||||||
import { shouldStream } from "../utils/shouldStream";
|
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
import { runEvalsForOutput } from "../utils/evaluations";
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
import { constructPrompt } from "../utils/constructPrompt";
|
|
||||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
|
||||||
import { type Prisma } from "@prisma/client";
|
import { type Prisma } from "@prisma/client";
|
||||||
|
import parseConstructFn from "../utils/parseConstructFn";
|
||||||
|
import hashPrompt from "../utils/hashPrompt";
|
||||||
|
import { type JsonObject } from "type-fest";
|
||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
|
|
||||||
|
export type queryLLMJob = {
|
||||||
|
scenarioVariantCellId: string;
|
||||||
|
};
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 10;
|
const MAX_AUTO_RETRIES = 10;
|
||||||
const MIN_DELAY = 500; // milliseconds
|
const MIN_DELAY = 500; // milliseconds
|
||||||
@@ -21,45 +24,6 @@ function calculateDelay(numPreviousTries: number): number {
|
|||||||
return baseDelay + jitter;
|
return baseDelay + jitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
const getCompletionWithRetries = async (
|
|
||||||
cellId: string,
|
|
||||||
payload: JSONSerializable,
|
|
||||||
channel?: string,
|
|
||||||
): Promise<CompletionResponse> => {
|
|
||||||
let modelResponse: CompletionResponse | null = null;
|
|
||||||
try {
|
|
||||||
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
|
|
||||||
modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel);
|
|
||||||
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
|
|
||||||
return modelResponse;
|
|
||||||
}
|
|
||||||
const delay = calculateDelay(i);
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: cellId },
|
|
||||||
data: {
|
|
||||||
errorMessage: "Rate limit exceeded",
|
|
||||||
statusCode: 429,
|
|
||||||
retryTime: new Date(Date.now() + delay),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
// TODO: Maybe requeue the job so other jobs can run in the future?
|
|
||||||
await sleep(delay);
|
|
||||||
}
|
|
||||||
throw new Error("Max retries limit reached");
|
|
||||||
} catch (error: unknown) {
|
|
||||||
return {
|
|
||||||
statusCode: modelResponse?.statusCode ?? 500,
|
|
||||||
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
|
|
||||||
output: null as unknown as Prisma.InputJsonValue,
|
|
||||||
timeToComplete: 0,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
export type queryLLMJob = {
|
|
||||||
scenarioVariantCellId: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||||
const { scenarioVariantCellId } = task;
|
const { scenarioVariantCellId } = task;
|
||||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
@@ -67,6 +31,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
include: { modelOutput: true },
|
include: { modelOutput: true },
|
||||||
});
|
});
|
||||||
if (!cell) {
|
if (!cell) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 404,
|
||||||
|
errorMessage: "Cell not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,6 +57,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
where: { id: cell.promptVariantId },
|
where: { id: cell.promptVariantId },
|
||||||
});
|
});
|
||||||
if (!variant) {
|
if (!variant) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 404,
|
||||||
|
errorMessage: "Prompt Variant not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,63 +72,97 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
where: { id: cell.testScenarioId },
|
where: { id: cell.testScenarioId },
|
||||||
});
|
});
|
||||||
if (!scenario) {
|
if (!scenario) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 404,
|
||||||
|
errorMessage: "Scenario not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
const prompt = await parseConstructFn(variant.constructFn, scenario.variableValues as JsonObject);
|
||||||
|
|
||||||
const streamingEnabled = shouldStream(prompt);
|
if ("error" in prompt) {
|
||||||
let streamingChannel;
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 400,
|
||||||
|
errorMessage: prompt.error,
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (streamingEnabled) {
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
streamingChannel = generateChannel();
|
|
||||||
|
// @ts-expect-error TODO FIX ASAP
|
||||||
|
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||||
|
|
||||||
|
if (streamingChannel) {
|
||||||
// Save streaming channel so that UI can connect to it
|
// Save streaming channel so that UI can connect to it
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: scenarioVariantCellId },
|
||||||
data: {
|
data: { streamingChannel },
|
||||||
streamingChannel,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
const onStream = streamingChannel
|
||||||
|
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
||||||
|
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
const modelResponse = await getCompletionWithRetries(
|
for (let i = 0; true; i++) {
|
||||||
scenarioVariantCellId,
|
// @ts-expect-error TODO FIX ASAP
|
||||||
prompt,
|
|
||||||
streamingChannel,
|
|
||||||
);
|
|
||||||
|
|
||||||
let modelOutput = null;
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
if (modelResponse.statusCode === 200) {
|
if (response.type === "success") {
|
||||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
const inputHash = hashPrompt(prompt);
|
||||||
|
|
||||||
modelOutput = await prisma.modelOutput.create({
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
data: {
|
data: {
|
||||||
scenarioVariantCellId,
|
scenarioVariantCellId,
|
||||||
inputHash,
|
inputHash,
|
||||||
output: modelResponse.output,
|
output: response.value as unknown as Prisma.InputJsonObject,
|
||||||
timeToComplete: modelResponse.timeToComplete,
|
timeToComplete: response.timeToComplete,
|
||||||
promptTokens: modelResponse.promptTokens,
|
promptTokens: response.promptTokens,
|
||||||
completionTokens: modelResponse.completionTokens,
|
completionTokens: response.completionTokens,
|
||||||
|
cost: response.cost,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: scenarioVariantCellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: modelResponse.statusCode,
|
statusCode: response.statusCode,
|
||||||
errorMessage: modelResponse.errorMessage,
|
retrievalStatus: "COMPLETE",
|
||||||
streamingChannel: null,
|
|
||||||
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
|
|
||||||
modelOutput: {
|
|
||||||
connect: {
|
|
||||||
id: modelOutput?.id,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (modelOutput) {
|
|
||||||
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
||||||
|
const delay = calculateDelay(i);
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
errorMessage: response.message,
|
||||||
|
statusCode: response.statusCode,
|
||||||
|
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||||
|
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (shouldRetry) {
|
||||||
|
await sleep(delay);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,14 +1,3 @@
|
|||||||
export type JSONSerializable =
|
|
||||||
| string
|
|
||||||
| number
|
|
||||||
| boolean
|
|
||||||
| null
|
|
||||||
| JSONSerializable[]
|
|
||||||
| { [key: string]: JSONSerializable };
|
|
||||||
|
|
||||||
// Placeholder for now
|
|
||||||
export type OpenAIChatConfig = NonNullable<JSONSerializable>;
|
|
||||||
|
|
||||||
export enum OpenAIChatModel {
|
export enum OpenAIChatModel {
|
||||||
"gpt-4" = "gpt-4",
|
"gpt-4" = "gpt-4",
|
||||||
"gpt-4-0613" = "gpt-4-0613",
|
"gpt-4-0613" = "gpt-4-0613",
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import { test } from "vitest";
|
|
||||||
import { constructPrompt } from "./constructPrompt";
|
|
||||||
|
|
||||||
test.skip("constructPrompt", async () => {
|
|
||||||
const constructed = await constructPrompt(
|
|
||||||
{
|
|
||||||
constructFn: `prompt = { "fooz": "bar" }`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
foo: "bar",
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
console.log(constructed);
|
|
||||||
});
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
import { type PromptVariant, type TestScenario } from "@prisma/client";
|
|
||||||
import ivm from "isolated-vm";
|
|
||||||
import { type JSONSerializable } from "../types";
|
|
||||||
|
|
||||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
|
||||||
|
|
||||||
export async function constructPrompt(
|
|
||||||
variant: Pick<PromptVariant, "constructFn">,
|
|
||||||
scenario: TestScenario["variableValues"],
|
|
||||||
): Promise<JSONSerializable> {
|
|
||||||
const code = `
|
|
||||||
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
|
||||||
let prompt
|
|
||||||
|
|
||||||
${variant.constructFn}
|
|
||||||
|
|
||||||
global.prompt = prompt;
|
|
||||||
`;
|
|
||||||
|
|
||||||
console.log("code is", code);
|
|
||||||
|
|
||||||
const context = await isolate.createContext();
|
|
||||||
|
|
||||||
const jail = context.global;
|
|
||||||
await jail.set("global", jail.derefInto());
|
|
||||||
|
|
||||||
const script = await isolate.compileScript(code);
|
|
||||||
|
|
||||||
await script.run(context);
|
|
||||||
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
|
|
||||||
|
|
||||||
const prompt = await promptReference.copy(); // Get the actual value from the isolate
|
|
||||||
|
|
||||||
return prompt as JSONSerializable;
|
|
||||||
}
|
|
||||||
123
src/server/utils/deriveNewContructFn.ts
Normal file
123
src/server/utils/deriveNewContructFn.ts
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { type SupportedModel } from "../types";
|
||||||
|
import ivm from "isolated-vm";
|
||||||
|
import dedent from "dedent";
|
||||||
|
import { openai } from "./openai";
|
||||||
|
import { getApiShapeForModel } from "./getTypesForModel";
|
||||||
|
import { isObject } from "lodash-es";
|
||||||
|
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
||||||
|
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||||
|
|
||||||
|
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||||
|
|
||||||
|
export async function deriveNewConstructFn(
|
||||||
|
originalVariant: PromptVariant | null,
|
||||||
|
newModel?: SupportedModel,
|
||||||
|
instructions?: string,
|
||||||
|
) {
|
||||||
|
if (originalVariant && !newModel && !instructions) {
|
||||||
|
return originalVariant.constructFn;
|
||||||
|
}
|
||||||
|
if (originalVariant && (newModel || instructions)) {
|
||||||
|
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
|
||||||
|
}
|
||||||
|
return dedent`
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "Return 'Hello, world!'",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NUM_RETRIES = 5;
|
||||||
|
const requestUpdatedPromptFunction = async (
|
||||||
|
originalVariant: PromptVariant,
|
||||||
|
newModel?: SupportedModel,
|
||||||
|
instructions?: string,
|
||||||
|
) => {
|
||||||
|
const originalModel = originalVariant.model as SupportedModel;
|
||||||
|
let newContructionFn = "";
|
||||||
|
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||||
|
try {
|
||||||
|
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming.Message[] = [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
||||||
|
getApiShapeForModel(originalModel),
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
)}\n\nDo not add any assistant messages.`,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
if (newModel) {
|
||||||
|
messages.push({
|
||||||
|
role: "user",
|
||||||
|
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (instructions) {
|
||||||
|
messages.push({
|
||||||
|
role: "user",
|
||||||
|
content: instructions,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
messages.push({
|
||||||
|
role: "system",
|
||||||
|
content: "The prompt variable has already been declared, so do not declare it again.",
|
||||||
|
});
|
||||||
|
const completion = await openai.chat.completions.create({
|
||||||
|
model: "gpt-4",
|
||||||
|
messages,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "update_prompt_constructor_function",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
new_prompt_function: {
|
||||||
|
type: "string",
|
||||||
|
description: "The new prompt function, runnable in typescript",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "update_prompt_constructor_function",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
|
||||||
|
|
||||||
|
const code = `
|
||||||
|
global.contructPromptFunctionArgs = ${argString};
|
||||||
|
`;
|
||||||
|
|
||||||
|
const context = await isolate.createContext();
|
||||||
|
|
||||||
|
const jail = context.global;
|
||||||
|
await jail.set("global", jail.derefInto());
|
||||||
|
|
||||||
|
const script = await isolate.compileScript(code);
|
||||||
|
|
||||||
|
await script.run(context);
|
||||||
|
const contructPromptFunctionArgs = (await context.global.get(
|
||||||
|
"contructPromptFunctionArgs",
|
||||||
|
)) as ivm.Reference;
|
||||||
|
|
||||||
|
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
|
||||||
|
|
||||||
|
if (args && isObject(args) && "new_prompt_function" in args) {
|
||||||
|
newContructionFn = await formatPromptConstructor(args.new_prompt_function as string);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newContructionFn;
|
||||||
|
};
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
import { type JSONSerializable } from "../types";
|
|
||||||
|
|
||||||
export type VariableMap = Record<string, string>;
|
export type VariableMap = Record<string, string>;
|
||||||
|
|
||||||
// Escape quotes to match the way we encode JSON
|
// Escape quotes to match the way we encode JSON
|
||||||
@@ -15,24 +13,3 @@ export function escapeRegExp(str: string) {
|
|||||||
export function fillTemplate(template: string, variables: VariableMap): string {
|
export function fillTemplate(template: string, variables: VariableMap): string {
|
||||||
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function fillTemplateJson<T extends JSONSerializable>(
|
|
||||||
template: T,
|
|
||||||
variables: VariableMap,
|
|
||||||
): T {
|
|
||||||
if (typeof template === "string") {
|
|
||||||
return fillTemplate(template, variables) as T;
|
|
||||||
} else if (Array.isArray(template)) {
|
|
||||||
return template.map((item) => fillTemplateJson(item, variables)) as T;
|
|
||||||
} else if (typeof template === "object" && template !== null) {
|
|
||||||
return Object.keys(template).reduce(
|
|
||||||
(acc, key) => {
|
|
||||||
acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables);
|
|
||||||
return acc;
|
|
||||||
},
|
|
||||||
{} as { [key: string]: JSONSerializable } & T,
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
return template;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import crypto from "crypto";
|
|
||||||
import { type Prisma } from "@prisma/client";
|
import { type Prisma } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||||
import { constructPrompt } from "./constructPrompt";
|
import parseConstructFn from "./parseConstructFn";
|
||||||
|
import { type JsonObject } from "type-fest";
|
||||||
|
import hashPrompt from "./hashPrompt";
|
||||||
|
|
||||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
@@ -19,10 +20,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
|
|
||||||
if (!variant || !scenario) return null;
|
if (!variant || !scenario) return null;
|
||||||
|
|
||||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
|
||||||
|
|
||||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
|
||||||
|
|
||||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId_testScenarioId: {
|
promptVariantId_testScenarioId: {
|
||||||
@@ -37,10 +34,31 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
|
|
||||||
if (cell) return cell;
|
if (cell) return cell;
|
||||||
|
|
||||||
|
const parsedConstructFn = await parseConstructFn(
|
||||||
|
variant.constructFn,
|
||||||
|
scenario.variableValues as JsonObject,
|
||||||
|
);
|
||||||
|
|
||||||
|
if ("error" in parsedConstructFn) {
|
||||||
|
return await prisma.scenarioVariantCell.create({
|
||||||
|
data: {
|
||||||
|
promptVariantId: variantId,
|
||||||
|
testScenarioId: scenarioId,
|
||||||
|
statusCode: 400,
|
||||||
|
errorMessage: parsedConstructFn.error,
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const inputHash = hashPrompt(parsedConstructFn);
|
||||||
|
|
||||||
cell = await prisma.scenarioVariantCell.create({
|
cell = await prisma.scenarioVariantCell.create({
|
||||||
data: {
|
data: {
|
||||||
promptVariantId: variantId,
|
promptVariantId: variantId,
|
||||||
testScenarioId: scenarioId,
|
testScenarioId: scenarioId,
|
||||||
|
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
|
||||||
|
retrievalStatus: "PENDING",
|
||||||
},
|
},
|
||||||
include: {
|
include: {
|
||||||
modelOutput: true,
|
modelOutput: true,
|
||||||
@@ -48,9 +66,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
});
|
});
|
||||||
|
|
||||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||||
where: {
|
where: { inputHash },
|
||||||
inputHash,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let newModelOutput;
|
let newModelOutput;
|
||||||
@@ -62,12 +78,17 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
inputHash,
|
inputHash,
|
||||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||||
timeToComplete: matchingModelOutput.timeToComplete,
|
timeToComplete: matchingModelOutput.timeToComplete,
|
||||||
|
cost: matchingModelOutput.cost,
|
||||||
promptTokens: matchingModelOutput.promptTokens,
|
promptTokens: matchingModelOutput.promptTokens,
|
||||||
completionTokens: matchingModelOutput.completionTokens,
|
completionTokens: matchingModelOutput.completionTokens,
|
||||||
createdAt: matchingModelOutput.createdAt,
|
createdAt: matchingModelOutput.createdAt,
|
||||||
updatedAt: matchingModelOutput.updatedAt,
|
updatedAt: matchingModelOutput.updatedAt,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cell.id },
|
||||||
|
data: { retrievalStatus: "COMPLETE" },
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
cell = await queueLLMRetrievalTask(cell.id);
|
cell = await queueLLMRetrievalTask(cell.id);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,99 +0,0 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
|
||||||
import { isObject } from "lodash-es";
|
|
||||||
import { Prisma } from "@prisma/client";
|
|
||||||
import { streamChatCompletion } from "./openai";
|
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
|
||||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
|
||||||
import { type OpenAIChatModel } from "../types";
|
|
||||||
import { env } from "~/env.mjs";
|
|
||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
|
||||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
|
||||||
|
|
||||||
export type CompletionResponse = {
|
|
||||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
|
||||||
statusCode: number;
|
|
||||||
errorMessage: string | null;
|
|
||||||
timeToComplete: number;
|
|
||||||
promptTokens?: number;
|
|
||||||
completionTokens?: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
export async function getCompletion(
|
|
||||||
payload: CompletionCreateParams,
|
|
||||||
channel?: string,
|
|
||||||
): Promise<CompletionResponse> {
|
|
||||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
|
||||||
if (payload.functions?.length) payload.stream = false;
|
|
||||||
const start = Date.now();
|
|
||||||
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(payload),
|
|
||||||
});
|
|
||||||
|
|
||||||
const resp: CompletionResponse = {
|
|
||||||
output: Prisma.JsonNull,
|
|
||||||
errorMessage: null,
|
|
||||||
statusCode: response.status,
|
|
||||||
timeToComplete: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (payload.stream) {
|
|
||||||
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
|
||||||
let finalOutput: ChatCompletion | null = null;
|
|
||||||
await (async () => {
|
|
||||||
for await (const partialCompletion of completion) {
|
|
||||||
finalOutput = partialCompletion;
|
|
||||||
wsConnection.emit("message", { channel, payload: partialCompletion });
|
|
||||||
}
|
|
||||||
})().catch((err) => console.error(err));
|
|
||||||
if (finalOutput) {
|
|
||||||
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
|
||||||
resp.timeToComplete = Date.now() - start;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
resp.timeToComplete = Date.now() - start;
|
|
||||||
resp.output = await response.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
if (response.status === 429) {
|
|
||||||
resp.errorMessage = rateLimitErrorMessage;
|
|
||||||
} else if (
|
|
||||||
isObject(resp.output) &&
|
|
||||||
"error" in resp.output &&
|
|
||||||
isObject(resp.output.error) &&
|
|
||||||
"message" in resp.output.error
|
|
||||||
) {
|
|
||||||
// If it's an object, try to get the error message
|
|
||||||
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isObject(resp.output) && "usage" in resp.output) {
|
|
||||||
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
|
|
||||||
resp.promptTokens = usage.prompt_tokens;
|
|
||||||
resp.completionTokens = usage.completion_tokens;
|
|
||||||
} else if (isObject(resp.output) && "choices" in resp.output) {
|
|
||||||
const model = payload.model as unknown as OpenAIChatModel;
|
|
||||||
resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
|
|
||||||
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
|
|
||||||
const message = choices[0]?.message;
|
|
||||||
if (message) {
|
|
||||||
const messages = [message];
|
|
||||||
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
if (response.ok) {
|
|
||||||
resp.errorMessage = "Failed to parse response";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp;
|
|
||||||
}
|
|
||||||
6
src/server/utils/getTypesForModel.ts
Normal file
6
src/server/utils/getTypesForModel.ts
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import { type SupportedModel } from "../types";
|
||||||
|
|
||||||
|
export const getApiShapeForModel = (model: SupportedModel) => {
|
||||||
|
// if (model in OpenAIChatModel) return openAIChatApiShape;
|
||||||
|
return "";
|
||||||
|
};
|
||||||
37
src/server/utils/hashPrompt.ts
Normal file
37
src/server/utils/hashPrompt.ts
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import { type JsonValue } from "type-fest";
|
||||||
|
import { type ParsedConstructFn } from "./parseConstructFn";
|
||||||
|
|
||||||
|
function sortKeys(obj: JsonValue): JsonValue {
|
||||||
|
if (typeof obj !== "object" || obj === null) {
|
||||||
|
// Not an object or array, return as is
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Array.isArray(obj)) {
|
||||||
|
return obj.map(sortKeys);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get keys and sort them
|
||||||
|
const keys = Object.keys(obj).sort();
|
||||||
|
const sortedObj = {};
|
||||||
|
|
||||||
|
for (const key of keys) {
|
||||||
|
// @ts-expect-error not worth fixing types
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
|
||||||
|
sortedObj[key] = sortKeys(obj[key]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sortedObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function hashPrompt(prompt: ParsedConstructFn<any>): string {
|
||||||
|
// Sort object keys recursively
|
||||||
|
const sortedObj = sortKeys(prompt as unknown as JsonValue);
|
||||||
|
|
||||||
|
// Convert to JSON and hash it
|
||||||
|
const str = JSON.stringify(sortedObj);
|
||||||
|
const hash = crypto.createHash("sha256");
|
||||||
|
hash.update(str);
|
||||||
|
return hash.digest("hex");
|
||||||
|
}
|
||||||
@@ -1,64 +1,5 @@
|
|||||||
import { omit } from "lodash-es";
|
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
import OpenAI from "openai";
|
import OpenAI from "openai";
|
||||||
import {
|
|
||||||
type ChatCompletion,
|
|
||||||
type ChatCompletionChunk,
|
|
||||||
type CompletionCreateParams,
|
|
||||||
} from "openai/resources/chat";
|
|
||||||
|
|
||||||
export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY });
|
export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY });
|
||||||
|
|
||||||
export const mergeStreamedChunks = (
|
|
||||||
base: ChatCompletion | null,
|
|
||||||
chunk: ChatCompletionChunk,
|
|
||||||
): ChatCompletion => {
|
|
||||||
if (base === null) {
|
|
||||||
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
const choices = [...base.choices];
|
|
||||||
for (const choice of chunk.choices) {
|
|
||||||
const baseChoice = choices.find((c) => c.index === choice.index);
|
|
||||||
if (baseChoice) {
|
|
||||||
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
|
|
||||||
baseChoice.message = baseChoice.message ?? { role: "assistant" };
|
|
||||||
|
|
||||||
if (choice.delta?.content)
|
|
||||||
baseChoice.message.content =
|
|
||||||
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
|
|
||||||
if (choice.delta?.function_call) {
|
|
||||||
const fnCall = baseChoice.message.function_call ?? {};
|
|
||||||
fnCall.name =
|
|
||||||
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
|
|
||||||
fnCall.arguments =
|
|
||||||
((fnCall.arguments as string) ?? "") +
|
|
||||||
((choice.delta.function_call.arguments as string) ?? "");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const merged: ChatCompletion = {
|
|
||||||
...base,
|
|
||||||
choices,
|
|
||||||
};
|
|
||||||
|
|
||||||
return merged;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const streamChatCompletion = async function* (body: CompletionCreateParams) {
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
|
|
||||||
const resp = await openai.chat.completions.create({
|
|
||||||
...body,
|
|
||||||
stream: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
let mergedChunks: ChatCompletion | null = null;
|
|
||||||
for await (const part of resp) {
|
|
||||||
mergedChunks = mergeStreamedChunks(mergedChunks, part);
|
|
||||||
yield mergedChunks;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|||||||
45
src/server/utils/parseConstructFn.test.ts
Normal file
45
src/server/utils/parseConstructFn.test.ts
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import { expect, test } from "vitest";
|
||||||
|
import parseConstructFn from "./parseConstructFn";
|
||||||
|
import assert from "assert";
|
||||||
|
|
||||||
|
// Note: this has to be run with `vitest --no-threads` option or else
|
||||||
|
// isolated-vm seems to throw errors
|
||||||
|
test("parseConstructFn", async () => {
|
||||||
|
const constructed = await parseConstructFn(
|
||||||
|
`
|
||||||
|
// These sometimes have a comment
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`What is the capital of \${scenario.country}?\`
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
`,
|
||||||
|
{ country: "Bolivia" },
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(constructed).toEqual({
|
||||||
|
modelProvider: "openai/ChatCompletion",
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
modelInput: {
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
content: "What is the capital of Bolivia?",
|
||||||
|
role: "user",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("bad syntax", async () => {
|
||||||
|
const parsed = await parseConstructFn(`definePrompt("openai/ChatCompletion", {`);
|
||||||
|
|
||||||
|
assert("error" in parsed);
|
||||||
|
expect(parsed.error).toContain("Unexpected end of input");
|
||||||
|
});
|
||||||
95
src/server/utils/parseConstructFn.ts
Normal file
95
src/server/utils/parseConstructFn.ts
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import ivm from "isolated-vm";
|
||||||
|
import { isObject, isString } from "lodash-es";
|
||||||
|
import { type JsonObject } from "type-fest";
|
||||||
|
import { validate } from "jsonschema";
|
||||||
|
|
||||||
|
export type ParsedConstructFn<T extends keyof typeof modelProviders> = {
|
||||||
|
modelProvider: T;
|
||||||
|
model: keyof (typeof modelProviders)[T]["models"];
|
||||||
|
modelInput: Parameters<(typeof modelProviders)[T]["getModel"]>[0];
|
||||||
|
};
|
||||||
|
|
||||||
|
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||||
|
|
||||||
|
export default async function parseConstructFn(
|
||||||
|
constructFn: string,
|
||||||
|
scenario: JsonObject | undefined = {},
|
||||||
|
): Promise<ParsedConstructFn<keyof typeof modelProviders> | { error: string }> {
|
||||||
|
try {
|
||||||
|
const modifiedConstructFn = constructFn.replace(
|
||||||
|
"definePrompt(",
|
||||||
|
"global.prompt = definePrompt(",
|
||||||
|
);
|
||||||
|
|
||||||
|
const code = `
|
||||||
|
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||||
|
|
||||||
|
const definePrompt = (modelProvider, input) => ({
|
||||||
|
modelProvider,
|
||||||
|
input
|
||||||
|
})
|
||||||
|
|
||||||
|
${modifiedConstructFn}
|
||||||
|
`;
|
||||||
|
|
||||||
|
const context = await isolate.createContext();
|
||||||
|
const jail = context.global;
|
||||||
|
await jail.set("global", jail.derefInto());
|
||||||
|
|
||||||
|
const script = await isolate.compileScript(code);
|
||||||
|
await script.run(context);
|
||||||
|
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
|
||||||
|
const prompt = await promptReference.copy();
|
||||||
|
|
||||||
|
if (!isObject(prompt)) {
|
||||||
|
return { error: "definePrompt did not return an object" };
|
||||||
|
}
|
||||||
|
if (!("modelProvider" in prompt) || !isString(prompt.modelProvider)) {
|
||||||
|
return { error: "definePrompt did not return a valid modelProvider" };
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider =
|
||||||
|
prompt.modelProvider in modelProviders &&
|
||||||
|
modelProviders[prompt.modelProvider as keyof typeof modelProviders];
|
||||||
|
if (!provider) {
|
||||||
|
return { error: "definePrompt did not return a known modelProvider" };
|
||||||
|
}
|
||||||
|
if (!("input" in prompt) || !isObject(prompt.input)) {
|
||||||
|
return { error: "definePrompt did not return an input" };
|
||||||
|
}
|
||||||
|
|
||||||
|
const validationResult = validate(prompt.input, provider.inputSchema);
|
||||||
|
if (!validationResult.valid)
|
||||||
|
return {
|
||||||
|
error: `definePrompt did not return a valid input: ${validationResult.errors
|
||||||
|
.map((e) => e.stack)
|
||||||
|
.join(", ")}`,
|
||||||
|
};
|
||||||
|
|
||||||
|
// We've validated the JSON schema so this should be safe
|
||||||
|
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||||
|
|
||||||
|
// @ts-expect-error TODO FIX ASAP
|
||||||
|
const model = provider.getModel(input);
|
||||||
|
if (!model) {
|
||||||
|
return {
|
||||||
|
error: `definePrompt did not return a known model for the provider ${prompt.modelProvider}`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||||
|
// @ts-expect-error TODO FIX ASAP
|
||||||
|
|
||||||
|
model,
|
||||||
|
modelInput: input,
|
||||||
|
};
|
||||||
|
} catch (e) {
|
||||||
|
const msg =
|
||||||
|
isObject(e) && "message" in e && isString(e.message)
|
||||||
|
? e.message
|
||||||
|
: "unknown error parsing definePrompt script";
|
||||||
|
return { error: msg };
|
||||||
|
}
|
||||||
|
}
|
||||||
65
src/server/utils/reorderPromptVariants.ts
Normal file
65
src/server/utils/reorderPromptVariants.ts
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import { prisma } from "~/server/db";
|
||||||
|
|
||||||
|
export const reorderPromptVariants = async (
|
||||||
|
movedId: string,
|
||||||
|
stationaryTargetId: string,
|
||||||
|
alwaysInsertRight?: boolean,
|
||||||
|
) => {
|
||||||
|
const moved = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: movedId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const target = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: stationaryTargetId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!moved || !target || moved.experimentId !== target.experimentId) {
|
||||||
|
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const visibleItems = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: moved.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
orderBy: {
|
||||||
|
sortIndex: "asc",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Remove the moved item from its current position
|
||||||
|
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
|
||||||
|
|
||||||
|
// Find the index of the moved item and the target item
|
||||||
|
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
|
||||||
|
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
|
||||||
|
|
||||||
|
// Determine the new index for the moved item
|
||||||
|
let newIndex;
|
||||||
|
if (movedIndex < targetIndex || alwaysInsertRight) {
|
||||||
|
newIndex = targetIndex + 1; // Insert after the target item
|
||||||
|
} else {
|
||||||
|
newIndex = targetIndex; // Insert before the target item
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the moved item at the new position
|
||||||
|
orderedItems.splice(newIndex, 0, moved);
|
||||||
|
|
||||||
|
// Now, we need to update all the items with their new sortIndex
|
||||||
|
await prisma.$transaction(
|
||||||
|
orderedItems.map((item, index) => {
|
||||||
|
return prisma.promptVariant.update({
|
||||||
|
where: {
|
||||||
|
id: item.id,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
sortIndex: index,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
import { isObject } from "lodash-es";
|
|
||||||
import { type JSONSerializable } from "../types";
|
|
||||||
|
|
||||||
export const shouldStream = (config: JSONSerializable): boolean => {
|
|
||||||
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
|
|
||||||
return shouldStream;
|
|
||||||
};
|
|
||||||
19
src/server/utils/userOrg.ts
Normal file
19
src/server/utils/userOrg.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { prisma } from "~/server/db";
|
||||||
|
|
||||||
|
export default async function userOrg(userId: string) {
|
||||||
|
return await prisma.organization.upsert({
|
||||||
|
where: {
|
||||||
|
personalOrgUserId: userId,
|
||||||
|
},
|
||||||
|
update: {},
|
||||||
|
create: {
|
||||||
|
personalOrgUserId: userId,
|
||||||
|
OrganizationUser: {
|
||||||
|
create: {
|
||||||
|
userId: userId,
|
||||||
|
role: "ADMIN",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
import { type RouterOutputs } from "~/utils/api";
|
import { type RouterOutputs } from "~/utils/api";
|
||||||
import { type SliceCreator } from "./store";
|
import { type SliceCreator } from "./store";
|
||||||
import loader from "@monaco-editor/loader";
|
import loader from "@monaco-editor/loader";
|
||||||
import openAITypes from "~/codegen/openai.types.ts.txt";
|
|
||||||
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||||
|
|
||||||
export const editorBackground = "#fafafa";
|
export const editorBackground = "#fafafa";
|
||||||
@@ -20,7 +19,10 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
// We only want to run this client-side
|
// We only want to run this client-side
|
||||||
if (typeof window === "undefined") return;
|
if (typeof window === "undefined") return;
|
||||||
|
|
||||||
const monaco = await loader.init();
|
const [monaco, promptTypes] = await Promise.all([
|
||||||
|
loader.init(),
|
||||||
|
get().api?.client.experiments.promptTypes.query(),
|
||||||
|
]);
|
||||||
|
|
||||||
monaco.editor.defineTheme("customTheme", {
|
monaco.editor.defineTheme("customTheme", {
|
||||||
base: "vs",
|
base: "vs",
|
||||||
@@ -37,14 +39,9 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
lib: ["esnext"],
|
lib: ["esnext"],
|
||||||
});
|
});
|
||||||
|
|
||||||
monaco.editor.createModel(
|
monaco.languages.typescript.typescriptDefaults.addExtraLib(
|
||||||
`
|
promptTypes ?? "",
|
||||||
${openAITypes}
|
"file:///PromptTypes.d.ts",
|
||||||
|
|
||||||
declare var prompt: components["schemas"]["CreateChatCompletionRequest"];
|
|
||||||
`,
|
|
||||||
"typescript",
|
|
||||||
monaco.Uri.parse("file:///openai.types.ts"),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
monaco.languages.registerDocumentFormattingEditProvider("typescript", {
|
monaco.languages.registerDocumentFormattingEditProvider("typescript", {
|
||||||
@@ -64,7 +61,6 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
get().sharedVariantEditor.updateScenariosModel();
|
get().sharedVariantEditor.updateScenariosModel();
|
||||||
},
|
},
|
||||||
scenarios: [],
|
scenarios: [],
|
||||||
// scenariosModel: null,
|
|
||||||
setScenarios: (scenarios) => {
|
setScenarios: (scenarios) => {
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.sharedVariantEditor.scenarios = scenarios;
|
state.sharedVariantEditor.scenarios = scenarios;
|
||||||
|
|||||||
@@ -5,11 +5,14 @@ import {
|
|||||||
type SharedVariantEditorSlice,
|
type SharedVariantEditorSlice,
|
||||||
createVariantEditorSlice,
|
createVariantEditorSlice,
|
||||||
} from "./sharedVariantEditor.slice";
|
} from "./sharedVariantEditor.slice";
|
||||||
|
import { type APIClient } from "~/utils/api";
|
||||||
|
|
||||||
export type State = {
|
export type State = {
|
||||||
drawerOpen: boolean;
|
drawerOpen: boolean;
|
||||||
openDrawer: () => void;
|
openDrawer: () => void;
|
||||||
closeDrawer: () => void;
|
closeDrawer: () => void;
|
||||||
|
api: APIClient | null;
|
||||||
|
setApi: (api: APIClient) => void;
|
||||||
sharedVariantEditor: SharedVariantEditorSlice;
|
sharedVariantEditor: SharedVariantEditorSlice;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -20,6 +23,12 @@ export type GetFn = Parameters<SliceCreator<unknown>>[1];
|
|||||||
|
|
||||||
const useBaseStore = create<State, [["zustand/immer", never]]>(
|
const useBaseStore = create<State, [["zustand/immer", never]]>(
|
||||||
immer((set, get, ...rest) => ({
|
immer((set, get, ...rest) => ({
|
||||||
|
api: null,
|
||||||
|
setApi: (api) =>
|
||||||
|
set((state) => {
|
||||||
|
state.api = api;
|
||||||
|
}),
|
||||||
|
|
||||||
drawerOpen: false,
|
drawerOpen: false,
|
||||||
openDrawer: () =>
|
openDrawer: () =>
|
||||||
set((state) => {
|
set((state) => {
|
||||||
@@ -34,5 +43,3 @@ const useBaseStore = create<State, [["zustand/immer", never]]>(
|
|||||||
);
|
);
|
||||||
|
|
||||||
export const useAppStore = createSelectors(useBaseStore);
|
export const useAppStore = createSelectors(useBaseStore);
|
||||||
|
|
||||||
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
|
||||||
|
|||||||
@@ -15,3 +15,15 @@ export function useSyncVariantEditor() {
|
|||||||
}
|
}
|
||||||
}, [scenarios.data]);
|
}, [scenarios.data]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function SyncAppStore() {
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const setApi = useAppStore((state) => state.setApi);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setApi(utils);
|
||||||
|
}, [utils, setApi]);
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
import { extendTheme } from "@chakra-ui/react";
|
import { extendTheme } from "@chakra-ui/react";
|
||||||
|
import "@fontsource/inconsolata";
|
||||||
|
import { ChakraProvider } from "@chakra-ui/react";
|
||||||
|
|
||||||
const systemFont =
|
const systemFont =
|
||||||
'ui-sans-serif, -apple-system, "system-ui", "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol"';
|
'ui-sans-serif, -apple-system, "system-ui", "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol"';
|
||||||
@@ -33,4 +35,6 @@ const theme = extendTheme({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export default theme;
|
export const ChakraThemeProvider = ({ children }: { children: JSX.Element }) => {
|
||||||
|
return <ChakraProvider theme={theme}>{children}</ChakraProvider>;
|
||||||
|
};
|
||||||
49
src/utils/accessControl.ts
Normal file
49
src/utils/accessControl.ts
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import { OrganizationUserRole } from "@prisma/client";
|
||||||
|
import { TRPCError } from "@trpc/server";
|
||||||
|
import { type TRPCContext } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
|
||||||
|
// No-op method for protected routes that really should be accessible to anyone.
|
||||||
|
export const requireNothing = (ctx: TRPCContext) => {
|
||||||
|
ctx.markAccessControlRun();
|
||||||
|
};
|
||||||
|
|
||||||
|
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
|
||||||
|
await prisma.experiment.findFirst({
|
||||||
|
where: { id: experimentId },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Right now all experiments are publicly viewable, so this is a no-op.
|
||||||
|
ctx.markAccessControlRun();
|
||||||
|
};
|
||||||
|
|
||||||
|
export const canModifyExperiment = async (experimentId: string, userId: string) => {
|
||||||
|
const experiment = await prisma.experiment.findFirst({
|
||||||
|
where: {
|
||||||
|
id: experimentId,
|
||||||
|
organization: {
|
||||||
|
OrganizationUser: {
|
||||||
|
some: {
|
||||||
|
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
||||||
|
userId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return !!experiment;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
|
||||||
|
const userId = ctx.session?.user.id;
|
||||||
|
if (!userId) {
|
||||||
|
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(await canModifyExperiment(experimentId, userId))) {
|
||||||
|
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.markAccessControlRun();
|
||||||
|
};
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user