Compare commits

..

66 Commits

Author SHA1 Message Date
Kyle Corbitt
61e5f0775d separate scenarios from prompts in outputs table 2023-07-22 07:38:19 -07:00
David Corbitt
7466db63df Make REPLICATE_API_TOKEN optional 2023-07-21 20:23:38 -07:00
David Corbitt
79a0b03bf8 Add another function call example 2023-07-21 20:16:36 -07:00
arcticfly
6fb7a82d72 Add support for switching to Llama models (#80)
* Add support for switching to Llama models

* Fix prettier
2023-07-21 20:10:59 -07:00
Kyle Corbitt
4ea30a3ba3 Merge pull request #79 from OpenPipe/copy-evals
Copy over evals when new cell created
2023-07-21 18:43:44 -07:00
Kyle Corbitt
52d1d5c7ee Copy over evals when new cell created
Fixes a bug where new cells generated as clones of existing cells didn't get the eval results cloned as well.
2023-07-21 18:40:40 -07:00
Kyle Corbitt
46036a44d2 small README update 2023-07-21 14:32:07 -07:00
Kyle Corbitt
3753fe5c16 Merge pull request #78 from OpenPipe/bugfix-max-tokens
Fix typescript hints for max_tokens
2023-07-21 12:10:00 -07:00
Kyle Corbitt
213a00a8e6 Fix typescript hints for max_tokens 2023-07-21 12:04:58 -07:00
Kyle Corbitt
af9943eefc Merge pull request #77 from OpenPipe/provider-types
Slightly better typings for ModelProviders
2023-07-21 11:51:25 -07:00
Kyle Corbitt
741128e0f4 Better division of labor between frontend and backend model providers
A bit better thinking on which types go where.
2023-07-21 11:49:35 -07:00
David Corbitt
aff14539d8 Add comment to .env.example 2023-07-21 11:29:21 -07:00
David Corbitt
1af81a50a9 Add REPLICATE_API_TOKEN to .env.example 2023-07-21 11:28:14 -07:00
Kyle Corbitt
7e1fbb3767 Slightly better typings for ModelProviders
Still not great because the `any`s loosen some call sites up more than I'd like, but better than the broken types before.
2023-07-21 06:50:05 -07:00
David Corbitt
a5d972005e Add user's current prompt to prompt derivation 2023-07-21 00:43:39 -07:00
arcticfly
a180b5bef2 Show prompt diff when changing models (#76)
* Make CompareFunctions more configurable

* Change RefinePromptModal styles

* Accept newModel in getModifiedPromptFn

* Show prompt comparison in SelectModelModal

* Pass variant to SelectModelModal

* Update instructions

* Properly use isDisabled
2023-07-20 23:26:49 -07:00
Kyle Corbitt
55c697223e Merge pull request #74 from OpenPipe/model-providers
replicate/llama2 provider
2023-07-20 23:21:42 -07:00
arcticfly
9978075867 Fix auth flicker (#75)
* Remove experiments flicker for unauthenticated users

* Decrease size of NewScenarioButton spinner
2023-07-20 20:46:31 -07:00
Kyle Corbitt
847753c32b replicate/llama2 provider
Still need to fix the types but it runs
2023-07-20 19:55:03 -07:00
Kyle Corbitt
372c2512c9 Merge pull request #73 from OpenPipe/model-providers
More work on modelProviders
2023-07-20 18:56:14 -07:00
Kyle Corbitt
332a2101c0 More work on modelProviders
I think everything that's OpenAI-specific is inside modelProviders at this point, so we can get started adding more providers.
2023-07-20 18:54:26 -07:00
arcticfly
1822fe198e Initially render AutoResizeTextArea without overflow (#72)
* Rerender resized text area with scroll

* Remove default hidden overflow
2023-07-20 15:00:09 -07:00
Kyle Corbitt
f06e1db3db Merge pull request #71 from OpenPipe/model-providers
Prep for more model providers
2023-07-20 14:55:31 -07:00
Kyle Corbitt
ded6678e97 Prep for more model providers
Adds a `modelProvider` field to `promptVariants`, currently just set to "openai/ChatCompletion" for all variants for now.

Adds a `modelProviders/` directory where we can define and store pluggable model providers. Currently just OpenAI. Not everything is pluggable yet -- notably the code to actually generate completions hasn't been migrated to this setup yet.

Does a lot of work to get the types working. Prompts are now defined with a function `definePrompt(modelProvider, config)` instead of `prompt = config`. Added a script to migrate old prompt definitions.

This is still partial work, but the diff is large enough that I want to get it in. I don't think anything is broken but I haven't tested thoroughly.
2023-07-20 14:49:22 -07:00
arcticfly
9314a86857 Use translation in initial scenarios (#70) 2023-07-20 14:28:48 -07:00
David Corbitt
54dcb4a567 Prevent text input labels from overlaying scenarios header 2023-07-20 14:28:36 -07:00
David Corbitt
2c8c8d07cf Merge branch 'main' of github.com:corbt/prompt-lab 2023-07-20 13:38:58 -07:00
David Corbitt
e885bdd365 Fix ScenarioEditor padding 2023-07-20 13:38:46 -07:00
arcticfly
86dc36a656 Improve refinement (#69)
* Format construction function on return

* Add more refinement examples

* Treat 503 like 429

* Define prompt as object

* Fix prettier
2023-07-20 13:05:27 -07:00
arcticfly
55c077d604 Create FloatingLabelInput for scenario variables (#68)
* Create FloatingLabelInput

* Fix prettier

* Simplify changes
2023-07-20 12:20:12 -07:00
arcticfly
e598e454d0 Add new predefined refinement options (#67)
* Add new predefined refinement options

* Fix prettier

* Add icon to SelectModelModal title
2023-07-19 20:10:08 -07:00
David Corbitt
6e3f90cd2f Add more info to refinement 2023-07-19 18:10:23 -07:00
David Corbitt
eec894e101 Allow multiline instructions 2023-07-19 18:10:04 -07:00
David Corbitt
f797fc3fa4 Eliminate spinner flicker in OutputCell 2023-07-19 18:09:47 -07:00
David Corbitt
335dc0357f Fix CompareFunctions for mobile 2023-07-19 17:24:19 -07:00
arcticfly
e6e2c706c2 Change up refinement UI (#66)
* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier

* Remove migration script

* Add refine modal

* Fix prettier

* Fix diff checker overflow

* Decrease diff height

* Add more context to prompt refining

* Auto-expand prompt when refining
2023-07-19 17:19:45 -07:00
Kyle Corbitt
7d2166b305 Merge pull request #65 from OpenPipe/no-model
Cache cost on ModelOutput
2023-07-19 16:22:35 -07:00
Kyle Corbitt
60765e51ac Remove model from promptVariant and add cost
Storing the model on promptVariant is problematic because it isn't always in sync with the actual prompt definition. I'm removing it for now to see if we can get away with that -- might have to add it back in later if this causes trouble.

Added `cost` to modelOutput as well so we can cache that, which is important given that the cost calculations won't be the same between different API providers.
2023-07-19 16:20:53 -07:00
arcticfly
2c4ba6eb9b Update README.md (#64) 2023-07-19 15:39:21 -07:00
arcticfly
4c97b9f147 Refine prompt (#63)
* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier

* Remove migration script

* Add refine modal

* Fix prettier

* Fix diff checker overflow

* Decrease diff height
2023-07-19 15:31:40 -07:00
arcticfly
58892d8b63 Remove unused fields, refine model translation (#62)
* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier
2023-07-19 13:59:11 -07:00
Kyle Corbitt
4fa2dffbcb styling tweaks for SelectModelModal 2023-07-19 07:17:56 -07:00
Kyle Corbitt
654f8c7cf2 Merge pull request #61 from OpenPipe/experiment-page
More visual tweaks
2023-07-19 06:56:58 -07:00
Kyle Corbitt
d02482468d more visual tweaks 2023-07-19 06:54:07 -07:00
Kyle Corbitt
5c6ed22f1d Merge pull request #60 from OpenPipe/experiment-page
experiment page visual tweaks
2023-07-18 22:26:05 -07:00
Kyle Corbitt
2cb623f332 experiment page visual tweaks 2023-07-18 22:22:58 -07:00
Kyle Corbitt
1c1cefe286 Merge pull request #59 from OpenPipe/auth
User accounts
2023-07-18 21:21:46 -07:00
Kyle Corbitt
b4aa95edca sidebar mobile styles 2023-07-18 21:19:06 -07:00
Kyle Corbitt
1dcdba04a6 User accounts
Allows for the creation of user accounts. A few notes on the specifics:

 - Experiments are the main access control objects. If you can view an experiment, you can view all its prompts/scenarios/evals. If you can edit it, you can edit or delete all of those as well.
 - Experiments are owned by Organizations in the database. Organizations can have multiple members and members can have roles of ADMIN, MEMBER or VIEWER.
 - Organizations can either be "personal" or general. Each user has a "personal" organization created as soon as they try to create an experiment. There's currently no UI support for creating general orgs or adding users to them; they're just in the database to future-proof all the ACL logic.
 - You can require that a user is signed-in to see a route using the `protectedProcedure` helper. When you use `protectedProcedure`, you also have to call `ctx.markAccessControlRun()` (or delegate to a function that does it for you; see accessControl.ts). This is to remind us to actually check for access control when we define a new endpoint.
2023-07-18 21:19:03 -07:00
arcticfly
e0e64c4207 Allow user to create a version of their current prompt with a new model (#58)
* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier

* Use env variable to restrict prisma logs

* Fix env.mjs

* Remove unnecessary scroll bar from function call output

* Properly record when 404 error occurs in queryLLM task

* Add SelectedModelInfo in SelectModelModal

* Add react-select

* Calculate new prompt after switching model

* Send newly selected model with creation request

* Get new prompt construction function back from GPT-4

* Fix prettier

* Fix prettier
2023-07-18 18:24:04 -07:00
arcticfly
fa5b1ab1c5 Allow user to duplicate prompt (#57)
* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier
2023-07-18 13:49:33 -07:00
David Corbitt
999a4c08fa Fix lint and prettier 2023-07-18 11:11:20 -07:00
arcticfly
374d0237ee Escape characters in Regex evaluations, minor UI fixes (#56)
* Fix ScenariosHeader stickiness

* Move meta tag from _app.tsx to _document.tsx

* Show spinner when saving variant

* Escape quotes and regex in evaluations
2023-07-18 11:07:04 -07:00
David Corbitt
b1f873623d Invalidate prompt stats after cell refetch 2023-07-18 09:45:11 -07:00
arcticfly
4131aa67d0 Continue polling VariantStats while LLM retrieval in progress, minor UI fixes (#54)
* Prevent zoom in on iOS

* Expand function return code background to fill cell

* Keep OutputStats on far right of cells

* Continue polling prompt stats while cells are retrieving from LLM

* Add comment to _document.tsx

* Fix prettier
2023-07-17 18:04:38 -07:00
Kyle Corbitt
8e7a6d3ae2 Merge pull request #55 from OpenPipe/more-eval
Add GPT4 Evals
2023-07-17 18:01:47 -07:00
Kyle Corbitt
7d41e94ca2 cache eval outputs and add gpt4 eval 2023-07-17 17:55:36 -07:00
Kyle Corbitt
011b12abb9 cache output evals 2023-07-17 17:52:30 -07:00
Kyle Corbitt
1ba18015bc Merge pull request #53 from OpenPipe/more-eval
Fix seeds and update eval field names
2023-07-17 14:26:29 -07:00
Kyle Corbitt
54369dba54 Fix seeds and update eval field names 2023-07-17 14:14:20 -07:00
arcticfly
6b84a59372 Properly catch completion errors (#51) 2023-07-17 10:50:25 -07:00
Kyle Corbitt
8db8aeacd3 Replace function chrome with comment
Use a block comment to explain the expected prompt formatting instead of function chrome. The advantage here is that once a user builds a mental model of how OpenPipe works they can just delete the comment, instead of the function chrome sitting around and taking up space in the UI forever.
2023-07-17 10:30:22 -07:00
Kyle Corbitt
64bd71e370 Merge pull request #50 from OpenPipe/remove-default
remove the default value for PromptVariant.model
2023-07-14 17:55:38 -07:00
Kyle Corbitt
ca21a7af06 Run checks on main
This will (1) make sure that anything we push directly passes CI, and also (2) cache the pnpm store on the main branch, which will make it available to PR runs as well and hopefully speed up CI a bit (see https://stackoverflow.com/a/75250061).``
2023-07-14 17:49:20 -07:00
Kyle Corbitt
3b99b7bd2b remove the default value for PromptVariant.model
We should be explicit about setting the appropriate model so it always matches the constructFn.
2023-07-14 17:43:52 -07:00
Kyle Corbitt
0c3bdbe4f2 Merge pull request #49 from OpenPipe/save-button
Make save button disappear on save
2023-07-14 17:39:06 -07:00
121 changed files with 5592 additions and 5877 deletions

View File

@@ -17,4 +17,15 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key # https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
OPENAI_API_KEY="" OPENAI_API_KEY=""
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
REPLICATE_API_TOKEN=""
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318" NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
# Next Auth
NEXTAUTH_SECRET="your_secret"
NEXTAUTH_URL="http://localhost:3000"
# Next Auth Github Provider
GITHUB_CLIENT_ID="your_client_id"
GITHUB_CLIENT_SECRET="your_secret"

View File

@@ -3,6 +3,8 @@ name: CI checks
on: on:
pull_request: pull_request:
branches: [main] branches: [main]
push:
branches: [main]
jobs: jobs:
run-checks: run-checks:

View File

@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
} from "next"; } from "next";
export type Route = export type Route =
| StaticRoute<"/account/signin">
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }> | DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }> | DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
| DynamicRoute<"/experiments/[id]", { "id": string }> | DynamicRoute<"/experiments/[id]", { "id": string }>

View File

@@ -19,7 +19,6 @@ FROM base as builder
# Include all NEXT_PUBLIC_* env vars here # Include all NEXT_PUBLIC_* env vars here
ARG NEXT_PUBLIC_POSTHOG_KEY ARG NEXT_PUBLIC_POSTHOG_KEY
ARG NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND
ARG NEXT_PUBLIC_SOCKET_URL ARG NEXT_PUBLIC_SOCKET_URL
WORKDIR /app WORKDIR /app

View File

@@ -4,11 +4,18 @@
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data. OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
**Live Demo:** https://openpipe.ai ## Sample Experiments
These are simple experiments users have created that show how OpenPipe works.
- [Country Capitals](https://app.openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
- [Reddit User Needs](https://app.openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
- [OpenAI Function Calls](https://app.openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
- [Activity Classification](https://app.openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px"> <img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally). You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
## High-Level Features ## High-Level Features
@@ -36,7 +43,8 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
## Supported Models ## Supported Models
OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned. - All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
## Running Locally ## Running Locally
@@ -47,5 +55,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
5. Install the dependencies: `cd openpipe && pnpm install` 5. Install the dependencies: `cd openpipe && pnpm install`
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`. 6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database. 7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
8. Start the app: `pnpm dev`. 8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
9. Navigate to [http://localhost:3000](http://localhost:3000) 9. Start the app: `pnpm dev`.
10. Navigate to [http://localhost:3000](http://localhost:3000)

View File

@@ -16,9 +16,12 @@
"postinstall": "prisma generate", "postinstall": "prisma generate",
"lint": "next lint", "lint": "next lint",
"start": "next start", "start": "next start",
"codegen": "tsx src/codegen/export-openai-types.ts" "codegen": "tsx src/codegen/export-openai-types.ts",
"seed": "tsx prisma/seed.ts",
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
}, },
"dependencies": { "dependencies": {
"@apidevtools/json-schema-ref-parser": "^10.1.0",
"@babel/preset-typescript": "^7.22.5", "@babel/preset-typescript": "^7.22.5",
"@babel/standalone": "^7.22.9", "@babel/standalone": "^7.22.9",
"@chakra-ui/next-js": "^2.1.4", "@chakra-ui/next-js": "^2.1.4",
@@ -26,6 +29,7 @@
"@emotion/react": "^11.11.1", "@emotion/react": "^11.11.1",
"@emotion/server": "^11.11.0", "@emotion/server": "^11.11.0",
"@emotion/styled": "^11.11.0", "@emotion/styled": "^11.11.0",
"@fontsource/inconsolata": "^5.0.5",
"@monaco-editor/loader": "^1.3.3", "@monaco-editor/loader": "^1.3.3",
"@next-auth/prisma-adapter": "^1.0.5", "@next-auth/prisma-adapter": "^1.0.5",
"@prisma/client": "^4.14.0", "@prisma/client": "^4.14.0",
@@ -36,6 +40,7 @@
"@trpc/next": "^10.26.0", "@trpc/next": "^10.26.0",
"@trpc/react-query": "^10.26.0", "@trpc/react-query": "^10.26.0",
"@trpc/server": "^10.26.0", "@trpc/server": "^10.26.0",
"ast-types": "^0.14.2",
"chroma-js": "^2.4.2", "chroma-js": "^2.4.2",
"concurrently": "^8.2.0", "concurrently": "^8.2.0",
"cors": "^2.8.5", "cors": "^2.8.5",
@@ -48,8 +53,10 @@
"graphile-worker": "^0.13.0", "graphile-worker": "^0.13.0",
"immer": "^10.0.2", "immer": "^10.0.2",
"isolated-vm": "^4.5.0", "isolated-vm": "^4.5.0",
"json-schema-to-typescript": "^13.0.2",
"json-stringify-pretty-compact": "^4.0.0", "json-stringify-pretty-compact": "^4.0.0",
"lodash": "^4.17.21", "jsonschema": "^1.4.1",
"lodash-es": "^4.17.21",
"next": "^13.4.2", "next": "^13.4.2",
"next-auth": "^4.22.1", "next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1", "nextjs-routes": "^2.0.1",
@@ -57,15 +64,22 @@
"pluralize": "^8.0.0", "pluralize": "^8.0.0",
"posthog-js": "^1.68.4", "posthog-js": "^1.68.4",
"prettier": "^3.0.0", "prettier": "^3.0.0",
"prismjs": "^1.29.0",
"react": "18.2.0", "react": "18.2.0",
"react-diff-viewer": "^3.1.1",
"react-dom": "18.2.0", "react-dom": "18.2.0",
"react-icons": "^4.10.1", "react-icons": "^4.10.1",
"react-select": "^5.7.4",
"react-syntax-highlighter": "^15.5.0", "react-syntax-highlighter": "^15.5.0",
"react-textarea-autosize": "^8.5.0", "react-textarea-autosize": "^8.5.0",
"recast": "^0.23.3",
"replicate": "^0.12.3",
"socket.io": "^4.7.1", "socket.io": "^4.7.1",
"socket.io-client": "^4.7.1", "socket.io-client": "^4.7.1",
"superjson": "1.12.2", "superjson": "1.12.2",
"tsx": "^3.12.7", "tsx": "^3.12.7",
"type-fest": "^4.0.0",
"vite-tsconfig-paths": "^4.2.0",
"zod": "^3.21.4", "zod": "^3.21.4",
"zustand": "^4.3.9" "zustand": "^4.3.9"
}, },
@@ -77,9 +91,11 @@
"@types/cors": "^2.8.13", "@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0", "@types/eslint": "^8.37.0",
"@types/express": "^4.17.17", "@types/express": "^4.17.17",
"@types/lodash": "^4.14.195", "@types/json-schema": "^7.0.12",
"@types/lodash-es": "^4.17.8",
"@types/node": "^18.16.0", "@types/node": "^18.16.0",
"@types/pluralize": "^0.0.30", "@types/pluralize": "^0.0.30",
"@types/prismjs": "^1.26.0",
"@types/react": "^18.2.6", "@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4", "@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7", "@types/react-syntax-highlighter": "^15.5.7",
@@ -100,6 +116,6 @@
"initVersion": "7.14.0" "initVersion": "7.14.0"
}, },
"prisma": { "prisma": {
"seed": "tsx prisma/seed.ts" "seed": "pnpm seed"
} }
} }

927
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;

View File

@@ -0,0 +1,24 @@
/*
Warnings:
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
*/
-- RenameEnum
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
-- AlterTable
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
-- AlterColumnType
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
-- SetNotNullConstraint
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;

View File

@@ -0,0 +1,39 @@
/*
Warnings:
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
*/
-- AlterEnum
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
-- DropForeignKey
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
-- DropTable
DROP TABLE "EvaluationResult";
-- CreateTable
CREATE TABLE "OutputEvaluation" (
"id" UUID NOT NULL,
"result" DOUBLE PRECISION NOT NULL,
"details" TEXT,
"modelOutputId" UUID NOT NULL,
"evaluationId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,124 @@
DROP TABLE "Account";
DROP TABLE "Session";
DROP TABLE "User";
DROP TABLE "VerificationToken";
CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
-- CreateTable
CREATE TABLE "Organization" (
"id" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"personalOrgUserId" UUID,
CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OrganizationUser" (
"id" UUID NOT NULL,
"role" "OrganizationUserRole" NOT NULL,
"organizationId" UUID NOT NULL,
"userId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Account" (
"id" UUID NOT NULL,
"userId" UUID NOT NULL,
"type" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"providerAccountId" TEXT NOT NULL,
"refresh_token" TEXT,
"refresh_token_expires_in" INTEGER,
"access_token" TEXT,
"expires_at" INTEGER,
"token_type" TEXT,
"scope" TEXT,
"id_token" TEXT,
"session_state" TEXT,
CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Session" (
"id" UUID NOT NULL,
"sessionToken" TEXT NOT NULL,
"userId" UUID NOT NULL,
"expires" TIMESTAMP(3) NOT NULL,
CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "User" (
"id" UUID NOT NULL,
"name" TEXT,
"email" TEXT,
"emailVerified" TIMESTAMP(3),
"image" TEXT,
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "VerificationToken" (
"identifier" TEXT NOT NULL,
"token" TEXT NOT NULL,
"expires" TIMESTAMP(3) NOT NULL
);
INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
-- AlterTable add organizationId as NULLABLE
ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
-- Set default organization for existing experiments
UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
-- AlterTable set organizationId as NOT NULL
ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
-- CreateIndex
CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
-- CreateIndex
CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
-- CreateIndex
CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
-- AddForeignKey
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,16 @@
/*
Warnings:
- You are about to drop the column `completionTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `inputHash` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `output` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `promptTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `timeToComplete` on the `ScenarioVariantCell` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "completionTokens",
DROP COLUMN "inputHash",
DROP COLUMN "output",
DROP COLUMN "promptTokens",
DROP COLUMN "timeToComplete";

View File

@@ -0,0 +1,8 @@
/*
Warnings:
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;

View File

@@ -0,0 +1,17 @@
-- Add new columns allowing NULL values
ALTER TABLE "PromptVariant"
ADD COLUMN "constructFnVersion" INTEGER,
ADD COLUMN "modelProvider" TEXT;
-- Update existing records to have the default values
UPDATE "PromptVariant"
SET "constructFnVersion" = 1,
"modelProvider" = 'openai/ChatCompletion'
WHERE "constructFnVersion" IS NULL OR "modelProvider" IS NULL;
-- Alter table to set NOT NULL constraint
ALTER TABLE "PromptVariant"
ALTER COLUMN "constructFnVersion" SET NOT NULL,
ALTER COLUMN "modelProvider" SET NOT NULL;
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "prompt" JSONB;

View File

@@ -2,8 +2,7 @@
// learn more about it in the docs: https://pris.ly/d/prisma-schema // learn more about it in the docs: https://pris.ly/d/prisma-schema
generator client { generator client {
provider = "prisma-client-js" provider = "prisma-client-js"
previewFeatures = ["jsonProtocol"]
} }
datasource db { datasource db {
@@ -17,8 +16,12 @@ model Experiment {
sortIndex Int @default(0) sortIndex Int @default(0)
createdAt DateTime @default(now()) organizationId String @db.Uuid
updatedAt DateTime @updatedAt organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
TemplateVariable TemplateVariable[] TemplateVariable TemplateVariable[]
PromptVariant PromptVariant[] PromptVariant PromptVariant[]
TestScenario TestScenario[] TestScenario TestScenario[]
@@ -28,9 +31,11 @@ model Experiment {
model PromptVariant { model PromptVariant {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
label String label String
constructFn String constructFn String
model String @default("gpt-3.5-turbo") constructFnVersion Int
model String
modelProvider String
uiId String @default(uuid()) @db.Uuid uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true) visible Boolean @default(true)
@@ -39,10 +44,9 @@ model PromptVariant {
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[] scenarioVariantCells ScenarioVariantCell[]
EvaluationResult EvaluationResult[]
@@index([uiId]) @@index([uiId])
} }
@@ -59,8 +63,8 @@ model TestScenario {
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[] scenarioVariantCells ScenarioVariantCell[]
} }
@@ -86,21 +90,17 @@ enum CellRetrievalStatus {
model ScenarioVariantCell { model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
inputHash String? // TODO: Remove once migration is complete
output Json? // TODO: Remove once migration is complete
statusCode Int? statusCode Int?
errorMessage String? errorMessage String?
timeToComplete Int? @default(0) // TODO: Remove once migration is complete
retryTime DateTime? retryTime DateTime?
streamingChannel String? streamingChannel String?
retrievalStatus CellRetrievalStatus @default(COMPLETE) retrievalStatus CellRetrievalStatus @default(COMPLETE)
promptTokens Int? // TODO: Remove once migration is complete modelOutput ModelOutput?
completionTokens Int? // TODO: Remove once migration is complete
modelOutput ModelOutput?
promptVariantId String @db.Uuid promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
prompt Json?
testScenarioId String @db.Uuid testScenarioId String @db.Uuid
testScenario TestScenario @relation(fields: [testScenarioId], references: [id], onDelete: Cascade) testScenario TestScenario @relation(fields: [testScenarioId], references: [id], onDelete: Cascade)
@@ -116,93 +116,133 @@ model ModelOutput {
inputHash String inputHash String
output Json output Json
timeToComplete Int @default(0) timeToComplete Int @default(0)
cost Float?
promptTokens Int? promptTokens Int?
completionTokens Int? completionTokens Int?
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCellId String @db.Uuid scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluation OutputEvaluation[]
@@unique([scenarioVariantCellId]) @@unique([scenarioVariantCellId])
@@index([inputHash]) @@index([inputHash])
} }
enum EvaluationMatchType { enum EvalType {
CONTAINS CONTAINS
DOES_NOT_CONTAIN DOES_NOT_CONTAIN
GPT4_EVAL
} }
model Evaluation { model Evaluation {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
name String label String
matchString String evalType EvalType
matchType EvaluationMatchType value String
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
EvaluationResult EvaluationResult[] OutputEvaluation OutputEvaluation[]
} }
model EvaluationResult { model OutputEvaluation {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
passCount Int // Number between 0 (fail) and 1 (pass)
failCount Int result Float
details String?
modelOutputId String @db.Uuid
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade) evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
promptVariantId String @db.Uuid createdAt DateTime @default(now())
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) updatedAt DateTime @updatedAt
@@unique([modelOutputId, evaluationId])
}
model Organization {
id String @id @default(uuid()) @db.Uuid
personalOrgUserId String? @unique @db.Uuid
PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
OrganizationUser OrganizationUser[]
Experiment Experiment[]
}
enum OrganizationUserRole {
ADMIN
MEMBER
VIEWER
}
model OrganizationUser {
id String @id @default(uuid()) @db.Uuid
role OrganizationUserRole
organizationId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
userId String @db.Uuid
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([evaluationId, promptVariantId]) @@unique([organizationId, userId])
} }
// Necessary for Next auth
model Account { model Account {
id String @id @default(cuid()) id String @id @default(uuid()) @db.Uuid
userId String userId String @db.Uuid
type String type String
provider String provider String
providerAccountId String providerAccountId String
refresh_token String? // @db.Text refresh_token String? @db.Text
access_token String? // @db.Text refresh_token_expires_in Int?
expires_at Int? access_token String? @db.Text
token_type String? expires_at Int?
scope String? token_type String?
id_token String? // @db.Text scope String?
session_state String? id_token String? @db.Text
user User @relation(fields: [userId], references: [id], onDelete: Cascade) session_state String?
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([provider, providerAccountId]) @@unique([provider, providerAccountId])
} }
model Session { model Session {
id String @id @default(cuid()) id String @id @default(uuid()) @db.Uuid
sessionToken String @unique sessionToken String @unique
userId String userId String @db.Uuid
expires DateTime expires DateTime
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)
} }
model User { model User {
id String @id @default(cuid()) id String @id @default(uuid()) @db.Uuid
name String? name String?
email String? @unique email String? @unique
emailVerified DateTime? emailVerified DateTime?
image String? image String?
accounts Account[] accounts Account[]
sessions Session[] sessions Session[]
OrganizationUser OrganizationUser[]
Organization Organization[]
} }
model VerificationToken { model VerificationToken {

View File

@@ -1,76 +1,97 @@
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
const experimentId = "11111111-1111-1111-1111-111111111111"; const defaultId = "11111111-1111-1111-1111-111111111111";
await prisma.organization.deleteMany({
where: { id: defaultId },
});
await prisma.organization.create({
data: { id: defaultId },
});
// Delete the existing experiment
await prisma.experiment.deleteMany({ await prisma.experiment.deleteMany({
where: { where: {
id: experimentId, id: defaultId,
}, },
}); });
const experiment = await prisma.experiment.create({ await prisma.experiment.create({
data: { data: {
id: experimentId, id: defaultId,
label: "Country Capitals Example", label: "Country Capitals Example",
organizationId: defaultId,
}, },
}); });
await prisma.scenarioVariantCell.deleteMany({ await prisma.scenarioVariantCell.deleteMany({
where: { where: {
promptVariant: { promptVariant: {
experimentId, experimentId: defaultId,
}, },
}, },
}); });
await prisma.promptVariant.deleteMany({ await prisma.promptVariant.deleteMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });
await prisma.promptVariant.createMany({ await prisma.promptVariant.createMany({
data: [ data: [
{ {
experimentId, experimentId: defaultId,
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
constructFn: `prompt = { model: "gpt-3.5-turbo-0613",
model: "gpt-3.5-turbo-0613", modelProvider: "openai/ChatCompletion",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }], constructFnVersion: 1,
temperature: 0, constructFn: dedent`
}`, definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: \`What is the capital of ${"$"}{scenario.country}?\`
}
],
temperature: 0,
})`,
}, },
{ {
experimentId, experimentId: defaultId,
label: "Prompt Variant 2", label: "Prompt Variant 2",
sortIndex: 1, sortIndex: 1,
constructFn: `prompt = { model: "gpt-3.5-turbo-0613",
model: "gpt-3.5-turbo-0613", modelProvider: "openai/ChatCompletion",
messages: [ constructFnVersion: 1,
{ constructFn: dedent`
role: "user", definePrompt("openai/ChatCompletion", {
content: model: "gpt-3.5-turbo-0613",
"What is the capital of {{country}}? Return just the city name and nothing else.", messages: [
}, {
], role: "user",
temperature: 0, content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
}`, }
],
temperature: 0,
})`,
}, },
], ],
}); });
await prisma.templateVariable.deleteMany({ await prisma.templateVariable.deleteMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });
await prisma.templateVariable.createMany({ await prisma.templateVariable.createMany({
data: [ data: [
{ {
experimentId, experimentId: defaultId,
label: "country", label: "country",
}, },
], ],
@@ -78,28 +99,28 @@ await prisma.templateVariable.createMany({
await prisma.testScenario.deleteMany({ await prisma.testScenario.deleteMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });
await prisma.testScenario.createMany({ await prisma.testScenario.createMany({
data: [ data: [
{ {
experimentId, experimentId: defaultId,
sortIndex: 0, sortIndex: 0,
variableValues: { variableValues: {
country: "Spain", country: "Spain",
}, },
}, },
{ {
experimentId, experimentId: defaultId,
sortIndex: 1, sortIndex: 1,
variableValues: { variableValues: {
country: "USA", country: "USA",
}, },
}, },
{ {
experimentId, experimentId: defaultId,
sortIndex: 2, sortIndex: 2,
variableValues: { variableValues: {
country: "Chile", country: "Chile",
@@ -107,3 +128,26 @@ await prisma.testScenario.createMany({
}, },
], ],
}); });
const variants = await prisma.promptVariant.findMany({
where: {
experimentId: defaultId,
},
});
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: defaultId,
},
});
await Promise.all(
variants
.flatMap((variant) =>
scenarios.map((scenario) => ({
promptVariantId: variant.id,
testScenarioId: scenario.id,
})),
)
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
);

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@ services:
dockerContext: . dockerContext: .
plan: standard plan: standard
domains: domains:
- openpipe.ai - app.openpipe.ai
envVars: envVars:
- key: NODE_ENV - key: NODE_ENV
value: production value: production

View File

@@ -1,48 +0,0 @@
/* eslint-disable @typescript-eslint/no-var-requires */
import YAML from "yaml";
import fs from "fs";
import path from "path";
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
import assert from "assert";
import { type AcceptibleInputSchema } from "@openapi-contrib/openapi-schema-to-json-schema/dist/mjs/openapi-schema-types";
const OPENAPI_URL =
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
const convertOpenApiToJsonSchema = async (url: string) => {
// Fetch the openapi document
const response = await fetch(url);
const openApiYaml = await response.text();
// Parse the yaml document
const openApiDocument = YAML.parse(openApiYaml) as AcceptibleInputSchema;
// Convert the openapi schema to json schema
const jsonSchema = openapiSchemaToJsonSchema(openApiDocument);
const modelProperty = jsonSchema.components.schemas.CreateChatCompletionRequest.properties.model;
assert(modelProperty.oneOf.length === 2, "Expected model to have oneOf length of 2");
// We need to do a bit of surgery here since the Monaco editor doesn't like
// the fact that the schema says `model` can be either a string or an enum,
// and displays a warning in the editor. Let's stick with just an enum for
// now and drop the string option.
modelProperty.type = "string";
modelProperty.enum = modelProperty.oneOf[1].enum;
modelProperty.oneOf = undefined;
// Get the directory of the current script
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
// Write the JSON schema to a file in the current directory
fs.writeFileSync(
path.join(currentDirectory, "openai.schema.json"),
JSON.stringify(jsonSchema, null, 2),
);
};
convertOpenApiToJsonSchema(OPENAPI_URL)
.then(() => console.log("JSON schema has been written successfully."))
.catch((err) => console.error(err));

View File

@@ -1,52 +0,0 @@
import fs from "fs";
import path from "path";
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
import YAML from "yaml";
import _ from "lodash";
import assert from "assert";
const OPENAPI_URL =
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
// Generate TypeScript types from OpenAPI
const schema = await fetch(OPENAPI_URL)
.then((res) => res.text())
.then((txt) => YAML.parse(txt) as OpenAPI3);
console.log(schema.components?.schemas?.CreateChatCompletionRequest);
// @ts-expect-error just assume this works, the assert will catch it if it doesn't
const modelProperty = schema.components?.schemas?.CreateChatCompletionRequest?.properties?.model;
assert(modelProperty.oneOf.length === 2, "Expected model to have oneOf length of 2");
// We need to do a bit of surgery here since the Monaco editor doesn't like
// the fact that the schema says `model` can be either a string or an enum,
// and displays a warning in the editor. Let's stick with just an enum for
// now and drop the string option.
modelProperty.type = "string";
modelProperty.enum = modelProperty.oneOf[1].enum;
modelProperty.oneOf = undefined;
delete schema["paths"];
assert(schema.components?.schemas);
schema.components.schemas = _.pick(schema.components?.schemas, [
"CreateChatCompletionRequest",
"ChatCompletionRequestMessage",
"ChatCompletionFunctions",
"ChatCompletionFunctionParameters",
]);
console.log(schema);
let openApiTypes = await openapiTS(schema);
// Remove the `export` from any line that starts with `export`
openApiTypes = openApiTypes.replaceAll("\nexport ", "\n");
// Get the directory of the current script
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
// Write the TypeScript types. We only want to use this in our in-app editor, so
// save as a .txt so VS Code doesn't try to auto-import definitions from it.
fs.writeFileSync(path.join(currentDirectory, "openai.types.ts.txt"), openApiTypes);

File diff suppressed because it is too large Load Diff

View File

@@ -1,148 +0,0 @@
/**
* This file was auto-generated by openapi-typescript.
* Do not make direct changes to the file.
*/
/** OneOf type helpers */
type Without<T, U> = { [P in Exclude<keyof T, keyof U>]?: never };
type XOR<T, U> = (T | U) extends object ? (Without<T, U> & U) | (Without<U, T> & T) : T | U;
type OneOf<T extends any[]> = T extends [infer Only] ? Only : T extends [infer A, infer B, ...infer Rest] ? OneOf<[XOR<A, B>, ...Rest]> : never;
type paths = Record<string, never>;
type webhooks = Record<string, never>;
interface components {
schemas: {
CreateChatCompletionRequest: {
/**
* @description ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
* @example gpt-3.5-turbo
* @enum {string}
*/
model: "gpt-4" | "gpt-4-0613" | "gpt-4-32k" | "gpt-4-32k-0613" | "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-3.5-turbo-0613" | "gpt-3.5-turbo-16k-0613";
/** @description A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb). */
messages: (components["schemas"]["ChatCompletionRequestMessage"])[];
/** @description A list of functions the model may generate JSON inputs for. */
functions?: (components["schemas"]["ChatCompletionFunctions"])[];
/** @description Controls how the model responds to function calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a function. Specifying a particular function via `{"name":\ "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present. */
function_call?: OneOf<["none" | "auto", {
/** @description The name of the function to call. */
name: string;
}]>;
/**
* @description What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
*
* We generally recommend altering this or `top_p` but not both.
*
* @default 1
* @example 1
*/
temperature?: number | null;
/**
* @description An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
*
* We generally recommend altering this or `temperature` but not both.
*
* @default 1
* @example 1
*/
top_p?: number | null;
/**
* @description How many chat completion choices to generate for each input message.
* @default 1
* @example 1
*/
n?: number | null;
/**
* @description If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).
*
* @default false
*/
stream?: boolean | null;
/**
* @description Up to 4 sequences where the API will stop generating further tokens.
*
* @default null
*/
stop?: (string | null) | (string)[];
/**
* @description The maximum number of [tokens](/tokenizer) to generate in the chat completion.
*
* The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.
*
* @default inf
*/
max_tokens?: number;
/**
* @description Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
*
* [See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)
*
* @default 0
*/
presence_penalty?: number | null;
/**
* @description Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
*
* [See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)
*
* @default 0
*/
frequency_penalty?: number | null;
/**
* @description Modify the likelihood of specified tokens appearing in the completion.
*
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
*
* @default null
*/
logit_bias?: Record<string, unknown> | null;
/**
* @description A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
*
* @example user-1234
*/
user?: string;
};
ChatCompletionRequestMessage: {
/**
* @description The role of the messages author. One of `system`, `user`, `assistant`, or `function`.
* @enum {string}
*/
role: "system" | "user" | "assistant" | "function";
/** @description The contents of the message. `content` is required for all messages except assistant messages with function calls. */
content?: string;
/** @description The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. */
name?: string;
/** @description The name and arguments of a function that should be called, as generated by the model. */
function_call?: {
/** @description The name of the function to call. */
name?: string;
/** @description The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. */
arguments?: string;
};
};
ChatCompletionFunctions: {
/** @description The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. */
name: string;
/** @description The description of what the function does. */
description?: string;
parameters?: components["schemas"]["ChatCompletionFunctionParameters"];
};
/** @description The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. */
ChatCompletionFunctionParameters: {
[key: string]: unknown;
};
};
responses: never;
parameters: never;
requestBodies: never;
headers: never;
pathItems: never;
}
type external = Record<string, never>;
type operations = Record<string, never>;

View File

@@ -1,6 +0,0 @@
{
"compilerOptions": {
"target": "esnext",
"moduleResolution": "nodenext"
}
}

View File

@@ -1,19 +1,22 @@
import { Textarea, type TextareaProps } from "@chakra-ui/react"; import { Textarea, type TextareaProps } from "@chakra-ui/react";
import ResizeTextarea from "react-textarea-autosize"; import ResizeTextarea from "react-textarea-autosize";
import React from "react"; import React, { useLayoutEffect, useState } from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction< export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement, HTMLTextAreaElement,
TextareaProps TextareaProps & { minRows?: number }
> = (props, ref) => { > = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
const [isRerendered, setIsRerendered] = useState(false);
useLayoutEffect(() => setIsRerendered(true), []);
return ( return (
<Textarea <Textarea
minH="unset" minH="unset"
overflow="hidden" minRows={minRows}
overflowY={isRerendered ? overflowY : "hidden"}
w="100%" w="100%"
resize="none" resize="none"
ref={ref} ref={ref}
minRows={1}
transition="height none" transition="height none"
as={ResizeTextarea} as={ResizeTextarea}
{...props} {...props}

View File

@@ -0,0 +1,135 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
Icon,
} from "@chakra-ui/react";
import { RiExchangeFundsFill } from "react-icons/ri";
import { useState } from "react";
import { ModelStatsCard } from "./ModelStatsCard";
import { ModelSearch } from "./ModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es";
import { type Model, type SupportedProvider } from "~/modelProviders/types";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { keyForModel } from "~/utils/utils";
export const ChangeModelModal = ({
variant,
onClose,
}: {
variant: PromptVariant;
onClose: () => void;
}) => {
const originalModelProviderName = variant.modelProvider as SupportedProvider;
const originalModelProvider = frontendModelProviders[originalModelProviderName];
const originalModel = originalModelProvider.models[variant.model] as Model;
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
const utils = api.useContext();
const experiment = useExperiment();
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment) return;
await getModifiedPromptMutateAsync({
id: variant.id,
newModel: selectedModel,
});
setConvertedModel(selectedModel);
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (
!variant.experimentId ||
!modifiedPromptFn ||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
)
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: modifiedPromptFn,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
const originalModelLabel = keyForModel(originalModel);
const selectedModelLabel = keyForModel(selectedModel);
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
return (
<Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={RiExchangeFundsFill} />
<Text>Change Model</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} />
{originalModelLabel !== selectedModelLabel && (
<ModelStatsCard label="New Model" model={selectedModel} />
)}
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && (
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={modifiedPromptFn}
leftTitle={originalModelLabel}
rightTitle={convertedModelLabel}
/>
)}
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button
colorScheme="gray"
onClick={getModifiedPromptFn}
minW={24}
isDisabled={originalModel === selectedModel || modificationInProgress}
>
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
</Button>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,50 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { useElementDimensions } from "~/utils/hooks";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { type Model } from "~/modelProviders/types";
import { keyForModel } from "~/utils/utils";
const modelOptions: { label: string; value: Model }[] = [];
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
for (const [_, modelValue] of Object.entries(providerValue.models)) {
modelOptions.push({
label: keyForModel(modelValue),
value: modelValue,
});
}
}
export const ModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: Model;
setSelectedModel: (model: Model) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ label: string; value: Model }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -0,0 +1,102 @@
import {
VStack,
Text,
HStack,
type StackProps,
GridItem,
SimpleGrid,
Link,
} from "@chakra-ui/react";
import { type Model } from "~/modelProviders/types";
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
return (
<VStack w="full" align="start">
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
{label}
</Text>
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
<HStack w="full" align="flex-start">
<Text flex={1} fontSize="lg">
<Text as="span" color="gray.600">
{model.provider} /{" "}
</Text>
<Text as="span" fontWeight="bold" color="gray.900">
{model.name}
</Text>
</Text>
<Link
href={model.learnMoreUrl}
isExternal
color="blue.500"
fontWeight="bold"
fontSize="sm"
ml={2}
>
Learn More
</Link>
</HStack>
<SimpleGrid
w="full"
justifyContent="space-between"
alignItems="flex-start"
fontSize="sm"
columns={{ base: 2, md: 4 }}
>
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
{model.promptTokenPrice && (
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(model.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
)}
{model.completionTokenPrice && (
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(model.completionTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
)}
{model.pricePerSecond && (
<SelectedModelLabeledInfo
label="Price"
info={
<Text>
${model.pricePerSecond.toFixed(3)}
<Text color="gray.500"> / second</Text>
</Text>
}
/>
)}
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
</SimpleGrid>
</VStack>
</VStack>
);
};
const SelectedModelLabeledInfo = ({
label,
info,
...props
}: {
label: string;
info: string | number | React.ReactElement;
} & StackProps) => (
<GridItem>
<VStack alignItems="flex-start" {...props}>
<Text fontWeight="bold">{label}</Text>
<Text>{info}</Text>
</VStack>
</GridItem>
);

View File

@@ -0,0 +1,48 @@
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding } from "../constants";
import { ActionButton } from "./ScenariosHeader";
export default function AddVariantButton() {
const experiment = useExperiment();
const mutation = api.promptVariants.create.useMutation();
const utils = api.useContext();
const [onClick, loading] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
});
await utils.promptVariants.list.invalidate();
}, [mutation]);
const { canModify } = useExperimentAccess();
if (!canModify) return <Box w={cellPadding.x} />;
return (
<Flex w="100%" justifyContent="flex-end">
<ActionButton
onClick={onClick}
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
>
Add Variant
</ActionButton>
{/* <Button
alignItems="center"
justifyContent="center"
fontWeight="normal"
bgColor="transparent"
_hover={{ bgColor: "gray.100" }}
px={cellPadding.x}
onClick={onClick}
height="unset"
minH={headerMinHeight}
>
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
</Button> */}
</Flex>
);
}

View File

@@ -11,14 +11,16 @@ import {
FormLabel, FormLabel,
Select, Select,
FormHelperText, FormHelperText,
Code,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { type Evaluation, EvaluationMatchType } from "@prisma/client"; import { type Evaluation, EvalType } from "@prisma/client";
import { useCallback, useState } from "react"; import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs"; import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import AutoResizeTextArea from "../AutoResizeTextArea";
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">; type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
export function EvaluationEditor(props: { export function EvaluationEditor(props: {
evaluation: Evaluation | null; evaluation: Evaluation | null;
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
onCancel: () => void; onCancel: () => void;
}) { }) {
const [values, setValues] = useState<EvalValues>({ const [values, setValues] = useState<EvalValues>({
name: props.evaluation?.name ?? props.defaultName ?? "", label: props.evaluation?.label ?? props.defaultName ?? "",
matchString: props.evaluation?.matchString ?? "", value: props.evaluation?.value ?? "",
matchType: props.evaluation?.matchType ?? "CONTAINS", evalType: props.evaluation?.evalType ?? "CONTAINS",
}); });
return ( return (
<VStack borderTopWidth={1} borderColor="gray.200" py={4}> <VStack borderTopWidth={1} borderColor="gray.200" py={4}>
<HStack w="100%"> <HStack w="100%">
<FormControl flex={1}> <FormControl flex={1}>
<FormLabel fontSize="sm">Evaluation Name</FormLabel> <FormLabel fontSize="sm">Eval Name</FormLabel>
<Input <Input
size="sm" size="sm"
value={values.name} value={values.label}
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))} onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
/> />
</FormControl> </FormControl>
<FormControl flex={1}> <FormControl flex={1}>
<FormLabel fontSize="sm">Match Type</FormLabel> <FormLabel fontSize="sm">Eval Type</FormLabel>
<Select <Select
size="sm" size="sm"
value={values.matchType} value={values.evalType}
onChange={(e) => onChange={(e) =>
setValues((values) => ({ setValues((values) => ({
...values, ...values,
matchType: e.target.value as EvaluationMatchType, evalType: e.target.value as EvalType,
})) }))
} }
> >
{Object.values(EvaluationMatchType).map((type) => ( {Object.values(EvalType).map((type) => (
<option key={type} value={type}> <option key={type} value={type}>
{type} {type}
</option> </option>
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
</Select> </Select>
</FormControl> </FormControl>
</HStack> </HStack>
<FormControl> {["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
<FormLabel fontSize="sm">Match String</FormLabel> <FormControl>
<Input <FormLabel fontSize="sm">Match String</FormLabel>
size="sm" <Input
value={values.matchString} size="sm"
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))} value={values.value}
/> onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
<FormHelperText> />
This string will be interpreted as a regex and checked against each model output. <FormHelperText>
</FormHelperText> This string will be interpreted as a regex and checked against each model output. You
</FormControl> can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
</FormHelperText>
</FormControl>
)}
{values.evalType === "GPT4_EVAL" && (
<FormControl pt={2}>
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
<AutoResizeTextArea
size="sm"
value={values.value}
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
minRows={3}
/>
<FormHelperText>
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
access to the specific prompt variant, so be sure to be clear about the task you want it
to perform.
</FormHelperText>
</FormControl>
)}
<HStack alignSelf="flex-end"> <HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray"> <Button size="sm" onClick={props.onCancel} colorScheme="gray">
Cancel Cancel
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
} }
await utils.evaluations.list.invalidate(); await utils.evaluations.list.invalidate();
await utils.promptVariants.stats.invalidate(); await utils.promptVariants.stats.invalidate();
await utils.scenarioVariantCells.get.invalidate();
}, []); }, []);
const onCancel = useCallback(() => { const onCancel = useCallback(() => {
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
align="center" align="center"
key={evaluation.id} key={evaluation.id}
> >
<Text fontWeight="bold">{evaluation.name}</Text> <Text fontWeight="bold">{evaluation.label}</Text>
<Text flex={1}> <Text flex={1}>
{evaluation.matchType}: &quot;{evaluation.matchString}&quot; {evaluation.evalType}: &quot;{evaluation.value}&quot;
</Text> </Text>
<Button <Button
variant="unstyled" variant="unstyled"

View File

@@ -1,4 +1,4 @@
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react"; import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
import { useState } from "react"; import { useState } from "react";
import { BsCheck, BsX } from "react-icons/bs"; import { BsCheck, BsX } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
<Heading size="sm">Scenario Variables</Heading> <Heading size="sm">Scenario Variables</Heading>
<Stack spacing={2}> <Stack spacing={2}>
<Text fontSize="sm"> <Text fontSize="sm">
Scenario variables can be used in your prompt variants as well as evaluations. Reference Scenario variables can be used in your prompt variants as well as evaluations.
them using <Code>{"{{curly_braces}}"}</Code>.
</Text> </Text>
<HStack spacing={0}> <HStack spacing={0}>
<Input <Input

View File

@@ -0,0 +1,47 @@
import { FormLabel, FormControl, type TextareaProps } from "@chakra-ui/react";
import { useState } from "react";
import AutoResizeTextArea from "../AutoResizeTextArea";
export const FloatingLabelInput = ({
label,
value,
...props
}: { label: string; value: string } & TextareaProps) => {
const [isFocused, setIsFocused] = useState(false);
return (
<FormControl position="relative">
<FormLabel
position="absolute"
left="10px"
top={isFocused || !!value ? 0 : 3}
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
fontSize={isFocused || !!value ? "12px" : "16px"}
transition="all 0.15s"
zIndex="5"
bg="white"
px={1}
lineHeight="1"
pointerEvents="none"
color={isFocused ? "blue.500" : "gray.500"}
>
{label}
</FormLabel>
<AutoResizeTextArea
px={3}
pt={3}
pb={2}
onFocus={() => setIsFocused(true)}
onBlur={() => setIsFocused(false)}
borderRadius="md"
borderColor={isFocused ? "blue.500" : "gray.400"}
autoComplete="off"
value={value}
maxHeight={32}
overflowY="auto"
overflowX="hidden"
{...props}
/>
</FormControl>
);
};

View File

@@ -1,53 +0,0 @@
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
// Extracted Button styling into reusable component
const StyledButton = ({ children, onClick }: ButtonProps) => (
<Button
fontWeight="normal"
bgColor="transparent"
_hover={{ bgColor: "gray.100" }}
px={2}
onClick={onClick}
>
{children}
</Button>
);
export default function NewScenarioButton() {
const experiment = useExperiment();
const mutation = api.scenarios.create.useMutation();
const utils = api.useContext();
const [onClick] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
});
await utils.scenarios.list.invalidate();
}, [mutation]);
const [onAutogenerate, autogenerating] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
autogenerate: true,
});
await utils.scenarios.list.invalidate();
}, [mutation]);
return (
<HStack spacing={2}>
<StyledButton onClick={onClick}>
<Icon as={BsPlus} boxSize={6} />
Add Scenario
</StyledButton>
<StyledButton onClick={onAutogenerate}>
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
Autogenerate Scenario
</StyledButton>
</HStack>
);
}

View File

@@ -1,37 +0,0 @@
import { Button, Icon, Spinner } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding, headerMinHeight } from "../constants";
export default function NewVariantButton() {
const experiment = useExperiment();
const mutation = api.promptVariants.create.useMutation();
const utils = api.useContext();
const [onClick, loading] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
});
await utils.promptVariants.list.invalidate();
}, [mutation]);
return (
<Button
w="100%"
alignItems="center"
justifyContent="center"
fontWeight="normal"
bgColor="transparent"
_hover={{ bgColor: "gray.100" }}
px={cellPadding.x}
onClick={onClick}
height="unset"
minH={headerMinHeight}
>
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
Add Variant
</Button>
);
}

View File

@@ -1,5 +1,6 @@
import { Button, HStack, Icon } from "@chakra-ui/react"; import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs"; import { BsArrowClockwise } from "react-icons/bs";
import { useExperimentAccess } from "~/utils/hooks";
export const CellOptions = ({ export const CellOptions = ({
refetchingOutput, refetchingOutput,
@@ -8,25 +9,28 @@ export const CellOptions = ({
refetchingOutput: boolean; refetchingOutput: boolean;
refetchOutput: () => void; refetchOutput: () => void;
}) => { }) => {
const { canModify } = useExperimentAccess();
return ( return (
<HStack justifyContent="flex-end" w="full"> <HStack justifyContent="flex-end" w="full">
{!refetchingOutput && ( {!refetchingOutput && canModify && (
<Button <Tooltip label="Refetch output" aria-label="refetch output">
size="xs" <Button
w={4} size="xs"
h={4} w={4}
py={4} h={4}
px={4} py={4}
minW={0} px={4}
borderRadius={8} minW={0}
color="gray.500" borderRadius={8}
variant="ghost" color="gray.500"
cursor="pointer" variant="ghost"
onClick={refetchOutput} cursor="pointer"
aria-label="refetch output" onClick={refetchOutput}
> aria-label="refetch output"
<Icon as={BsArrowClockwise} boxSize={4} /> >
</Button> <Icon as={BsArrowClockwise} boxSize={4} />
</Button>
</Tooltip>
)} )}
</HStack> </HStack>
); );

View File

@@ -1,16 +1,16 @@
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types"; import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react"; import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect } from "react"; import { type ReactElement, useState, useEffect } from "react";
import { type ChatCompletion } from "openai/resources/chat";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats"; import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler"; import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions"; import { CellOptions } from "./CellOptions";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -33,36 +33,42 @@ export default function OutputCell({
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output"; if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
// if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output";
const [refetchInterval, setRefetchInterval] = useState(0); const [refetchInterval, setRefetchInterval] = useState(0);
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery( const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
{ scenarioId: scenario.id, variantId: variant.id }, { scenarioId: scenario.id, variantId: variant.id },
{ refetchInterval }, { refetchInterval },
); );
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } = const provider =
api.scenarioVariantCells.forceRefetch.useMutation(); frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
const [hardRefetch] = useHandledAsyncCallback(async () => {
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation();
const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => {
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id }); await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
await utils.scenarioVariantCells.get.invalidate({ await utils.scenarioVariantCells.get.invalidate({
scenarioId: scenario.id, scenarioId: scenario.id,
variantId: variant.id, variantId: variant.id,
}); });
await utils.promptVariants.stats.invalidate({
variantId: variant.id,
});
}, [hardRefetchMutate, scenario.id, variant.id]); }, [hardRefetchMutate, scenario.id, variant.id]);
const fetchingOutput = queryLoading || refetchingOutput; const fetchingOutput = queryLoading || hardRefetching;
const awaitingOutput = const awaitingOutput =
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS"; !cell ||
cell.retrievalStatus === "PENDING" ||
cell.retrievalStatus === "IN_PROGRESS" ||
hardRefetching;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]); useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput; const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore // Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(cell?.streamingChannel); const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null; if (!vars) return null;
@@ -81,25 +87,26 @@ export default function OutputCell({
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />; return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
} }
const response = modelOutput?.output as unknown as ChatCompletion; const normalizedOutput = modelOutput
const message = response?.choices?.[0]?.message; ? provider.normalizeOutput(modelOutput.output)
: streamedMessage
if (modelOutput && message?.function_call) { ? provider.normalizeOutput(streamedMessage)
const rawArgs = message.function_call.arguments ?? "null"; : null;
let parsedArgs: string;
try {
parsedArgs = JSON.parse(rawArgs);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
}
if (modelOutput && normalizedOutput?.type === "json") {
return ( return (
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto"> <VStack
<VStack w="full" spacing={0}> w="100%"
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} /> h="100%"
fontSize="xs"
flexWrap="wrap"
overflowX="hidden"
justifyContent="space-between"
>
<VStack w="full" flex={1} spacing={0}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
<SyntaxHighlighter <SyntaxHighlighter
customStyle={{ overflowX: "unset" }} customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
language="json" language="json"
style={docco} style={docco}
lineProps={{ lineProps={{
@@ -107,32 +114,23 @@ export default function OutputCell({
}} }}
wrapLines wrapLines
> >
{stringify( {stringify(normalizedOutput.value, { maxLength: 40 })}
{
function: message.function_call.name,
args: parsedArgs,
},
{ maxLength: 40 },
)}
</SyntaxHighlighter> </SyntaxHighlighter>
</VStack> </VStack>
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} /> <OutputStats modelOutput={modelOutput} scenario={scenario} />
</Box> </VStack>
); );
} }
const contentToDisplay = const contentToDisplay = (normalizedOutput?.type === "text" && normalizedOutput.value) || "";
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return ( return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap"> <VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
<VStack w="full" alignItems="flex-start" spacing={0}> <VStack w="full" alignItems="flex-start" spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} /> <CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text> <Text>{contentToDisplay}</Text>
</VStack> </VStack>
{modelOutput && ( {modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} /> </VStack>
)}
</Flex>
); );
} }

View File

@@ -1,62 +1,56 @@
import { type ModelOutput } from "@prisma/client";
import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types"; import { type Scenario } from "../types";
import { useExperiment } from "~/utils/hooks"; import { type RouterOutputs } from "~/utils/api";
import { api } from "~/utils/api"; import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { evaluateOutput } from "~/server/utils/evaluateOutput";
import { HStack, Icon, Text } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip"; import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_COST = true;
const SHOW_TIME = true; const SHOW_TIME = true;
export const OutputStats = ({ export const OutputStats = ({
model,
modelOutput, modelOutput,
scenario,
}: { }: {
model: SupportedModel | string | null; modelOutput: NonNullable<
modelOutput: ModelOutput; NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
>;
scenario: Scenario; scenario: Scenario;
}) => { }) => {
const timeToComplete = modelOutput.timeToComplete; const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens; const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens; const completionTokens = modelOutput.completionTokens;
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
const completionCost =
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
const cost = promptCost + completionCost;
return ( return (
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>
{evals.map((evaluation) => { {modelOutput.outputEvaluation.map((evaluation) => {
const passed = evaluateOutput(modelOutput, scenario, evaluation); const passed = evaluation.result > 0.5;
return ( return (
<HStack spacing={0} key={evaluation.id}> <Tooltip
<Text>{evaluation.name}</Text> isDisabled={!evaluation.details}
<Icon label={evaluation.details}
as={passed ? BsCheck : BsX} key={evaluation.id}
color={passed ? "green.500" : "red.500"} >
boxSize={6} <HStack spacing={0}>
/> <Text>{evaluation.evaluation.label}</Text>
</HStack> <Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
</Tooltip>
); );
})} })}
</HStack> </HStack>
{SHOW_COST && ( {modelOutput.cost && (
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}> <CostTooltip
promptTokens={promptTokens}
completionTokens={completionTokens}
cost={modelOutput.cost}
>
<HStack spacing={0}> <HStack spacing={0}>
<Icon as={BsCurrencyDollar} /> <Icon as={BsCurrencyDollar} />
<Text mr={1}>{cost.toFixed(3)}</Text> <Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
</HStack> </HStack>
</CostTooltip> </CostTooltip>
)} )}

View File

@@ -1,15 +1,15 @@
import { type DragEvent } from "react"; import { type DragEvent } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { isEqual } from "lodash"; import { isEqual } from "lodash-es";
import { type Scenario } from "./types"; import { type Scenario } from "./types";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react"; import { useState } from "react";
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react"; import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import { BsX } from "react-icons/bs"; import { BsX } from "react-icons/bs";
import { RiDraggable } from "react-icons/ri"; import { RiDraggable } from "react-icons/ri";
import AutoResizeTextArea from "../AutoResizeTextArea"; import { FloatingLabelInput } from "./FloatingLabelInput";
export default function ScenarioEditor({ export default function ScenarioEditor({
scenario, scenario,
@@ -19,6 +19,8 @@ export default function ScenarioEditor({
hovered: boolean; hovered: boolean;
canHide: boolean; canHide: boolean;
}) { }) {
const { canModify } = useExperimentAccess();
const savedValues = scenario.variableValues as Record<string, string>; const savedValues = scenario.variableValues as Record<string, string>;
const utils = api.useContext(); const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false); const [isDragTarget, setIsDragTarget] = useState(false);
@@ -72,8 +74,9 @@ export default function ScenarioEditor({
return ( return (
<HStack <HStack
alignItems="flex-start" alignItems="flex-start"
pr={cellPadding.x} px={cellPadding.x}
py={cellPadding.y} py={cellPadding.y}
spacing={0}
height="100%" height="100%"
draggable={!variableInputHovered} draggable={!variableInputHovered}
onDragStart={(e) => { onDragStart={(e) => {
@@ -93,39 +96,43 @@ export default function ScenarioEditor({
onDrop={onReorder} onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"} backgroundColor={isDragTarget ? "gray.100" : "transparent"}
> >
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}> {canModify && props.canHide && (
{props.canHide && ( <Stack
<> alignSelf="flex-start"
<Tooltip label="Hide scenario" hasArrow> opacity={props.hovered ? 1 : 0}
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */} spacing={0}
<Button ml={-cellPadding.x}
variant="unstyled" >
color="gray.400" <Tooltip label="Hide scenario" hasArrow>
height="unset" {/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
width="unset" <Button
minW="unset" variant="unstyled"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400" color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }} height="unset"
/> width="unset"
</> minW="unset"
)} onClick={onHide}
</Stack> _hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={hidingInProgress ? 4 : 6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
</Stack>
)}
{variableLabels.length === 0 ? ( {variableLabels.length === 0 ? (
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box> <Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
) : ( ) : (
<VStack spacing={1}> <VStack spacing={4} flex={1} py={2}>
{variableLabels.map((key) => { {variableLabels.map((key) => {
const value = values[key] ?? ""; const value = values[key] ?? "";
const layoutDirection = value.length > 20 ? "column" : "row"; const layoutDirection = value.length > 20 ? "column" : "row";
@@ -137,29 +144,14 @@ export default function ScenarioEditor({
flexWrap="wrap" flexWrap="wrap"
width="full" width="full"
> >
<Box <FloatingLabelInput
bgColor="blue.100" label={key}
color="blue.600" isDisabled={!canModify}
px={1} style={{ width: "100%" }}
my="3px"
fontSize="xs"
fontWeight="bold"
>
{key}
</Box>
<AutoResizeTextArea
px={2}
py={1}
placeholder="empty"
borderRadius="sm"
fontSize="sm"
lineHeight={1.2}
value={value} value={value}
onChange={(e) => { onChange={(e) => {
setValues((prev) => ({ ...prev, [key]: e.target.value })); setValues((prev) => ({ ...prev, [key]: e.target.value }));
}} }}
maxH="32"
overflowY="auto"
onKeyDown={(e) => { onKeyDown={(e) => {
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) { if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
e.preventDefault(); e.preventDefault();
@@ -167,12 +159,6 @@ export default function ScenarioEditor({
onSave(); onSave();
} }
}} }}
resize="none"
overflow="hidden"
flex={layoutDirection === "row" ? 1 : undefined}
borderColor={hasChanged ? "blue.300" : "transparent"}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none", bg: "white" }}
onMouseEnter={() => setVariableInputHovered(true)} onMouseEnter={() => setVariableInputHovered(true)}
onMouseLeave={() => setVariableInputHovered(false)} onMouseLeave={() => setVariableInputHovered(false)}
/> />

View File

@@ -4,11 +4,13 @@ import { cellPadding } from "../constants";
import OutputCell from "./OutputCell/OutputCell"; import OutputCell from "./OutputCell/OutputCell";
import ScenarioEditor from "./ScenarioEditor"; import ScenarioEditor from "./ScenarioEditor";
import type { PromptVariant, Scenario } from "./types"; import type { PromptVariant, Scenario } from "./types";
import { borders } from "./styles";
const ScenarioRow = (props: { const ScenarioRow = (props: {
scenario: Scenario; scenario: Scenario;
variants: PromptVariant[]; variants: PromptVariant[];
canHide: boolean; canHide: boolean;
rowStart: number;
}) => { }) => {
const [isHovered, setIsHovered] = useState(false); const [isHovered, setIsHovered] = useState(false);
@@ -21,15 +23,21 @@ const ScenarioRow = (props: {
onMouseLeave={() => setIsHovered(false)} onMouseLeave={() => setIsHovered(false)}
sx={isHovered ? highlightStyle : undefined} sx={isHovered ? highlightStyle : undefined}
borderLeftWidth={1} borderLeftWidth={1}
{...borders}
rowStart={props.rowStart}
colStart={1}
> >
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} /> <ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
</GridItem> </GridItem>
{props.variants.map((variant) => ( {props.variants.map((variant, i) => (
<GridItem <GridItem
key={variant.id} key={variant.id}
onMouseEnter={() => setIsHovered(true)} onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)} onMouseLeave={() => setIsHovered(false)}
sx={isHovered ? highlightStyle : undefined} sx={isHovered ? highlightStyle : undefined}
rowStart={props.rowStart}
colStart={i + 2}
{...borders}
> >
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}> <Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} /> <OutputCell key={variant.id} scenario={props.scenario} variant={variant} />

View File

@@ -0,0 +1,73 @@
import {
Button,
type ButtonProps,
HStack,
Text,
Icon,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
Spinner,
} from "@chakra-ui/react";
import { cellPadding } from "../constants";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
import { useAppStore } from "~/state/store";
import { api } from "~/utils/api";
export const ActionButton = (props: ButtonProps) => (
<Button size="sm" variant="ghost" color="gray.600" {...props} />
);
export const ScenariosHeader = (props: { numScenarios: number }) => {
const openDrawer = useAppStore((s) => s.openDrawer);
const { canModify } = useExperimentAccess();
const experiment = useExperiment();
const createScenarioMutation = api.scenarios.create.useMutation();
const utils = api.useContext();
const [onAddScenario, loading] = useHandledAsyncCallback(
async (autogenerate: boolean) => {
if (!experiment.data) return;
await createScenarioMutation.mutateAsync({
experimentId: experiment.data.id,
autogenerate,
});
await utils.scenarios.list.invalidate();
},
[createScenarioMutation],
);
return (
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
<Text fontSize={16} fontWeight="bold">
Scenarios ({props.numScenarios})
</Text>
{canModify && (
<Menu>
<MenuButton mt={1}>
<IconButton
variant="ghost"
aria-label="Edit Scenarios"
icon={<Icon as={loading ? Spinner : BsGear} />}
/>
</MenuButton>
<MenuList fontSize="md">
<MenuItem icon={<Icon as={BsPlus} boxSize={6} />} onClick={() => onAddScenario(false)}>
Add Scenario
</MenuItem>
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
Autogenerate Scenario
</MenuItem>
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
Edit Vars
</MenuItem>
</MenuList>
</Menu>
)}
</HStack>
);
};

View File

@@ -1,12 +1,12 @@
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react"; import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react"; import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
import { editorBackground } from "~/state/sharedVariantEditor.slice";
export default function VariantEditor(props: { variant: PromptVariant }) { export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco(); const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null); const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
@@ -22,13 +22,19 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn); setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
}, [lastSavedFn]); }, [lastSavedFn]);
useEffect(checkForChanges, [checkForChanges, lastSavedFn]); const matchUpdatedSavedFn = useCallback(() => {
if (!editorRef.current) return;
editorRef.current.setValue(lastSavedFn);
setIsChanged(false);
}, [lastSavedFn]);
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
const replaceVariant = api.promptVariants.replaceVariant.useMutation(); const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext(); const utils = api.useContext();
const toast = useToast(); const toast = useToast();
const [onSave] = useHandledAsyncCallback(async () => { const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
if (!editorRef.current) return; if (!editorRef.current) return;
await editorRef.current.getAction("editor.action.formatDocument")?.run(); await editorRef.current.getAction("editor.action.formatDocument")?.run();
@@ -41,26 +47,12 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const model = editorRef.current.getModel(); const model = editorRef.current.getModel();
if (!model) return; if (!model) return;
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
if (hasErrors) {
toast({
title: "Invalid TypeScript",
description: "Please fix the TypeScript errors before saving.",
status: "error",
});
return;
}
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere // Make sure the user defined the prompt with the string "prompt\w*=" somewhere
const promptRegex = /prompt\s*=/; const promptRegex = /definePrompt\(/;
if (!promptRegex.test(currentFn)) { if (!promptRegex.test(currentFn)) {
console.log("no prompt");
console.log(currentFn);
toast({ toast({
title: "Missing prompt", title: "Missing prompt",
description: "Please define the prompt (eg. `prompt = { ...`).", description: "Please define the prompt (eg. `definePrompt(...`",
status: "error", status: "error",
}); });
return; return;
@@ -104,6 +96,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
wordWrapBreakAfterCharacters: "", wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "", wordWrapBreakBeforeCharacters: "",
quickSuggestions: true, quickSuggestions: true,
readOnly: !canModify,
}); });
editorRef.current.onDidFocusEditorText(() => { editorRef.current.onDidFocusEditorText(() => {
@@ -131,21 +124,16 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
/* eslint-disable-next-line react-hooks/exhaustive-deps */ /* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]); }, [monaco, editorId]);
useEffect(() => {
if (!editorRef.current) return;
editorRef.current.updateOptions({
readOnly: !canModify,
});
}, [canModify]);
return ( return (
<Box w="100%" pos="relative"> <Box w="100%" pos="relative">
<VStack <div id={editorId} style={{ height: "400px", width: "100%" }}></div>
spacing={0}
align="stretch"
fontSize="xs"
fontWeight="bold"
color="gray.600"
py={2}
bgColor={editorBackground}
>
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
<code>{`return prompt; }`}</code>
</VStack>
{isChanged && ( {isChanged && (
<HStack pos="absolute" bottom={2} right={2}> <HStack pos="absolute" bottom={2} right={2}>
<Button <Button
@@ -159,8 +147,8 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
Reset Reset
</Button> </Button>
<Tooltip label={`${modifierKey} + Enter`}> <Tooltip label={`${modifierKey} + Enter`}>
<Button size="sm" onClick={onSave} colorScheme="blue"> <Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
Save {saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
</Button> </Button>
</Tooltip> </Tooltip>
</HStack> </HStack>

View File

@@ -1,107 +0,0 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
import { BsX } from "react-icons/bs";
import { RiDraggable } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
const [label, setLabel] = useState(props.variant.label);
const updateMutation = api.promptVariants.update.useMutation();
const [onSaveLabel] = useHandledAsyncCallback(async () => {
if (label && label !== props.variant.label) {
await updateMutation.mutateAsync({
id: props.variant.id,
updates: { label: label },
});
}
}, [updateMutation, props.variant.id, props.variant.label, label]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: props.variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, props.variant.id]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
async (e: DragEvent<HTMLDivElement>) => {
e.preventDefault();
setIsDragTarget(false);
const draggedId = e.dataTransfer.getData("text/plain");
const droppedId = props.variant.id;
if (!draggedId || !droppedId || draggedId === droppedId) return;
await reorderMutation.mutateAsync({
draggedId,
droppedId,
});
await utils.promptVariants.list.invalidate();
},
[reorderMutation, props.variant.id],
);
return (
<HStack
spacing={4}
alignItems="center"
minH={headerMinHeight}
draggable={!isInputHovered}
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", props.variant.id);
e.currentTarget.style.opacity = "0.4";
}}
onDragEnd={(e) => {
e.currentTarget.style.opacity = "1";
}}
onDragOver={(e) => {
e.preventDefault();
setIsDragTarget(true);
}}
onDragLeave={() => {
setIsDragTarget(false);
}}
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
<AutoResizeTextArea // Changed to Input
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
placeholder="Variant Name"
borderWidth={1}
borderColor="transparent"
fontWeight="bold"
fontSize={16}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
flex={1}
px={cellPadding.x}
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
{props.canHide && (
<Tooltip label="Remove Variant" hasArrow>
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
<Icon as={BsX} boxSize={6} />
</Button>
</Tooltip>
)}
</HStack>
);
}

View File

@@ -5,8 +5,10 @@ import { api } from "~/utils/api";
import chroma from "chroma-js"; import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs"; import { BsCurrencyDollar } from "react-icons/bs";
import { CostTooltip } from "../tooltip/CostTooltip"; import { CostTooltip } from "../tooltip/CostTooltip";
import { useEffect, useState } from "react";
export default function VariantStats(props: { variant: PromptVariant }) { export default function VariantStats(props: { variant: PromptVariant }) {
const [refetchInterval, setRefetchInterval] = useState(0);
const { data } = api.promptVariants.stats.useQuery( const { data } = api.promptVariants.stats.useQuery(
{ {
variantId: props.variant.id, variantId: props.variant.id,
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
completionTokens: 0, completionTokens: 0,
scenarioCount: 0, scenarioCount: 0,
outputCount: 0, outputCount: 0,
awaitingRetrievals: false,
}, },
refetchInterval,
}, },
); );
// Poll every two seconds while we are waiting for LLM retrievals to finish
useEffect(
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
[data.awaitingRetrievals],
);
const [passColor, neutralColor, failColor] = useToken("colors", [ const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500", "green.500",
"gray.500", "gray.500",
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount; const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
return ( return (
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs"> <HStack
justifyContent="space-between"
alignItems="center"
mx="2"
fontSize="xs"
py={cellPadding.y}
>
{showNumFinished && ( {showNumFinished && (
<Text> <Text>
{data.outputCount} / {data.scenarioCount} {data.outputCount} / {data.scenarioCount}
</Text> </Text>
)} )}
<HStack px={cellPadding.x} py={cellPadding.y}> <HStack px={cellPadding.x}>
{data.evalResults.map((result) => { {data.evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount); const passedFrac = result.passCount / result.totalCount;
return ( return (
<HStack key={result.id}> <HStack key={result.id}>
<Text>{result.evaluation.name}</Text> <Text>{result.label}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold"> <Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}% {(passedFrac * 100).toFixed(1)}%
</Text> </Text>
@@ -55,13 +69,13 @@ export default function VariantStats(props: { variant: PromptVariant }) {
); );
})} })}
</HStack> </HStack>
{data.overallCost && ( {data.overallCost && !data.awaitingRetrievals && (
<CostTooltip <CostTooltip
promptTokens={data.promptTokens} promptTokens={data.promptTokens}
completionTokens={data.completionTokens} completionTokens={data.completionTokens}
cost={data.overallCost} cost={data.overallCost}
> >
<HStack spacing={0} align="center" color="gray.500" my="2"> <HStack spacing={0} align="center" color="gray.500">
<Icon as={BsCurrencyDollar} /> <Icon as={BsCurrencyDollar} />
<Text mr={1}>{data.overallCost.toFixed(3)}</Text> <Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack> </HStack>

View File

@@ -1,28 +1,18 @@
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react"; import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import NewScenarioButton from "./NewScenarioButton"; import AddVariantButton from "./AddVariantButton";
import NewVariantButton from "./NewVariantButton";
import ScenarioRow from "./ScenarioRow"; import ScenarioRow from "./ScenarioRow";
import VariantEditor from "./VariantEditor"; import VariantEditor from "./VariantEditor";
import VariantHeader from "./VariantHeader"; import VariantHeader from "../VariantHeader/VariantHeader";
import { cellPadding } from "../constants";
import { BsPencil } from "react-icons/bs";
import VariantStats from "./VariantStats"; import VariantStats from "./VariantStats";
import { useAppStore } from "~/state/store"; import { ScenariosHeader } from "./ScenariosHeader";
import { borders } from "./styles";
const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "-1px",
backgroundColor: "#fff",
zIndex: 1,
};
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) { export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const variants = api.promptVariants.list.useQuery( const variants = api.promptVariants.list.useQuery(
{ experimentId: experimentId as string }, { experimentId: experimentId as string },
{ enabled: !!experimentId }, { enabled: !!experimentId },
); );
const openDrawer = useAppStore((s) => s.openDrawer);
const scenarios = api.scenarios.list.useQuery( const scenarios = api.scenarios.list.useQuery(
{ experimentId: experimentId as string }, { experimentId: experimentId as string },
@@ -31,88 +21,76 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
if (!variants.data || !scenarios.data) return null; if (!variants.data || !scenarios.data) return null;
const allCols = variants.data.length + 1; const allCols = variants.data.length + 2;
const headerRows = 3; const variantHeaderRows = 3;
const scenarioHeaderRows = 1;
const allRows = variantHeaderRows + scenarioHeaderRows + scenarios.data.length;
return ( return (
<Grid <Grid
p={4} pt={4}
pb={24} pb={24}
pl={4}
display="grid" display="grid"
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`} gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
sx={{ sx={{
"> *": { "> *": {
borderColor: "gray.300", borderColor: "gray.300",
borderBottomWidth: 1,
borderRightWidth: 1,
}, },
}} }}
fontSize="sm" fontSize="sm"
> >
<GridItem <GridItem rowSpan={variantHeaderRows}>
display="flex" <AddVariantButton />
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
// so if the header height changes, this will need to be updated.
sx={{ ...stickyHeaderStyle, top: "-337px" }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({scenarios.data.length})
</Heading>
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
</HStack>
</GridItem> </GridItem>
{variants.data.map((variant) => ( {variants.data.map((variant, i) => {
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}> const sharedProps: GridItemProps = {
<VariantHeader variant={variant} canHide={variants.data.length > 1} /> ...borders,
</GridItem> colStart: i + 2,
))} borderLeftWidth: i === 0 ? 1 : 0,
};
return (
<>
<VariantHeader
key={variant.uiId}
variant={variant}
canHide={variants.data.length > 1}
rowStart={1}
{...sharedProps}
/>
<GridItem rowStart={2} {...sharedProps}>
<VariantEditor variant={variant} />
</GridItem>
<GridItem rowStart={3} {...sharedProps}>
<VariantStats variant={variant} />
</GridItem>
</>
);
})}
<GridItem <GridItem
rowSpan={scenarios.data.length + headerRows} colSpan={allCols - 1}
padding={0} rowStart={variantHeaderRows + 1}
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid colStart={1}
style={{ borderRightWidth: 0, borderBottomWidth: 0 }} {...borders}
h={8} borderRightWidth={0}
sx={stickyHeaderStyle}
> >
<NewVariantButton /> <ScenariosHeader numScenarios={scenarios.data.length} />
</GridItem> </GridItem>
{variants.data.map((variant) => ( {scenarios.data.map((scenario, i) => (
<GridItem key={variant.uiId}>
<VariantEditor variant={variant} />
</GridItem>
))}
{variants.data.map((variant) => (
<GridItem key={variant.uiId}>
<VariantStats variant={variant} />
</GridItem>
))}
{scenarios.data.map((scenario) => (
<ScenarioRow <ScenarioRow
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
key={scenario.uiId} key={scenario.uiId}
scenario={scenario} scenario={scenario}
variants={variants.data} variants={variants.data}
canHide={scenarios.data.length > 1} canHide={scenarios.data.length > 1}
/> />
))} ))}
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
<NewScenarioButton /> {/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
</GridItem> <GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
</Grid> </Grid>
); );
} }

View File

@@ -0,0 +1,13 @@
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "0",
backgroundColor: "#fff",
zIndex: 10,
};
export const borders: GridItemProps = {
borderRightWidth: 1,
borderBottomWidth: 1,
};

View File

@@ -1,21 +0,0 @@
import { Flex, Icon, Link, Text } from "@chakra-ui/react";
import { BsExclamationTriangleFill } from "react-icons/bs";
import { env } from "~/env.mjs";
export default function PublicPlaygroundWarning() {
if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
return (
<Flex bgColor="red.600" color="whiteAlpha.900" p={2} align="center">
<Icon boxSize={4} mr={2} as={BsExclamationTriangleFill} />
<Text>
Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
private use,{" "}
<Link textDecor="underline" href="https://github.com/openpipe/openpipe" target="_blank">
run a local copy
</Link>
.
</Text>
</Flex>
);
}

View File

@@ -0,0 +1,59 @@
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
import React from "react";
import DiffViewer, { DiffMethod } from "react-diff-viewer";
import Prism from "prismjs";
import "prismjs/components/prism-javascript";
import "prismjs/themes/prism.css"; // choose a theme you like
const highlightSyntax = (str: string) => {
let highlighted;
try {
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
} catch (e) {
console.error("Error highlighting:", e);
highlighted = str;
}
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
};
const CompareFunctions = ({
originalFunction,
newFunction = "",
leftTitle = "Original",
rightTitle = "Modified",
...props
}: {
originalFunction: string;
newFunction?: string;
leftTitle?: string;
rightTitle?: string;
} & StackProps) => {
const showSplitView = useBreakpointValue(
{
base: false,
md: true,
},
{
fallback: "base",
},
);
return (
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={showSplitView}
hideLineNumbers={!showSplitView}
leftTitle={leftTitle}
rightTitle={rightTitle}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
showDiffOnly={false}
/>
</VStack>
);
};
export default CompareFunctions;

View File

@@ -0,0 +1,74 @@
import { Button, Spinner, InputGroup, InputRightElement, Icon, HStack } from "@chakra-ui/react";
import { IoMdSend } from "react-icons/io";
import AutoResizeTextArea from "../AutoResizeTextArea";
export const CustomInstructionsInput = ({
instructions,
setInstructions,
loading,
onSubmit,
}: {
instructions: string;
setInstructions: (instructions: string) => void;
loading: boolean;
onSubmit: () => void;
}) => {
return (
<InputGroup
size="md"
w="full"
maxW="600"
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
borderRadius={8}
alignItems="center"
colorScheme="orange"
>
<AutoResizeTextArea
value={instructions}
onChange={(e) => setInstructions(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
e.preventDefault();
e.currentTarget.blur();
onSubmit();
}
}}
placeholder="Send custom instructions"
py={4}
pl={4}
pr={12}
colorScheme="orange"
borderColor="gray.300"
borderWidth={1}
_hover={{
borderColor: "gray.300",
}}
_focus={{
borderColor: "gray.300",
}}
isDisabled={loading}
/>
<HStack></HStack>
<InputRightElement width="8" height="full">
<Button
h="8"
w="8"
minW="unset"
size="sm"
onClick={() => onSubmit()}
variant={instructions ? "solid" : "ghost"}
mr={4}
borderRadius="8"
bgColor={instructions ? "orange.400" : "transparent"}
colorScheme="orange"
>
{loading ? (
<Spinner boxSize={4} />
) : (
<Icon as={IoMdSend} color={instructions ? "white" : "gray.500"} boxSize={5} />
)}
</Button>
</InputRightElement>
</InputGroup>
);
};

View File

@@ -0,0 +1,64 @@
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
import { type IconType } from "react-icons";
export const RefineOption = ({
label,
icon,
desciption,
activeLabel,
onClick,
loading,
}: {
label: string;
icon: IconType;
desciption: string;
activeLabel: string | undefined;
onClick: (label: string) => void;
loading: boolean;
}) => {
const isActive = activeLabel === label;
return (
<GridItem w="80" h="44">
<VStack
w="full"
h="full"
onClick={() => {
!loading && onClick(label);
}}
borderColor={isActive ? "blue.500" : "gray.200"}
borderWidth={2}
borderRadius={16}
padding={6}
backgroundColor="gray.50"
_hover={
loading
? undefined
: {
backgroundColor: "gray.100",
}
}
spacing={8}
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
cursor="pointer"
opacity={loading ? 0.5 : 1}
>
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
<Icon as={icon} boxSize={12} />
<Heading size="md" fontFamily="inconsolata, monospace">
{label}
</Heading>
</HStack>
<Text
fontSize="sm"
color="gray.500"
flexWrap="wrap"
wordBreak="break-word"
overflowWrap="break-word"
>
{desciption}
</Text>
</VStack>
</GridItem>
);
};

View File

@@ -0,0 +1,148 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
Icon,
SimpleGrid,
} from "@chakra-ui/react";
import { BsStars } from "react-icons/bs";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import CompareFunctions from "./CompareFunctions";
import { CustomInstructionsInput } from "./CustomInstructionsInput";
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
import { RefineOption } from "./RefineOption";
import { isObject, isString } from "lodash-es";
import { type SupportedProvider } from "~/modelProviders/types";
export const RefinePromptModal = ({
variant,
onClose,
}: {
variant: PromptVariant;
onClose: () => void;
}) => {
const utils = api.useContext();
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
undefined,
);
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
async (label?: string) => {
if (!variant.experimentId) return;
const updatedInstructions = label
? (providerRefineOptions[label] as RefineOptionInfo).instructions
: instructions;
setActiveRefineOptionLabel(label);
await getModifiedPromptMutateAsync({
id: variant.id,
instructions: updatedInstructions,
});
},
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (
!variant.experimentId ||
!refinedPromptFn ||
(isObject(refinedPromptFn) && "status" in refinedPromptFn)
)
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: refinedPromptFn,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
return (
<Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={BsStars} />
<Text>Refine with GPT-4</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<VStack spacing={4}>
{Object.keys(providerRefineOptions).length && (
<>
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
{Object.keys(providerRefineOptions).map((label) => (
<RefineOption
key={label}
label={label}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
icon={providerRefineOptions[label]!.icon}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
desciption={providerRefineOptions[label]!.description}
activeLabel={activeRefineOptionLabel}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>
))}
</SimpleGrid>
<Text color="gray.500">or</Text>
</>
)}
<CustomInstructionsInput
instructions={instructions}
setInstructions={setInstructions}
loading={modificationInProgress}
onSubmit={getModifiedPromptFn}
/>
</VStack>
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
maxH="40vh"
/>
</VStack>
</ModalBody>
<ModalFooter>
<HStack spacing={4}>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
isDisabled={replacementInProgress || !refinedPromptFn}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,287 @@
// Super hacky, but we'll redo the organization when we have more models
import { type SupportedProvider } from "~/modelProviders/types";
import { VscJson } from "react-icons/vsc";
import { TfiThought } from "react-icons/tfi";
import { type IconType } from "react-icons";
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
"openai/ChatCompletion": {
"Add chain of thought": {
icon: VscJson,
description: "Asking the model to plan its answer can increase accuracy.",
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
This is what a prompt looks like before adding chain of thought:
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
content: \`Evaluate sentiment.\`,
},
{
role: "user",
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
},
],
});
This is what one looks like after adding chain of thought:
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
content: \`Evaluate sentiment.\`,
},
{
role: "user",
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
},
],
});
Here's another example:
Before:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
role: "user",
content: \`Title: \${scenario.title}
Body: \${scenario.body}
Need: \${scenario.need}
Rate likelihood on 1-3 scale.\`,
},
],
temperature: 0,
functions: [
{
name: "score_post",
parameters: {
type: "object",
properties: {
score: {
type: "number",
},
},
},
},
],
function_call: {
name: "score_post",
},
});
After:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
role: "user",
content: \`Title: \${scenario.title}
Body: \${scenario.body}
Need: \${scenario.need}
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
},
],
temperature: 0,
functions: [
{
name: "score_post",
parameters: {
type: "object",
properties: {
explanation: {
type: "string",
}
score: {
type: "number",
},
},
},
},
],
function_call: {
name: "score_post",
},
});
Add chain of thought to the original prompt.`,
},
"Convert to function call": {
icon: TfiThought,
description: "Use function calls to get output from the model in a more structured way.",
instructions: `OpenAI functions are a specialized way for an LLM to return output.
This is what a prompt looks like before adding a function:
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
content: \`Evaluate sentiment.\`,
},
{
role: "user",
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
},
],
});
This is what one looks like after adding a function:
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
content: "Evaluate sentiment.",
},
{
role: "user",
content: scenario.user_message,
},
],
functions: [
{
name: "extract_sentiment",
parameters: {
type: "object", // parameters must always be an object with a properties key
properties: { // properties key is required
sentiment: {
type: "string",
description: "one of positive/negative/neutral",
},
},
},
},
],
function_call: {
name: "extract_sentiment",
},
});
Here's another example of adding a function:
Before:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
role: "user",
content: \`Here is the title and body of a reddit post I am interested in:
title: \${scenario.title}
body: \${scenario.body}
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
Need: \${scenario.need}
Answer one integer between 1 and 3.\`,
},
],
temperature: 0,
});
After:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
role: "user",
content: \`Title: \${scenario.title}
Body: \${scenario.body}
Need: \${scenario.need}
Rate likelihood on 1-3 scale.\`,
},
],
temperature: 0,
functions: [
{
name: "score_post",
parameters: {
type: "object",
properties: {
score: {
type: "number",
},
},
},
},
],
function_call: {
name: "score_post",
},
});
Another example
Before:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
stream: true,
messages: [
{
role: "system",
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
},
],
});
After:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
},
],
functions: [
{
name: "write_in_language",
parameters: {
type: "object",
properties: {
text: {
type: "string",
},
},
},
},
],
function_call: {
name: "write_in_language",
},
});
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
},
},
"replicate/llama2": {},
};

View File

@@ -0,0 +1,130 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { RiDraggable } from "react-icons/ri";
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "../OutputsTable/styles";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(
allProps: {
variant: PromptVariant;
canHide: boolean;
} & GridItemProps,
) {
const { variant, canHide, ...gridItemProps } = allProps;
const { canModify } = useExperimentAccess();
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
const [label, setLabel] = useState(variant.label);
const updateMutation = api.promptVariants.update.useMutation();
const [onSaveLabel] = useHandledAsyncCallback(async () => {
if (label && label !== variant.label) {
await updateMutation.mutateAsync({
id: variant.id,
updates: { label: label },
});
}
}, [updateMutation, variant.id, variant.label, label]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
async (e: DragEvent<HTMLDivElement>) => {
e.preventDefault();
setIsDragTarget(false);
const draggedId = e.dataTransfer.getData("text/plain");
const droppedId = variant.id;
if (!draggedId || !droppedId || draggedId === droppedId) return;
await reorderMutation.mutateAsync({
draggedId,
droppedId,
});
await utils.promptVariants.list.invalidate();
},
[reorderMutation, variant.id],
);
const [menuOpen, setMenuOpen] = useState(false);
if (!canModify) {
return (
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
{variant.label}
</Text>
</GridItem>
);
}
return (
<GridItem
padding={0}
sx={{
...stickyHeaderStyle,
// Ensure that the menu always appears above the sticky header of other variants
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
}}
borderTopWidth={1}
{...gridItemProps}
>
<HStack
spacing={4}
alignItems="flex-start"
minH={headerMinHeight}
draggable={!isInputHovered}
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", variant.id);
e.currentTarget.style.opacity = "0.4";
}}
onDragEnd={(e) => {
e.currentTarget.style.opacity = "1";
}}
onDragOver={(e) => {
e.preventDefault();
setIsDragTarget(true);
}}
onDragLeave={() => {
setIsDragTarget(false);
}}
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Icon
as={RiDraggable}
boxSize={6}
mt={2}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
<AutoResizeTextArea
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
placeholder="Variant Name"
borderWidth={1}
borderColor="transparent"
fontWeight="bold"
fontSize={16}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
flex={1}
px={cellPadding.x}
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
<VariantHeaderMenuButton
variant={variant}
canHide={canHide}
menuOpen={menuOpen}
setMenuOpen={setMenuOpen}
/>
</HStack>
</GridItem>
);
}

View File

@@ -0,0 +1,108 @@
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import {
Button,
Icon,
Menu,
MenuButton,
MenuItem,
MenuList,
MenuDivider,
Text,
Spinner,
} from "@chakra-ui/react";
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
export default function VariantHeaderMenuButton({
variant,
canHide,
menuOpen,
setMenuOpen,
}: {
variant: PromptVariant;
canHide: boolean;
menuOpen: boolean;
setMenuOpen: (open: boolean) => void;
}) {
const utils = api.useContext();
const duplicateMutation = api.promptVariants.create.useMutation();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({
experimentId: variant.experimentId,
variantId: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [duplicateMutation, variant.experimentId, variant.id]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, variant.id]);
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
return (
<>
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
{duplicationInProgress ? (
<Spinner boxSize={4} mx={3} my={3} />
) : (
<MenuButton>
<Button variant="ghost">
<Icon as={BsGear} />
</Button>
</MenuButton>
)}
<MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
Duplicate
</MenuItem>
<MenuItem
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
onClick={() => setChangeModelModalOpen(true)}
>
Change Model
</MenuItem>
<MenuItem
icon={<Icon as={BsStars} boxSize={5} />}
onClick={() => setRefinePromptModalOpen(true)}
>
Refine
</MenuItem>
{canHide && (
<>
<MenuDivider />
<MenuItem
onClick={onHide}
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
color="red.600"
_hover={{ backgroundColor: "red.50" }}
>
<Text>Hide</Text>
</MenuItem>
</>
)}
</MenuList>
</Menu>
{changeModelModalOpen && (
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
)}
{refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
)}
</>
);
}

View File

@@ -1,17 +1,11 @@
import { import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
Card,
CardBody,
HStack,
Icon,
VStack,
Text,
CardHeader,
Divider,
Box,
} from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import { formatTimePast } from "~/utils/dayjs"; import { formatTimePast } from "~/utils/dayjs";
import Link from "next/link";
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
type ExperimentData = { type ExperimentData = {
testScenarioCount: number; testScenarioCount: number;
@@ -24,47 +18,42 @@ type ExperimentData = {
}; };
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => { export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
const router = useRouter();
return ( return (
<Box <AspectRatio ratio={1.2} w="full">
as={Card} <VStack
variant="elevated" as={Link}
bg="gray.50" href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
_hover={{ bg: "gray.100" }} bg="gray.50"
transition="background 0.2s" _hover={{ bg: "gray.100" }}
cursor="pointer" transition="background 0.2s"
onClick={(e) => { cursor="pointer"
e.preventDefault(); borderColor="gray.200"
void router.push({ pathname: "/experiments/[id]", query: { id: exp.id } }, undefined, { borderWidth={1}
shallow: true, p={4}
}); justify="space-between"
}} >
> <HStack w="full" color="gray.700" justify="center">
<CardHeader>
<HStack w="full" color="gray.700">
<Icon as={RiFlaskLine} boxSize={4} /> <Icon as={RiFlaskLine} boxSize={4} />
<Text fontWeight="bold">{exp.label}</Text> <Text fontWeight="bold">{exp.label}</Text>
</HStack> </HStack>
</CardHeader> <HStack h="full" spacing={4} flex={1} align="center">
<CardBody>
<HStack w="full" mb={8} spacing={4}>
<CountLabel label="Variants" count={exp.promptVariantCount} /> <CountLabel label="Variants" count={exp.promptVariantCount} />
<Divider h={12} orientation="vertical" /> <Divider h={12} orientation="vertical" />
<CountLabel label="Scenarios" count={exp.testScenarioCount} /> <CountLabel label="Scenarios" count={exp.testScenarioCount} />
</HStack> </HStack>
<HStack w="full" color="gray.500" fontSize="xs"> <HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
<Text>Created {formatTimePast(exp.createdAt)}</Text> <Text flex={1}>Created {formatTimePast(exp.createdAt)}</Text>
<Divider h={4} orientation="vertical" /> <Divider h={4} orientation="vertical" />
<Text>Updated {formatTimePast(exp.updatedAt)}</Text> <Text flex={1}>Updated {formatTimePast(exp.updatedAt)}</Text>
</HStack> </HStack>
</CardBody> </VStack>
</Box> </AspectRatio>
); );
}; };
const CountLabel = ({ label, count }: { label: string; count: number }) => { const CountLabel = ({ label, count }: { label: string; count: number }) => {
return ( return (
<VStack alignItems="flex-start"> <VStack alignItems="center" flex={1}>
<Text color="gray.500" fontWeight="bold"> <Text color="gray.500" fontWeight="bold">
{label} {label}
</Text> </Text>
@@ -74,3 +63,33 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
</VStack> </VStack>
); );
}; };
export const NewExperimentCard = () => {
const router = useRouter();
const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
}, [createMutation, router]);
return (
<AspectRatio ratio={1.2} w="full">
<VStack
align="center"
justify="center"
_hover={{ cursor: "pointer", bg: "gray.50" }}
transition="background 0.2s"
cursor="pointer"
borderColor="gray.200"
borderWidth={1}
p={4}
onClick={createExperiment}
>
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
<Text display={{ base: "none", md: "block" }} ml={2}>
New Experiment
</Text>
</VStack>
</AspectRatio>
);
};

View File

@@ -1,31 +0,0 @@
import { Icon, Button, Spinner, Text, type ButtonProps } from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs";
import { useHandledAsyncCallback } from "~/utils/hooks";
export const NewExperimentButton = (props: ButtonProps) => {
const router = useRouter();
const utils = api.useContext();
const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
}, [createMutation, router]);
return (
<Button
onClick={createExperiment}
display="flex"
alignItems="center"
variant={{ base: "solid", md: "ghost" }}
{...props}
>
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={4} />
<Text display={{ base: "none", md: "block" }} ml={2}>
New Experiment
</Text>
</Button>
);
};

View File

@@ -1,84 +1,100 @@
import { useState, useEffect } from "react";
import { import {
Heading, Heading,
VStack, VStack,
Icon, Icon,
HStack, HStack,
Image, Image,
Grid,
GridItem,
Divider,
Text, Text,
Box, Box,
type BoxProps, type BoxProps,
type LinkProps, type LinkProps,
Link, Link,
Flex,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import Head from "next/head"; import Head from "next/head";
import { BsGithub, BsTwitter } from "react-icons/bs"; import { BsGithub, BsPersonCircle } from "react-icons/bs";
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import PublicPlaygroundWarning from "../PublicPlaygroundWarning";
import { type IconType } from "react-icons"; import { type IconType } from "react-icons";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import { useState, useEffect } from "react"; import { signIn, useSession } from "next-auth/react";
import UserMenu from "./UserMenu";
type IconLinkProps = BoxProps & LinkProps & { label: string; icon: IconType; href: string }; type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => { const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
const isActive = useRouter().pathname.startsWith(href); const router = useRouter();
const isActive = href && router.pathname.startsWith(href);
return ( return (
<Box <HStack
w="full"
p={4}
color={color}
as={Link} as={Link}
href={href} href={href}
target={target} target={target}
w="full" bgColor={isActive ? "gray.200" : "transparent"}
bgColor={isActive ? "gray.300" : "transparent"} _hover={{ bgColor: "gray.200", textDecoration: "none" }}
_hover={{ bgColor: "gray.300" }}
py={4}
justifyContent="start" justifyContent="start"
cursor="pointer" cursor="pointer"
{...props} {...props}
> >
<HStack w="full" px={4} color={color}> <Icon as={icon} boxSize={6} mr={2} />
<Icon as={icon} boxSize={6} mr={2} /> <Text fontWeight="bold" fontSize="sm">
<Text fontWeight="bold">{label}</Text> {label}
</HStack> </Text>
</Box> </HStack>
); );
}; };
const Divider = () => <Box h="1px" bgColor="gray.200" />;
const NavSidebar = () => { const NavSidebar = () => {
const user = useSession().data;
return ( return (
<VStack align="stretch" bgColor="gray.100" py={2} pb={0} height="100%"> <VStack
<Link href="/" w="full" _hover={{ textDecoration: "none" }}> align="stretch"
<HStack spacing={0} pl="3"> bgColor="gray.100"
<Image src="/logo.svg" alt="" w={8} h={8} /> py={2}
<Heading size="md" p={2} pl={{ base: 16, md: 2 }}> pb={0}
OpenPipe height="100%"
</Heading> w={{ base: "56px", md: "200px" }}
</HStack> overflow="hidden"
</Link> >
<Divider /> <HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={4} py={2}>
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
<Heading size="md" fontFamily="inconsolata, monospace">
OpenPipe
</Heading>
</HStack>
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}> <VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" /> {user != null && (
<>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
</>
)}
{user === null && (
<IconLink
icon={BsPersonCircle}
label="Sign In"
onClick={() => {
signIn("github").catch(console.error);
}}
/>
)}
</VStack> </VStack>
<Divider /> {user ? <UserMenu user={user} /> : <Divider />}
<VStack w="full" spacing={0} pb={2}> <VStack spacing={0} align="center">
<IconLink <Link
icon={BsGithub}
label="GitHub"
href="https://github.com/openpipe/openpipe" href="https://github.com/openpipe/openpipe"
target="_blank" target="_blank"
color="gray.500" color="gray.500"
_hover={{ color: "gray.800" }} _hover={{ color: "gray.800" }}
/> p={2}
<IconLink >
icon={BsTwitter} <Icon as={BsGithub} boxSize={6} />
label="Twitter" </Link>
href="https://twitter.com/corbtt"
target="_blank"
color="gray.500"
_hover={{ color: "gray.800" }}
/>
</VStack> </VStack>
</VStack> </VStack>
); );
@@ -105,25 +121,14 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
}, []); }, []);
return ( return (
<Grid <Flex h={vh} w="100vw">
h={vh}
w="100vw"
templateColumns={{ base: "56px minmax(0, 1fr)", md: "200px minmax(0, 1fr)" }}
templateRows="max-content 1fr"
templateAreas={'"warning warning"\n"sidebar main"'}
>
<Head> <Head>
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title> <title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
</Head> </Head>
<GridItem area="warning"> <NavSidebar />
<PublicPlaygroundWarning /> <Box h="100%" flex={1} overflowY="auto">
</GridItem>
<GridItem area="sidebar" overflow="hidden">
<NavSidebar />
</GridItem>
<GridItem area="main" overflowY="auto">
{props.children} {props.children}
</GridItem> </Box>
</Grid> </Flex>
); );
} }

View File

@@ -0,0 +1,74 @@
import {
HStack,
Icon,
Image,
VStack,
Text,
Popover,
PopoverTrigger,
PopoverContent,
Link,
} from "@chakra-ui/react";
import { type Session } from "next-auth";
import { signOut } from "next-auth/react";
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
export default function UserMenu({ user }: { user: Session }) {
const profileImage = user.user.image ? (
<Image src={user.user.image} alt="profile picture" boxSize={8} borderRadius="50%" />
) : (
<Icon as={BsPersonCircle} boxSize={6} />
);
return (
<>
<Popover placement="right">
<PopoverTrigger>
<HStack
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
px={3}
spacing={3}
py={2}
borderColor={"gray.200"}
borderTopWidth={1}
borderBottomWidth={1}
cursor="pointer"
_hover={{
bgColor: "gray.200",
}}
>
{profileImage}
<VStack spacing={0} align="start" flex={1} flexShrink={1}>
<Text fontWeight="bold" fontSize="sm">
{user.user.name}
</Text>
<Text color="gray.500" fontSize="xs">
{user.user.email}
</Text>
</VStack>
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
</HStack>
</PopoverTrigger>
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
<VStack align="stretch" spacing={0}>
{/* sign out */}
<HStack
as={Link}
onClick={() => {
signOut().catch(console.error);
}}
px={4}
py={2}
spacing={4}
color="gray.500"
fontSize="sm"
>
<Icon as={BsBoxArrowRight} boxSize={6} />
<Text>Sign out</Text>
</HStack>
</VStack>
</PopoverContent>
</Popover>
</>
);
}

View File

@@ -20,7 +20,6 @@ export const CostTooltip = ({
color="gray.800" color="gray.800"
bgColor="gray.50" bgColor="gray.50"
borderWidth={1} borderWidth={1}
py={2}
hasArrow hasArrow
shouldWrapChildren shouldWrapChildren
label={ label={

View File

@@ -9,7 +9,15 @@ export const env = createEnv({
server: { server: {
DATABASE_URL: z.string().url(), DATABASE_URL: z.string().url(),
NODE_ENV: z.enum(["development", "test", "production"]).default("development"), NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
RESTRICT_PRISMA_LOGS: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
GITHUB_CLIENT_ID: z.string().min(1),
GITHUB_CLIENT_SECRET: z.string().min(1),
OPENAI_API_KEY: z.string().min(1), OPENAI_API_KEY: z.string().min(1),
REPLICATE_API_TOKEN: z.string().default("placeholder"),
}, },
/** /**
@@ -19,11 +27,6 @@ export const env = createEnv({
*/ */
client: { client: {
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(), NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"), NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
}, },
@@ -35,9 +38,12 @@ export const env = createEnv({
DATABASE_URL: process.env.DATABASE_URL, DATABASE_URL: process.env.DATABASE_URL,
NODE_ENV: process.env.NODE_ENV, NODE_ENV: process.env.NODE_ENV,
OPENAI_API_KEY: process.env.OPENAI_API_KEY, OPENAI_API_KEY: process.env.OPENAI_API_KEY,
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY, NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL, NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
}, },
/** /**
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.

View File

@@ -0,0 +1,15 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend";
import { type SupportedProvider, type FrontendModelProvider } from "./types";
// TODO: make sure we get a typescript error if you forget to add a provider here
// Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server.
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend,
};
export default frontendModelProviders;

View File

@@ -0,0 +1,36 @@
import { type JSONSchema4Object } from "json-schema";
import modelProviders from "./modelProviders";
import { compile } from "json-schema-to-typescript";
import dedent from "dedent";
export default async function generateTypes() {
const combinedSchema = {
type: "object",
properties: {} as Record<string, JSONSchema4Object>,
};
Object.entries(modelProviders).forEach(([id, provider]) => {
combinedSchema.properties[id] = provider.inputSchema;
});
Object.entries(modelProviders).forEach(([id, provider]) => {
combinedSchema.properties[id] = provider.inputSchema;
});
const promptTypes = (
await compile(combinedSchema as JSONSchema4Object, "PromptTypes", {
additionalProperties: false,
bannerComment: dedent`
/**
* This type map defines the input types for each model provider.
*/
`,
})
).replace(/export interface PromptTypes/g, "interface PromptTypes");
return dedent`
${promptTypes}
declare function definePrompt<T extends keyof PromptTypes>(modelProvider: T, input: PromptTypes[T])
`;
}

View File

@@ -0,0 +1,10 @@
import openaiChatCompletion from "./openai-ChatCompletion";
import replicateLlama2 from "./replicate-llama2";
import { type SupportedProvider, type ModelProvider } from "./types";
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
"openai/ChatCompletion": openaiChatCompletion,
"replicate/llama2": replicateLlama2,
};
export default modelProviders;

View File

@@ -0,0 +1,77 @@
/* eslint-disable @typescript-eslint/no-var-requires */
import YAML from "yaml";
import fs from "fs";
import path from "path";
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
import $RefParser from "@apidevtools/json-schema-ref-parser";
import { type JSONObject } from "superjson/dist/types";
import assert from "assert";
import { type JSONSchema4Object } from "json-schema";
import { isObject } from "lodash-es";
// @ts-expect-error for some reason missing from types
import parserEstree from "prettier/plugins/estree";
import parserBabel from "prettier/plugins/babel";
import prettier from "prettier/standalone";
const OPENAPI_URL =
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
// Fetch the openapi document
const response = await fetch(OPENAPI_URL);
const openApiYaml = await response.text();
// Parse the yaml document
let schema = YAML.parse(openApiYaml) as JSONObject;
schema = openapiSchemaToJsonSchema(schema);
const jsonSchema = await $RefParser.dereference(schema);
assert("components" in jsonSchema);
const completionRequestSchema = jsonSchema.components.schemas
.CreateChatCompletionRequest as JSONSchema4Object;
// We need to do a bit of surgery here since the Monaco editor doesn't like
// the fact that the schema says `model` can be either a string or an enum,
// and displays a warning in the editor. Let's stick with just an enum for
// now and drop the string option.
assert(
"properties" in completionRequestSchema &&
isObject(completionRequestSchema.properties) &&
"model" in completionRequestSchema.properties &&
isObject(completionRequestSchema.properties.model),
);
const modelProperty = completionRequestSchema.properties.model;
assert(
"oneOf" in modelProperty &&
Array.isArray(modelProperty.oneOf) &&
modelProperty.oneOf.length === 2 &&
isObject(modelProperty.oneOf[1]) &&
"enum" in modelProperty.oneOf[1],
"Expected model to have oneOf length of 2",
);
modelProperty.type = "string";
modelProperty.enum = modelProperty.oneOf[1].enum;
delete modelProperty["oneOf"];
// The default of "inf" confuses the Typescript generator, so can just remove it
assert(
"max_tokens" in completionRequestSchema.properties &&
isObject(completionRequestSchema.properties.max_tokens) &&
"default" in completionRequestSchema.properties.max_tokens,
);
delete completionRequestSchema.properties.max_tokens["default"];
// Get the directory of the current script
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
// Write the JSON schema to a file in the current directory
fs.writeFileSync(
path.join(currentDirectory, "input.schema.json"),
await prettier.format(JSON.stringify(completionRequestSchema, null, 2), {
parser: "json",
plugins: [parserBabel, parserEstree],
}),
);

View File

@@ -0,0 +1,185 @@
{
"type": "object",
"properties": {
"model": {
"description": "ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.",
"example": "gpt-3.5-turbo",
"type": "string",
"enum": [
"gpt-4",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613"
]
},
"messages": {
"description": "A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb).",
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"enum": ["system", "user", "assistant", "function"],
"description": "The role of the messages author. One of `system`, `user`, `assistant`, or `function`."
},
"content": {
"type": "string",
"description": "The contents of the message. `content` is required for all messages except assistant messages with function calls."
},
"name": {
"type": "string",
"description": "The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters."
},
"function_call": {
"type": "object",
"description": "The name and arguments of a function that should be called, as generated by the model.",
"properties": {
"name": {
"type": "string",
"description": "The name of the function to call."
},
"arguments": {
"type": "string",
"description": "The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function."
}
}
}
},
"required": ["role"]
}
},
"functions": {
"description": "A list of functions the model may generate JSON inputs for.",
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64."
},
"description": {
"type": "string",
"description": "The description of what the function does."
},
"parameters": {
"type": "object",
"description": "The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format.",
"additionalProperties": true
}
},
"required": ["name"]
}
},
"function_call": {
"description": "Controls how the model responds to function calls. \"none\" means the model does not call a function, and responds to the end-user. \"auto\" means the model can pick between an end-user or calling a function. Specifying a particular function via `{\"name\":\\ \"my_function\"}` forces the model to call that function. \"none\" is the default when no functions are present. \"auto\" is the default if functions are present.",
"oneOf": [
{
"type": "string",
"enum": ["none", "auto"]
},
{
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the function to call."
}
},
"required": ["name"]
}
]
},
"temperature": {
"type": "number",
"minimum": 0,
"maximum": 2,
"default": 1,
"example": 1,
"nullable": true,
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n"
},
"top_p": {
"type": "number",
"minimum": 0,
"maximum": 1,
"default": 1,
"example": 1,
"nullable": true,
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n"
},
"n": {
"type": "integer",
"minimum": 1,
"maximum": 128,
"default": 1,
"example": 1,
"nullable": true,
"description": "How many chat completion choices to generate for each input message."
},
"stream": {
"description": "If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).\n",
"type": "boolean",
"nullable": true,
"default": false
},
"stop": {
"description": "Up to 4 sequences where the API will stop generating further tokens.\n",
"default": null,
"oneOf": [
{
"type": "string",
"nullable": true
},
{
"type": "array",
"minItems": 1,
"maxItems": 4,
"items": {
"type": "string"
}
}
]
},
"max_tokens": {
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
"type": "integer"
},
"presence_penalty": {
"type": "number",
"default": 0,
"minimum": -2,
"maximum": 2,
"nullable": true,
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
},
"frequency_penalty": {
"type": "number",
"default": 0,
"minimum": -2,
"maximum": 2,
"nullable": true,
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
},
"logit_bias": {
"type": "object",
"x-oaiTypeLabel": "map",
"default": null,
"nullable": true,
"description": "Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n"
},
"user": {
"type": "string",
"example": "user-1234",
"description": "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n"
}
},
"required": ["model", "messages"]
}

View File

@@ -0,0 +1,84 @@
import { type JsonValue } from "type-fest";
import { type SupportedModel } from ".";
import { type FrontendModelProvider } from "../types";
import { type ChatCompletion } from "openai/resources/chat";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
contextWindow: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
contextWindow: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
contextWindow: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
contextWindow: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
normalizeOutput: (output) => {
const message = output.choices[0]?.message;
if (!message)
return {
type: "json",
value: output as unknown as JsonValue,
};
if (message.content) {
return {
type: "text",
value: message.content,
};
} else if (message.function_call) {
let args = message.function_call.arguments ?? "";
try {
args = JSON.parse(args);
} catch (e) {
// Ignore
}
return {
type: "json",
value: {
...message.function_call,
arguments: args,
},
};
} else {
return {
type: "json",
value: message as unknown as JsonValue,
};
}
},
};
export default frontendModelProvider;

View File

@@ -0,0 +1,140 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import {
type ChatCompletionChunk,
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types";
import { omit } from "lodash-es";
import { openai } from "~/server/utils/openai";
import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai";
import frontendModelProvider from "./frontend";
import modelProvider, { type SupportedModel } from ".";
const mergeStreamedChunks = (
base: ChatCompletion | null,
chunk: ChatCompletionChunk,
): ChatCompletion => {
if (base === null) {
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
}
const choices = [...base.choices];
for (const choice of chunk.choices) {
const baseChoice = choices.find((c) => c.index === choice.index);
if (baseChoice) {
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
baseChoice.message = baseChoice.message ?? { role: "assistant" };
if (choice.delta?.content)
baseChoice.message.content =
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
if (choice.delta?.function_call) {
const fnCall = baseChoice.message.function_call ?? {};
fnCall.name =
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
fnCall.arguments =
((fnCall.arguments as string) ?? "") +
((choice.delta.function_call.arguments as string) ?? "");
}
} else {
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
}
}
const merged: ChatCompletion = {
...base,
choices,
};
return merged;
};
export async function getCompletion(
input: CompletionCreateParams,
onStream: ((partialOutput: ChatCompletion) => void) | null,
): Promise<CompletionResponse<ChatCompletion>> {
const start = Date.now();
let finalCompletion: ChatCompletion | null = null;
let promptTokens: number | undefined = undefined;
let completionTokens: number | undefined = undefined;
const modelName = modelProvider.getModel(input) as SupportedModel;
try {
if (onStream) {
const resp = await openai.chat.completions.create(
{ ...input, stream: true },
{
maxRetries: 0,
},
);
for await (const part of resp) {
finalCompletion = mergeStreamedChunks(finalCompletion, part);
onStream(finalCompletion);
}
if (!finalCompletion) {
return {
type: "error",
message: "Streaming failed to return a completion",
autoRetry: false,
};
}
try {
promptTokens = countOpenAIChatTokens(modelName, input.messages);
completionTokens = countOpenAIChatTokens(
modelName,
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
} else {
const resp = await openai.chat.completions.create(
{ ...input, stream: false },
{
maxRetries: 0,
},
);
finalCompletion = resp;
promptTokens = resp.usage?.prompt_tokens ?? 0;
completionTokens = resp.usage?.completion_tokens ?? 0;
}
const timeToComplete = Date.now() - start;
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
let cost = undefined;
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
}
return {
type: "success",
statusCode: 200,
value: finalCompletion,
timeToComplete,
promptTokens,
completionTokens,
cost,
};
} catch (error: unknown) {
console.error("ERROR IS", error);
if (error instanceof APIError) {
return {
type: "error",
message: error.message,
autoRetry: error.status === 429 || error.status === 503,
statusCode: error.status,
};
} else {
console.error(error);
return {
type: "error",
message: (error as Error).message,
autoRetry: true,
};
}
}
}

View File

@@ -0,0 +1,45 @@
import { type JSONSchema4 } from "json-schema";
import { type ModelProvider } from "../types";
import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend";
const supportedModels = [
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
] as const;
export type SupportedModel = (typeof supportedModels)[number];
export type OpenaiChatModelProvider = ModelProvider<
SupportedModel,
CompletionCreateParams,
ChatCompletion
>;
const modelProvider: OpenaiChatModelProvider = {
getModel: (input) => {
if (supportedModels.includes(input.model as SupportedModel))
return input.model as SupportedModel;
const modelMaps: Record<string, SupportedModel> = {
"gpt-4": "gpt-4-0613",
"gpt-4-32k": "gpt-4-32k-0613",
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
};
if (input.model in modelMaps) return modelMaps[input.model] as SupportedModel;
return null;
},
inputSchema: inputSchema as JSONSchema4,
shouldStream: (input) => input.stream ?? false,
getCompletion,
...frontendModelProvider,
};
export default modelProvider;

View File

@@ -0,0 +1,42 @@
import { type SupportedModel, type ReplicateLlama2Output } from ".";
import { type FrontendModelProvider } from "../types";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
name: "Replicate Llama2",
models: {
"7b-chat": {
name: "LLama 2 7B Chat",
contextWindow: 4096,
pricePerSecond: 0.0023,
speed: "fast",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
},
"13b-chat": {
name: "LLama 2 13B Chat",
contextWindow: 4096,
pricePerSecond: 0.0023,
speed: "medium",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
},
"70b-chat": {
name: "LLama 2 70B Chat",
contextWindow: 4096,
pricePerSecond: 0.0032,
speed: "slow",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
},
},
normalizeOutput: (output) => {
return {
type: "text",
value: output.join(""),
};
},
};
export default frontendModelProvider;

View File

@@ -0,0 +1,62 @@
import { env } from "~/env.mjs";
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
import { type CompletionResponse } from "../types";
import Replicate from "replicate";
const replicate = new Replicate({
auth: env.REPLICATE_API_TOKEN || "",
});
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
};
export async function getCompletion(
input: ReplicateLlama2Input,
onStream: ((partialOutput: string[]) => void) | null,
): Promise<CompletionResponse<ReplicateLlama2Output>> {
const start = Date.now();
const { model, stream, ...rest } = input;
try {
const prediction = await replicate.predictions.create({
version: modelIds[model],
input: rest,
});
console.log("stream?", onStream);
const interval = onStream
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
setInterval(async () => {
const partialPrediction = await replicate.predictions.get(prediction.id);
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
}, 500)
: null;
const resp = await replicate.wait(prediction, {});
if (interval) clearInterval(interval);
const timeToComplete = Date.now() - start;
if (resp.error) throw new Error(resp.error as string);
return {
type: "success",
statusCode: 200,
value: resp.output as ReplicateLlama2Output,
timeToComplete,
};
} catch (error: unknown) {
console.error("ERROR IS", error);
return {
type: "error",
message: (error as Error).message,
autoRetry: true,
};
}
}

View File

@@ -0,0 +1,70 @@
import { type ModelProvider } from "../types";
import frontendModelProvider from "./frontend";
import { getCompletion } from "./getCompletion";
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
export type SupportedModel = (typeof supportedModels)[number];
export type ReplicateLlama2Input = {
model: SupportedModel;
prompt: string;
stream?: boolean;
max_length?: number;
temperature?: number;
top_p?: number;
repetition_penalty?: number;
debug?: boolean;
};
export type ReplicateLlama2Output = string[];
export type ReplicateLlama2Provider = ModelProvider<
SupportedModel,
ReplicateLlama2Input,
ReplicateLlama2Output
>;
const modelProvider: ReplicateLlama2Provider = {
getModel: (input) => {
if (supportedModels.includes(input.model)) return input.model;
return null;
},
inputSchema: {
type: "object",
properties: {
model: {
type: "string",
enum: supportedModels as unknown as string[],
},
prompt: {
type: "string",
},
stream: {
type: "boolean",
},
max_length: {
type: "number",
},
temperature: {
type: "number",
},
top_p: {
type: "number",
},
repetition_penalty: {
type: "number",
},
debug: {
type: "boolean",
},
},
required: ["model", "prompt"],
},
shouldStream: (input) => input.stream ?? false,
getCompletion,
...frontendModelProvider,
};
export default modelProvider;

View File

@@ -0,0 +1,66 @@
import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest";
import { z } from "zod";
const ZodSupportedProvider = z.union([
z.literal("openai/ChatCompletion"),
z.literal("replicate/llama2"),
]);
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
export const ZodModel = z.object({
name: z.string(),
contextWindow: z.number(),
promptTokenPrice: z.number().optional(),
completionTokenPrice: z.number().optional(),
pricePerSecond: z.number().optional(),
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
provider: ZodSupportedProvider,
description: z.string().optional(),
learnMoreUrl: z.string().optional(),
});
export type Model = z.infer<typeof ZodModel>;
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
name: string;
models: Record<SupportedModels, Model>;
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
};
export type CompletionResponse<T> =
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
| {
type: "success";
value: T;
timeToComplete: number;
statusCode: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
};
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
getModel: (input: InputSchema) => SupportedModels | null;
shouldStream: (input: InputSchema) => boolean;
inputSchema: JSONSchema4;
getCompletion: (
input: InputSchema,
onStream: ((partialOutput: OutputSchema) => void) | null,
) => Promise<CompletionResponse<OutputSchema>>;
// This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null;
} & FrontendModelProvider<SupportedModels, OutputSchema>;
export type NormalizedOutput =
| {
type: "text";
value: string;
}
| {
type: "json";
value: JsonValue;
};

View File

@@ -2,22 +2,32 @@ import { type Session } from "next-auth";
import { SessionProvider } from "next-auth/react"; import { SessionProvider } from "next-auth/react";
import { type AppType } from "next/app"; import { type AppType } from "next/app";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { ChakraProvider } from "@chakra-ui/react";
import theme from "~/utils/theme";
import Favicon from "~/components/Favicon"; import Favicon from "~/components/Favicon";
import "~/utils/analytics"; import "~/utils/analytics";
import Head from "next/head";
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
import { SyncAppStore } from "~/state/sync";
const MyApp: AppType<{ session: Session | null }> = ({ const MyApp: AppType<{ session: Session | null }> = ({
Component, Component,
pageProps: { session, ...pageProps }, pageProps: { session, ...pageProps },
}) => { }) => {
return ( return (
<SessionProvider session={session}> <>
<Favicon /> <Head>
<ChakraProvider theme={theme}> <meta
<Component {...pageProps} /> name="viewport"
</ChakraProvider> content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
</SessionProvider> />
</Head>
<SessionProvider session={session}>
<SyncAppStore />
<Favicon />
<ChakraThemeProvider>
<Component {...pageProps} />
</ChakraThemeProvider>
</SessionProvider>
</>
); );
}; };

View File

@@ -0,0 +1,23 @@
import { signIn, useSession } from "next-auth/react";
import { useRouter } from "next/router";
import { useEffect } from "react";
import AppShell from "~/components/nav/AppShell";
export default function SignIn() {
const session = useSession().data;
const router = useRouter();
useEffect(() => {
if (session) {
router.push("/experiments").catch(console.error);
} else if (session === null) {
signIn("github").catch(console.error);
}
}, [session, router]);
return (
<AppShell>
<div />
</AppShell>
);
}

View File

@@ -49,6 +49,10 @@ const DeleteButton = () => {
onClose(); onClose();
}, [mutation, experiment.data?.id, router]); }, [mutation, experiment.data?.id, router]);
useEffect(() => {
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
});
return ( return (
<> <>
<Button <Button
@@ -124,6 +128,8 @@ export default function Experiment() {
); );
} }
const canModify = experiment.data?.access.canModify ?? false;
return ( return (
<AppShell title={experiment.data?.label}> <AppShell title={experiment.data?.label}>
<VStack h="full"> <VStack h="full">
@@ -143,37 +149,45 @@ export default function Experiment() {
</Link> </Link>
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem isCurrentPage> <BreadcrumbItem isCurrentPage>
<Input {canModify ? (
size="sm" <Input
value={label} size="sm"
onChange={(e) => setLabel(e.target.value)} value={label}
onBlur={onSaveLabel} onChange={(e) => setLabel(e.target.value)}
borderWidth={1} onBlur={onSaveLabel}
borderColor="transparent" borderWidth={1}
fontSize={16} borderColor="transparent"
px={0} fontSize={16}
minW={{ base: 100, lg: 300 }} px={0}
flex={1} minW={{ base: 100, lg: 300 }}
_hover={{ borderColor: "gray.300" }} flex={1}
_focus={{ borderColor: "blue.500", outline: "none" }} _hover={{ borderColor: "gray.300" }}
/> _focus={{ borderColor: "blue.500", outline: "none" }}
/>
) : (
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
{experiment.data?.label}
</Text>
)}
</BreadcrumbItem> </BreadcrumbItem>
</Breadcrumb> </Breadcrumb>
<HStack> {canModify && (
<Button <HStack>
size="sm" <Button
variant={{ base: "outline", lg: "ghost" }} size="sm"
colorScheme="gray" variant={{ base: "outline", lg: "ghost" }}
fontWeight="normal" colorScheme="gray"
onClick={openDrawer} fontWeight="normal"
> onClick={openDrawer}
<Icon as={BsGearFill} boxSize={4} color="gray.600" /> >
<Text display={{ base: "none", lg: "block" }} ml={2}> <Icon as={BsGearFill} boxSize={4} color="gray.600" />
Edit Vars & Evals <Text display={{ base: "none", lg: "block" }} ml={2}>
</Text> Edit Vars & Evals
</Button> </Text>
<DeleteButton /> </Button>
</HStack> <DeleteButton />
</HStack>
)}
</Flex> </Flex>
<SettingsDrawer /> <SettingsDrawer />
<Box w="100%" overflowX="auto" flex={1}> <Box w="100%" overflowX="auto" flex={1}>

View File

@@ -1,25 +1,53 @@
import { import {
SimpleGrid, SimpleGrid,
HStack,
Icon, Icon,
VStack, VStack,
Breadcrumb, Breadcrumb,
BreadcrumbItem, BreadcrumbItem,
Flex, Flex,
Center,
Text,
Link,
HStack,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import AppShell from "~/components/nav/AppShell"; import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton"; import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
import { ExperimentCard } from "~/components/experiments/ExperimentCard"; import { signIn, useSession } from "next-auth/react";
export default function ExperimentsPage() { export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery(); const experiments = api.experiments.list.useQuery();
const user = useSession().data;
const authLoading = useSession().status === "loading";
if (user === null || authLoading) {
return (
<AppShell title="Experiments">
<Center h="100%">
{!authLoading && (
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
)}
</Center>
</AppShell>
);
}
return ( return (
<AppShell> <AppShell title="Experiments">
<VStack alignItems={"flex-start"} m={4} mt={1}> <VStack alignItems={"flex-start"} px={4} py={2}>
<HStack w="full" justifyContent="space-between" mb={4}> <HStack minH={8} align="center">
<Breadcrumb flex={1}> <Breadcrumb flex={1}>
<BreadcrumbItem> <BreadcrumbItem>
<Flex alignItems="center"> <Flex alignItems="center">
@@ -27,9 +55,9 @@ export default function ExperimentsPage() {
</Flex> </Flex>
</BreadcrumbItem> </BreadcrumbItem>
</Breadcrumb> </Breadcrumb>
<NewExperimentButton mr={4} borderRadius={8} />
</HStack> </HStack>
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4"> <SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
<NewExperimentCard />
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)} {experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
</SimpleGrid> </SimpleGrid>
</VStack> </VStack>

View File

@@ -1,7 +1,7 @@
import { type CompletionCreateParams } from "openai/resources/chat"; import { type CompletionCreateParams } from "openai/resources/chat";
import { prisma } from "../db"; import { prisma } from "../db";
import { openai } from "../utils/openai"; import { openai } from "../utils/openai";
import { pick } from "lodash"; import { pick } from "lodash-es";
type AxiosError = { type AxiosError = {
response?: { response?: {

View File

@@ -1,68 +1,96 @@
import { EvaluationMatchType } from "@prisma/client"; import { EvalType } from "@prisma/client";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { reevaluateEvaluation } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const evaluationsRouter = createTRPCRouter({ export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure
return await prisma.evaluation.findMany({ .input(z.object({ experimentId: z.string() }))
where: { .query(async ({ input, ctx }) => {
experimentId: input.experimentId, await requireCanViewExperiment(input.experimentId, ctx);
},
orderBy: { createdAt: "asc" },
});
}),
create: publicProcedure return await prisma.evaluation.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: { createdAt: "asc" },
});
}),
create: protectedProcedure
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
name: z.string(), label: z.string(),
matchString: z.string(), value: z.string(),
matchType: z.nativeEnum(EvaluationMatchType), evalType: z.nativeEnum(EvalType),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const evaluation = await prisma.evaluation.create({ await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.evaluation.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
name: input.name, label: input.label,
matchString: input.matchString, value: input.value,
matchType: input.matchType, evalType: input.evalType,
}, },
}); });
await reevaluateEvaluation(evaluation);
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
// to kick off a background job or something instead
await runAllEvals(input.experimentId);
}), }),
update: publicProcedure update: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
updates: z.object({ updates: z.object({
name: z.string().optional(), label: z.string().optional(),
matchString: z.string().optional(), value: z.string().optional(),
matchType: z.nativeEnum(EvaluationMatchType).optional(), evalType: z.nativeEnum(EvalType).optional(),
}), }),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
await prisma.evaluation.update({ const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const evaluation = await prisma.evaluation.update({
where: { id: input.id }, where: { id: input.id },
data: { data: {
name: input.updates.name, label: input.updates.label,
matchString: input.updates.matchString, value: input.updates.value,
matchType: input.updates.matchType, evalType: input.updates.evalType,
}, },
}); });
await reevaluateEvaluation(
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }), await prisma.outputEvaluation.deleteMany({
); where: {
evaluationId: evaluation.id,
},
});
// Re-run all evals. Other eval results will already be cached, so this
// should only re-run the updated one.
await runAllEvals(evaluation.experimentId);
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: protectedProcedure
await prisma.evaluation.delete({ .input(z.object({ id: z.string() }))
where: { id: input.id }, .mutation(async ({ input, ctx }) => {
}); const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
}), where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
await prisma.evaluation.delete({
where: { id: input.id },
});
}),
}); });

View File

@@ -1,14 +1,32 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent"; import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import {
canModifyExperiment,
requireCanModifyExperiment,
requireCanViewExperiment,
requireNothing,
} from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
import generateTypes from "~/modelProviders/generateTypes";
export const experimentsRouter = createTRPCRouter({ export const experimentsRouter = createTRPCRouter({
list: publicProcedure.query(async () => { list: protectedProcedure.query(async ({ ctx }) => {
// Anyone can list experiments
requireNothing(ctx);
const experiments = await prisma.experiment.findMany({ const experiments = await prisma.experiment.findMany({
where: {
organization: {
OrganizationUser: {
some: { userId: ctx.session.user.id },
},
},
},
orderBy: { orderBy: {
sortIndex: "asc", sortIndex: "desc",
}, },
}); });
@@ -40,15 +58,29 @@ export const experimentsRouter = createTRPCRouter({
return experimentsWithCounts; return experimentsWithCounts;
}), }),
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => { get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
return await prisma.experiment.findFirst({ await requireCanViewExperiment(input.id, ctx);
where: { const experiment = await prisma.experiment.findFirstOrThrow({
id: input.id, where: { id: input.id },
},
}); });
const canModify = ctx.session?.user.id
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
: false;
return {
...experiment,
access: {
canView: true,
canModify,
},
};
}), }),
create: publicProcedure.input(z.object({})).mutation(async () => { create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
// Anyone can create an experiment
requireNothing(ctx);
const maxSortIndex = const maxSortIndex =
( (
await prisma.experiment.aggregate({ await prisma.experiment.aggregate({
@@ -62,47 +94,85 @@ export const experimentsRouter = createTRPCRouter({
data: { data: {
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`, label: `Experiment ${maxSortIndex + 1}`,
organizationId: (await userOrg(ctx.session.user.id)).id,
}, },
}); });
const [variant, _, scenario] = await prisma.$transaction([ const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
prisma.promptVariant.create({ prisma.promptVariant.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
constructFn: dedent`prompt = { // The interpolated $ is necessary until dedent incorporates
// https://github.com/dmnd/dedent/pull/46
constructFn: dedent`
/**
* Use Javascript to define an OpenAI chat completion
* (https://platform.openai.com/docs/api-reference/chat/create).
*
* You have access to the current scenario in the \`scenario\`
* variable.
*/
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
stream: true, stream: true,
messages: [{ role: "system", content: ${"`Return '${scenario.text}'`"} }], messages: [
}`, {
role: "system",
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
},
],
});`,
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 2,
}, },
}), }),
prisma.templateVariable.create({ prisma.templateVariable.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
label: "text", label: "language",
}, },
}), }),
prisma.testScenario.create({ prisma.testScenario.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
variableValues: { variableValues: {
text: "This is a test scenario.", language: "English",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "Spanish",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "German",
}, },
}, },
}), }),
]); ]);
await generateNewCell(variant.id, scenario.id); await generateNewCell(variant.id, scenario1.id);
await generateNewCell(variant.id, scenario2.id);
await generateNewCell(variant.id, scenario3.id);
return exp; return exp;
}), }),
update: publicProcedure update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) })) .input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.id, ctx);
return await prisma.experiment.update({ return await prisma.experiment.update({
where: { where: {
id: input.id, id: input.id,
@@ -113,11 +183,21 @@ export const experimentsRouter = createTRPCRouter({
}); });
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: protectedProcedure
await prisma.experiment.delete({ .input(z.object({ id: z.string() }))
where: { .mutation(async ({ input, ctx }) => {
id: input.id, await requireCanModifyExperiment(input.id, ctx);
},
}); await prisma.experiment.delete({
where: {
id: input.id,
},
});
}),
// Keeping these on `experiment` for now because we might want to limit the
// providers based on your account/experiment
promptTypes: publicProcedure.query(async () => {
return await generateTypes();
}), }),
}); });

View File

@@ -1,102 +1,174 @@
import dedent from "dedent";
import { isObject } from "lodash";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error"; import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
import parseConstructFn from "~/server/utils/parseConstructFn";
import { ZodModel } from "~/modelProviders/types";
export const promptVariantsRouter = createTRPCRouter({ export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure
return await prisma.promptVariant.findMany({ .input(z.object({ experimentId: z.string() }))
where: { .query(async ({ input, ctx }) => {
experimentId: input.experimentId, await requireCanViewExperiment(input.experimentId, ctx);
visible: true,
},
orderBy: { sortIndex: "asc" },
});
}),
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => { return await prisma.promptVariant.findMany({
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
if (!variant) {
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
const evalResults = await prisma.evaluationResult.findMany({
where: {
promptVariantId: input.variantId,
},
include: { evaluation: true },
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
isNot: null,
},
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: {
visible: true,
},
},
},
_sum: {
promptTokens: true,
completionTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
const overallCost = overallPromptCost + overallCompletionCost;
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
}),
create: publicProcedure
.input(
z.object({
experimentId: z.string(),
}),
)
.mutation(async ({ input }) => {
const lastVariant = await prisma.promptVariant.findFirst({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
visible: true, visible: true,
}, },
orderBy: { orderBy: { sortIndex: "asc" },
sortIndex: "desc", });
}),
stats: publicProcedure
.input(z.object({ variantId: z.string() }))
.query(async ({ input, ctx }) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
}, },
}); });
if (!variant) {
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
await requireCanViewExperiment(variant.experimentId, ctx);
const outputEvals = await prisma.outputEvaluation.groupBy({
by: ["evaluationId"],
_sum: {
result: true,
},
_count: {
id: true,
},
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find(
(outputEval) => outputEval.evaluationId === evalItem.id,
);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
},
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: {
visible: true,
},
},
},
_sum: {
cost: true,
promptTokens: true,
completionTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
// Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
},
},
}));
return {
evalResults,
promptTokens,
completionTokens,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}),
create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
variantId: z.string().optional(),
newModel: ZodModel.optional(),
}),
)
.mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
let originalVariant: PromptVariant | null = null;
if (input.variantId) {
originalVariant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
} else {
originalVariant = await prisma.promptVariant.findFirst({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "desc",
},
});
}
const largestSortIndex = const largestSortIndex =
( (
await prisma.promptVariant.aggregate({ await prisma.promptVariant.aggregate({
@@ -109,24 +181,22 @@ export const promptVariantsRouter = createTRPCRouter({
}) })
)._max?.sortIndex ?? 0; )._max?.sortIndex ?? 0;
const newVariantLabel =
input.variantId && originalVariant
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
const createNewVariantAction = prisma.promptVariant.create({ const createNewVariantAction = prisma.promptVariant.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`, label: newVariantLabel,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1, sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn: constructFn: newConstructFn,
lastVariant?.constructFn ?? constructFnVersion: 2,
dedent` model: originalVariant?.model ?? "gpt-3.5-turbo",
prompt = { modelProvider: originalVariant?.modelProvider ?? "openai/ChatCompletion",
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`,
model: lastVariant?.model ?? "gpt-3.5-turbo",
}, },
}); });
@@ -135,6 +205,11 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
if (originalVariant) {
// Insert new variant to right of original variant
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
}
const scenarios = await prisma.testScenario.findMany({ const scenarios = await prisma.testScenario.findMany({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
@@ -149,7 +224,7 @@ export const promptVariantsRouter = createTRPCRouter({
return newVariant; return newVariant;
}), }),
update: publicProcedure update: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
@@ -158,7 +233,7 @@ export const promptVariantsRouter = createTRPCRouter({
}), }),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUnique({ const existing = await prisma.promptVariant.findUnique({
where: { where: {
id: input.id, id: input.id,
@@ -169,6 +244,8 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`); throw new Error(`Prompt Variant with id ${input.id} does not exist`);
} }
await requireCanModifyExperiment(existing.experimentId, ctx);
const updatePromptVariantAction = prisma.promptVariant.update({ const updatePromptVariantAction = prisma.promptVariant.update({
where: { where: {
id: input.id, id: input.id,
@@ -184,13 +261,18 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant; return updatedPromptVariant;
}), }),
hide: publicProcedure hide: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const updatedPromptVariant = await prisma.promptVariant.update({ const updatedPromptVariant = await prisma.promptVariant.update({
where: { id: input.id }, where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -199,43 +281,63 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant; return updatedPromptVariant;
}), }),
replaceVariant: publicProcedure getModifiedPromptFn: protectedProcedure
.input(
z.object({
id: z.string(),
instructions: z.string().optional(),
newModel: ZodModel.optional(),
}),
)
.mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUniqueOrThrow({
where: {
id: input.id,
},
});
await requireCanModifyExperiment(existing.experimentId, ctx);
const constructedPrompt = await parseConstructFn(existing.constructFn);
if ("error" in constructedPrompt) {
return userError(constructedPrompt.error);
}
const promptConstructionFn = await deriveNewConstructFn(
existing,
input.newModel,
input.instructions,
);
// TODO: Validate promptConstructionFn
// TODO: Record in some sort of history
return promptConstructionFn;
}),
replaceVariant: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
constructFn: z.string(), constructFn: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUnique({ const existing = await prisma.promptVariant.findUniqueOrThrow({
where: { where: {
id: input.id, id: input.id,
}, },
}); });
await requireCanModifyExperiment(existing.experimentId, ctx);
if (!existing) { if (!existing) {
throw new Error(`Prompt Variant with id ${input.id} does not exist`); throw new Error(`Prompt Variant with id ${input.id} does not exist`);
} }
let model = existing.model; const parsedPrompt = await parseConstructFn(input.constructFn);
try {
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
if (!isObject(contructedPrompt)) { if ("error" in parsedPrompt) {
return userError("Prompt is not an object"); return userError(parsedPrompt.error);
}
if (!("model" in contructedPrompt)) {
return userError("Prompt does not define a model");
}
if (
typeof contructedPrompt.model !== "string" ||
!(contructedPrompt.model in OpenAIChatModel)
) {
return userError("Prompt defines an invalid model");
}
model = contructedPrompt.model;
} catch (e) {
return userError((e as Error).message);
} }
// Create a duplicate with only the config changed // Create a duplicate with only the config changed
@@ -246,7 +348,9 @@ export const promptVariantsRouter = createTRPCRouter({
sortIndex: existing.sortIndex, sortIndex: existing.sortIndex,
uiId: existing.uiId, uiId: existing.uiId,
constructFn: input.constructFn, constructFn: input.constructFn,
model, constructFnVersion: 2,
modelProvider: parsedPrompt.modelProvider,
model: parsedPrompt.model,
}, },
}); });
@@ -279,72 +383,19 @@ export const promptVariantsRouter = createTRPCRouter({
return { status: "ok" } as const; return { status: "ok" } as const;
}), }),
reorder: publicProcedure reorder: protectedProcedure
.input( .input(
z.object({ z.object({
draggedId: z.string(), draggedId: z.string(),
droppedId: z.string(), droppedId: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const dragged = await prisma.promptVariant.findUnique({ const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { where: { id: input.draggedId },
id: input.draggedId,
},
}); });
await requireCanModifyExperiment(experimentId, ctx);
const dropped = await prisma.promptVariant.findUnique({ await reorderPromptVariants(input.draggedId, input.droppedId);
where: {
id: input.droppedId,
},
});
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
throw new Error(
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: dragged.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the dragged item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
// Find the index of the dragged item and the dropped item
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
// Determine the new index for the dragged item
let newIndex;
if (dragIndex < dropIndex) {
newIndex = dropIndex + 1; // Insert after the dropped item
} else {
newIndex = dropIndex; // Insert before the dropped item
}
// Insert the dragged item at the new position
orderedItems.splice(newIndex, 0, dragged);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
}), }),
}); });

View File

@@ -1,8 +1,9 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask"; import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenarioVariantCellsRouter = createTRPCRouter({ export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure get: publicProcedure
@@ -12,7 +13,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
variantId: z.string(), variantId: z.string(),
}), }),
) )
.query(async ({ input }) => { .query(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanViewExperiment(experimentId, ctx);
return await prisma.scenarioVariantCell.findUnique({ return await prisma.scenarioVariantCell.findUnique({
where: { where: {
promptVariantId_testScenarioId: { promptVariantId_testScenarioId: {
@@ -21,18 +27,34 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
}, },
}, },
include: { include: {
modelOutput: true, modelOutput: {
include: {
outputEvaluation: {
include: {
evaluation: {
select: { label: true },
},
},
},
},
},
}, },
}); });
}), }),
forceRefetch: publicProcedure forceRefetch: protectedProcedure
.input( .input(
z.object({ z.object({
scenarioId: z.string(), scenarioId: z.string(),
variantId: z.string(), variantId: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanModifyExperiment(experimentId, ctx);
const cell = await prisma.scenarioVariantCell.findUnique({ const cell = await prisma.scenarioVariantCell.findUnique({
where: { where: {
promptVariantId_testScenarioId: { promptVariantId_testScenarioId: {

View File

@@ -1,48 +1,54 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen"; import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reevaluateAll } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenariosRouter = createTRPCRouter({ export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure
return await prisma.testScenario.findMany({ .input(z.object({ experimentId: z.string() }))
where: { .query(async ({ input, ctx }) => {
experimentId: input.experimentId, await requireCanViewExperiment(input.experimentId, ctx);
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
create: publicProcedure return await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
create: protectedProcedure
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
autogenerate: z.boolean().optional(), autogenerate: z.boolean().optional(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const maxSortIndex = await requireCanModifyExperiment(input.experimentId, ctx);
(
await prisma.testScenario.aggregate({ await prisma.testScenario.updateMany({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
}, },
_max: { data: {
sortIndex: true, sortIndex: {
}, increment: 1,
}) },
)._max.sortIndex ?? 0; },
});
const createNewScenarioAction = prisma.testScenario.create({ const createNewScenarioAction = prisma.testScenario.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
sortIndex: maxSortIndex + 1, sortIndex: 0,
variableValues: input.autogenerate variableValues: input.autogenerate
? await autogenerateScenarioValues(input.experimentId) ? await autogenerateScenarioValues(input.experimentId)
: {}, : {},
@@ -66,26 +72,33 @@ export const scenariosRouter = createTRPCRouter({
} }
}), }),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
const experimentId = (
await prisma.testScenario.findUniqueOrThrow({
where: { id: input.id },
})
).experimentId;
await requireCanModifyExperiment(experimentId, ctx);
const hiddenScenario = await prisma.testScenario.update({ const hiddenScenario = await prisma.testScenario.update({
where: { id: input.id }, where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
}); });
// Reevaluate all evaluations now that this scenario is hidden // Reevaluate all evaluations now that this scenario is hidden
await reevaluateAll(hiddenScenario.experimentId); await runAllEvals(hiddenScenario.experimentId);
return hiddenScenario; return hiddenScenario;
}), }),
reorder: publicProcedure reorder: protectedProcedure
.input( .input(
z.object({ z.object({
draggedId: z.string(), draggedId: z.string(),
droppedId: z.string(), droppedId: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const dragged = await prisma.testScenario.findUnique({ const dragged = await prisma.testScenario.findUnique({
where: { where: {
id: input.draggedId, id: input.draggedId,
@@ -104,6 +117,8 @@ export const scenariosRouter = createTRPCRouter({
); );
} }
await requireCanModifyExperiment(dragged.experimentId, ctx);
const visibleItems = await prisma.testScenario.findMany({ const visibleItems = await prisma.testScenario.findMany({
where: { where: {
experimentId: dragged.experimentId, experimentId: dragged.experimentId,
@@ -147,14 +162,14 @@ export const scenariosRouter = createTRPCRouter({
); );
}), }),
replaceWithValues: publicProcedure replaceWithValues: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
values: z.record(z.string()), values: z.record(z.string()),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const existing = await prisma.testScenario.findUnique({ const existing = await prisma.testScenario.findUnique({
where: { where: {
id: input.id, id: input.id,
@@ -165,6 +180,8 @@ export const scenariosRouter = createTRPCRouter({
throw new Error(`Scenario with id ${input.id} does not exist`); throw new Error(`Scenario with id ${input.id} does not exist`);
} }
await requireCanModifyExperiment(existing.experimentId, ctx);
const newScenario = await prisma.testScenario.create({ const newScenario = await prisma.testScenario.create({
data: { data: {
experimentId: existing.experimentId, experimentId: existing.experimentId,

View File

@@ -1,11 +1,14 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const templateVarsRouter = createTRPCRouter({ export const templateVarsRouter = createTRPCRouter({
create: publicProcedure create: protectedProcedure
.input(z.object({ experimentId: z.string(), label: z.string() })) .input(z.object({ experimentId: z.string(), label: z.string() }))
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.templateVariable.create({ await prisma.templateVariable.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
@@ -14,22 +17,33 @@ export const templateVarsRouter = createTRPCRouter({
}); });
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: protectedProcedure
await prisma.templateVariable.delete({ where: { id: input.id } }); .input(z.object({ id: z.string() }))
}), .mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
where: { id: input.id },
});
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { await requireCanModifyExperiment(experimentId, ctx);
return await prisma.templateVariable.findMany({
where: { await prisma.templateVariable.delete({ where: { id: input.id } });
experimentId: input.experimentId, }),
},
orderBy: { list: publicProcedure
createdAt: "asc", .input(z.object({ experimentId: z.string() }))
}, .query(async ({ input, ctx }) => {
select: { await requireCanViewExperiment(input.experimentId, ctx);
id: true, return await prisma.templateVariable.findMany({
label: true, where: {
}, experimentId: input.experimentId,
}); },
}), orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
}); });

View File

@@ -27,6 +27,9 @@ type CreateContextOptions = {
session: Session | null; session: Session | null;
}; };
// eslint-disable-next-line @typescript-eslint/no-empty-function
const noOp = () => {};
/** /**
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export * This helper generates the "internals" for a tRPC context. If you need to use it, you can export
* it from here. * it from here.
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
return { return {
session: opts.session, session: opts.session,
prisma, prisma,
markAccessControlRun: noOp,
}; };
}; };
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
* errors on the backend. * errors on the backend.
*/ */
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
const t = initTRPC.context<typeof createTRPCContext>().create({ const t = initTRPC.context<typeof createTRPCContext>().create({
transformer: superjson, transformer: superjson,
errorFormatter({ shape, error }) { errorFormatter({ shape, error }) {
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
export const publicProcedure = t.procedure; export const publicProcedure = t.procedure;
/** Reusable middleware that enforces users are logged in before running the procedure. */ /** Reusable middleware that enforces users are logged in before running the procedure. */
const enforceUserIsAuthed = t.middleware(({ ctx, next }) => { const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
if (!ctx.session || !ctx.session.user) { if (!ctx.session || !ctx.session.user) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
return next({
let accessControlRun = false;
const resp = await next({
ctx: { ctx: {
// infers the `session` as non-nullable // infers the `session` as non-nullable
session: { ...ctx.session, user: ctx.session.user }, session: { ...ctx.session, user: ctx.session.user },
markAccessControlRun: () => {
accessControlRun = true;
},
}, },
}); });
if (!accessControlRun)
throw new TRPCError({
code: "INTERNAL_SERVER_ERROR",
message:
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
});
return resp;
}); });
/** /**

View File

@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
import { type GetServerSidePropsContext } from "next"; import { type GetServerSidePropsContext } from "next";
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth"; import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import GitHubProvider from "next-auth/providers/github";
import { env } from "~/env.mjs";
/** /**
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session` * Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
}, },
adapter: PrismaAdapter(prisma), adapter: PrismaAdapter(prisma),
providers: [ providers: [
// DiscordProvider({ GitHubProvider({
// clientId: env.DISCORD_CLIENT_ID, clientId: env.GITHUB_CLIENT_ID,
// clientSecret: env.DISCORD_CLIENT_SECRET, clientSecret: env.GITHUB_CLIENT_SECRET,
// }), }),
/**
* ...add more providers here.
*
* Most other providers require a bit more work than the Discord provider. For example, the
* GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
* model. Refer to the NextAuth.js docs for the provider you want to use. Example:
*
* @see https://next-auth.js.org/providers/github
*/
], ],
theme: {
logo: "/logo.svg",
brandColor: "#ff5733",
},
}; };
/** /**

View File

@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
export const prisma = export const prisma =
globalForPrisma.prisma ?? globalForPrisma.prisma ??
new PrismaClient({ new PrismaClient({
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"], log:
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
? ["query", "error", "warn"]
: ["error"],
}); });
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma; if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;

View File

@@ -0,0 +1,45 @@
import "dotenv/config";
import dedent from "dedent";
import { expect, test } from "vitest";
import { migrate1to2 } from "./migrateConstructFns";
test("migrate1to2", () => {
const constructFn = dedent`
// Test comment
prompt = {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: "What is the capital of China?"
}
]
}
`;
const migrated = migrate1to2(constructFn);
expect(migrated).toBe(dedent`
// Test comment
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",
content: "What is the capital of China?"
}
]
})
`);
// console.log(
// migrateConstructFn(dedent`definePrompt(
// "openai/ChatCompletion",
// {
// model: 'gpt-3.5-turbo-0613',
// messages: []
// }
// )`),
// );
});

View File

@@ -0,0 +1,58 @@
import * as recast from "recast";
import { type ASTNode } from "ast-types";
import { prisma } from "../db";
import { fileURLToPath } from "url";
const { builders: b } = recast.types;
export const migrate1to2 = (fnBody: string): string => {
const ast: ASTNode = recast.parse(fnBody);
recast.visit(ast, {
visitAssignmentExpression(path) {
const node = path.node;
if ("name" in node.left && node.left.name === "prompt") {
const functionCall = b.callExpression(b.identifier("definePrompt"), [
b.literal("openai/ChatCompletion"),
node.right,
]);
path.replace(functionCall);
}
return false;
},
});
return recast.print(ast).code;
};
export default async function migrateConstructFns() {
const v1Prompts = await prisma.promptVariant.findMany({
where: {
constructFnVersion: 1,
},
});
console.log(`Migrating ${v1Prompts.length} prompts 1->2`);
await Promise.all(
v1Prompts.map(async (variant) => {
try {
await prisma.promptVariant.update({
where: {
id: variant.id,
},
data: {
constructFn: migrate1to2(variant.constructFn),
constructFnVersion: 2,
},
});
} catch (e) {
console.error("Error migrating constructFn for variant", variant.id, e);
}
}),
);
}
// If we're running this file directly, run the migration
if (process.argv.at(-1) === fileURLToPath(import.meta.url)) {
console.log("Running migration");
await migrateConstructFns();
console.log("Done");
}

View File

@@ -1,47 +0,0 @@
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
async function migrateScenarioVariantOutputData() {
// Get all ScenarioVariantCells
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
console.log(`Found ${cells.length} records`);
let updatedCount = 0;
// Loop through all scenarioVariants
for (const cell of cells) {
// Create a new ModelOutput for each ScenarioVariant with an existing output
if (cell.output && !cell.modelOutput) {
updatedCount++;
await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash: cell.inputHash || "",
output: cell.output as Prisma.InputJsonValue,
timeToComplete: cell.timeToComplete ?? undefined,
promptTokens: cell.promptTokens,
completionTokens: cell.completionTokens,
createdAt: cell.createdAt,
updatedAt: cell.updatedAt,
},
});
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
updatedCount++;
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: {
retrievalStatus: "ERROR",
},
});
}
}
console.log("Data migration completed");
console.log(`Updated ${updatedCount} records`);
}
// Execute the function
migrateScenarioVariantOutputData().catch((error) => {
console.error("An error occurred while migrating data: ", error);
process.exit(1);
});

View File

@@ -0,0 +1,19 @@
import "dotenv/config";
import { openai } from "../utils/openai";
const resp = await openai.chat.completions.create({
model: "gpt-3.5-turbo-0613",
stream: true,
messages: [
{
role: "user",
content: "count to 20",
},
],
});
for await (const part of resp) {
console.log("part", part);
}
console.log("final resp", resp);

View File

@@ -0,0 +1,26 @@
/* eslint-disable */
import "dotenv/config";
import Replicate from "replicate";
const replicate = new Replicate({
auth: process.env.REPLICATE_API_TOKEN || "",
});
console.log("going to run");
const prediction = await replicate.predictions.create({
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
input: {
prompt: "...",
},
});
console.log("waiting");
setInterval(() => {
replicate.predictions.get(prediction.id).then((prediction) => {
console.log(prediction);
});
}, 500);
// const output = await replicate.wait(prediction, {});
// console.log(output);

View File

@@ -1,14 +1,18 @@
import crypto from "crypto";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import defineTask from "./defineTask"; import defineTask from "./defineTask";
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep"; import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel"; import { generateChannel } from "~/utils/generateChannel";
import { reevaluateVariant } from "../utils/evaluations"; import { runEvalsForOutput } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt"; import { type Prisma } from "@prisma/client";
import { type CompletionCreateParams } from "openai/resources/chat"; import parseConstructFn from "../utils/parseConstructFn";
import hashPrompt from "../utils/hashPrompt";
import { type JsonObject } from "type-fest";
import modelProviders from "~/modelProviders/modelProviders";
import { wsConnection } from "~/utils/wsConnection";
export type queryLLMJob = {
scenarioVariantCellId: string;
};
const MAX_AUTO_RETRIES = 10; const MAX_AUTO_RETRIES = 10;
const MIN_DELAY = 500; // milliseconds const MIN_DELAY = 500; // milliseconds
@@ -20,38 +24,6 @@ function calculateDelay(numPreviousTries: number): number {
return baseDelay + jitter; return baseDelay + jitter;
} }
const getCompletionWithRetries = async (
cellId: string,
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> => {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
const modelResponse = await getCompletion(
payload as unknown as CompletionCreateParams,
channel,
);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse;
}
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
errorMessage: "Rate limit exceeded",
statusCode: 429,
retryTime: new Date(Date.now() + delay),
},
});
// TODO: Maybe requeue the job so other jobs can run in the future?
await sleep(delay);
}
throw new Error("Max retries limit reached");
};
export type queryLLMJob = {
scenarioVariantCellId: string;
};
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => { export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const { scenarioVariantCellId } = task; const { scenarioVariantCellId } = task;
const cell = await prisma.scenarioVariantCell.findUnique({ const cell = await prisma.scenarioVariantCell.findUnique({
@@ -59,6 +31,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
include: { modelOutput: true }, include: { modelOutput: true },
}); });
if (!cell) { if (!cell) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Cell not found",
retrievalStatus: "ERROR",
},
});
return; return;
} }
@@ -77,6 +57,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
where: { id: cell.promptVariantId }, where: { id: cell.promptVariantId },
}); });
if (!variant) { if (!variant) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR",
},
});
return; return;
} }
@@ -84,61 +72,94 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
where: { id: cell.testScenarioId }, where: { id: cell.testScenarioId },
}); });
if (!scenario) { if (!scenario) {
return;
}
const prompt = await constructPrompt(variant, scenario.variableValues);
const streamingEnabled = shouldStream(prompt);
let streamingChannel;
if (streamingEnabled) {
streamingChannel = generateChannel();
// Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: scenarioVariantCellId },
data: { data: {
streamingChannel, statusCode: 404,
errorMessage: "Scenario not found",
retrievalStatus: "ERROR",
}, },
}); });
return;
} }
const modelResponse = await getCompletionWithRetries( const prompt = await parseConstructFn(variant.constructFn, scenario.variableValues as JsonObject);
scenarioVariantCellId,
prompt,
streamingChannel,
);
let modelOutput = null; if ("error" in prompt) {
if (modelResponse.statusCode === 200) { await prisma.scenarioVariantCell.update({
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex"); where: { id: scenarioVariantCellId },
modelOutput = await prisma.modelOutput.create({
data: { data: {
scenarioVariantCellId, statusCode: 400,
inputHash, errorMessage: prompt.error,
output: modelResponse.output, retrievalStatus: "ERROR",
timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens,
}, },
}); });
return;
} }
await prisma.scenarioVariantCell.update({ const provider = modelProviders[prompt.modelProvider];
where: { id: scenarioVariantCellId },
data: {
statusCode: modelResponse.statusCode,
errorMessage: modelResponse.errorMessage,
streamingChannel: null,
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
modelOutput: {
connect: {
id: modelOutput?.id,
},
},
},
});
await reevaluateVariant(cell.promptVariantId); const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
if (streamingChannel) {
// Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: { streamingChannel },
});
}
const onStream = streamingChannel
? (partialOutput: (typeof provider)["_outputSchema"]) => {
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
}
: null;
for (let i = 0; true; i++) {
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const inputHash = hashPrompt(prompt);
const modelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId,
inputHash,
output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete,
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
cost: response.cost,
},
});
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: response.statusCode,
retrievalStatus: "COMPLETE",
},
});
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
break;
} else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
errorMessage: response.message,
statusCode: response.statusCode,
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: "ERROR",
},
});
if (shouldRetry) {
await sleep(delay);
} else {
break;
}
}
}
}); });

View File

@@ -1,23 +0,0 @@
export type JSONSerializable =
| string
| number
| boolean
| null
| JSONSerializable[]
| { [key: string]: JSONSerializable };
// Placeholder for now
export type OpenAIChatConfig = NonNullable<JSONSerializable>;
export enum OpenAIChatModel {
"gpt-4" = "gpt-4",
"gpt-4-0613" = "gpt-4-0613",
"gpt-4-32k" = "gpt-4-32k",
"gpt-4-32k-0613" = "gpt-4-32k-0613",
"gpt-3.5-turbo" = "gpt-3.5-turbo",
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
}
export type SupportedModel = keyof typeof OpenAIChatModel;

View File

@@ -1,15 +0,0 @@
import { test } from "vitest";
import { constructPrompt } from "./constructPrompt";
test.skip("constructPrompt", async () => {
const constructed = await constructPrompt(
{
constructFn: `prompt = { "fooz": "bar" }`,
},
{
foo: "bar",
},
);
console.log(constructed);
});

View File

@@ -1,35 +0,0 @@
import { type PromptVariant, type TestScenario } from "@prisma/client";
import ivm from "isolated-vm";
import { type JSONSerializable } from "../types";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function constructPrompt(
variant: Pick<PromptVariant, "constructFn">,
scenario: TestScenario["variableValues"],
): Promise<JSONSerializable> {
const code = `
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
let prompt
${variant.constructFn}
global.prompt = prompt;
`;
console.log("code is", code);
const context = await isolate.createContext();
const jail = context.global;
await jail.set("global", jail.derefInto());
const script = await isolate.compileScript(code);
await script.run(context);
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
const prompt = await promptReference.copy(); // Get the actual value from the isolate
return prompt as JSONSerializable;
}

View File

@@ -0,0 +1,136 @@
import { type PromptVariant } from "@prisma/client";
import ivm from "isolated-vm";
import dedent from "dedent";
import { openai } from "./openai";
import { isObject } from "lodash-es";
import { type CompletionCreateParams } from "openai/resources/chat/completions";
import formatPromptConstructor from "~/utils/formatPromptConstructor";
import { type SupportedProvider, type Model } from "~/modelProviders/types";
import modelProviders from "~/modelProviders/modelProviders";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn(
originalVariant: PromptVariant | null,
newModel?: Model,
instructions?: string,
) {
if (originalVariant && !newModel && !instructions) {
return originalVariant.constructFn;
}
if (originalVariant && (newModel || instructions)) {
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
}
return dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`;
}
const NUM_RETRIES = 5;
const requestUpdatedPromptFunction = async (
originalVariant: PromptVariant,
newModel?: Model,
instructions?: string,
) => {
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
let newContructionFn = "";
for (let i = 0; i < NUM_RETRIES; i++) {
try {
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming.Message[] = [
{
role: "system",
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
originalModelProvider.inputSchema,
null,
2,
)}\n\nDo not add any assistant messages.`,
},
{
role: "user",
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
},
];
if (newModel) {
messages.push({
role: "user",
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
});
if (newModel.provider !== originalModel.provider) {
messages.push({
role: "user",
content: `The old provider was ${originalModel.provider}. The new provider is ${
newModel.provider
}. Here is the schema for the new model:\n---\n${JSON.stringify(
modelProviders[newModel.provider].inputSchema,
null,
2,
)}`,
});
}
}
if (instructions) {
messages.push({
role: "user",
content: instructions,
});
}
const completion = await openai.chat.completions.create({
model: "gpt-4",
messages,
functions: [
{
name: "update_prompt_constructor_function",
parameters: {
type: "object",
properties: {
new_prompt_function: {
type: "string",
description: "The new prompt function, runnable in typescript",
},
},
},
},
],
function_call: {
name: "update_prompt_constructor_function",
},
});
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
const code = `
global.contructPromptFunctionArgs = ${argString};
`;
const context = await isolate.createContext();
const jail = context.global;
await jail.set("global", jail.derefInto());
const script = await isolate.compileScript(code);
await script.run(context);
const contructPromptFunctionArgs = (await context.global.get(
"contructPromptFunctionArgs",
)) as ivm.Reference;
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
if (args && isObject(args) && "new_prompt_function" in args) {
newContructionFn = await formatPromptConstructor(args.new_prompt_function as string);
break;
}
} catch (e) {
console.error(e);
}
}
return newContructionFn;
};

View File

@@ -1,31 +0,0 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate";
export const evaluateOutput = (
modelOutput: ModelOutput,
scenario: TestScenario,
evaluation: Evaluation,
): boolean => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
if (!message) return false;
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
let match;
switch (evaluation.matchType) {
case "CONTAINS":
match = stringifiedMessage.match(matchRegex) !== null;
break;
case "DOES_NOT_CONTAIN":
match = stringifiedMessage.match(matchRegex) === null;
break;
}
return match;
};

View File

@@ -1,105 +1,79 @@
import { type ModelOutput, type Evaluation } from "@prisma/client"; import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput"; import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
export const reevaluateVariant = async (variantId: string) => { const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
const variant = await prisma.promptVariant.findUnique({ const result = await runOneEval(evaluation, scenario, modelOutput);
where: { id: variantId }, return await prisma.outputEvaluation.upsert({
});
if (!variant) return;
const evaluations = await prisma.evaluation.findMany({
where: { experimentId: variant.experimentId },
});
const cells = await prisma.scenarioVariantCell.findMany({
where: { where: {
promptVariantId: variantId, modelOutputId_evaluationId: {
retrievalStatus: "COMPLETE", modelOutputId: modelOutput.id,
testScenario: { visible: true }, evaluationId: evaluation.id,
modelOutput: { isNot: null }, },
},
create: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id,
...result,
},
update: {
...result,
}, },
include: { testScenario: true, modelOutput: true },
}); });
await Promise.all(
evaluations.map(async (evaluation) => {
const passCount = cells.filter((cell) =>
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length;
const failCount = cells.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variantId,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variantId,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
}; };
export const reevaluateEvaluation = async (evaluation: Evaluation) => { export const runEvalsForOutput = async (
const variants = await prisma.promptVariant.findMany({ experimentId: string,
where: { experimentId: evaluation.experimentId, visible: true }, scenario: Scenario,
}); modelOutput: ModelOutput,
) => {
const cells = await prisma.scenarioVariantCell.findMany({
where: {
promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true },
statusCode: { notIn: [429] },
modelOutput: { isNot: null },
},
include: { testScenario: true, modelOutput: true },
});
await Promise.all(
variants.map(async (variant) => {
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
const passCount = variantCells.filter((cell) =>
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length;
const failCount = variantCells.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
};
export const reevaluateAll = async (experimentId: string) => {
const evaluations = await prisma.evaluation.findMany({ const evaluations = await prisma.evaluation.findMany({
where: { experimentId }, where: { experimentId },
}); });
await Promise.all(evaluations.map(reevaluateEvaluation)); await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
);
};
export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelOutput.findMany({
where: {
scenarioVariantCell: {
promptVariant: {
experimentId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
include: {
scenarioVariantCell: {
include: {
testScenario: true,
},
},
outputEvaluation: true,
},
});
const evals = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
outputs.map(async (output) => {
const unrunEvals = evals.filter(
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
);
await Promise.all(
unrunEvals.map(async (evaluation) => {
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
}),
);
}),
);
}; };

View File

@@ -1,28 +1,15 @@
import { type JSONSerializable } from "../types";
export type VariableMap = Record<string, string>; export type VariableMap = Record<string, string>;
// Escape quotes to match the way we encode JSON
export function escapeQuotes(str: string) {
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
}
// Escape regex special characters
export function escapeRegExp(str: string) {
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
}
export function fillTemplate(template: string, variables: VariableMap): string { export function fillTemplate(template: string, variables: VariableMap): string {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || ""); return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
} }
export function fillTemplateJson<T extends JSONSerializable>(
template: T,
variables: VariableMap,
): T {
if (typeof template === "string") {
return fillTemplate(template, variables) as T;
} else if (Array.isArray(template)) {
return template.map((item) => fillTemplateJson(item, variables)) as T;
} else if (typeof template === "object" && template !== null) {
return Object.keys(template).reduce(
(acc, key) => {
acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables);
return acc;
},
{} as { [key: string]: JSONSerializable } & T,
);
} else {
return template;
}
}

View File

@@ -1,10 +1,12 @@
import crypto from "crypto";
import { type Prisma } from "@prisma/client"; import { type Prisma } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask"; import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import { constructPrompt } from "./constructPrompt"; import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
import { omit } from "lodash-es";
export const generateNewCell = async (variantId: string, scenarioId: string) => { export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
const variant = await prisma.promptVariant.findUnique({ const variant = await prisma.promptVariant.findUnique({
where: { where: {
id: variantId, id: variantId,
@@ -17,11 +19,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
}, },
}); });
if (!variant || !scenario) return null; if (!variant || !scenario) return;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
let cell = await prisma.scenarioVariantCell.findUnique({ let cell = await prisma.scenarioVariantCell.findUnique({
where: { where: {
@@ -35,12 +33,34 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
}, },
}); });
if (cell) return cell; if (cell) return;
const parsedConstructFn = await parseConstructFn(
variant.constructFn,
scenario.variableValues as JsonObject,
);
if ("error" in parsedConstructFn) {
await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
statusCode: 400,
errorMessage: parsedConstructFn.error,
retrievalStatus: "ERROR",
},
});
return;
}
const inputHash = hashPrompt(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({ cell = await prisma.scenarioVariantCell.create({
data: { data: {
promptVariantId: variantId, promptVariantId: variantId,
testScenarioId: scenarioId, testScenarioId: scenarioId,
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
retrievalStatus: "PENDING",
}, },
include: { include: {
modelOutput: true, modelOutput: true,
@@ -48,29 +68,36 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
}); });
const matchingModelOutput = await prisma.modelOutput.findFirst({ const matchingModelOutput = await prisma.modelOutput.findFirst({
where: { where: { inputHash },
inputHash,
},
}); });
let newModelOutput;
if (matchingModelOutput) { if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({ const newModelOutput = await prisma.modelOutput.create({
data: { data: {
...omit(matchingModelOutput, ["id"]),
scenarioVariantCellId: cell.id, scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue, output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
}, },
}); });
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "COMPLETE" },
});
// Copy over all eval results as well
await Promise.all(
(
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
).map(async (evaluation) => {
await prisma.outputEvaluation.create({
data: {
...omit(evaluation, ["id"]),
modelOutputId: newModelOutput.id,
},
});
}),
);
} else { } else {
cell = await queueLLMRetrievalTask(cell.id); cell = await queueLLMRetrievalTask(cell.id);
} }
return { ...cell, modelOutput: newModelOutput };
}; };

View File

@@ -1,99 +0,0 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings";
export type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number;
errorMessage: string | null;
timeToComplete: number;
promptTokens?: number;
completionTokens?: number;
};
export async function getCompletion(
payload: CompletionCreateParams,
channel?: string,
): Promise<CompletionResponse> {
// If functions are enabled, disable streaming so that we get the full response with token counts
if (payload.functions?.length) payload.stream = false;
const start = Date.now();
const response = await fetch("https://api.openai.com/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
},
body: JSON.stringify(payload),
});
const resp: CompletionResponse = {
output: Prisma.JsonNull,
errorMessage: null,
statusCode: response.status,
timeToComplete: 0,
};
try {
if (payload.stream) {
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
let finalOutput: ChatCompletion | null = null;
await (async () => {
for await (const partialCompletion of completion) {
finalOutput = partialCompletion;
wsConnection.emit("message", { channel, payload: partialCompletion });
}
})().catch((err) => console.error(err));
if (finalOutput) {
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
resp.timeToComplete = Date.now() - start;
}
} else {
resp.timeToComplete = Date.now() - start;
resp.output = await response.json();
}
if (!response.ok) {
if (response.status === 429) {
resp.errorMessage = rateLimitErrorMessage;
} else if (
isObject(resp.output) &&
"error" in resp.output &&
isObject(resp.output.error) &&
"message" in resp.output.error
) {
// If it's an object, try to get the error message
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
}
}
if (isObject(resp.output) && "usage" in resp.output) {
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
resp.promptTokens = usage.prompt_tokens;
resp.completionTokens = usage.completion_tokens;
} else if (isObject(resp.output) && "choices" in resp.output) {
const model = payload.model as unknown as OpenAIChatModel;
resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
const message = choices[0]?.message;
if (message) {
const messages = [message];
resp.completionTokens = countOpenAIChatTokens(model, messages);
}
}
} catch (e) {
console.error(e);
if (response.ok) {
resp.errorMessage = "Failed to parse response";
}
}
return resp;
}

View File

@@ -0,0 +1,37 @@
import crypto from "crypto";
import { type JsonValue } from "type-fest";
import { type ParsedConstructFn } from "./parseConstructFn";
function sortKeys(obj: JsonValue): JsonValue {
if (typeof obj !== "object" || obj === null) {
// Not an object or array, return as is
return obj;
}
if (Array.isArray(obj)) {
return obj.map(sortKeys);
}
// Get keys and sort them
const keys = Object.keys(obj).sort();
const sortedObj = {};
for (const key of keys) {
// @ts-expect-error not worth fixing types
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
sortedObj[key] = sortKeys(obj[key]);
}
return sortedObj;
}
export default function hashPrompt(prompt: ParsedConstructFn<any>): string {
// Sort object keys recursively
const sortedObj = sortKeys(prompt as unknown as JsonValue);
// Convert to JSON and hash it
const str = JSON.stringify(sortedObj);
const hash = crypto.createHash("sha256");
hash.update(str);
return hash.digest("hex");
}

Some files were not shown because too many files have changed in this diff Show More