Compare commits
91 Commits
prompt-tem
...
paginated-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e395e4d39 | ||
|
|
4b06d05908 | ||
|
|
aabf355b81 | ||
|
|
61e5f0775d | ||
|
|
cc1d1178da | ||
|
|
7466db63df | ||
|
|
79a0b03bf8 | ||
|
|
6fb7a82d72 | ||
|
|
4ea30a3ba3 | ||
|
|
52d1d5c7ee | ||
|
|
46036a44d2 | ||
|
|
3753fe5c16 | ||
|
|
213a00a8e6 | ||
|
|
af9943eefc | ||
|
|
741128e0f4 | ||
|
|
aff14539d8 | ||
|
|
1af81a50a9 | ||
|
|
7e1fbb3767 | ||
|
|
a5d972005e | ||
|
|
a180b5bef2 | ||
|
|
55c697223e | ||
|
|
9978075867 | ||
|
|
847753c32b | ||
|
|
372c2512c9 | ||
|
|
332a2101c0 | ||
|
|
1822fe198e | ||
|
|
f06e1db3db | ||
|
|
ded6678e97 | ||
|
|
9314a86857 | ||
|
|
54dcb4a567 | ||
|
|
2c8c8d07cf | ||
|
|
e885bdd365 | ||
|
|
86dc36a656 | ||
|
|
55c077d604 | ||
|
|
e598e454d0 | ||
|
|
6e3f90cd2f | ||
|
|
eec894e101 | ||
|
|
f797fc3fa4 | ||
|
|
335dc0357f | ||
|
|
e6e2c706c2 | ||
|
|
7d2166b305 | ||
|
|
60765e51ac | ||
|
|
2c4ba6eb9b | ||
|
|
4c97b9f147 | ||
|
|
58892d8b63 | ||
|
|
4fa2dffbcb | ||
|
|
654f8c7cf2 | ||
|
|
d02482468d | ||
|
|
5c6ed22f1d | ||
|
|
2cb623f332 | ||
|
|
1c1cefe286 | ||
|
|
b4aa95edca | ||
|
|
1dcdba04a6 | ||
|
|
e0e64c4207 | ||
|
|
fa5b1ab1c5 | ||
|
|
999a4c08fa | ||
|
|
374d0237ee | ||
|
|
b1f873623d | ||
|
|
4131aa67d0 | ||
|
|
8e7a6d3ae2 | ||
|
|
7d41e94ca2 | ||
|
|
011b12abb9 | ||
|
|
1ba18015bc | ||
|
|
54369dba54 | ||
|
|
6b84a59372 | ||
|
|
8db8aeacd3 | ||
|
|
64bd71e370 | ||
|
|
ca21a7af06 | ||
|
|
3b99b7bd2b | ||
|
|
0c3bdbe4f2 | ||
|
|
74c201d3a8 | ||
|
|
ab9c721d09 | ||
|
|
0a2578a1d8 | ||
|
|
1bebaff386 | ||
|
|
3bf5eaf4a2 | ||
|
|
ded97f8bb9 | ||
|
|
26ee8698be | ||
|
|
b98eb9b729 | ||
|
|
032c07ec65 | ||
|
|
80c0d13bb9 | ||
|
|
f7c94be3f6 | ||
|
|
c3e85607e0 | ||
|
|
cd5927b8f5 | ||
|
|
731406d1f4 | ||
|
|
3c59e4b774 | ||
|
|
a20f81939d | ||
|
|
972b1f2333 | ||
|
|
7321f3deda | ||
|
|
2bd41fdfbf | ||
|
|
a5378b106b | ||
|
|
0371dacfca |
11
.env.example
11
.env.example
@@ -17,4 +17,15 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
|
|||||||
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
||||||
OPENAI_API_KEY=""
|
OPENAI_API_KEY=""
|
||||||
|
|
||||||
|
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
|
||||||
|
REPLICATE_API_TOKEN=""
|
||||||
|
|
||||||
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||||
|
|
||||||
|
# Next Auth
|
||||||
|
NEXTAUTH_SECRET="your_secret"
|
||||||
|
NEXTAUTH_URL="http://localhost:3000"
|
||||||
|
|
||||||
|
# Next Auth Github Provider
|
||||||
|
GITHUB_CLIENT_ID="your_client_id"
|
||||||
|
GITHUB_CLIENT_SECRET="your_secret"
|
||||||
|
|||||||
53
.github/workflows/ci.yaml
vendored
Normal file
53
.github/workflows/ci.yaml
vendored
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
name: CI checks
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run-checks:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v2
|
||||||
|
with:
|
||||||
|
node-version: "20"
|
||||||
|
|
||||||
|
- uses: pnpm/action-setup@v2
|
||||||
|
name: Install pnpm
|
||||||
|
id: pnpm-install
|
||||||
|
with:
|
||||||
|
version: 8.6.1
|
||||||
|
run_install: false
|
||||||
|
|
||||||
|
- name: Get pnpm store directory
|
||||||
|
id: pnpm-cache
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- uses: actions/cache@v3
|
||||||
|
name: Setup pnpm cache
|
||||||
|
with:
|
||||||
|
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||||
|
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-store-
|
||||||
|
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: pnpm install
|
||||||
|
|
||||||
|
- name: Check types
|
||||||
|
run: pnpm tsc
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: SKIP_ENV_VALIDATION=1 pnpm lint
|
||||||
|
|
||||||
|
- name: Check prettier
|
||||||
|
run: pnpm prettier . --check
|
||||||
1
.tool-versions
Normal file
1
.tool-versions
Normal file
@@ -0,0 +1 @@
|
|||||||
|
nodejs 20.2.0
|
||||||
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@@ -1,6 +1,3 @@
|
|||||||
{
|
{
|
||||||
"eslint.format.enable": true,
|
"eslint.format.enable": true
|
||||||
"editor.codeActionsOnSave": {
|
|
||||||
"source.fixAll.eslint": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
1
@types/nextjs-routes.d.ts
vendored
1
@types/nextjs-routes.d.ts
vendored
@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
|
|||||||
} from "next";
|
} from "next";
|
||||||
|
|
||||||
export type Route =
|
export type Route =
|
||||||
|
| StaticRoute<"/account/signin">
|
||||||
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
|
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
|
||||||
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
|
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
|
||||||
| DynamicRoute<"/experiments/[id]", { "id": string }>
|
| DynamicRoute<"/experiments/[id]", { "id": string }>
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ FROM base as builder
|
|||||||
|
|
||||||
# Include all NEXT_PUBLIC_* env vars here
|
# Include all NEXT_PUBLIC_* env vars here
|
||||||
ARG NEXT_PUBLIC_POSTHOG_KEY
|
ARG NEXT_PUBLIC_POSTHOG_KEY
|
||||||
ARG NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND
|
|
||||||
ARG NEXT_PUBLIC_SOCKET_URL
|
ARG NEXT_PUBLIC_SOCKET_URL
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|||||||
19
README.md
19
README.md
@@ -4,11 +4,18 @@
|
|||||||
|
|
||||||
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
|
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
|
||||||
|
|
||||||
**Live Demo:** https://openpipe.ai
|
## Sample Experiments
|
||||||
|
|
||||||
|
These are simple experiments users have created that show how OpenPipe works.
|
||||||
|
|
||||||
|
- [Country Capitals](https://app.openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
|
||||||
|
- [Reddit User Needs](https://app.openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
|
||||||
|
- [OpenAI Function Calls](https://app.openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
|
||||||
|
- [Activity Classification](https://app.openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
|
||||||
|
|
||||||
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
|
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
|
||||||
|
|
||||||
Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally).
|
You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
|
||||||
|
|
||||||
## High-Level Features
|
## High-Level Features
|
||||||
|
|
||||||
@@ -36,7 +43,8 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
|
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
|
||||||
|
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
|
||||||
|
|
||||||
## Running Locally
|
## Running Locally
|
||||||
|
|
||||||
@@ -47,5 +55,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
|
|||||||
5. Install the dependencies: `cd openpipe && pnpm install`
|
5. Install the dependencies: `cd openpipe && pnpm install`
|
||||||
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
|
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
|
||||||
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
|
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
|
||||||
8. Start the app: `pnpm dev`.
|
8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
|
||||||
9. Navigate to [http://localhost:3000](http://localhost:3000)
|
9. Start the app: `pnpm dev`.
|
||||||
|
10. Navigate to [http://localhost:3000](http://localhost:3000)
|
||||||
|
|||||||
39
package.json
39
package.json
@@ -3,22 +3,33 @@
|
|||||||
"type": "module",
|
"type": "module",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=20.0.0",
|
||||||
|
"pnpm": ">=8.6.1"
|
||||||
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"build": "next build",
|
"build": "next build",
|
||||||
"dev:next": "next dev",
|
"dev:next": "next dev",
|
||||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||||
|
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
||||||
"postinstall": "prisma generate",
|
"postinstall": "prisma generate",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"codegen": "tsx src/codegen/export-openai-types.ts"
|
"codegen": "tsx src/codegen/export-openai-types.ts",
|
||||||
|
"seed": "tsx prisma/seed.ts",
|
||||||
|
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@apidevtools/json-schema-ref-parser": "^10.1.0",
|
||||||
|
"@babel/preset-typescript": "^7.22.5",
|
||||||
|
"@babel/standalone": "^7.22.9",
|
||||||
"@chakra-ui/next-js": "^2.1.4",
|
"@chakra-ui/next-js": "^2.1.4",
|
||||||
"@chakra-ui/react": "^2.7.1",
|
"@chakra-ui/react": "^2.7.1",
|
||||||
"@emotion/react": "^11.11.1",
|
"@emotion/react": "^11.11.1",
|
||||||
"@emotion/server": "^11.11.0",
|
"@emotion/server": "^11.11.0",
|
||||||
"@emotion/styled": "^11.11.0",
|
"@emotion/styled": "^11.11.0",
|
||||||
|
"@fontsource/inconsolata": "^5.0.5",
|
||||||
"@monaco-editor/loader": "^1.3.3",
|
"@monaco-editor/loader": "^1.3.3",
|
||||||
"@next-auth/prisma-adapter": "^1.0.5",
|
"@next-auth/prisma-adapter": "^1.0.5",
|
||||||
"@prisma/client": "^4.14.0",
|
"@prisma/client": "^4.14.0",
|
||||||
@@ -29,6 +40,7 @@
|
|||||||
"@trpc/next": "^10.26.0",
|
"@trpc/next": "^10.26.0",
|
||||||
"@trpc/react-query": "^10.26.0",
|
"@trpc/react-query": "^10.26.0",
|
||||||
"@trpc/server": "^10.26.0",
|
"@trpc/server": "^10.26.0",
|
||||||
|
"ast-types": "^0.14.2",
|
||||||
"chroma-js": "^2.4.2",
|
"chroma-js": "^2.4.2",
|
||||||
"concurrently": "^8.2.0",
|
"concurrently": "^8.2.0",
|
||||||
"cors": "^2.8.5",
|
"cors": "^2.8.5",
|
||||||
@@ -38,37 +50,54 @@
|
|||||||
"express": "^4.18.2",
|
"express": "^4.18.2",
|
||||||
"framer-motion": "^10.12.17",
|
"framer-motion": "^10.12.17",
|
||||||
"gpt-tokens": "^1.0.10",
|
"gpt-tokens": "^1.0.10",
|
||||||
|
"graphile-worker": "^0.13.0",
|
||||||
"immer": "^10.0.2",
|
"immer": "^10.0.2",
|
||||||
"isolated-vm": "^4.5.0",
|
"isolated-vm": "^4.5.0",
|
||||||
|
"json-schema-to-typescript": "^13.0.2",
|
||||||
"json-stringify-pretty-compact": "^4.0.0",
|
"json-stringify-pretty-compact": "^4.0.0",
|
||||||
"lodash": "^4.17.21",
|
"jsonschema": "^1.4.1",
|
||||||
|
"lodash-es": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
|
"next-query-params": "^4.2.3",
|
||||||
"nextjs-routes": "^2.0.1",
|
"nextjs-routes": "^2.0.1",
|
||||||
"openai": "4.0.0-beta.2",
|
"openai": "4.0.0-beta.2",
|
||||||
"pluralize": "^8.0.0",
|
"pluralize": "^8.0.0",
|
||||||
"posthog-js": "^1.68.4",
|
"posthog-js": "^1.68.4",
|
||||||
|
"prettier": "^3.0.0",
|
||||||
|
"prismjs": "^1.29.0",
|
||||||
"react": "18.2.0",
|
"react": "18.2.0",
|
||||||
|
"react-diff-viewer": "^3.1.1",
|
||||||
"react-dom": "18.2.0",
|
"react-dom": "18.2.0",
|
||||||
"react-icons": "^4.10.1",
|
"react-icons": "^4.10.1",
|
||||||
|
"react-select": "^5.7.4",
|
||||||
"react-syntax-highlighter": "^15.5.0",
|
"react-syntax-highlighter": "^15.5.0",
|
||||||
"react-textarea-autosize": "^8.5.0",
|
"react-textarea-autosize": "^8.5.0",
|
||||||
|
"recast": "^0.23.3",
|
||||||
|
"replicate": "^0.12.3",
|
||||||
"socket.io": "^4.7.1",
|
"socket.io": "^4.7.1",
|
||||||
"socket.io-client": "^4.7.1",
|
"socket.io-client": "^4.7.1",
|
||||||
"superjson": "1.12.2",
|
"superjson": "1.12.2",
|
||||||
"tsx": "^3.12.7",
|
"tsx": "^3.12.7",
|
||||||
|
"type-fest": "^4.0.0",
|
||||||
|
"use-query-params": "^2.2.1",
|
||||||
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
"zod": "^3.21.4",
|
"zod": "^3.21.4",
|
||||||
"zustand": "^4.3.9"
|
"zustand": "^4.3.9"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
||||||
|
"@types/babel__core": "^7.20.1",
|
||||||
|
"@types/babel__standalone": "^7.1.4",
|
||||||
"@types/chroma-js": "^2.4.0",
|
"@types/chroma-js": "^2.4.0",
|
||||||
"@types/cors": "^2.8.13",
|
"@types/cors": "^2.8.13",
|
||||||
"@types/eslint": "^8.37.0",
|
"@types/eslint": "^8.37.0",
|
||||||
"@types/express": "^4.17.17",
|
"@types/express": "^4.17.17",
|
||||||
"@types/lodash": "^4.14.195",
|
"@types/json-schema": "^7.0.12",
|
||||||
|
"@types/lodash-es": "^4.17.8",
|
||||||
"@types/node": "^18.16.0",
|
"@types/node": "^18.16.0",
|
||||||
"@types/pluralize": "^0.0.30",
|
"@types/pluralize": "^0.0.30",
|
||||||
|
"@types/prismjs": "^1.26.0",
|
||||||
"@types/react": "^18.2.6",
|
"@types/react": "^18.2.6",
|
||||||
"@types/react-dom": "^18.2.4",
|
"@types/react-dom": "^18.2.4",
|
||||||
"@types/react-syntax-highlighter": "^15.5.7",
|
"@types/react-syntax-highlighter": "^15.5.7",
|
||||||
@@ -77,8 +106,8 @@
|
|||||||
"eslint": "^8.40.0",
|
"eslint": "^8.40.0",
|
||||||
"eslint-config-next": "^13.4.2",
|
"eslint-config-next": "^13.4.2",
|
||||||
"eslint-plugin-unused-imports": "^2.0.0",
|
"eslint-plugin-unused-imports": "^2.0.0",
|
||||||
|
"monaco-editor": "^0.40.0",
|
||||||
"openapi-typescript": "^6.3.4",
|
"openapi-typescript": "^6.3.4",
|
||||||
"prettier": "^3.0.0",
|
|
||||||
"prisma": "^4.14.0",
|
"prisma": "^4.14.0",
|
||||||
"raw-loader": "^4.0.2",
|
"raw-loader": "^4.0.2",
|
||||||
"typescript": "^5.0.4",
|
"typescript": "^5.0.4",
|
||||||
@@ -89,6 +118,6 @@
|
|||||||
"initVersion": "7.14.0"
|
"initVersion": "7.14.0"
|
||||||
},
|
},
|
||||||
"prisma": {
|
"prisma": {
|
||||||
"seed": "tsx prisma/seed.ts"
|
"seed": "pnpm seed"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
1830
pnpm-lock.yaml
generated
1830
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
|||||||
|
-- Drop the foreign key constraints on the original ModelOutput
|
||||||
|
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
|
||||||
|
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
|
||||||
|
|
||||||
|
-- Rename the old table
|
||||||
|
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
|
||||||
|
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
|
||||||
|
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
|
||||||
|
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
|
||||||
|
|
||||||
|
-- Add the new fields to the renamed table
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
|
||||||
|
ALTER COLUMN "statusCode" DROP NOT NULL,
|
||||||
|
ALTER COLUMN "timeToComplete" DROP NOT NULL;
|
||||||
|
|
||||||
|
-- Create the new table
|
||||||
|
CREATE TABLE "ModelOutput" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"inputHash" TEXT NOT NULL,
|
||||||
|
"output" JSONB NOT NULL,
|
||||||
|
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
|
||||||
|
"promptTokens" INTEGER,
|
||||||
|
"completionTokens" INTEGER,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"scenarioVariantCellId" UUID
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Move inputHash index
|
||||||
|
DROP INDEX "ScenarioVariantCell_inputHash_idx";
|
||||||
|
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
|
||||||
|
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
|
||||||
|
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
|
||||||
|
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;
|
||||||
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- RenameEnum
|
||||||
|
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
|
||||||
|
|
||||||
|
-- AlterColumnType
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
|
||||||
|
|
||||||
|
-- SetNotNullConstraint
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the `EvaluationResult` table. If the table is not empty, all the data it contains will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterEnum
|
||||||
|
ALTER TYPE "EvalType" ADD VALUE 'GPT4_EVAL';
|
||||||
|
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_evaluationId_fkey";
|
||||||
|
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "EvaluationResult" DROP CONSTRAINT "EvaluationResult_promptVariantId_fkey";
|
||||||
|
|
||||||
|
-- DropTable
|
||||||
|
DROP TABLE "EvaluationResult";
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "OutputEvaluation" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"result" DOUBLE PRECISION NOT NULL,
|
||||||
|
"details" TEXT,
|
||||||
|
"modelOutputId" UUID NOT NULL,
|
||||||
|
"evaluationId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "OutputEvaluation_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "OutputEvaluation_modelOutputId_evaluationId_key" ON "OutputEvaluation"("modelOutputId", "evaluationId");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelOutputId_fkey" FOREIGN KEY ("modelOutputId") REFERENCES "ModelOutput"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
DROP TABLE "Account";
|
||||||
|
DROP TABLE "Session";
|
||||||
|
DROP TABLE "User";
|
||||||
|
DROP TABLE "VerificationToken";
|
||||||
|
|
||||||
|
CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Organization" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"personalOrgUserId" UUID,
|
||||||
|
|
||||||
|
CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "OrganizationUser" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"role" "OrganizationUserRole" NOT NULL,
|
||||||
|
"organizationId" UUID NOT NULL,
|
||||||
|
"userId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Account" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"userId" UUID NOT NULL,
|
||||||
|
"type" TEXT NOT NULL,
|
||||||
|
"provider" TEXT NOT NULL,
|
||||||
|
"providerAccountId" TEXT NOT NULL,
|
||||||
|
"refresh_token" TEXT,
|
||||||
|
"refresh_token_expires_in" INTEGER,
|
||||||
|
"access_token" TEXT,
|
||||||
|
"expires_at" INTEGER,
|
||||||
|
"token_type" TEXT,
|
||||||
|
"scope" TEXT,
|
||||||
|
"id_token" TEXT,
|
||||||
|
"session_state" TEXT,
|
||||||
|
|
||||||
|
CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Session" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"sessionToken" TEXT NOT NULL,
|
||||||
|
"userId" UUID NOT NULL,
|
||||||
|
"expires" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "User" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"name" TEXT,
|
||||||
|
"email" TEXT,
|
||||||
|
"emailVerified" TIMESTAMP(3),
|
||||||
|
"image" TEXT,
|
||||||
|
|
||||||
|
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "VerificationToken" (
|
||||||
|
"identifier" TEXT NOT NULL,
|
||||||
|
"token" TEXT NOT NULL,
|
||||||
|
"expires" TIMESTAMP(3) NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
|
||||||
|
|
||||||
|
-- AlterTable add organizationId as NULLABLE
|
||||||
|
ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
|
||||||
|
|
||||||
|
-- Set default organization for existing experiments
|
||||||
|
UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
|
||||||
|
|
||||||
|
-- AlterTable set organizationId as NOT NULL
|
||||||
|
ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
|
||||||
|
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
|
||||||
|
|
||||||
|
ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `completionTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `inputHash` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `output` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `promptTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `timeToComplete` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "completionTokens",
|
||||||
|
DROP COLUMN "inputHash",
|
||||||
|
DROP COLUMN "output",
|
||||||
|
DROP COLUMN "promptTokens",
|
||||||
|
DROP COLUMN "timeToComplete";
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
-- Add new columns allowing NULL values
|
||||||
|
ALTER TABLE "PromptVariant"
|
||||||
|
ADD COLUMN "constructFnVersion" INTEGER,
|
||||||
|
ADD COLUMN "modelProvider" TEXT;
|
||||||
|
|
||||||
|
-- Update existing records to have the default values
|
||||||
|
UPDATE "PromptVariant"
|
||||||
|
SET "constructFnVersion" = 1,
|
||||||
|
"modelProvider" = 'openai/ChatCompletion'
|
||||||
|
WHERE "constructFnVersion" IS NULL OR "modelProvider" IS NULL;
|
||||||
|
|
||||||
|
-- Alter table to set NOT NULL constraint
|
||||||
|
ALTER TABLE "PromptVariant"
|
||||||
|
ALTER COLUMN "constructFnVersion" SET NOT NULL,
|
||||||
|
ALTER COLUMN "modelProvider" SET NOT NULL;
|
||||||
|
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "prompt" JSONB;
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
generator client {
|
generator client {
|
||||||
provider = "prisma-client-js"
|
provider = "prisma-client-js"
|
||||||
previewFeatures = ["jsonProtocol"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
datasource db {
|
datasource db {
|
||||||
@@ -17,8 +16,12 @@ model Experiment {
|
|||||||
|
|
||||||
sortIndex Int @default(0)
|
sortIndex Int @default(0)
|
||||||
|
|
||||||
|
organizationId String @db.Uuid
|
||||||
|
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
TemplateVariable TemplateVariable[]
|
TemplateVariable TemplateVariable[]
|
||||||
PromptVariant PromptVariant[]
|
PromptVariant PromptVariant[]
|
||||||
TestScenario TestScenario[]
|
TestScenario TestScenario[]
|
||||||
@@ -27,9 +30,12 @@ model Experiment {
|
|||||||
|
|
||||||
model PromptVariant {
|
model PromptVariant {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
label String
|
|
||||||
|
|
||||||
|
label String
|
||||||
constructFn String
|
constructFn String
|
||||||
|
constructFnVersion Int
|
||||||
|
model String
|
||||||
|
modelProvider String
|
||||||
|
|
||||||
uiId String @default(uuid()) @db.Uuid
|
uiId String @default(uuid()) @db.Uuid
|
||||||
visible Boolean @default(true)
|
visible Boolean @default(true)
|
||||||
@@ -40,8 +46,7 @@ model PromptVariant {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
ModelOutput ModelOutput[]
|
scenarioVariantCells ScenarioVariantCell[]
|
||||||
EvaluationResult EvaluationResult[]
|
|
||||||
|
|
||||||
@@index([uiId])
|
@@index([uiId])
|
||||||
}
|
}
|
||||||
@@ -60,7 +65,7 @@ model TestScenario {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
ModelOutput ModelOutput[]
|
scenarioVariantCells ScenarioVariantCell[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model TemplateVariable {
|
model TemplateVariable {
|
||||||
@@ -75,20 +80,27 @@ model TemplateVariable {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
model ModelOutput {
|
enum CellRetrievalStatus {
|
||||||
|
PENDING
|
||||||
|
IN_PROGRESS
|
||||||
|
COMPLETE
|
||||||
|
ERROR
|
||||||
|
}
|
||||||
|
|
||||||
|
model ScenarioVariantCell {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
inputHash String
|
statusCode Int?
|
||||||
output Json
|
|
||||||
statusCode Int
|
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
timeToComplete Int @default(0)
|
retryTime DateTime?
|
||||||
|
streamingChannel String?
|
||||||
|
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||||
|
|
||||||
promptTokens Int? // Added promptTokens field
|
modelOutput ModelOutput?
|
||||||
completionTokens Int? // Added completionTokens field
|
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||||
|
prompt Json?
|
||||||
|
|
||||||
testScenarioId String @db.Uuid
|
testScenarioId String @db.Uuid
|
||||||
testScenario TestScenario @relation(fields: [testScenarioId], references: [id], onDelete: Cascade)
|
testScenario TestScenario @relation(fields: [testScenarioId], references: [id], onDelete: Cascade)
|
||||||
@@ -97,60 +109,116 @@ model ModelOutput {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([promptVariantId, testScenarioId])
|
@@unique([promptVariantId, testScenarioId])
|
||||||
|
}
|
||||||
|
|
||||||
|
model ModelOutput {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
inputHash String
|
||||||
|
output Json
|
||||||
|
timeToComplete Int @default(0)
|
||||||
|
cost Float?
|
||||||
|
promptTokens Int?
|
||||||
|
completionTokens Int?
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
scenarioVariantCellId String @db.Uuid
|
||||||
|
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||||
|
outputEvaluation OutputEvaluation[]
|
||||||
|
|
||||||
|
@@unique([scenarioVariantCellId])
|
||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
}
|
}
|
||||||
|
|
||||||
enum EvaluationMatchType {
|
enum EvalType {
|
||||||
CONTAINS
|
CONTAINS
|
||||||
DOES_NOT_CONTAIN
|
DOES_NOT_CONTAIN
|
||||||
|
GPT4_EVAL
|
||||||
}
|
}
|
||||||
|
|
||||||
model Evaluation {
|
model Evaluation {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
name String
|
label String
|
||||||
matchString String
|
evalType EvalType
|
||||||
matchType EvaluationMatchType
|
value String
|
||||||
|
|
||||||
experimentId String @db.Uuid
|
experimentId String @db.Uuid
|
||||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
EvaluationResult EvaluationResult[]
|
OutputEvaluation OutputEvaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model EvaluationResult {
|
model OutputEvaluation {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
passCount Int
|
// Number between 0 (fail) and 1 (pass)
|
||||||
failCount Int
|
result Float
|
||||||
|
details String?
|
||||||
|
|
||||||
|
modelOutputId String @db.Uuid
|
||||||
|
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
evaluationId String @db.Uuid
|
evaluationId String @db.Uuid
|
||||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
createdAt DateTime @default(now())
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
@@unique([modelOutputId, evaluationId])
|
||||||
|
}
|
||||||
|
|
||||||
|
model Organization {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
personalOrgUserId String? @unique @db.Uuid
|
||||||
|
PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
OrganizationUser OrganizationUser[]
|
||||||
|
Experiment Experiment[]
|
||||||
|
}
|
||||||
|
|
||||||
|
enum OrganizationUserRole {
|
||||||
|
ADMIN
|
||||||
|
MEMBER
|
||||||
|
VIEWER
|
||||||
|
}
|
||||||
|
|
||||||
|
model OrganizationUser {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
role OrganizationUserRole
|
||||||
|
|
||||||
|
organizationId String @db.Uuid
|
||||||
|
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
userId String @db.Uuid
|
||||||
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([evaluationId, promptVariantId])
|
@@unique([organizationId, userId])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Necessary for Next auth
|
|
||||||
model Account {
|
model Account {
|
||||||
id String @id @default(cuid())
|
id String @id @default(uuid()) @db.Uuid
|
||||||
userId String
|
userId String @db.Uuid
|
||||||
type String
|
type String
|
||||||
provider String
|
provider String
|
||||||
providerAccountId String
|
providerAccountId String
|
||||||
refresh_token String? // @db.Text
|
refresh_token String? @db.Text
|
||||||
access_token String? // @db.Text
|
refresh_token_expires_in Int?
|
||||||
|
access_token String? @db.Text
|
||||||
expires_at Int?
|
expires_at Int?
|
||||||
token_type String?
|
token_type String?
|
||||||
scope String?
|
scope String?
|
||||||
id_token String? // @db.Text
|
id_token String? @db.Text
|
||||||
session_state String?
|
session_state String?
|
||||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
@@ -158,21 +226,23 @@ model Account {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model Session {
|
model Session {
|
||||||
id String @id @default(cuid())
|
id String @id @default(uuid()) @db.Uuid
|
||||||
sessionToken String @unique
|
sessionToken String @unique
|
||||||
userId String
|
userId String @db.Uuid
|
||||||
expires DateTime
|
expires DateTime
|
||||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
}
|
}
|
||||||
|
|
||||||
model User {
|
model User {
|
||||||
id String @id @default(cuid())
|
id String @id @default(uuid()) @db.Uuid
|
||||||
name String?
|
name String?
|
||||||
email String? @unique
|
email String? @unique
|
||||||
emailVerified DateTime?
|
emailVerified DateTime?
|
||||||
image String?
|
image String?
|
||||||
accounts Account[]
|
accounts Account[]
|
||||||
sessions Session[]
|
sessions Session[]
|
||||||
|
OrganizationUser OrganizationUser[]
|
||||||
|
Organization Organization[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model VerificationToken {
|
model VerificationToken {
|
||||||
|
|||||||
153
prisma/seed.ts
153
prisma/seed.ts
@@ -1,76 +1,101 @@
|
|||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import dedent from "dedent";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
const experimentId = "11111111-1111-1111-1111-111111111111";
|
const defaultId = "11111111-1111-1111-1111-111111111111";
|
||||||
|
|
||||||
|
await prisma.organization.deleteMany({
|
||||||
|
where: { id: defaultId },
|
||||||
|
});
|
||||||
|
|
||||||
|
// If there's an existing org, just seed into it
|
||||||
|
const org =
|
||||||
|
(await prisma.organization.findFirst({})) ??
|
||||||
|
(await prisma.organization.create({
|
||||||
|
data: { id: defaultId },
|
||||||
|
}));
|
||||||
|
|
||||||
// Delete the existing experiment
|
|
||||||
await prisma.experiment.deleteMany({
|
await prisma.experiment.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
id: experimentId,
|
id: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const experiment = await prisma.experiment.create({
|
await prisma.experiment.create({
|
||||||
data: {
|
data: {
|
||||||
id: experimentId,
|
id: defaultId,
|
||||||
label: "Country Capitals Example",
|
label: "Country Capitals Example",
|
||||||
|
organizationId: org.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.modelOutput.deleteMany({
|
await prisma.scenarioVariantCell.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
promptVariant: {
|
promptVariant: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.promptVariant.deleteMany({
|
await prisma.promptVariant.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.promptVariant.createMany({
|
await prisma.promptVariant.createMany({
|
||||||
data: [
|
data: [
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
constructFn: `prompt = {
|
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
|
modelProvider: "openai/ChatCompletion",
|
||||||
temperature: 0,
|
constructFnVersion: 1,
|
||||||
}`,
|
constructFn: dedent`
|
||||||
},
|
definePrompt("openai/ChatCompletion", {
|
||||||
{
|
|
||||||
experimentId,
|
|
||||||
label: "Prompt Variant 2",
|
|
||||||
sortIndex: 1,
|
|
||||||
constructFn: `prompt = {
|
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content:
|
content: \`What is the capital of ${"$"}{scenario.country}?\`
|
||||||
"What is the capital of {{country}}? Return just the city name and nothing else.",
|
}
|
||||||
},
|
|
||||||
],
|
],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
}`,
|
})`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
experimentId: defaultId,
|
||||||
|
label: "Prompt Variant 2",
|
||||||
|
sortIndex: 1,
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
modelProvider: "openai/ChatCompletion",
|
||||||
|
constructFnVersion: 1,
|
||||||
|
constructFn: dedent`
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
})`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.templateVariable.deleteMany({
|
await prisma.templateVariable.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.templateVariable.createMany({
|
await prisma.templateVariable.createMany({
|
||||||
data: [
|
data: [
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
label: "country",
|
label: "country",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -78,32 +103,66 @@ await prisma.templateVariable.createMany({
|
|||||||
|
|
||||||
await prisma.testScenario.deleteMany({
|
await prisma.testScenario.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId,
|
experimentId: defaultId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const countries = [
|
||||||
|
"Afghanistan",
|
||||||
|
"Albania",
|
||||||
|
"Algeria",
|
||||||
|
"Andorra",
|
||||||
|
"Angola",
|
||||||
|
"Antigua and Barbuda",
|
||||||
|
"Argentina",
|
||||||
|
"Armenia",
|
||||||
|
"Australia",
|
||||||
|
"Austria",
|
||||||
|
"Austrian Empire",
|
||||||
|
"Azerbaijan",
|
||||||
|
"Baden",
|
||||||
|
"Bahamas, The",
|
||||||
|
"Bahrain",
|
||||||
|
"Bangladesh",
|
||||||
|
"Barbados",
|
||||||
|
"Bavaria",
|
||||||
|
"Belarus",
|
||||||
|
"Belgium",
|
||||||
|
"Belize",
|
||||||
|
"Benin (Dahomey)",
|
||||||
|
"Bolivia",
|
||||||
|
"Bosnia and Herzegovina",
|
||||||
|
"Botswana",
|
||||||
|
];
|
||||||
await prisma.testScenario.createMany({
|
await prisma.testScenario.createMany({
|
||||||
data: [
|
data: countries.map((country, i) => ({
|
||||||
{
|
experimentId: defaultId,
|
||||||
experimentId,
|
sortIndex: i,
|
||||||
sortIndex: 0,
|
|
||||||
variableValues: {
|
variableValues: {
|
||||||
country: "Spain",
|
country: country,
|
||||||
},
|
},
|
||||||
},
|
})),
|
||||||
{
|
|
||||||
experimentId,
|
|
||||||
sortIndex: 1,
|
|
||||||
variableValues: {
|
|
||||||
country: "USA",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
experimentId,
|
|
||||||
sortIndex: 2,
|
|
||||||
variableValues: {
|
|
||||||
country: "Chile",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const variants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: defaultId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: defaultId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
variants
|
||||||
|
.flatMap((variant) =>
|
||||||
|
scenarios.map((scenario) => ({
|
||||||
|
promptVariantId: variant.id,
|
||||||
|
testScenarioId: scenario.id,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
||||||
|
);
|
||||||
|
|||||||
1130
prisma/seedDemo.ts
1130
prisma/seedDemo.ts
File diff suppressed because one or more lines are too long
@@ -12,7 +12,7 @@ services:
|
|||||||
dockerContext: .
|
dockerContext: .
|
||||||
plan: standard
|
plan: standard
|
||||||
domains:
|
domains:
|
||||||
- openpipe.ai
|
- app.openpipe.ai
|
||||||
envVars:
|
envVars:
|
||||||
- key: NODE_ENV
|
- key: NODE_ENV
|
||||||
value: production
|
value: production
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-var-requires */
|
|
||||||
|
|
||||||
import YAML from "yaml";
|
|
||||||
import fs from "fs";
|
|
||||||
import path from "path";
|
|
||||||
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
|
|
||||||
import assert from "assert";
|
|
||||||
import { type AcceptibleInputSchema } from "@openapi-contrib/openapi-schema-to-json-schema/dist/mjs/openapi-schema-types";
|
|
||||||
|
|
||||||
const OPENAPI_URL =
|
|
||||||
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
|
|
||||||
|
|
||||||
const convertOpenApiToJsonSchema = async (url: string) => {
|
|
||||||
// Fetch the openapi document
|
|
||||||
const response = await fetch(url);
|
|
||||||
const openApiYaml = await response.text();
|
|
||||||
|
|
||||||
// Parse the yaml document
|
|
||||||
const openApiDocument = YAML.parse(openApiYaml) as AcceptibleInputSchema;
|
|
||||||
|
|
||||||
// Convert the openapi schema to json schema
|
|
||||||
const jsonSchema = openapiSchemaToJsonSchema(openApiDocument);
|
|
||||||
|
|
||||||
const modelProperty = jsonSchema.components.schemas.CreateChatCompletionRequest.properties.model;
|
|
||||||
|
|
||||||
assert(modelProperty.oneOf.length === 2, "Expected model to have oneOf length of 2");
|
|
||||||
|
|
||||||
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
|
||||||
// the fact that the schema says `model` can be either a string or an enum,
|
|
||||||
// and displays a warning in the editor. Let's stick with just an enum for
|
|
||||||
// now and drop the string option.
|
|
||||||
modelProperty.type = "string";
|
|
||||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
|
||||||
modelProperty.oneOf = undefined;
|
|
||||||
|
|
||||||
// Get the directory of the current script
|
|
||||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
|
||||||
|
|
||||||
// Write the JSON schema to a file in the current directory
|
|
||||||
fs.writeFileSync(
|
|
||||||
path.join(currentDirectory, "openai.schema.json"),
|
|
||||||
JSON.stringify(jsonSchema, null, 2),
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
convertOpenApiToJsonSchema(OPENAPI_URL)
|
|
||||||
.then(() => console.log("JSON schema has been written successfully."))
|
|
||||||
.catch((err) => console.error(err));
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
import fs from "fs";
|
|
||||||
import path from "path";
|
|
||||||
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
|
||||||
import YAML from "yaml";
|
|
||||||
import _ from "lodash";
|
|
||||||
import assert from "assert";
|
|
||||||
|
|
||||||
const OPENAPI_URL =
|
|
||||||
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
|
|
||||||
|
|
||||||
// Generate TypeScript types from OpenAPI
|
|
||||||
|
|
||||||
const schema = await fetch(OPENAPI_URL)
|
|
||||||
.then((res) => res.text())
|
|
||||||
.then((txt) => YAML.parse(txt) as OpenAPI3);
|
|
||||||
|
|
||||||
console.log(schema.components?.schemas?.CreateChatCompletionRequest);
|
|
||||||
|
|
||||||
// @ts-expect-error just assume this works, the assert will catch it if it doesn't
|
|
||||||
const modelProperty = schema.components?.schemas?.CreateChatCompletionRequest?.properties?.model;
|
|
||||||
|
|
||||||
assert(modelProperty.oneOf.length === 2, "Expected model to have oneOf length of 2");
|
|
||||||
|
|
||||||
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
|
||||||
// the fact that the schema says `model` can be either a string or an enum,
|
|
||||||
// and displays a warning in the editor. Let's stick with just an enum for
|
|
||||||
// now and drop the string option.
|
|
||||||
modelProperty.type = "string";
|
|
||||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
|
||||||
modelProperty.oneOf = undefined;
|
|
||||||
|
|
||||||
delete schema["paths"];
|
|
||||||
assert(schema.components?.schemas);
|
|
||||||
schema.components.schemas = _.pick(schema.components?.schemas, [
|
|
||||||
"CreateChatCompletionRequest",
|
|
||||||
"ChatCompletionRequestMessage",
|
|
||||||
"ChatCompletionFunctions",
|
|
||||||
"ChatCompletionFunctionParameters",
|
|
||||||
]);
|
|
||||||
console.log(schema);
|
|
||||||
|
|
||||||
let openApiTypes = await openapiTS(schema);
|
|
||||||
|
|
||||||
// Remove the `export` from any line that starts with `export`
|
|
||||||
openApiTypes = openApiTypes.replaceAll("\nexport ", "\n");
|
|
||||||
|
|
||||||
// Get the directory of the current script
|
|
||||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
|
||||||
|
|
||||||
// Write the TypeScript types. We only want to use this in our in-app editor, so
|
|
||||||
// save as a .txt so VS Code doesn't try to auto-import definitions from it.
|
|
||||||
fs.writeFileSync(path.join(currentDirectory, "openai.types.ts.txt"), openApiTypes);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
|||||||
/**
|
|
||||||
* This file was auto-generated by openapi-typescript.
|
|
||||||
* Do not make direct changes to the file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
/** OneOf type helpers */
|
|
||||||
type Without<T, U> = { [P in Exclude<keyof T, keyof U>]?: never };
|
|
||||||
type XOR<T, U> = (T | U) extends object ? (Without<T, U> & U) | (Without<U, T> & T) : T | U;
|
|
||||||
type OneOf<T extends any[]> = T extends [infer Only] ? Only : T extends [infer A, infer B, ...infer Rest] ? OneOf<[XOR<A, B>, ...Rest]> : never;
|
|
||||||
|
|
||||||
type paths = Record<string, never>;
|
|
||||||
|
|
||||||
type webhooks = Record<string, never>;
|
|
||||||
|
|
||||||
interface components {
|
|
||||||
schemas: {
|
|
||||||
CreateChatCompletionRequest: {
|
|
||||||
/**
|
|
||||||
* @description ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
|
|
||||||
* @example gpt-3.5-turbo
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
model: "gpt-4" | "gpt-4-0613" | "gpt-4-32k" | "gpt-4-32k-0613" | "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-3.5-turbo-0613" | "gpt-3.5-turbo-16k-0613";
|
|
||||||
/** @description A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb). */
|
|
||||||
messages: (components["schemas"]["ChatCompletionRequestMessage"])[];
|
|
||||||
/** @description A list of functions the model may generate JSON inputs for. */
|
|
||||||
functions?: (components["schemas"]["ChatCompletionFunctions"])[];
|
|
||||||
/** @description Controls how the model responds to function calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a function. Specifying a particular function via `{"name":\ "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present. */
|
|
||||||
function_call?: OneOf<["none" | "auto", {
|
|
||||||
/** @description The name of the function to call. */
|
|
||||||
name: string;
|
|
||||||
}]>;
|
|
||||||
/**
|
|
||||||
* @description What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
|
|
||||||
*
|
|
||||||
* We generally recommend altering this or `top_p` but not both.
|
|
||||||
*
|
|
||||||
* @default 1
|
|
||||||
* @example 1
|
|
||||||
*/
|
|
||||||
temperature?: number | null;
|
|
||||||
/**
|
|
||||||
* @description An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
|
||||||
*
|
|
||||||
* We generally recommend altering this or `temperature` but not both.
|
|
||||||
*
|
|
||||||
* @default 1
|
|
||||||
* @example 1
|
|
||||||
*/
|
|
||||||
top_p?: number | null;
|
|
||||||
/**
|
|
||||||
* @description How many chat completion choices to generate for each input message.
|
|
||||||
* @default 1
|
|
||||||
* @example 1
|
|
||||||
*/
|
|
||||||
n?: number | null;
|
|
||||||
/**
|
|
||||||
* @description If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).
|
|
||||||
*
|
|
||||||
* @default false
|
|
||||||
*/
|
|
||||||
stream?: boolean | null;
|
|
||||||
/**
|
|
||||||
* @description Up to 4 sequences where the API will stop generating further tokens.
|
|
||||||
*
|
|
||||||
* @default null
|
|
||||||
*/
|
|
||||||
stop?: (string | null) | (string)[];
|
|
||||||
/**
|
|
||||||
* @description The maximum number of [tokens](/tokenizer) to generate in the chat completion.
|
|
||||||
*
|
|
||||||
* The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.
|
|
||||||
*
|
|
||||||
* @default inf
|
|
||||||
*/
|
|
||||||
max_tokens?: number;
|
|
||||||
/**
|
|
||||||
* @description Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
|
||||||
*
|
|
||||||
* [See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)
|
|
||||||
*
|
|
||||||
* @default 0
|
|
||||||
*/
|
|
||||||
presence_penalty?: number | null;
|
|
||||||
/**
|
|
||||||
* @description Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
|
||||||
*
|
|
||||||
* [See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)
|
|
||||||
*
|
|
||||||
* @default 0
|
|
||||||
*/
|
|
||||||
frequency_penalty?: number | null;
|
|
||||||
/**
|
|
||||||
* @description Modify the likelihood of specified tokens appearing in the completion.
|
|
||||||
*
|
|
||||||
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
|
|
||||||
*
|
|
||||||
* @default null
|
|
||||||
*/
|
|
||||||
logit_bias?: Record<string, unknown> | null;
|
|
||||||
/**
|
|
||||||
* @description A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
|
|
||||||
*
|
|
||||||
* @example user-1234
|
|
||||||
*/
|
|
||||||
user?: string;
|
|
||||||
};
|
|
||||||
ChatCompletionRequestMessage: {
|
|
||||||
/**
|
|
||||||
* @description The role of the messages author. One of `system`, `user`, `assistant`, or `function`.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
role: "system" | "user" | "assistant" | "function";
|
|
||||||
/** @description The contents of the message. `content` is required for all messages except assistant messages with function calls. */
|
|
||||||
content?: string;
|
|
||||||
/** @description The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. */
|
|
||||||
name?: string;
|
|
||||||
/** @description The name and arguments of a function that should be called, as generated by the model. */
|
|
||||||
function_call?: {
|
|
||||||
/** @description The name of the function to call. */
|
|
||||||
name?: string;
|
|
||||||
/** @description The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. */
|
|
||||||
arguments?: string;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
ChatCompletionFunctions: {
|
|
||||||
/** @description The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. */
|
|
||||||
name: string;
|
|
||||||
/** @description The description of what the function does. */
|
|
||||||
description?: string;
|
|
||||||
parameters?: components["schemas"]["ChatCompletionFunctionParameters"];
|
|
||||||
};
|
|
||||||
/** @description The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. */
|
|
||||||
ChatCompletionFunctionParameters: {
|
|
||||||
[key: string]: unknown;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
responses: never;
|
|
||||||
parameters: never;
|
|
||||||
requestBodies: never;
|
|
||||||
headers: never;
|
|
||||||
pathItems: never;
|
|
||||||
}
|
|
||||||
|
|
||||||
type external = Record<string, never>;
|
|
||||||
|
|
||||||
type operations = Record<string, never>;
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
{
|
|
||||||
"compilerOptions": {
|
|
||||||
"target": "esnext",
|
|
||||||
"moduleResolution": "nodenext"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +1,22 @@
|
|||||||
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
||||||
import ResizeTextarea from "react-textarea-autosize";
|
import ResizeTextarea from "react-textarea-autosize";
|
||||||
import React from "react";
|
import React, { useLayoutEffect, useState } from "react";
|
||||||
|
|
||||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||||
HTMLTextAreaElement,
|
HTMLTextAreaElement,
|
||||||
TextareaProps
|
TextareaProps & { minRows?: number }
|
||||||
> = (props, ref) => {
|
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
|
||||||
|
const [isRerendered, setIsRerendered] = useState(false);
|
||||||
|
useLayoutEffect(() => setIsRerendered(true), []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Textarea
|
<Textarea
|
||||||
minH="unset"
|
minH="unset"
|
||||||
overflow="hidden"
|
minRows={minRows}
|
||||||
|
overflowY={isRerendered ? overflowY : "hidden"}
|
||||||
w="100%"
|
w="100%"
|
||||||
resize="none"
|
resize="none"
|
||||||
ref={ref}
|
ref={ref}
|
||||||
minRows={1}
|
|
||||||
transition="height none"
|
transition="height none"
|
||||||
as={ResizeTextarea}
|
as={ResizeTextarea}
|
||||||
{...props}
|
{...props}
|
||||||
|
|||||||
135
src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
135
src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Spinner,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { ModelStatsCard } from "./ModelStatsCard";
|
||||||
|
import { ModelSearch } from "./ModelSearch";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { isObject, isString } from "lodash-es";
|
||||||
|
import { type Model, type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { keyForModel } from "~/utils/utils";
|
||||||
|
|
||||||
|
export const ChangeModelModal = ({
|
||||||
|
variant,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
onClose: () => void;
|
||||||
|
}) => {
|
||||||
|
const originalModelProviderName = variant.modelProvider as SupportedProvider;
|
||||||
|
const originalModelProvider = frontendModelProviders[originalModelProviderName];
|
||||||
|
const originalModel = originalModelProvider.models[variant.model] as Model;
|
||||||
|
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
|
||||||
|
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const experiment = useExperiment();
|
||||||
|
|
||||||
|
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
|
||||||
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
|
|
||||||
|
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment) return;
|
||||||
|
|
||||||
|
await getModifiedPromptMutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
newModel: selectedModel,
|
||||||
|
});
|
||||||
|
setConvertedModel(selectedModel);
|
||||||
|
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
|
||||||
|
|
||||||
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
|
|
||||||
|
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (
|
||||||
|
!variant.experimentId ||
|
||||||
|
!modifiedPromptFn ||
|
||||||
|
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
|
||||||
|
)
|
||||||
|
return;
|
||||||
|
await replaceVariantMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
constructFn: modifiedPromptFn,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
onClose();
|
||||||
|
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||||
|
|
||||||
|
const originalModelLabel = keyForModel(originalModel);
|
||||||
|
const selectedModelLabel = keyForModel(selectedModel);
|
||||||
|
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
isOpen
|
||||||
|
onClose={onClose}
|
||||||
|
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||||
|
>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w={1200}>
|
||||||
|
<ModalHeader>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={RiExchangeFundsFill} />
|
||||||
|
<Text>Change Model</Text>
|
||||||
|
</HStack>
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody maxW="unset">
|
||||||
|
<VStack spacing={8}>
|
||||||
|
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||||
|
{originalModelLabel !== selectedModelLabel && (
|
||||||
|
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||||
|
)}
|
||||||
|
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
|
{isString(modifiedPromptFn) && (
|
||||||
|
<CompareFunctions
|
||||||
|
originalFunction={variant.constructFn}
|
||||||
|
newFunction={modifiedPromptFn}
|
||||||
|
leftTitle={originalModelLabel}
|
||||||
|
rightTitle={convertedModelLabel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</ModalBody>
|
||||||
|
|
||||||
|
<ModalFooter>
|
||||||
|
<HStack>
|
||||||
|
<Button
|
||||||
|
colorScheme="gray"
|
||||||
|
onClick={getModifiedPromptFn}
|
||||||
|
minW={24}
|
||||||
|
isDisabled={originalModel === selectedModel || modificationInProgress}
|
||||||
|
>
|
||||||
|
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
colorScheme="blue"
|
||||||
|
onClick={replaceVariant}
|
||||||
|
minW={24}
|
||||||
|
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
|
||||||
|
>
|
||||||
|
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
50
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
50
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import { VStack, Text } from "@chakra-ui/react";
|
||||||
|
import { type LegacyRef, useCallback } from "react";
|
||||||
|
import Select, { type SingleValue } from "react-select";
|
||||||
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { type Model } from "~/modelProviders/types";
|
||||||
|
import { keyForModel } from "~/utils/utils";
|
||||||
|
|
||||||
|
const modelOptions: { label: string; value: Model }[] = [];
|
||||||
|
|
||||||
|
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
|
||||||
|
for (const [_, modelValue] of Object.entries(providerValue.models)) {
|
||||||
|
modelOptions.push({
|
||||||
|
label: keyForModel(modelValue),
|
||||||
|
value: modelValue,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ModelSearch = ({
|
||||||
|
selectedModel,
|
||||||
|
setSelectedModel,
|
||||||
|
}: {
|
||||||
|
selectedModel: Model;
|
||||||
|
setSelectedModel: (model: Model) => void;
|
||||||
|
}) => {
|
||||||
|
const handleSelection = useCallback(
|
||||||
|
(option: SingleValue<{ label: string; value: Model }>) => {
|
||||||
|
if (!option) return;
|
||||||
|
setSelectedModel(option.value);
|
||||||
|
},
|
||||||
|
[setSelectedModel],
|
||||||
|
);
|
||||||
|
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
|
||||||
|
|
||||||
|
const [containerRef, containerDimensions] = useElementDimensions();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||||
|
<Text>Browse Models</Text>
|
||||||
|
<Select
|
||||||
|
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||||
|
value={selectedOption}
|
||||||
|
options={modelOptions}
|
||||||
|
onChange={handleSelection}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
102
src/components/ChangeModelModal/ModelStatsCard.tsx
Normal file
102
src/components/ChangeModelModal/ModelStatsCard.tsx
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import {
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
HStack,
|
||||||
|
type StackProps,
|
||||||
|
GridItem,
|
||||||
|
SimpleGrid,
|
||||||
|
Link,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { type Model } from "~/modelProviders/types";
|
||||||
|
|
||||||
|
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
|
||||||
|
return (
|
||||||
|
<VStack w="full" align="start">
|
||||||
|
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
|
||||||
|
<HStack w="full" align="flex-start">
|
||||||
|
<Text flex={1} fontSize="lg">
|
||||||
|
<Text as="span" color="gray.600">
|
||||||
|
{model.provider} /{" "}
|
||||||
|
</Text>
|
||||||
|
<Text as="span" fontWeight="bold" color="gray.900">
|
||||||
|
{model.name}
|
||||||
|
</Text>
|
||||||
|
</Text>
|
||||||
|
<Link
|
||||||
|
href={model.learnMoreUrl}
|
||||||
|
isExternal
|
||||||
|
color="blue.500"
|
||||||
|
fontWeight="bold"
|
||||||
|
fontSize="sm"
|
||||||
|
ml={2}
|
||||||
|
>
|
||||||
|
Learn More
|
||||||
|
</Link>
|
||||||
|
</HStack>
|
||||||
|
<SimpleGrid
|
||||||
|
w="full"
|
||||||
|
justifyContent="space-between"
|
||||||
|
alignItems="flex-start"
|
||||||
|
fontSize="sm"
|
||||||
|
columns={{ base: 2, md: 4 }}
|
||||||
|
>
|
||||||
|
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
|
||||||
|
{model.promptTokenPrice && (
|
||||||
|
<SelectedModelLabeledInfo
|
||||||
|
label="Input"
|
||||||
|
info={
|
||||||
|
<Text>
|
||||||
|
${(model.promptTokenPrice * 1000).toFixed(3)}
|
||||||
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{model.completionTokenPrice && (
|
||||||
|
<SelectedModelLabeledInfo
|
||||||
|
label="Output"
|
||||||
|
info={
|
||||||
|
<Text>
|
||||||
|
${(model.completionTokenPrice * 1000).toFixed(3)}
|
||||||
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{model.pricePerSecond && (
|
||||||
|
<SelectedModelLabeledInfo
|
||||||
|
label="Price"
|
||||||
|
info={
|
||||||
|
<Text>
|
||||||
|
${model.pricePerSecond.toFixed(3)}
|
||||||
|
<Text color="gray.500"> / second</Text>
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
|
||||||
|
</SimpleGrid>
|
||||||
|
</VStack>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const SelectedModelLabeledInfo = ({
|
||||||
|
label,
|
||||||
|
info,
|
||||||
|
...props
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
info: string | number | React.ReactElement;
|
||||||
|
} & StackProps) => (
|
||||||
|
<GridItem>
|
||||||
|
<VStack alignItems="flex-start" {...props}>
|
||||||
|
<Text fontWeight="bold">{label}</Text>
|
||||||
|
<Text>{info}</Text>
|
||||||
|
</VStack>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
48
src/components/OutputsTable/AddVariantButton.tsx
Normal file
48
src/components/OutputsTable/AddVariantButton.tsx
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
||||||
|
import { BsPlus } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { cellPadding } from "../constants";
|
||||||
|
import { ActionButton } from "./ScenariosHeader";
|
||||||
|
|
||||||
|
export default function AddVariantButton() {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const mutation = api.promptVariants.create.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data) return;
|
||||||
|
await mutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [mutation]);
|
||||||
|
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
if (!canModify) return <Box w={cellPadding.x} />;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="100%" justifyContent="flex-end">
|
||||||
|
<ActionButton
|
||||||
|
onClick={onClick}
|
||||||
|
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
||||||
|
>
|
||||||
|
Add Variant
|
||||||
|
</ActionButton>
|
||||||
|
{/* <Button
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
fontWeight="normal"
|
||||||
|
bgColor="transparent"
|
||||||
|
_hover={{ bgColor: "gray.100" }}
|
||||||
|
px={cellPadding.x}
|
||||||
|
onClick={onClick}
|
||||||
|
height="unset"
|
||||||
|
minH={headerMinHeight}
|
||||||
|
>
|
||||||
|
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||||
|
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||||
|
</Button> */}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -11,14 +11,16 @@ import {
|
|||||||
FormLabel,
|
FormLabel,
|
||||||
Select,
|
Select,
|
||||||
FormHelperText,
|
FormHelperText,
|
||||||
|
Code,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
|
import { type Evaluation, EvalType } from "@prisma/client";
|
||||||
import { useCallback, useState } from "react";
|
import { useCallback, useState } from "react";
|
||||||
import { BsPencil, BsX } from "react-icons/bs";
|
import { BsPencil, BsX } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
|
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
|
||||||
|
|
||||||
export function EvaluationEditor(props: {
|
export function EvaluationEditor(props: {
|
||||||
evaluation: Evaluation | null;
|
evaluation: Evaluation | null;
|
||||||
@@ -27,35 +29,35 @@ export function EvaluationEditor(props: {
|
|||||||
onCancel: () => void;
|
onCancel: () => void;
|
||||||
}) {
|
}) {
|
||||||
const [values, setValues] = useState<EvalValues>({
|
const [values, setValues] = useState<EvalValues>({
|
||||||
name: props.evaluation?.name ?? props.defaultName ?? "",
|
label: props.evaluation?.label ?? props.defaultName ?? "",
|
||||||
matchString: props.evaluation?.matchString ?? "",
|
value: props.evaluation?.value ?? "",
|
||||||
matchType: props.evaluation?.matchType ?? "CONTAINS",
|
evalType: props.evaluation?.evalType ?? "CONTAINS",
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
||||||
<HStack w="100%">
|
<HStack w="100%">
|
||||||
<FormControl flex={1}>
|
<FormControl flex={1}>
|
||||||
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
|
<FormLabel fontSize="sm">Eval Name</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.name}
|
value={values.label}
|
||||||
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl flex={1}>
|
<FormControl flex={1}>
|
||||||
<FormLabel fontSize="sm">Match Type</FormLabel>
|
<FormLabel fontSize="sm">Eval Type</FormLabel>
|
||||||
<Select
|
<Select
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.matchType}
|
value={values.evalType}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
setValues((values) => ({
|
setValues((values) => ({
|
||||||
...values,
|
...values,
|
||||||
matchType: e.target.value as EvaluationMatchType,
|
evalType: e.target.value as EvalType,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
{Object.values(EvaluationMatchType).map((type) => (
|
{Object.values(EvalType).map((type) => (
|
||||||
<option key={type} value={type}>
|
<option key={type} value={type}>
|
||||||
{type}
|
{type}
|
||||||
</option>
|
</option>
|
||||||
@@ -63,17 +65,37 @@ export function EvaluationEditor(props: {
|
|||||||
</Select>
|
</Select>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</HStack>
|
</HStack>
|
||||||
|
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<FormLabel fontSize="sm">Match String</FormLabel>
|
<FormLabel fontSize="sm">Match String</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.matchString}
|
value={values.value}
|
||||||
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
|
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||||
/>
|
/>
|
||||||
<FormHelperText>
|
<FormHelperText>
|
||||||
This string will be interpreted as a regex and checked against each model output.
|
This string will be interpreted as a regex and checked against each model output. You
|
||||||
|
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
|
||||||
</FormHelperText>
|
</FormHelperText>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
)}
|
||||||
|
{values.evalType === "GPT4_EVAL" && (
|
||||||
|
<FormControl pt={2}>
|
||||||
|
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
size="sm"
|
||||||
|
value={values.value}
|
||||||
|
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||||
|
minRows={3}
|
||||||
|
/>
|
||||||
|
<FormHelperText>
|
||||||
|
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
|
||||||
|
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
|
||||||
|
access to the specific prompt variant, so be sure to be clear about the task you want it
|
||||||
|
to perform.
|
||||||
|
</FormHelperText>
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
<HStack alignSelf="flex-end">
|
<HStack alignSelf="flex-end">
|
||||||
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
||||||
Cancel
|
Cancel
|
||||||
@@ -125,6 +147,7 @@ export default function EditEvaluations() {
|
|||||||
}
|
}
|
||||||
await utils.evaluations.list.invalidate();
|
await utils.evaluations.list.invalidate();
|
||||||
await utils.promptVariants.stats.invalidate();
|
await utils.promptVariants.stats.invalidate();
|
||||||
|
await utils.scenarioVariantCells.get.invalidate();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onCancel = useCallback(() => {
|
const onCancel = useCallback(() => {
|
||||||
@@ -156,9 +179,9 @@ export default function EditEvaluations() {
|
|||||||
align="center"
|
align="center"
|
||||||
key={evaluation.id}
|
key={evaluation.id}
|
||||||
>
|
>
|
||||||
<Text fontWeight="bold">{evaluation.name}</Text>
|
<Text fontWeight="bold">{evaluation.label}</Text>
|
||||||
<Text flex={1}>
|
<Text flex={1}>
|
||||||
{evaluation.matchType}: "{evaluation.matchString}"
|
{evaluation.evalType}: "{evaluation.value}"
|
||||||
</Text>
|
</Text>
|
||||||
<Button
|
<Button
|
||||||
variant="unstyled"
|
variant="unstyled"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { Text, Button, HStack, Heading, Icon, Input, Stack, Code } from "@chakra-ui/react";
|
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { BsCheck, BsX } from "react-icons/bs";
|
import { BsCheck, BsX } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
@@ -36,8 +36,7 @@ export default function EditScenarioVars() {
|
|||||||
<Heading size="sm">Scenario Variables</Heading>
|
<Heading size="sm">Scenario Variables</Heading>
|
||||||
<Stack spacing={2}>
|
<Stack spacing={2}>
|
||||||
<Text fontSize="sm">
|
<Text fontSize="sm">
|
||||||
Scenario variables can be used in your prompt variants as well as evaluations. Reference
|
Scenario variables can be used in your prompt variants as well as evaluations.
|
||||||
them using <Code>{"{{curly_braces}}"}</Code>.
|
|
||||||
</Text>
|
</Text>
|
||||||
<HStack spacing={0}>
|
<HStack spacing={0}>
|
||||||
<Input
|
<Input
|
||||||
|
|||||||
47
src/components/OutputsTable/FloatingLabelInput.tsx
Normal file
47
src/components/OutputsTable/FloatingLabelInput.tsx
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import { FormLabel, FormControl, type TextareaProps } from "@chakra-ui/react";
|
||||||
|
import { useState } from "react";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
|
export const FloatingLabelInput = ({
|
||||||
|
label,
|
||||||
|
value,
|
||||||
|
...props
|
||||||
|
}: { label: string; value: string } & TextareaProps) => {
|
||||||
|
const [isFocused, setIsFocused] = useState(false);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl position="relative">
|
||||||
|
<FormLabel
|
||||||
|
position="absolute"
|
||||||
|
left="10px"
|
||||||
|
top={isFocused || !!value ? 0 : 3}
|
||||||
|
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
||||||
|
fontSize={isFocused || !!value ? "12px" : "16px"}
|
||||||
|
transition="all 0.15s"
|
||||||
|
zIndex="5"
|
||||||
|
bg="white"
|
||||||
|
px={1}
|
||||||
|
lineHeight="1"
|
||||||
|
pointerEvents="none"
|
||||||
|
color={isFocused ? "blue.500" : "gray.500"}
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</FormLabel>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
px={3}
|
||||||
|
pt={3}
|
||||||
|
pb={2}
|
||||||
|
onFocus={() => setIsFocused(true)}
|
||||||
|
onBlur={() => setIsFocused(false)}
|
||||||
|
borderRadius="md"
|
||||||
|
borderColor={isFocused ? "blue.500" : "gray.400"}
|
||||||
|
autoComplete="off"
|
||||||
|
value={value}
|
||||||
|
maxHeight={32}
|
||||||
|
overflowY="auto"
|
||||||
|
overflowX="hidden"
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
|
||||||
import { BsPlus } from "react-icons/bs";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
|
|
||||||
// Extracted Button styling into reusable component
|
|
||||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
|
||||||
<Button
|
|
||||||
fontWeight="normal"
|
|
||||||
bgColor="transparent"
|
|
||||||
_hover={{ bgColor: "gray.100" }}
|
|
||||||
px={2}
|
|
||||||
onClick={onClick}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
|
|
||||||
export default function NewScenarioButton() {
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.scenarios.create.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
|
|
||||||
const [onClick] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
});
|
|
||||||
await utils.scenarios.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
const [onAutogenerate, autogenerating] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
autogenerate: true,
|
|
||||||
});
|
|
||||||
await utils.scenarios.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<HStack spacing={2}>
|
|
||||||
<StyledButton onClick={onClick}>
|
|
||||||
<Icon as={BsPlus} boxSize={6} />
|
|
||||||
Add Scenario
|
|
||||||
</StyledButton>
|
|
||||||
<StyledButton onClick={onAutogenerate}>
|
|
||||||
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
|
|
||||||
Autogenerate Scenario
|
|
||||||
</StyledButton>
|
|
||||||
</HStack>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
import { Button, Icon, Spinner } from "@chakra-ui/react";
|
|
||||||
import { BsPlus } from "react-icons/bs";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
|
||||||
|
|
||||||
export default function NewVariantButton() {
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.promptVariants.create.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
|
|
||||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
w="100%"
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
fontWeight="normal"
|
|
||||||
bgColor="transparent"
|
|
||||||
_hover={{ bgColor: "gray.100" }}
|
|
||||||
px={cellPadding.x}
|
|
||||||
onClick={onClick}
|
|
||||||
height="unset"
|
|
||||||
minH={headerMinHeight}
|
|
||||||
>
|
|
||||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
|
||||||
Add Variant
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
37
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
37
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
|
||||||
|
import { BsArrowClockwise } from "react-icons/bs";
|
||||||
|
import { useExperimentAccess } from "~/utils/hooks";
|
||||||
|
|
||||||
|
export const CellOptions = ({
|
||||||
|
refetchingOutput,
|
||||||
|
refetchOutput,
|
||||||
|
}: {
|
||||||
|
refetchingOutput: boolean;
|
||||||
|
refetchOutput: () => void;
|
||||||
|
}) => {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
return (
|
||||||
|
<HStack justifyContent="flex-end" w="full">
|
||||||
|
{!refetchingOutput && canModify && (
|
||||||
|
<Tooltip label="Refetch output" aria-label="refetch output">
|
||||||
|
<Button
|
||||||
|
size="xs"
|
||||||
|
w={4}
|
||||||
|
h={4}
|
||||||
|
py={4}
|
||||||
|
px={4}
|
||||||
|
minW={0}
|
||||||
|
borderRadius={8}
|
||||||
|
color="gray.500"
|
||||||
|
variant="ghost"
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={refetchOutput}
|
||||||
|
aria-label="refetch output"
|
||||||
|
>
|
||||||
|
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||||
|
</Button>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,29 +1,21 @@
|
|||||||
import { type ModelOutput } from "@prisma/client";
|
import { type ScenarioVariantCell } from "@prisma/client";
|
||||||
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
import { VStack, Text } from "@chakra-ui/react";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { BsArrowClockwise } from "react-icons/bs";
|
|
||||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
|
||||||
import pluralize from "pluralize";
|
import pluralize from "pluralize";
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 3;
|
|
||||||
|
|
||||||
export const ErrorHandler = ({
|
export const ErrorHandler = ({
|
||||||
output,
|
cell,
|
||||||
refetchOutput,
|
refetchOutput,
|
||||||
numPreviousTries,
|
|
||||||
}: {
|
}: {
|
||||||
output: ModelOutput;
|
cell: ScenarioVariantCell;
|
||||||
refetchOutput: () => void;
|
refetchOutput: () => void;
|
||||||
numPreviousTries: number;
|
|
||||||
}) => {
|
}) => {
|
||||||
const [msToWait, setMsToWait] = useState(0);
|
const [msToWait, setMsToWait] = useState(0);
|
||||||
const shouldAutoRetry =
|
|
||||||
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!shouldAutoRetry) return;
|
if (!cell.retryTime) return;
|
||||||
|
|
||||||
const initialWaitTime = calculateDelay(numPreviousTries);
|
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
||||||
const msModuloOneSecond = initialWaitTime % 1000;
|
const msModuloOneSecond = initialWaitTime % 1000;
|
||||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||||
setMsToWait(remainingTime);
|
setMsToWait(remainingTime);
|
||||||
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
|
|||||||
setMsToWait(remainingTime);
|
setMsToWait(remainingTime);
|
||||||
|
|
||||||
if (remainingTime <= 0) {
|
if (remainingTime <= 0) {
|
||||||
refetchOutput();
|
|
||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
}
|
}
|
||||||
}, 1000);
|
}, 1000);
|
||||||
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
|
|||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
clearTimeout(timeout);
|
clearTimeout(timeout);
|
||||||
};
|
};
|
||||||
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack w="full">
|
<VStack w="full">
|
||||||
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
|
||||||
<Text color="red.600" fontWeight="bold">
|
|
||||||
Error
|
|
||||||
</Text>
|
|
||||||
<Button
|
|
||||||
size="xs"
|
|
||||||
w={4}
|
|
||||||
h={4}
|
|
||||||
px={4}
|
|
||||||
py={4}
|
|
||||||
minW={0}
|
|
||||||
borderRadius={8}
|
|
||||||
variant="ghost"
|
|
||||||
cursor="pointer"
|
|
||||||
onClick={refetchOutput}
|
|
||||||
aria-label="refetch output"
|
|
||||||
>
|
|
||||||
<Icon as={BsArrowClockwise} boxSize={6} />
|
|
||||||
</Button>
|
|
||||||
</HStack>
|
|
||||||
<Text color="red.600" wordBreak="break-word">
|
<Text color="red.600" wordBreak="break-word">
|
||||||
{output.errorMessage}
|
{cell.errorMessage}
|
||||||
</Text>
|
</Text>
|
||||||
{msToWait > 0 && (
|
{msToWait > 0 && (
|
||||||
<Text color="red.600" fontSize="sm">
|
<Text color="red.600" fontSize="sm">
|
||||||
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
|
|||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const MIN_DELAY = 500; // milliseconds
|
|
||||||
const MAX_DELAY = 5000; // milliseconds
|
|
||||||
|
|
||||||
function calculateDelay(numPreviousTries: number): number {
|
|
||||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
|
||||||
const jitter = Math.random() * baseDelay;
|
|
||||||
return baseDelay + jitter;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
import { type RouterOutputs, api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { type PromptVariant, type Scenario } from "../types";
|
import { type PromptVariant, type Scenario } from "../types";
|
||||||
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
|
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
import { type ReactElement, useState, useEffect } from "react";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
|
||||||
import { isObject } from "lodash";
|
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
import { OutputStats } from "./OutputStats";
|
import { OutputStats } from "./OutputStats";
|
||||||
import { ErrorHandler } from "./ErrorHandler";
|
import { ErrorHandler } from "./ErrorHandler";
|
||||||
|
import { CellOptions } from "./CellOptions";
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -34,97 +33,80 @@ export default function OutputCell({
|
|||||||
|
|
||||||
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
|
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
|
||||||
|
|
||||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||||
// disabledReason = "Save your prompt variant to see output";
|
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||||
|
{ scenarioId: scenario.id, variantId: variant.id },
|
||||||
|
{ refetchInterval },
|
||||||
|
);
|
||||||
|
|
||||||
// const model = getModelName(variant.config as JSONSerializable);
|
const provider =
|
||||||
// TODO: Temporarily hardcoding this while we get other stuff working
|
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
|
||||||
const model = "gpt-3.5-turbo";
|
|
||||||
|
|
||||||
const outputMutation = api.outputs.get.useMutation();
|
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
||||||
|
|
||||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation();
|
||||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => {
|
||||||
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||||
|
await utils.scenarioVariantCells.get.invalidate({
|
||||||
const fetchMutex = useRef(false);
|
|
||||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
|
||||||
async (forceRefetch?: boolean) => {
|
|
||||||
if (fetchMutex.current) return;
|
|
||||||
setNumPreviousTries((prev) => prev + 1);
|
|
||||||
|
|
||||||
fetchMutex.current = true;
|
|
||||||
setOutput(null);
|
|
||||||
|
|
||||||
const shouldStream =
|
|
||||||
isObject(variant) &&
|
|
||||||
"config" in variant &&
|
|
||||||
isObject(variant.config) &&
|
|
||||||
"stream" in variant.config &&
|
|
||||||
variant.config.stream === true;
|
|
||||||
|
|
||||||
const channel = shouldStream ? generateChannel() : undefined;
|
|
||||||
setChannel(channel);
|
|
||||||
|
|
||||||
const output = await outputMutation.mutateAsync({
|
|
||||||
scenarioId: scenario.id,
|
scenarioId: scenario.id,
|
||||||
variantId: variant.id,
|
variantId: variant.id,
|
||||||
channel,
|
|
||||||
forceRefetch,
|
|
||||||
});
|
});
|
||||||
setOutput(output);
|
await utils.promptVariants.stats.invalidate({
|
||||||
await utils.promptVariants.stats.invalidate();
|
variantId: variant.id,
|
||||||
fetchMutex.current = false;
|
});
|
||||||
},
|
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||||
[outputMutation, scenario.id, variant.id],
|
|
||||||
);
|
|
||||||
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
|
||||||
|
|
||||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
const fetchingOutput = queryLoading || hardRefetching;
|
||||||
|
|
||||||
|
const awaitingOutput =
|
||||||
|
!cell ||
|
||||||
|
cell.retrievalStatus === "PENDING" ||
|
||||||
|
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||||
|
hardRefetching;
|
||||||
|
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||||
|
|
||||||
|
const modelOutput = cell?.modelOutput;
|
||||||
|
|
||||||
// Disconnect from socket if we're not streaming anymore
|
// Disconnect from socket if we're not streaming anymore
|
||||||
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
|
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
|
||||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|
||||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||||
|
|
||||||
if (fetchingOutput && !streamedMessage)
|
if (awaitingOutput && !streamedMessage)
|
||||||
return (
|
return (
|
||||||
<Center h="100%" w="100%">
|
<Center h="100%" w="100%">
|
||||||
<Spinner />
|
<Spinner />
|
||||||
</Center>
|
</Center>
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||||
|
|
||||||
if (output && output.errorMessage) {
|
if (cell && cell.errorMessage) {
|
||||||
return (
|
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||||
<ErrorHandler
|
|
||||||
output={output}
|
|
||||||
refetchOutput={hardRefetch}
|
|
||||||
numPreviousTries={numPreviousTries}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = output?.output as unknown as ChatCompletion;
|
const normalizedOutput = modelOutput
|
||||||
const message = response?.choices?.[0]?.message;
|
? provider.normalizeOutput(modelOutput.output)
|
||||||
|
: streamedMessage
|
||||||
if (output && message?.function_call) {
|
? provider.normalizeOutput(streamedMessage)
|
||||||
const rawArgs = message.function_call.arguments ?? "null";
|
: null;
|
||||||
let parsedArgs: string;
|
|
||||||
try {
|
|
||||||
parsedArgs = JSON.parse(rawArgs);
|
|
||||||
} catch (e: any) {
|
|
||||||
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if (modelOutput && normalizedOutput?.type === "json") {
|
||||||
return (
|
return (
|
||||||
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
<VStack
|
||||||
|
w="100%"
|
||||||
|
h="100%"
|
||||||
|
fontSize="xs"
|
||||||
|
flexWrap="wrap"
|
||||||
|
overflowX="hidden"
|
||||||
|
justifyContent="space-between"
|
||||||
|
>
|
||||||
|
<VStack w="full" flex={1} spacing={0}>
|
||||||
|
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||||
<SyntaxHighlighter
|
<SyntaxHighlighter
|
||||||
customStyle={{ overflowX: "unset" }}
|
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||||
language="json"
|
language="json"
|
||||||
style={docco}
|
style={docco}
|
||||||
lineProps={{
|
lineProps={{
|
||||||
@@ -132,25 +114,23 @@ export default function OutputCell({
|
|||||||
}}
|
}}
|
||||||
wrapLines
|
wrapLines
|
||||||
>
|
>
|
||||||
{stringify(
|
{stringify(normalizedOutput.value, { maxLength: 40 })}
|
||||||
{
|
|
||||||
function: message.function_call.name,
|
|
||||||
args: parsedArgs,
|
|
||||||
},
|
|
||||||
{ maxLength: 40 },
|
|
||||||
)}
|
|
||||||
</SyntaxHighlighter>
|
</SyntaxHighlighter>
|
||||||
<OutputStats model={model} modelOutput={output} scenario={scenario} />
|
</VStack>
|
||||||
</Box>
|
<OutputStats modelOutput={modelOutput} scenario={scenario} />
|
||||||
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
|
const contentToDisplay = (normalizedOutput?.type === "text" && normalizedOutput.value) || "";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
{contentToDisplay}
|
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||||
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
|
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||||
</Flex>
|
<Text>{contentToDisplay}</Text>
|
||||||
|
</VStack>
|
||||||
|
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
|
||||||
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,64 +1,56 @@
|
|||||||
import { type ModelOutput } from "@prisma/client";
|
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import { type Scenario } from "../types";
|
import { type Scenario } from "../types";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { type RouterOutputs } from "~/utils/api";
|
||||||
import { api } from "~/utils/api";
|
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
|
||||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
|
||||||
import { HStack, Icon, Text } from "@chakra-ui/react";
|
|
||||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||||
|
|
||||||
const SHOW_COST = false;
|
const SHOW_TIME = true;
|
||||||
const SHOW_TIME = false;
|
|
||||||
|
|
||||||
export const OutputStats = ({
|
export const OutputStats = ({
|
||||||
model,
|
|
||||||
modelOutput,
|
modelOutput,
|
||||||
scenario,
|
|
||||||
}: {
|
}: {
|
||||||
model: SupportedModel | null;
|
modelOutput: NonNullable<
|
||||||
modelOutput: ModelOutput;
|
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||||
|
>;
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
}) => {
|
}) => {
|
||||||
const timeToComplete = modelOutput.timeToComplete;
|
const timeToComplete = modelOutput.timeToComplete;
|
||||||
const experiment = useExperiment();
|
|
||||||
const evals =
|
|
||||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
|
||||||
|
|
||||||
const promptTokens = modelOutput.promptTokens;
|
const promptTokens = modelOutput.promptTokens;
|
||||||
const completionTokens = modelOutput.completionTokens;
|
const completionTokens = modelOutput.completionTokens;
|
||||||
|
|
||||||
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
|
||||||
const completionCost =
|
|
||||||
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
|
||||||
|
|
||||||
const cost = promptCost + completionCost;
|
|
||||||
|
|
||||||
if (!evals.length) return null;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
{evals.map((evaluation) => {
|
{modelOutput.outputEvaluation.map((evaluation) => {
|
||||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
const passed = evaluation.result > 0.5;
|
||||||
return (
|
return (
|
||||||
<HStack spacing={0} key={evaluation.id}>
|
<Tooltip
|
||||||
<Text>{evaluation.name}</Text>
|
isDisabled={!evaluation.details}
|
||||||
|
label={evaluation.details}
|
||||||
|
key={evaluation.id}
|
||||||
|
>
|
||||||
|
<HStack spacing={0}>
|
||||||
|
<Text>{evaluation.evaluation.label}</Text>
|
||||||
<Icon
|
<Icon
|
||||||
as={passed ? BsCheck : BsX}
|
as={passed ? BsCheck : BsX}
|
||||||
color={passed ? "green.500" : "red.500"}
|
color={passed ? "green.500" : "red.500"}
|
||||||
boxSize={6}
|
boxSize={6}
|
||||||
/>
|
/>
|
||||||
</HStack>
|
</HStack>
|
||||||
|
</Tooltip>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{SHOW_COST && (
|
{modelOutput.cost && (
|
||||||
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}>
|
<CostTooltip
|
||||||
|
promptTokens={promptTokens}
|
||||||
|
completionTokens={completionTokens}
|
||||||
|
cost={modelOutput.cost}
|
||||||
|
>
|
||||||
<HStack spacing={0}>
|
<HStack spacing={0}>
|
||||||
<Icon as={BsCurrencyDollar} />
|
<Icon as={BsCurrencyDollar} />
|
||||||
<Text mr={1}>{cost.toFixed(3)}</Text>
|
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CostTooltip>
|
</CostTooltip>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -1,23 +1,26 @@
|
|||||||
import { type DragEvent } from "react";
|
import { type DragEvent } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { isEqual } from "lodash";
|
import { isEqual } from "lodash-es";
|
||||||
import { type Scenario } from "./types";
|
import { type Scenario } from "./types";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
|
||||||
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
|
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { BsX } from "react-icons/bs";
|
import { BsX } from "react-icons/bs";
|
||||||
import { RiDraggable } from "react-icons/ri";
|
import { RiDraggable } from "react-icons/ri";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import { FloatingLabelInput } from "./FloatingLabelInput";
|
||||||
|
|
||||||
export default function ScenarioEditor({
|
export default function ScenarioEditor({
|
||||||
scenario,
|
scenario,
|
||||||
hovered,
|
...props
|
||||||
}: {
|
}: {
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
hovered: boolean;
|
hovered: boolean;
|
||||||
|
canHide: boolean;
|
||||||
}) {
|
}) {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
|
||||||
const savedValues = scenario.variableValues as Record<string, string>;
|
const savedValues = scenario.variableValues as Record<string, string>;
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
@@ -71,8 +74,9 @@ export default function ScenarioEditor({
|
|||||||
return (
|
return (
|
||||||
<HStack
|
<HStack
|
||||||
alignItems="flex-start"
|
alignItems="flex-start"
|
||||||
pr={cellPadding.x}
|
px={cellPadding.x}
|
||||||
py={cellPadding.y}
|
py={cellPadding.y}
|
||||||
|
spacing={0}
|
||||||
height="100%"
|
height="100%"
|
||||||
draggable={!variableInputHovered}
|
draggable={!variableInputHovered}
|
||||||
onDragStart={(e) => {
|
onDragStart={(e) => {
|
||||||
@@ -92,7 +96,13 @@ export default function ScenarioEditor({
|
|||||||
onDrop={onReorder}
|
onDrop={onReorder}
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||||
>
|
>
|
||||||
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}>
|
{canModify && props.canHide && (
|
||||||
|
<Stack
|
||||||
|
alignSelf="flex-start"
|
||||||
|
opacity={props.hovered ? 1 : 0}
|
||||||
|
spacing={0}
|
||||||
|
ml={-cellPadding.x}
|
||||||
|
>
|
||||||
<Tooltip label="Hide scenario" hasArrow>
|
<Tooltip label="Hide scenario" hasArrow>
|
||||||
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
||||||
<Button
|
<Button
|
||||||
@@ -107,7 +117,7 @@ export default function ScenarioEditor({
|
|||||||
cursor: "pointer",
|
cursor: "pointer",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={hidingInProgress ? 4 : 6} />
|
||||||
</Button>
|
</Button>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<Icon
|
<Icon
|
||||||
@@ -117,10 +127,12 @@ export default function ScenarioEditor({
|
|||||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||||
/>
|
/>
|
||||||
</Stack>
|
</Stack>
|
||||||
|
)}
|
||||||
|
|
||||||
{variableLabels.length === 0 ? (
|
{variableLabels.length === 0 ? (
|
||||||
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
||||||
) : (
|
) : (
|
||||||
<VStack spacing={1}>
|
<VStack spacing={4} flex={1} py={2}>
|
||||||
{variableLabels.map((key) => {
|
{variableLabels.map((key) => {
|
||||||
const value = values[key] ?? "";
|
const value = values[key] ?? "";
|
||||||
const layoutDirection = value.length > 20 ? "column" : "row";
|
const layoutDirection = value.length > 20 ? "column" : "row";
|
||||||
@@ -132,29 +144,14 @@ export default function ScenarioEditor({
|
|||||||
flexWrap="wrap"
|
flexWrap="wrap"
|
||||||
width="full"
|
width="full"
|
||||||
>
|
>
|
||||||
<Box
|
<FloatingLabelInput
|
||||||
bgColor="blue.100"
|
label={key}
|
||||||
color="blue.600"
|
isDisabled={!canModify}
|
||||||
px={1}
|
style={{ width: "100%" }}
|
||||||
my="3px"
|
|
||||||
fontSize="xs"
|
|
||||||
fontWeight="bold"
|
|
||||||
>
|
|
||||||
{key}
|
|
||||||
</Box>
|
|
||||||
<AutoResizeTextArea
|
|
||||||
px={2}
|
|
||||||
py={1}
|
|
||||||
placeholder="empty"
|
|
||||||
borderRadius="sm"
|
|
||||||
fontSize="sm"
|
|
||||||
lineHeight={1.2}
|
|
||||||
value={value}
|
value={value}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
setValues((prev) => ({ ...prev, [key]: e.target.value }));
|
setValues((prev) => ({ ...prev, [key]: e.target.value }));
|
||||||
}}
|
}}
|
||||||
maxH="32"
|
|
||||||
overflowY="auto"
|
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
|
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
@@ -162,12 +159,6 @@ export default function ScenarioEditor({
|
|||||||
onSave();
|
onSave();
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
resize="none"
|
|
||||||
overflow="hidden"
|
|
||||||
flex={layoutDirection === "row" ? 1 : undefined}
|
|
||||||
borderColor={hasChanged ? "blue.300" : "transparent"}
|
|
||||||
_hover={{ borderColor: "gray.300" }}
|
|
||||||
_focus={{ borderColor: "blue.500", outline: "none", bg: "white" }}
|
|
||||||
onMouseEnter={() => setVariableInputHovered(true)}
|
onMouseEnter={() => setVariableInputHovered(true)}
|
||||||
onMouseLeave={() => setVariableInputHovered(false)}
|
onMouseLeave={() => setVariableInputHovered(false)}
|
||||||
/>
|
/>
|
||||||
|
|||||||
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
||||||
|
import {
|
||||||
|
BsChevronDoubleLeft,
|
||||||
|
BsChevronDoubleRight,
|
||||||
|
BsChevronLeft,
|
||||||
|
BsChevronRight,
|
||||||
|
} from "react-icons/bs";
|
||||||
|
import { usePage, useScenarios } from "~/utils/hooks";
|
||||||
|
|
||||||
|
const ScenarioPaginator = () => {
|
||||||
|
const [page, setPage] = usePage();
|
||||||
|
const { data } = useScenarios();
|
||||||
|
|
||||||
|
if (!data) return null;
|
||||||
|
|
||||||
|
const { scenarios, startIndex, lastPage, count } = data;
|
||||||
|
|
||||||
|
const nextPage = () => {
|
||||||
|
if (page < lastPage) {
|
||||||
|
setPage(page + 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const prevPage = () => {
|
||||||
|
if (page > 1) {
|
||||||
|
setPage(page - 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const goToLastPage = () => setPage(lastPage, "replace");
|
||||||
|
const goToFirstPage = () => setPage(1, "replace");
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack pt={4}>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToFirstPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Go to first page"
|
||||||
|
icon={<BsChevronDoubleLeft />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={prevPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Previous page"
|
||||||
|
icon={<BsChevronLeft />}
|
||||||
|
/>
|
||||||
|
<Box>
|
||||||
|
{startIndex}-{startIndex + scenarios.length - 1} / {count}
|
||||||
|
</Box>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={nextPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Next page"
|
||||||
|
icon={<BsChevronRight />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToLastPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Go to last page"
|
||||||
|
icon={<BsChevronDoubleRight />}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ScenarioPaginator;
|
||||||
@@ -4,8 +4,14 @@ import { cellPadding } from "../constants";
|
|||||||
import OutputCell from "./OutputCell/OutputCell";
|
import OutputCell from "./OutputCell/OutputCell";
|
||||||
import ScenarioEditor from "./ScenarioEditor";
|
import ScenarioEditor from "./ScenarioEditor";
|
||||||
import type { PromptVariant, Scenario } from "./types";
|
import type { PromptVariant, Scenario } from "./types";
|
||||||
|
import { borders } from "./styles";
|
||||||
|
|
||||||
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
|
const ScenarioRow = (props: {
|
||||||
|
scenario: Scenario;
|
||||||
|
variants: PromptVariant[];
|
||||||
|
canHide: boolean;
|
||||||
|
rowStart: number;
|
||||||
|
}) => {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const highlightStyle = { backgroundColor: "gray.50" };
|
const highlightStyle = { backgroundColor: "gray.50" };
|
||||||
@@ -17,15 +23,21 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
|
|||||||
onMouseLeave={() => setIsHovered(false)}
|
onMouseLeave={() => setIsHovered(false)}
|
||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
borderLeftWidth={1}
|
borderLeftWidth={1}
|
||||||
|
{...borders}
|
||||||
|
rowStart={props.rowStart}
|
||||||
|
colStart={1}
|
||||||
>
|
>
|
||||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} />
|
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
{props.variants.map((variant) => (
|
{props.variants.map((variant, i) => (
|
||||||
<GridItem
|
<GridItem
|
||||||
key={variant.id}
|
key={variant.id}
|
||||||
onMouseEnter={() => setIsHovered(true)}
|
onMouseEnter={() => setIsHovered(true)}
|
||||||
onMouseLeave={() => setIsHovered(false)}
|
onMouseLeave={() => setIsHovered(false)}
|
||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
|
rowStart={props.rowStart}
|
||||||
|
colStart={i + 2}
|
||||||
|
{...borders}
|
||||||
>
|
>
|
||||||
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
||||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||||
|
|||||||
79
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
79
src/components/OutputsTable/ScenariosHeader.tsx
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
type ButtonProps,
|
||||||
|
HStack,
|
||||||
|
Text,
|
||||||
|
Icon,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuList,
|
||||||
|
MenuItem,
|
||||||
|
IconButton,
|
||||||
|
Spinner,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { cellPadding } from "../constants";
|
||||||
|
import {
|
||||||
|
useExperiment,
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useScenarios,
|
||||||
|
} from "~/utils/hooks";
|
||||||
|
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
||||||
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
|
||||||
|
export const ActionButton = (props: ButtonProps) => (
|
||||||
|
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
||||||
|
);
|
||||||
|
|
||||||
|
export const ScenariosHeader = () => {
|
||||||
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
const scenarios = useScenarios();
|
||||||
|
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const createScenarioMutation = api.scenarios.create.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const [onAddScenario, loading] = useHandledAsyncCallback(
|
||||||
|
async (autogenerate: boolean) => {
|
||||||
|
if (!experiment.data) return;
|
||||||
|
await createScenarioMutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
autogenerate,
|
||||||
|
});
|
||||||
|
await utils.scenarios.list.invalidate();
|
||||||
|
},
|
||||||
|
[createScenarioMutation],
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
||||||
|
<Text fontSize={16} fontWeight="bold">
|
||||||
|
Scenarios ({scenarios.data?.count})
|
||||||
|
</Text>
|
||||||
|
{canModify && (
|
||||||
|
<Menu>
|
||||||
|
<MenuButton mt={1}>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
aria-label="Edit Scenarios"
|
||||||
|
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||||
|
/>
|
||||||
|
</MenuButton>
|
||||||
|
<MenuList fontSize="md">
|
||||||
|
<MenuItem icon={<Icon as={BsPlus} boxSize={6} />} onClick={() => onAddScenario(false)}>
|
||||||
|
Add Scenario
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
||||||
|
Autogenerate Scenario
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
|
||||||
|
Edit Vars
|
||||||
|
</MenuItem>
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,72 +1,108 @@
|
|||||||
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
|
import {
|
||||||
|
Box,
|
||||||
|
Button,
|
||||||
|
HStack,
|
||||||
|
Spinner,
|
||||||
|
Tooltip,
|
||||||
|
useToast,
|
||||||
|
Text,
|
||||||
|
IconButton,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useRef, useEffect, useState, useCallback } from "react";
|
||||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
// import openAITypes from "~/codegen/openai.types.ts.txt";
|
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||||
|
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||||
|
|
||||||
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||||
|
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
const [isChanged, setIsChanged] = useState(false);
|
const [isChanged, setIsChanged] = useState(false);
|
||||||
|
|
||||||
|
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||||
|
|
||||||
|
const toggleFullscreen = useCallback(() => {
|
||||||
|
setIsFullscreen((prev) => !prev);
|
||||||
|
editorRef.current?.focus();
|
||||||
|
}, [setIsFullscreen]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleEsc = (event: KeyboardEvent) => {
|
||||||
|
if (event.key === "Escape" && isFullscreen) {
|
||||||
|
toggleFullscreen();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener("keydown", handleEsc);
|
||||||
|
return () => window.removeEventListener("keydown", handleEsc);
|
||||||
|
}, [isFullscreen, toggleFullscreen]);
|
||||||
|
|
||||||
const lastSavedFn = props.variant.constructFn;
|
const lastSavedFn = props.variant.constructFn;
|
||||||
|
|
||||||
const modifierKey = useModifierKeyLabel();
|
const modifierKey = useModifierKeyLabel();
|
||||||
|
|
||||||
const checkForChanges = useCallback(() => {
|
const checkForChanges = useCallback(() => {
|
||||||
if (!editorRef.current) return;
|
if (!editorRef.current) return;
|
||||||
const currentConfig = editorRef.current.getValue();
|
const currentFn = editorRef.current.getValue();
|
||||||
setIsChanged(currentConfig !== lastSavedFn);
|
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||||
}, [lastSavedFn]);
|
}, [lastSavedFn]);
|
||||||
|
|
||||||
|
const matchUpdatedSavedFn = useCallback(() => {
|
||||||
|
if (!editorRef.current) return;
|
||||||
|
editorRef.current.setValue(lastSavedFn);
|
||||||
|
setIsChanged(false);
|
||||||
|
}, [lastSavedFn]);
|
||||||
|
|
||||||
|
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
|
||||||
|
|
||||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
|
|
||||||
const [onSave] = useHandledAsyncCallback(async () => {
|
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||||
const currentFn = editorRef.current?.getValue();
|
if (!editorRef.current) return;
|
||||||
|
|
||||||
|
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
||||||
|
|
||||||
|
const currentFn = editorRef.current.getValue();
|
||||||
|
|
||||||
if (!currentFn) return;
|
if (!currentFn) return;
|
||||||
|
|
||||||
// Check if the editor has any typescript errors
|
// Check if the editor has any typescript errors
|
||||||
const model = editorRef.current?.getModel();
|
const model = editorRef.current.getModel();
|
||||||
if (!model) return;
|
if (!model) return;
|
||||||
|
|
||||||
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
|
|
||||||
const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
|
|
||||||
|
|
||||||
if (hasErrors) {
|
|
||||||
toast({
|
|
||||||
title: "Invalid TypeScript",
|
|
||||||
description: "Please fix the TypeScript errors before saving.",
|
|
||||||
status: "error",
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
|
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
|
||||||
const promptRegex = /prompt\s*=/;
|
const promptRegex = /definePrompt\(/;
|
||||||
if (!promptRegex.test(currentFn)) {
|
if (!promptRegex.test(currentFn)) {
|
||||||
console.log("no prompt");
|
|
||||||
console.log(currentFn);
|
|
||||||
toast({
|
toast({
|
||||||
title: "Missing prompt",
|
title: "Missing prompt",
|
||||||
description: "Please define the prompt (eg. `prompt = { ...`).",
|
description: "Please define the prompt (eg. `definePrompt(...`",
|
||||||
status: "error",
|
status: "error",
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
await replaceVariant.mutateAsync({
|
const resp = await replaceVariant.mutateAsync({
|
||||||
id: props.variant.id,
|
id: props.variant.id,
|
||||||
constructFn: currentFn,
|
constructFn: currentFn,
|
||||||
});
|
});
|
||||||
|
if (resp.status === "error") {
|
||||||
|
return toast({
|
||||||
|
title: "Error saving variant",
|
||||||
|
description: resp.message,
|
||||||
|
status: "error",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsChanged(false);
|
||||||
|
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
|
|
||||||
checkForChanges();
|
|
||||||
}, [checkForChanges]);
|
}, [checkForChanges]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -90,13 +126,26 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
wordWrapBreakAfterCharacters: "",
|
wordWrapBreakAfterCharacters: "",
|
||||||
wordWrapBreakBeforeCharacters: "",
|
wordWrapBreakBeforeCharacters: "",
|
||||||
quickSuggestions: true,
|
quickSuggestions: true,
|
||||||
|
readOnly: !canModify,
|
||||||
});
|
});
|
||||||
|
|
||||||
editorRef.current.onDidFocusEditorText(() => {
|
// Workaround because otherwise the commands only work on whatever
|
||||||
// Workaround because otherwise the command only works on whatever
|
|
||||||
// editor was loaded on the page last.
|
// editor was loaded on the page last.
|
||||||
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
||||||
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave);
|
editorRef.current.onDidFocusEditorText(() => {
|
||||||
|
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, onSave);
|
||||||
|
|
||||||
|
editorRef.current?.addCommand(
|
||||||
|
monaco.KeyMod.CtrlCmd | monaco.KeyMod.Shift | monaco.KeyCode.KeyF,
|
||||||
|
toggleFullscreen,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Exit fullscreen with escape
|
||||||
|
editorRef.current?.addCommand(monaco.KeyCode.Escape, () => {
|
||||||
|
if (isFullscreen) {
|
||||||
|
toggleFullscreen();
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
editorRef.current.onDidChangeModelContent(checkForChanges);
|
editorRef.current.onDidChangeModelContent(checkForChanges);
|
||||||
@@ -117,21 +166,48 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||||
}, [monaco, editorId]);
|
}, [monaco, editorId]);
|
||||||
|
|
||||||
// useEffect(() => {
|
useEffect(() => {
|
||||||
// const savedConfigChanged = lastSavedFn !== savedConfig;
|
if (!editorRef.current) return;
|
||||||
|
editorRef.current.updateOptions({
|
||||||
// lastSavedFn = savedConfig;
|
readOnly: !canModify,
|
||||||
|
});
|
||||||
// if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
|
}, [canModify]);
|
||||||
// editorRef.current?.setValue(savedConfig);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// checkForChanges();
|
|
||||||
// }, [savedConfig, checkForChanges]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="100%" pos="relative">
|
<Box
|
||||||
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
w="100%"
|
||||||
|
ref={containerRef}
|
||||||
|
sx={
|
||||||
|
isFullscreen
|
||||||
|
? {
|
||||||
|
position: "fixed",
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
right: 0,
|
||||||
|
bottom: 0,
|
||||||
|
}
|
||||||
|
: { h: "400px", w: "100%" }
|
||||||
|
}
|
||||||
|
bgColor={editorBackground}
|
||||||
|
zIndex={isFullscreen ? 1000 : "unset"}
|
||||||
|
pos="relative"
|
||||||
|
_hover={{ ".fullscreen-toggle": { opacity: 1 } }}
|
||||||
|
>
|
||||||
|
<Box id={editorId} w="100%" h="100%" />
|
||||||
|
<Tooltip label={`${modifierKey} + ⇧ + F`}>
|
||||||
|
<IconButton
|
||||||
|
className="fullscreen-toggle"
|
||||||
|
aria-label="Minimize"
|
||||||
|
icon={isFullscreen ? <FiMinimize /> : <FiMaximize />}
|
||||||
|
position="absolute"
|
||||||
|
top={2}
|
||||||
|
right={2}
|
||||||
|
onClick={toggleFullscreen}
|
||||||
|
opacity={0}
|
||||||
|
transition="opacity 0.2s"
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
|
||||||
{isChanged && (
|
{isChanged && (
|
||||||
<HStack pos="absolute" bottom={2} right={2}>
|
<HStack pos="absolute" bottom={2} right={2}>
|
||||||
<Button
|
<Button
|
||||||
@@ -144,9 +220,9 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
>
|
>
|
||||||
Reset
|
Reset
|
||||||
</Button>
|
</Button>
|
||||||
<Tooltip label={`${modifierKey} + Enter`}>
|
<Tooltip label={`${modifierKey} + S`}>
|
||||||
<Button size="sm" onClick={onSave} colorScheme="blue">
|
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||||
Save
|
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|||||||
@@ -1,105 +0,0 @@
|
|||||||
import { useState, type DragEvent } from "react";
|
|
||||||
import { type PromptVariant } from "./types";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
|
|
||||||
import { BsX } from "react-icons/bs";
|
|
||||||
import { RiDraggable } from "react-icons/ri";
|
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
|
||||||
|
|
||||||
export default function VariantHeader(props: { variant: PromptVariant }) {
|
|
||||||
const utils = api.useContext();
|
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
|
||||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
|
||||||
const [label, setLabel] = useState(props.variant.label);
|
|
||||||
|
|
||||||
const updateMutation = api.promptVariants.update.useMutation();
|
|
||||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
|
||||||
if (label && label !== props.variant.label) {
|
|
||||||
await updateMutation.mutateAsync({
|
|
||||||
id: props.variant.id,
|
|
||||||
updates: { label: label },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
|
||||||
|
|
||||||
const hideMutation = api.promptVariants.hide.useMutation();
|
|
||||||
const [onHide] = useHandledAsyncCallback(async () => {
|
|
||||||
await hideMutation.mutateAsync({
|
|
||||||
id: props.variant.id,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
}, [hideMutation, props.variant.id]);
|
|
||||||
|
|
||||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
|
||||||
const [onReorder] = useHandledAsyncCallback(
|
|
||||||
async (e: DragEvent<HTMLDivElement>) => {
|
|
||||||
e.preventDefault();
|
|
||||||
setIsDragTarget(false);
|
|
||||||
const draggedId = e.dataTransfer.getData("text/plain");
|
|
||||||
const droppedId = props.variant.id;
|
|
||||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
|
||||||
await reorderMutation.mutateAsync({
|
|
||||||
draggedId,
|
|
||||||
droppedId,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
},
|
|
||||||
[reorderMutation, props.variant.id],
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<HStack
|
|
||||||
spacing={4}
|
|
||||||
alignItems="center"
|
|
||||||
minH={headerMinHeight}
|
|
||||||
draggable={!isInputHovered}
|
|
||||||
onDragStart={(e) => {
|
|
||||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
|
||||||
e.currentTarget.style.opacity = "0.4";
|
|
||||||
}}
|
|
||||||
onDragEnd={(e) => {
|
|
||||||
e.currentTarget.style.opacity = "1";
|
|
||||||
}}
|
|
||||||
onDragOver={(e) => {
|
|
||||||
e.preventDefault();
|
|
||||||
setIsDragTarget(true);
|
|
||||||
}}
|
|
||||||
onDragLeave={() => {
|
|
||||||
setIsDragTarget(false);
|
|
||||||
}}
|
|
||||||
onDrop={onReorder}
|
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
|
||||||
>
|
|
||||||
<Icon
|
|
||||||
as={RiDraggable}
|
|
||||||
boxSize={6}
|
|
||||||
color="gray.400"
|
|
||||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
|
||||||
/>
|
|
||||||
<AutoResizeTextArea // Changed to Input
|
|
||||||
size="sm"
|
|
||||||
value={label}
|
|
||||||
onChange={(e) => setLabel(e.target.value)}
|
|
||||||
onBlur={onSaveLabel}
|
|
||||||
placeholder="Variant Name"
|
|
||||||
borderWidth={1}
|
|
||||||
borderColor="transparent"
|
|
||||||
fontWeight="bold"
|
|
||||||
fontSize={16}
|
|
||||||
_hover={{ borderColor: "gray.300" }}
|
|
||||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
|
||||||
flex={1}
|
|
||||||
px={cellPadding.x}
|
|
||||||
onMouseEnter={() => setIsInputHovered(true)}
|
|
||||||
onMouseLeave={() => setIsInputHovered(false)}
|
|
||||||
/>
|
|
||||||
<Tooltip label="Hide Variant" hasArrow>
|
|
||||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
|
||||||
<Icon as={BsX} boxSize={6} />
|
|
||||||
</Button>
|
|
||||||
</Tooltip>
|
|
||||||
</HStack>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -5,8 +5,10 @@ import { api } from "~/utils/api";
|
|||||||
import chroma from "chroma-js";
|
import chroma from "chroma-js";
|
||||||
import { BsCurrencyDollar } from "react-icons/bs";
|
import { BsCurrencyDollar } from "react-icons/bs";
|
||||||
import { CostTooltip } from "../tooltip/CostTooltip";
|
import { CostTooltip } from "../tooltip/CostTooltip";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
export default function VariantStats(props: { variant: PromptVariant }) {
|
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||||
|
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||||
const { data } = api.promptVariants.stats.useQuery(
|
const { data } = api.promptVariants.stats.useQuery(
|
||||||
{
|
{
|
||||||
variantId: props.variant.id,
|
variantId: props.variant.id,
|
||||||
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
completionTokens: 0,
|
completionTokens: 0,
|
||||||
scenarioCount: 0,
|
scenarioCount: 0,
|
||||||
outputCount: 0,
|
outputCount: 0,
|
||||||
|
awaitingRetrievals: false,
|
||||||
},
|
},
|
||||||
|
refetchInterval,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||||
|
useEffect(
|
||||||
|
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
|
||||||
|
[data.awaitingRetrievals],
|
||||||
|
);
|
||||||
|
|
||||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||||
"green.500",
|
"green.500",
|
||||||
"gray.500",
|
"gray.500",
|
||||||
@@ -33,21 +43,25 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
|
|
||||||
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
||||||
|
|
||||||
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
|
<HStack
|
||||||
|
justifyContent="space-between"
|
||||||
|
alignItems="center"
|
||||||
|
mx="2"
|
||||||
|
fontSize="xs"
|
||||||
|
py={cellPadding.y}
|
||||||
|
>
|
||||||
{showNumFinished && (
|
{showNumFinished && (
|
||||||
<Text>
|
<Text>
|
||||||
{data.outputCount} / {data.scenarioCount}
|
{data.outputCount} / {data.scenarioCount}
|
||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
<HStack px={cellPadding.x} py={cellPadding.y}>
|
<HStack px={cellPadding.x}>
|
||||||
{data.evalResults.map((result) => {
|
{data.evalResults.map((result) => {
|
||||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
const passedFrac = result.passCount / result.totalCount;
|
||||||
return (
|
return (
|
||||||
<HStack key={result.id}>
|
<HStack key={result.id}>
|
||||||
<Text>{result.evaluation.name}</Text>
|
<Text>{result.label}</Text>
|
||||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||||
{(passedFrac * 100).toFixed(1)}%
|
{(passedFrac * 100).toFixed(1)}%
|
||||||
</Text>
|
</Text>
|
||||||
@@ -55,13 +69,13 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{data.overallCost && (
|
{data.overallCost && !data.awaitingRetrievals && (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={data.promptTokens}
|
promptTokens={data.promptTokens}
|
||||||
completionTokens={data.completionTokens}
|
completionTokens={data.completionTokens}
|
||||||
cost={data.overallCost}
|
cost={data.overallCost}
|
||||||
>
|
>
|
||||||
<HStack spacing={0} align="center" color="gray.500" my="2">
|
<HStack spacing={0} align="center" color="gray.500">
|
||||||
<Icon as={BsCurrencyDollar} />
|
<Icon as={BsCurrencyDollar} />
|
||||||
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|||||||
@@ -1,113 +1,105 @@
|
|||||||
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
|
import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import NewScenarioButton from "./NewScenarioButton";
|
import AddVariantButton from "./AddVariantButton";
|
||||||
import NewVariantButton from "./NewVariantButton";
|
|
||||||
import ScenarioRow from "./ScenarioRow";
|
import ScenarioRow from "./ScenarioRow";
|
||||||
import VariantConfigEditor from "./VariantEditor";
|
import VariantEditor from "./VariantEditor";
|
||||||
import VariantHeader from "./VariantHeader";
|
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||||
import { cellPadding } from "../constants";
|
|
||||||
import { BsPencil } from "react-icons/bs";
|
|
||||||
import VariantStats from "./VariantStats";
|
import VariantStats from "./VariantStats";
|
||||||
import { useAppStore } from "~/state/store";
|
import { ScenariosHeader } from "./ScenariosHeader";
|
||||||
|
import { borders } from "./styles";
|
||||||
const stickyHeaderStyle: SystemStyleObject = {
|
import { useScenarios } from "~/utils/hooks";
|
||||||
position: "sticky",
|
import ScenarioPaginator from "./ScenarioPaginator";
|
||||||
top: "-1px",
|
|
||||||
backgroundColor: "#fff",
|
|
||||||
zIndex: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||||
const variants = api.promptVariants.list.useQuery(
|
const variants = api.promptVariants.list.useQuery(
|
||||||
{ experimentId: experimentId as string },
|
{ experimentId: experimentId as string },
|
||||||
{ enabled: !!experimentId },
|
{ enabled: !!experimentId },
|
||||||
);
|
);
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
|
||||||
|
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
const scenarios = useScenarios();
|
||||||
{ experimentId: experimentId as string },
|
|
||||||
{ enabled: !!experimentId },
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!variants.data || !scenarios.data) return null;
|
if (!variants.data || !scenarios.data) return null;
|
||||||
|
|
||||||
const allCols = variants.data.length + 1;
|
const allCols = variants.data.length + 2;
|
||||||
const headerRows = 3;
|
const variantHeaderRows = 3;
|
||||||
|
const scenarioHeaderRows = 1;
|
||||||
|
const scenarioFooterRows = 1;
|
||||||
|
const visibleScenariosCount = scenarios.data.scenarios.length;
|
||||||
|
const allRows =
|
||||||
|
variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + scenarioFooterRows;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid
|
<Grid
|
||||||
p={4}
|
pt={4}
|
||||||
pb={24}
|
pb={24}
|
||||||
|
pl={4}
|
||||||
display="grid"
|
display="grid"
|
||||||
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
||||||
sx={{
|
sx={{
|
||||||
"> *": {
|
"> *": {
|
||||||
borderColor: "gray.300",
|
borderColor: "gray.300",
|
||||||
borderBottomWidth: 1,
|
|
||||||
borderRightWidth: 1,
|
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
fontSize="sm"
|
fontSize="sm"
|
||||||
>
|
>
|
||||||
<GridItem
|
<GridItem rowSpan={variantHeaderRows}>
|
||||||
display="flex"
|
<AddVariantButton />
|
||||||
alignItems="flex-end"
|
|
||||||
rowSpan={headerRows}
|
|
||||||
px={cellPadding.x}
|
|
||||||
py={cellPadding.y}
|
|
||||||
// TODO: This is a hack to get the sticky header to work. It's not ideal because it's not responsive to the height of the header,
|
|
||||||
// so if the header height changes, this will need to be updated.
|
|
||||||
sx={{ ...stickyHeaderStyle, top: "-337px" }}
|
|
||||||
>
|
|
||||||
<HStack w="100%">
|
|
||||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
|
||||||
Scenarios ({scenarios.data.length})
|
|
||||||
</Heading>
|
|
||||||
<Button
|
|
||||||
size="xs"
|
|
||||||
variant="ghost"
|
|
||||||
color="gray.500"
|
|
||||||
aria-label="Edit"
|
|
||||||
leftIcon={<BsPencil />}
|
|
||||||
onClick={openDrawer}
|
|
||||||
>
|
|
||||||
Edit Vars
|
|
||||||
</Button>
|
|
||||||
</HStack>
|
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant, i) => {
|
||||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
const sharedProps: GridItemProps = {
|
||||||
<VariantHeader variant={variant} />
|
...borders,
|
||||||
|
colStart: i + 2,
|
||||||
|
borderLeftWidth: i === 0 ? 1 : 0,
|
||||||
|
};
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<VariantHeader
|
||||||
|
key={variant.uiId}
|
||||||
|
variant={variant}
|
||||||
|
canHide={variants.data.length > 1}
|
||||||
|
rowStart={1}
|
||||||
|
{...sharedProps}
|
||||||
|
/>
|
||||||
|
<GridItem rowStart={2} {...sharedProps}>
|
||||||
|
<VariantEditor variant={variant} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
<GridItem rowStart={3} {...sharedProps}>
|
||||||
<GridItem
|
|
||||||
rowSpan={scenarios.data.length + headerRows}
|
|
||||||
padding={0}
|
|
||||||
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
|
|
||||||
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
|
|
||||||
h={8}
|
|
||||||
sx={stickyHeaderStyle}
|
|
||||||
>
|
|
||||||
<NewVariantButton />
|
|
||||||
</GridItem>
|
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
|
||||||
<GridItem key={variant.uiId}>
|
|
||||||
<VariantConfigEditor variant={variant} />
|
|
||||||
</GridItem>
|
|
||||||
))}
|
|
||||||
{variants.data.map((variant) => (
|
|
||||||
<GridItem key={variant.uiId}>
|
|
||||||
<VariantStats variant={variant} />
|
<VariantStats variant={variant} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
</>
|
||||||
{scenarios.data.map((scenario) => (
|
);
|
||||||
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
|
})}
|
||||||
))}
|
|
||||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
<GridItem
|
||||||
<NewScenarioButton />
|
colSpan={allCols - 1}
|
||||||
|
rowStart={variantHeaderRows + 1}
|
||||||
|
colStart={1}
|
||||||
|
{...borders}
|
||||||
|
borderRightWidth={0}
|
||||||
|
>
|
||||||
|
<ScenariosHeader />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
|
{scenarios.data.scenarios.map((scenario, i) => (
|
||||||
|
<ScenarioRow
|
||||||
|
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
||||||
|
key={scenario.uiId}
|
||||||
|
scenario={scenario}
|
||||||
|
variants={variants.data}
|
||||||
|
canHide={visibleScenariosCount > 1}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
<GridItem
|
||||||
|
rowStart={variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + 2}
|
||||||
|
colStart={1}
|
||||||
|
colSpan={allCols}
|
||||||
|
>
|
||||||
|
<ScenarioPaginator />
|
||||||
|
</GridItem>
|
||||||
|
|
||||||
|
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
||||||
|
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
||||||
</Grid>
|
</Grid>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
13
src/components/OutputsTable/styles.ts
Normal file
13
src/components/OutputsTable/styles.ts
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
|
||||||
|
|
||||||
|
export const stickyHeaderStyle: SystemStyleObject = {
|
||||||
|
position: "sticky",
|
||||||
|
top: "0",
|
||||||
|
backgroundColor: "#fff",
|
||||||
|
zIndex: 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const borders: GridItemProps = {
|
||||||
|
borderRightWidth: 1,
|
||||||
|
borderBottomWidth: 1,
|
||||||
|
};
|
||||||
@@ -2,4 +2,4 @@ import { type RouterOutputs } from "~/utils/api";
|
|||||||
|
|
||||||
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
||||||
|
|
||||||
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>["scenarios"][0];
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
import { Flex, Icon, Link, Text } from "@chakra-ui/react";
|
|
||||||
import { BsExclamationTriangleFill } from "react-icons/bs";
|
|
||||||
import { env } from "~/env.mjs";
|
|
||||||
|
|
||||||
export default function PublicPlaygroundWarning() {
|
|
||||||
if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex bgColor="red.600" color="whiteAlpha.900" p={2} align="center">
|
|
||||||
<Icon boxSize={4} mr={2} as={BsExclamationTriangleFill} />
|
|
||||||
<Text>
|
|
||||||
Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
|
|
||||||
private use,{" "}
|
|
||||||
<Link textDecor="underline" href="https://github.com/openpipe/openpipe" target="_blank">
|
|
||||||
run a local copy
|
|
||||||
</Link>
|
|
||||||
.
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
59
src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
59
src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||||
|
import React from "react";
|
||||||
|
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||||
|
import Prism from "prismjs";
|
||||||
|
import "prismjs/components/prism-javascript";
|
||||||
|
import "prismjs/themes/prism.css"; // choose a theme you like
|
||||||
|
|
||||||
|
const highlightSyntax = (str: string) => {
|
||||||
|
let highlighted;
|
||||||
|
try {
|
||||||
|
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error highlighting:", e);
|
||||||
|
highlighted = str;
|
||||||
|
}
|
||||||
|
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
const CompareFunctions = ({
|
||||||
|
originalFunction,
|
||||||
|
newFunction = "",
|
||||||
|
leftTitle = "Original",
|
||||||
|
rightTitle = "Modified",
|
||||||
|
...props
|
||||||
|
}: {
|
||||||
|
originalFunction: string;
|
||||||
|
newFunction?: string;
|
||||||
|
leftTitle?: string;
|
||||||
|
rightTitle?: string;
|
||||||
|
} & StackProps) => {
|
||||||
|
const showSplitView = useBreakpointValue(
|
||||||
|
{
|
||||||
|
base: false,
|
||||||
|
md: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
fallback: "base",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
|
||||||
|
<DiffViewer
|
||||||
|
oldValue={originalFunction}
|
||||||
|
newValue={newFunction || originalFunction}
|
||||||
|
splitView={showSplitView}
|
||||||
|
hideLineNumbers={!showSplitView}
|
||||||
|
leftTitle={leftTitle}
|
||||||
|
rightTitle={rightTitle}
|
||||||
|
disableWordDiff={true}
|
||||||
|
compareMethod={DiffMethod.CHARS}
|
||||||
|
renderContent={highlightSyntax}
|
||||||
|
showDiffOnly={false}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default CompareFunctions;
|
||||||
74
src/components/RefinePromptModal/CustomInstructionsInput.tsx
Normal file
74
src/components/RefinePromptModal/CustomInstructionsInput.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { Button, Spinner, InputGroup, InputRightElement, Icon, HStack } from "@chakra-ui/react";
|
||||||
|
import { IoMdSend } from "react-icons/io";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
|
export const CustomInstructionsInput = ({
|
||||||
|
instructions,
|
||||||
|
setInstructions,
|
||||||
|
loading,
|
||||||
|
onSubmit,
|
||||||
|
}: {
|
||||||
|
instructions: string;
|
||||||
|
setInstructions: (instructions: string) => void;
|
||||||
|
loading: boolean;
|
||||||
|
onSubmit: () => void;
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<InputGroup
|
||||||
|
size="md"
|
||||||
|
w="full"
|
||||||
|
maxW="600"
|
||||||
|
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
|
||||||
|
borderRadius={8}
|
||||||
|
alignItems="center"
|
||||||
|
colorScheme="orange"
|
||||||
|
>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
value={instructions}
|
||||||
|
onChange={(e) => setInstructions(e.target.value)}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
e.currentTarget.blur();
|
||||||
|
onSubmit();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
placeholder="Send custom instructions"
|
||||||
|
py={4}
|
||||||
|
pl={4}
|
||||||
|
pr={12}
|
||||||
|
colorScheme="orange"
|
||||||
|
borderColor="gray.300"
|
||||||
|
borderWidth={1}
|
||||||
|
_hover={{
|
||||||
|
borderColor: "gray.300",
|
||||||
|
}}
|
||||||
|
_focus={{
|
||||||
|
borderColor: "gray.300",
|
||||||
|
}}
|
||||||
|
isDisabled={loading}
|
||||||
|
/>
|
||||||
|
<HStack></HStack>
|
||||||
|
<InputRightElement width="8" height="full">
|
||||||
|
<Button
|
||||||
|
h="8"
|
||||||
|
w="8"
|
||||||
|
minW="unset"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => onSubmit()}
|
||||||
|
variant={instructions ? "solid" : "ghost"}
|
||||||
|
mr={4}
|
||||||
|
borderRadius="8"
|
||||||
|
bgColor={instructions ? "orange.400" : "transparent"}
|
||||||
|
colorScheme="orange"
|
||||||
|
>
|
||||||
|
{loading ? (
|
||||||
|
<Spinner boxSize={4} />
|
||||||
|
) : (
|
||||||
|
<Icon as={IoMdSend} color={instructions ? "white" : "gray.500"} boxSize={5} />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</InputRightElement>
|
||||||
|
</InputGroup>
|
||||||
|
);
|
||||||
|
};
|
||||||
64
src/components/RefinePromptModal/RefineOption.tsx
Normal file
64
src/components/RefinePromptModal/RefineOption.tsx
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||||
|
import { type IconType } from "react-icons";
|
||||||
|
|
||||||
|
export const RefineOption = ({
|
||||||
|
label,
|
||||||
|
icon,
|
||||||
|
desciption,
|
||||||
|
activeLabel,
|
||||||
|
onClick,
|
||||||
|
loading,
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
icon: IconType;
|
||||||
|
desciption: string;
|
||||||
|
activeLabel: string | undefined;
|
||||||
|
onClick: (label: string) => void;
|
||||||
|
loading: boolean;
|
||||||
|
}) => {
|
||||||
|
const isActive = activeLabel === label;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<GridItem w="80" h="44">
|
||||||
|
<VStack
|
||||||
|
w="full"
|
||||||
|
h="full"
|
||||||
|
onClick={() => {
|
||||||
|
!loading && onClick(label);
|
||||||
|
}}
|
||||||
|
borderColor={isActive ? "blue.500" : "gray.200"}
|
||||||
|
borderWidth={2}
|
||||||
|
borderRadius={16}
|
||||||
|
padding={6}
|
||||||
|
backgroundColor="gray.50"
|
||||||
|
_hover={
|
||||||
|
loading
|
||||||
|
? undefined
|
||||||
|
: {
|
||||||
|
backgroundColor: "gray.100",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
spacing={8}
|
||||||
|
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
|
||||||
|
cursor="pointer"
|
||||||
|
opacity={loading ? 0.5 : 1}
|
||||||
|
>
|
||||||
|
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
||||||
|
<Icon as={icon} boxSize={12} />
|
||||||
|
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||||
|
{label}
|
||||||
|
</Heading>
|
||||||
|
</HStack>
|
||||||
|
<Text
|
||||||
|
fontSize="sm"
|
||||||
|
color="gray.500"
|
||||||
|
flexWrap="wrap"
|
||||||
|
wordBreak="break-word"
|
||||||
|
overflowWrap="break-word"
|
||||||
|
>
|
||||||
|
{desciption}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
|
};
|
||||||
148
src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
148
src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Spinner,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
SimpleGrid,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { BsStars } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { useState } from "react";
|
||||||
|
import CompareFunctions from "./CompareFunctions";
|
||||||
|
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||||
|
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
|
||||||
|
import { RefineOption } from "./RefineOption";
|
||||||
|
import { isObject, isString } from "lodash-es";
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
|
export const RefinePromptModal = ({
|
||||||
|
variant,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
onClose: () => void;
|
||||||
|
}) => {
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
|
||||||
|
|
||||||
|
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||||
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
|
const [instructions, setInstructions] = useState<string>("");
|
||||||
|
|
||||||
|
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
|
||||||
|
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||||
|
async (label?: string) => {
|
||||||
|
if (!variant.experimentId) return;
|
||||||
|
const updatedInstructions = label
|
||||||
|
? (providerRefineOptions[label] as RefineOptionInfo).instructions
|
||||||
|
: instructions;
|
||||||
|
setActiveRefineOptionLabel(label);
|
||||||
|
await getModifiedPromptMutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
instructions: updatedInstructions,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
||||||
|
);
|
||||||
|
|
||||||
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
|
|
||||||
|
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (
|
||||||
|
!variant.experimentId ||
|
||||||
|
!refinedPromptFn ||
|
||||||
|
(isObject(refinedPromptFn) && "status" in refinedPromptFn)
|
||||||
|
)
|
||||||
|
return;
|
||||||
|
await replaceVariantMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
constructFn: refinedPromptFn,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
onClose();
|
||||||
|
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
isOpen
|
||||||
|
onClose={onClose}
|
||||||
|
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||||
|
>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w={1200}>
|
||||||
|
<ModalHeader>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={BsStars} />
|
||||||
|
<Text>Refine with GPT-4</Text>
|
||||||
|
</HStack>
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody maxW="unset">
|
||||||
|
<VStack spacing={8}>
|
||||||
|
<VStack spacing={4}>
|
||||||
|
{Object.keys(providerRefineOptions).length && (
|
||||||
|
<>
|
||||||
|
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||||
|
{Object.keys(providerRefineOptions).map((label) => (
|
||||||
|
<RefineOption
|
||||||
|
key={label}
|
||||||
|
label={label}
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
|
icon={providerRefineOptions[label]!.icon}
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
|
desciption={providerRefineOptions[label]!.description}
|
||||||
|
activeLabel={activeRefineOptionLabel}
|
||||||
|
onClick={getModifiedPromptFn}
|
||||||
|
loading={modificationInProgress}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</SimpleGrid>
|
||||||
|
<Text color="gray.500">or</Text>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
<CustomInstructionsInput
|
||||||
|
instructions={instructions}
|
||||||
|
setInstructions={setInstructions}
|
||||||
|
loading={modificationInProgress}
|
||||||
|
onSubmit={getModifiedPromptFn}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
<CompareFunctions
|
||||||
|
originalFunction={variant.constructFn}
|
||||||
|
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||||
|
maxH="40vh"
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
</ModalBody>
|
||||||
|
|
||||||
|
<ModalFooter>
|
||||||
|
<HStack spacing={4}>
|
||||||
|
<Button
|
||||||
|
colorScheme="blue"
|
||||||
|
onClick={replaceVariant}
|
||||||
|
minW={24}
|
||||||
|
isDisabled={replacementInProgress || !refinedPromptFn}
|
||||||
|
>
|
||||||
|
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
287
src/components/RefinePromptModal/refineOptions.ts
Normal file
287
src/components/RefinePromptModal/refineOptions.ts
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
// Super hacky, but we'll redo the organization when we have more models
|
||||||
|
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import { VscJson } from "react-icons/vsc";
|
||||||
|
import { TfiThought } from "react-icons/tfi";
|
||||||
|
import { type IconType } from "react-icons";
|
||||||
|
|
||||||
|
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
|
||||||
|
|
||||||
|
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
|
||||||
|
"openai/ChatCompletion": {
|
||||||
|
"Add chain of thought": {
|
||||||
|
icon: VscJson,
|
||||||
|
description: "Asking the model to plan its answer can increase accuracy.",
|
||||||
|
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding chain of thought:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
This is what one looks like after adding chain of thought:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
Here's another example:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
explanation: {
|
||||||
|
type: "string",
|
||||||
|
}
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Add chain of thought to the original prompt.`,
|
||||||
|
},
|
||||||
|
"Convert to function call": {
|
||||||
|
icon: TfiThought,
|
||||||
|
description: "Use function calls to get output from the model in a more structured way.",
|
||||||
|
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding a function:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
This is what one looks like after adding a function:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "Evaluate sentiment.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: scenario.user_message,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "extract_sentiment",
|
||||||
|
parameters: {
|
||||||
|
type: "object", // parameters must always be an object with a properties key
|
||||||
|
properties: { // properties key is required
|
||||||
|
sentiment: {
|
||||||
|
type: "string",
|
||||||
|
description: "one of positive/negative/neutral",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "extract_sentiment",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Here's another example of adding a function:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Here is the title and body of a reddit post I am interested in:
|
||||||
|
|
||||||
|
title: \${scenario.title}
|
||||||
|
body: \${scenario.body}
|
||||||
|
|
||||||
|
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Answer one integer between 1 and 3.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Another example
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "write_in_language",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
text: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "write_in_language",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"replicate/llama2": {},
|
||||||
|
};
|
||||||
130
src/components/VariantHeader/VariantHeader.tsx
Normal file
130
src/components/VariantHeader/VariantHeader.tsx
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import { useState, type DragEvent } from "react";
|
||||||
|
import { type PromptVariant } from "../OutputsTable/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { RiDraggable } from "react-icons/ri";
|
||||||
|
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
||||||
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||||
|
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||||
|
|
||||||
|
export default function VariantHeader(
|
||||||
|
allProps: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
canHide: boolean;
|
||||||
|
} & GridItemProps,
|
||||||
|
) {
|
||||||
|
const { variant, canHide, ...gridItemProps } = allProps;
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
|
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||||
|
const [label, setLabel] = useState(variant.label);
|
||||||
|
|
||||||
|
const updateMutation = api.promptVariants.update.useMutation();
|
||||||
|
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||||
|
if (label && label !== variant.label) {
|
||||||
|
await updateMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
updates: { label: label },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [updateMutation, variant.id, variant.label, label]);
|
||||||
|
|
||||||
|
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||||
|
const [onReorder] = useHandledAsyncCallback(
|
||||||
|
async (e: DragEvent<HTMLDivElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setIsDragTarget(false);
|
||||||
|
const draggedId = e.dataTransfer.getData("text/plain");
|
||||||
|
const droppedId = variant.id;
|
||||||
|
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||||
|
await reorderMutation.mutateAsync({
|
||||||
|
draggedId,
|
||||||
|
droppedId,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
},
|
||||||
|
[reorderMutation, variant.id],
|
||||||
|
);
|
||||||
|
|
||||||
|
const [menuOpen, setMenuOpen] = useState(false);
|
||||||
|
|
||||||
|
if (!canModify) {
|
||||||
|
return (
|
||||||
|
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
|
||||||
|
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||||
|
{variant.label}
|
||||||
|
</Text>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<GridItem
|
||||||
|
padding={0}
|
||||||
|
sx={{
|
||||||
|
...stickyHeaderStyle,
|
||||||
|
// Ensure that the menu always appears above the sticky header of other variants
|
||||||
|
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||||
|
}}
|
||||||
|
borderTopWidth={1}
|
||||||
|
{...gridItemProps}
|
||||||
|
>
|
||||||
|
<HStack
|
||||||
|
spacing={4}
|
||||||
|
alignItems="flex-start"
|
||||||
|
minH={headerMinHeight}
|
||||||
|
draggable={!isInputHovered}
|
||||||
|
onDragStart={(e) => {
|
||||||
|
e.dataTransfer.setData("text/plain", variant.id);
|
||||||
|
e.currentTarget.style.opacity = "0.4";
|
||||||
|
}}
|
||||||
|
onDragEnd={(e) => {
|
||||||
|
e.currentTarget.style.opacity = "1";
|
||||||
|
}}
|
||||||
|
onDragOver={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setIsDragTarget(true);
|
||||||
|
}}
|
||||||
|
onDragLeave={() => {
|
||||||
|
setIsDragTarget(false);
|
||||||
|
}}
|
||||||
|
onDrop={onReorder}
|
||||||
|
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||||
|
>
|
||||||
|
<Icon
|
||||||
|
as={RiDraggable}
|
||||||
|
boxSize={6}
|
||||||
|
mt={2}
|
||||||
|
color="gray.400"
|
||||||
|
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||||
|
/>
|
||||||
|
<AutoResizeTextArea
|
||||||
|
size="sm"
|
||||||
|
value={label}
|
||||||
|
onChange={(e) => setLabel(e.target.value)}
|
||||||
|
onBlur={onSaveLabel}
|
||||||
|
placeholder="Variant Name"
|
||||||
|
borderWidth={1}
|
||||||
|
borderColor="transparent"
|
||||||
|
fontWeight="bold"
|
||||||
|
fontSize={16}
|
||||||
|
_hover={{ borderColor: "gray.300" }}
|
||||||
|
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||||
|
flex={1}
|
||||||
|
px={cellPadding.x}
|
||||||
|
onMouseEnter={() => setIsInputHovered(true)}
|
||||||
|
onMouseLeave={() => setIsInputHovered(false)}
|
||||||
|
/>
|
||||||
|
<VariantHeaderMenuButton
|
||||||
|
variant={variant}
|
||||||
|
canHide={canHide}
|
||||||
|
menuOpen={menuOpen}
|
||||||
|
setMenuOpen={setMenuOpen}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</GridItem>
|
||||||
|
);
|
||||||
|
}
|
||||||
108
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
108
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import { type PromptVariant } from "../OutputsTable/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Icon,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuItem,
|
||||||
|
MenuList,
|
||||||
|
MenuDivider,
|
||||||
|
Text,
|
||||||
|
Spinner,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
||||||
|
import { FaRegClone } from "react-icons/fa";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||||
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
|
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
|
||||||
|
|
||||||
|
export default function VariantHeaderMenuButton({
|
||||||
|
variant,
|
||||||
|
canHide,
|
||||||
|
menuOpen,
|
||||||
|
setMenuOpen,
|
||||||
|
}: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
canHide: boolean;
|
||||||
|
menuOpen: boolean;
|
||||||
|
setMenuOpen: (open: boolean) => void;
|
||||||
|
}) {
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||||
|
|
||||||
|
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
await duplicateMutation.mutateAsync({
|
||||||
|
experimentId: variant.experimentId,
|
||||||
|
variantId: variant.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||||
|
|
||||||
|
const hideMutation = api.promptVariants.hide.useMutation();
|
||||||
|
const [onHide] = useHandledAsyncCallback(async () => {
|
||||||
|
await hideMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [hideMutation, variant.id]);
|
||||||
|
|
||||||
|
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
|
||||||
|
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||||
|
{duplicationInProgress ? (
|
||||||
|
<Spinner boxSize={4} mx={3} my={3} />
|
||||||
|
) : (
|
||||||
|
<MenuButton>
|
||||||
|
<Button variant="ghost">
|
||||||
|
<Icon as={BsGear} />
|
||||||
|
</Button>
|
||||||
|
</MenuButton>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<MenuList mt={-3} fontSize="md">
|
||||||
|
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||||
|
Duplicate
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem
|
||||||
|
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||||
|
onClick={() => setChangeModelModalOpen(true)}
|
||||||
|
>
|
||||||
|
Change Model
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem
|
||||||
|
icon={<Icon as={BsStars} boxSize={5} />}
|
||||||
|
onClick={() => setRefinePromptModalOpen(true)}
|
||||||
|
>
|
||||||
|
Refine
|
||||||
|
</MenuItem>
|
||||||
|
{canHide && (
|
||||||
|
<>
|
||||||
|
<MenuDivider />
|
||||||
|
<MenuItem
|
||||||
|
onClick={onHide}
|
||||||
|
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
|
||||||
|
color="red.600"
|
||||||
|
_hover={{ backgroundColor: "red.50" }}
|
||||||
|
>
|
||||||
|
<Text>Hide</Text>
|
||||||
|
</MenuItem>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
{changeModelModalOpen && (
|
||||||
|
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
|
||||||
|
)}
|
||||||
|
{refinePromptModalOpen && (
|
||||||
|
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,17 +1,11 @@
|
|||||||
import {
|
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
|
||||||
Card,
|
|
||||||
CardBody,
|
|
||||||
HStack,
|
|
||||||
Icon,
|
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
CardHeader,
|
|
||||||
Divider,
|
|
||||||
Box,
|
|
||||||
} from "@chakra-ui/react";
|
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import { formatTimePast } from "~/utils/dayjs";
|
import { formatTimePast } from "~/utils/dayjs";
|
||||||
|
import Link from "next/link";
|
||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
|
import { BsPlusSquare } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
type ExperimentData = {
|
type ExperimentData = {
|
||||||
testScenarioCount: number;
|
testScenarioCount: number;
|
||||||
@@ -24,47 +18,42 @@ type ExperimentData = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
|
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
|
||||||
const router = useRouter();
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<AspectRatio ratio={1.2} w="full">
|
||||||
as={Card}
|
<VStack
|
||||||
variant="elevated"
|
as={Link}
|
||||||
|
href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
|
||||||
bg="gray.50"
|
bg="gray.50"
|
||||||
_hover={{ bg: "gray.100" }}
|
_hover={{ bg: "gray.100" }}
|
||||||
transition="background 0.2s"
|
transition="background 0.2s"
|
||||||
cursor="pointer"
|
cursor="pointer"
|
||||||
onClick={(e) => {
|
borderColor="gray.200"
|
||||||
e.preventDefault();
|
borderWidth={1}
|
||||||
void router.push({ pathname: "/experiments/[id]", query: { id: exp.id } }, undefined, {
|
p={4}
|
||||||
shallow: true,
|
justify="space-between"
|
||||||
});
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
<CardHeader>
|
<HStack w="full" color="gray.700" justify="center">
|
||||||
<HStack w="full" color="gray.700">
|
|
||||||
<Icon as={RiFlaskLine} boxSize={4} />
|
<Icon as={RiFlaskLine} boxSize={4} />
|
||||||
<Text fontWeight="bold">{exp.label}</Text>
|
<Text fontWeight="bold">{exp.label}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CardHeader>
|
<HStack h="full" spacing={4} flex={1} align="center">
|
||||||
<CardBody>
|
|
||||||
<HStack w="full" mb={8} spacing={4}>
|
|
||||||
<CountLabel label="Variants" count={exp.promptVariantCount} />
|
<CountLabel label="Variants" count={exp.promptVariantCount} />
|
||||||
<Divider h={12} orientation="vertical" />
|
<Divider h={12} orientation="vertical" />
|
||||||
<CountLabel label="Scenarios" count={exp.testScenarioCount} />
|
<CountLabel label="Scenarios" count={exp.testScenarioCount} />
|
||||||
</HStack>
|
</HStack>
|
||||||
<HStack w="full" color="gray.500" fontSize="xs">
|
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
|
||||||
<Text>Created {formatTimePast(exp.createdAt)}</Text>
|
<Text flex={1}>Created {formatTimePast(exp.createdAt)}</Text>
|
||||||
<Divider h={4} orientation="vertical" />
|
<Divider h={4} orientation="vertical" />
|
||||||
<Text>Updated {formatTimePast(exp.updatedAt)}</Text>
|
<Text flex={1}>Updated {formatTimePast(exp.updatedAt)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CardBody>
|
</VStack>
|
||||||
</Box>
|
</AspectRatio>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||||
return (
|
return (
|
||||||
<VStack alignItems="flex-start">
|
<VStack alignItems="center" flex={1}>
|
||||||
<Text color="gray.500" fontWeight="bold">
|
<Text color="gray.500" fontWeight="bold">
|
||||||
{label}
|
{label}
|
||||||
</Text>
|
</Text>
|
||||||
@@ -74,3 +63,33 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
|||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const NewExperimentCard = () => {
|
||||||
|
const router = useRouter();
|
||||||
|
const createMutation = api.experiments.create.useMutation();
|
||||||
|
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
|
||||||
|
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
|
||||||
|
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
|
||||||
|
}, [createMutation, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AspectRatio ratio={1.2} w="full">
|
||||||
|
<VStack
|
||||||
|
align="center"
|
||||||
|
justify="center"
|
||||||
|
_hover={{ cursor: "pointer", bg: "gray.50" }}
|
||||||
|
transition="background 0.2s"
|
||||||
|
cursor="pointer"
|
||||||
|
borderColor="gray.200"
|
||||||
|
borderWidth={1}
|
||||||
|
p={4}
|
||||||
|
onClick={createExperiment}
|
||||||
|
>
|
||||||
|
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||||
|
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||||
|
New Experiment
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
</AspectRatio>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
import { Icon, Button, Spinner, Text, type ButtonProps } from "@chakra-ui/react";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useRouter } from "next/router";
|
|
||||||
import { BsPlusSquare } from "react-icons/bs";
|
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
|
|
||||||
export const NewExperimentButton = (props: ButtonProps) => {
|
|
||||||
const router = useRouter();
|
|
||||||
const utils = api.useContext();
|
|
||||||
const createMutation = api.experiments.create.useMutation();
|
|
||||||
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
|
|
||||||
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
|
|
||||||
await utils.experiments.list.invalidate();
|
|
||||||
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
|
|
||||||
}, [createMutation, router]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
onClick={createExperiment}
|
|
||||||
display="flex"
|
|
||||||
alignItems="center"
|
|
||||||
variant={{ base: "solid", md: "ghost" }}
|
|
||||||
{...props}
|
|
||||||
>
|
|
||||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={4} />
|
|
||||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
|
||||||
New Experiment
|
|
||||||
</Text>
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -1,84 +1,100 @@
|
|||||||
|
import { useState, useEffect } from "react";
|
||||||
import {
|
import {
|
||||||
Heading,
|
Heading,
|
||||||
VStack,
|
VStack,
|
||||||
Icon,
|
Icon,
|
||||||
HStack,
|
HStack,
|
||||||
Image,
|
Image,
|
||||||
Grid,
|
|
||||||
GridItem,
|
|
||||||
Divider,
|
|
||||||
Text,
|
Text,
|
||||||
Box,
|
Box,
|
||||||
type BoxProps,
|
type BoxProps,
|
||||||
type LinkProps,
|
type LinkProps,
|
||||||
Link,
|
Link,
|
||||||
|
Flex,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
import { BsGithub, BsTwitter } from "react-icons/bs";
|
import { BsGithub, BsPersonCircle } from "react-icons/bs";
|
||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import PublicPlaygroundWarning from "../PublicPlaygroundWarning";
|
|
||||||
import { type IconType } from "react-icons";
|
import { type IconType } from "react-icons";
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import { useState, useEffect } from "react";
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
import UserMenu from "./UserMenu";
|
||||||
|
|
||||||
type IconLinkProps = BoxProps & LinkProps & { label: string; icon: IconType; href: string };
|
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
|
||||||
|
|
||||||
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
|
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
|
||||||
const isActive = useRouter().pathname.startsWith(href);
|
const router = useRouter();
|
||||||
|
const isActive = href && router.pathname.startsWith(href);
|
||||||
return (
|
return (
|
||||||
<Box
|
<HStack
|
||||||
|
w="full"
|
||||||
|
p={4}
|
||||||
|
color={color}
|
||||||
as={Link}
|
as={Link}
|
||||||
href={href}
|
href={href}
|
||||||
target={target}
|
target={target}
|
||||||
w="full"
|
bgColor={isActive ? "gray.200" : "transparent"}
|
||||||
bgColor={isActive ? "gray.300" : "transparent"}
|
_hover={{ bgColor: "gray.200", textDecoration: "none" }}
|
||||||
_hover={{ bgColor: "gray.300" }}
|
|
||||||
py={4}
|
|
||||||
justifyContent="start"
|
justifyContent="start"
|
||||||
cursor="pointer"
|
cursor="pointer"
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
<HStack w="full" px={4} color={color}>
|
|
||||||
<Icon as={icon} boxSize={6} mr={2} />
|
<Icon as={icon} boxSize={6} mr={2} />
|
||||||
<Text fontWeight="bold">{label}</Text>
|
<Text fontWeight="bold" fontSize="sm">
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</Box>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const Divider = () => <Box h="1px" bgColor="gray.200" />;
|
||||||
|
|
||||||
const NavSidebar = () => {
|
const NavSidebar = () => {
|
||||||
|
const user = useSession().data;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack align="stretch" bgColor="gray.100" py={2} pb={0} height="100%">
|
<VStack
|
||||||
<Link href="/" w="full" _hover={{ textDecoration: "none" }}>
|
align="stretch"
|
||||||
<HStack spacing={0} pl="3">
|
bgColor="gray.100"
|
||||||
<Image src="/logo.svg" alt="" w={8} h={8} />
|
py={2}
|
||||||
<Heading size="md" p={2} pl={{ base: 16, md: 2 }}>
|
pb={0}
|
||||||
|
height="100%"
|
||||||
|
w={{ base: "56px", md: "200px" }}
|
||||||
|
overflow="hidden"
|
||||||
|
>
|
||||||
|
<HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={4} py={2}>
|
||||||
|
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
|
||||||
|
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||||
OpenPipe
|
OpenPipe
|
||||||
</Heading>
|
</Heading>
|
||||||
</HStack>
|
</HStack>
|
||||||
</Link>
|
|
||||||
<Divider />
|
|
||||||
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
|
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
|
||||||
|
{user != null && (
|
||||||
|
<>
|
||||||
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
||||||
</VStack>
|
</>
|
||||||
<Divider />
|
)}
|
||||||
<VStack w="full" spacing={0} pb={2}>
|
{user === null && (
|
||||||
<IconLink
|
<IconLink
|
||||||
icon={BsGithub}
|
icon={BsPersonCircle}
|
||||||
label="GitHub"
|
label="Sign In"
|
||||||
|
onClick={() => {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
{user ? <UserMenu user={user} /> : <Divider />}
|
||||||
|
<VStack spacing={0} align="center">
|
||||||
|
<Link
|
||||||
href="https://github.com/openpipe/openpipe"
|
href="https://github.com/openpipe/openpipe"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
color="gray.500"
|
color="gray.500"
|
||||||
_hover={{ color: "gray.800" }}
|
_hover={{ color: "gray.800" }}
|
||||||
/>
|
p={2}
|
||||||
<IconLink
|
>
|
||||||
icon={BsTwitter}
|
<Icon as={BsGithub} boxSize={6} />
|
||||||
label="Twitter"
|
</Link>
|
||||||
href="https://twitter.com/corbtt"
|
|
||||||
target="_blank"
|
|
||||||
color="gray.500"
|
|
||||||
_hover={{ color: "gray.800" }}
|
|
||||||
/>
|
|
||||||
</VStack>
|
</VStack>
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
@@ -105,25 +121,14 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid
|
<Flex h={vh} w="100vw">
|
||||||
h={vh}
|
|
||||||
w="100vw"
|
|
||||||
templateColumns={{ base: "56px minmax(0, 1fr)", md: "200px minmax(0, 1fr)" }}
|
|
||||||
templateRows="max-content 1fr"
|
|
||||||
templateAreas={'"warning warning"\n"sidebar main"'}
|
|
||||||
>
|
|
||||||
<Head>
|
<Head>
|
||||||
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
|
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
|
||||||
</Head>
|
</Head>
|
||||||
<GridItem area="warning">
|
|
||||||
<PublicPlaygroundWarning />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem area="sidebar" overflow="hidden">
|
|
||||||
<NavSidebar />
|
<NavSidebar />
|
||||||
</GridItem>
|
<Box h="100%" flex={1} overflowY="auto">
|
||||||
<GridItem area="main" overflowY="auto">
|
|
||||||
{props.children}
|
{props.children}
|
||||||
</GridItem>
|
</Box>
|
||||||
</Grid>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
74
src/components/nav/UserMenu.tsx
Normal file
74
src/components/nav/UserMenu.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import {
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
Image,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Popover,
|
||||||
|
PopoverTrigger,
|
||||||
|
PopoverContent,
|
||||||
|
Link,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { type Session } from "next-auth";
|
||||||
|
import { signOut } from "next-auth/react";
|
||||||
|
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
|
||||||
|
|
||||||
|
export default function UserMenu({ user }: { user: Session }) {
|
||||||
|
const profileImage = user.user.image ? (
|
||||||
|
<Image src={user.user.image} alt="profile picture" boxSize={8} borderRadius="50%" />
|
||||||
|
) : (
|
||||||
|
<Icon as={BsPersonCircle} boxSize={6} />
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Popover placement="right">
|
||||||
|
<PopoverTrigger>
|
||||||
|
<HStack
|
||||||
|
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
|
||||||
|
px={3}
|
||||||
|
spacing={3}
|
||||||
|
py={2}
|
||||||
|
borderColor={"gray.200"}
|
||||||
|
borderTopWidth={1}
|
||||||
|
borderBottomWidth={1}
|
||||||
|
cursor="pointer"
|
||||||
|
_hover={{
|
||||||
|
bgColor: "gray.200",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{profileImage}
|
||||||
|
<VStack spacing={0} align="start" flex={1} flexShrink={1}>
|
||||||
|
<Text fontWeight="bold" fontSize="sm">
|
||||||
|
{user.user.name}
|
||||||
|
</Text>
|
||||||
|
<Text color="gray.500" fontSize="xs">
|
||||||
|
{user.user.email}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
|
||||||
|
</HStack>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
|
||||||
|
<VStack align="stretch" spacing={0}>
|
||||||
|
{/* sign out */}
|
||||||
|
<HStack
|
||||||
|
as={Link}
|
||||||
|
onClick={() => {
|
||||||
|
signOut().catch(console.error);
|
||||||
|
}}
|
||||||
|
px={4}
|
||||||
|
py={2}
|
||||||
|
spacing={4}
|
||||||
|
color="gray.500"
|
||||||
|
fontSize="sm"
|
||||||
|
>
|
||||||
|
<Icon as={BsBoxArrowRight} boxSize={6} />
|
||||||
|
<Text>Sign out</Text>
|
||||||
|
</HStack>
|
||||||
|
</VStack>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -20,7 +20,6 @@ export const CostTooltip = ({
|
|||||||
color="gray.800"
|
color="gray.800"
|
||||||
bgColor="gray.50"
|
bgColor="gray.50"
|
||||||
borderWidth={1}
|
borderWidth={1}
|
||||||
py={2}
|
|
||||||
hasArrow
|
hasArrow
|
||||||
shouldWrapChildren
|
shouldWrapChildren
|
||||||
label={
|
label={
|
||||||
|
|||||||
18
src/env.mjs
18
src/env.mjs
@@ -9,7 +9,15 @@ export const env = createEnv({
|
|||||||
server: {
|
server: {
|
||||||
DATABASE_URL: z.string().url(),
|
DATABASE_URL: z.string().url(),
|
||||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||||
|
RESTRICT_PRISMA_LOGS: z
|
||||||
|
.string()
|
||||||
|
.optional()
|
||||||
|
.default("false")
|
||||||
|
.transform((val) => val.toLowerCase() === "true"),
|
||||||
|
GITHUB_CLIENT_ID: z.string().min(1),
|
||||||
|
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||||
OPENAI_API_KEY: z.string().min(1),
|
OPENAI_API_KEY: z.string().min(1),
|
||||||
|
REPLICATE_API_TOKEN: z.string().default("placeholder"),
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -19,11 +27,6 @@ export const env = createEnv({
|
|||||||
*/
|
*/
|
||||||
client: {
|
client: {
|
||||||
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
|
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
|
||||||
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
|
|
||||||
.string()
|
|
||||||
.optional()
|
|
||||||
.default("false")
|
|
||||||
.transform((val) => val.toLowerCase() === "true"),
|
|
||||||
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -35,9 +38,12 @@ export const env = createEnv({
|
|||||||
DATABASE_URL: process.env.DATABASE_URL,
|
DATABASE_URL: process.env.DATABASE_URL,
|
||||||
NODE_ENV: process.env.NODE_ENV,
|
NODE_ENV: process.env.NODE_ENV,
|
||||||
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
||||||
|
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
|
||||||
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
||||||
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
|
|
||||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||||
|
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||||
|
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||||
|
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||||
|
|||||||
15
src/modelProviders/frontendModelProviders.ts
Normal file
15
src/modelProviders/frontendModelProviders.ts
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||||
|
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||||
|
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
||||||
|
|
||||||
|
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||||
|
|
||||||
|
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||||
|
// just include them in the default `modelProviders` object because it has some
|
||||||
|
// transient dependencies that can only be imported on the server.
|
||||||
|
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
||||||
|
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||||
|
"replicate/llama2": replicateLlama2Frontend,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default frontendModelProviders;
|
||||||
36
src/modelProviders/generateTypes.ts
Normal file
36
src/modelProviders/generateTypes.ts
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import { type JSONSchema4Object } from "json-schema";
|
||||||
|
import modelProviders from "./modelProviders";
|
||||||
|
import { compile } from "json-schema-to-typescript";
|
||||||
|
import dedent from "dedent";
|
||||||
|
|
||||||
|
export default async function generateTypes() {
|
||||||
|
const combinedSchema = {
|
||||||
|
type: "object",
|
||||||
|
properties: {} as Record<string, JSONSchema4Object>,
|
||||||
|
};
|
||||||
|
|
||||||
|
Object.entries(modelProviders).forEach(([id, provider]) => {
|
||||||
|
combinedSchema.properties[id] = provider.inputSchema;
|
||||||
|
});
|
||||||
|
|
||||||
|
Object.entries(modelProviders).forEach(([id, provider]) => {
|
||||||
|
combinedSchema.properties[id] = provider.inputSchema;
|
||||||
|
});
|
||||||
|
|
||||||
|
const promptTypes = (
|
||||||
|
await compile(combinedSchema as JSONSchema4Object, "PromptTypes", {
|
||||||
|
additionalProperties: false,
|
||||||
|
bannerComment: dedent`
|
||||||
|
/**
|
||||||
|
* This type map defines the input types for each model provider.
|
||||||
|
*/
|
||||||
|
`,
|
||||||
|
})
|
||||||
|
).replace(/export interface PromptTypes/g, "interface PromptTypes");
|
||||||
|
|
||||||
|
return dedent`
|
||||||
|
${promptTypes}
|
||||||
|
|
||||||
|
declare function definePrompt<T extends keyof PromptTypes>(modelProvider: T, input: PromptTypes[T])
|
||||||
|
`;
|
||||||
|
}
|
||||||
10
src/modelProviders/modelProviders.ts
Normal file
10
src/modelProviders/modelProviders.ts
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||||
|
import replicateLlama2 from "./replicate-llama2";
|
||||||
|
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||||
|
|
||||||
|
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||||
|
"openai/ChatCompletion": openaiChatCompletion,
|
||||||
|
"replicate/llama2": replicateLlama2,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProviders;
|
||||||
77
src/modelProviders/openai-ChatCompletion/codegen/codegen.ts
Normal file
77
src/modelProviders/openai-ChatCompletion/codegen/codegen.ts
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-var-requires */
|
||||||
|
|
||||||
|
import YAML from "yaml";
|
||||||
|
import fs from "fs";
|
||||||
|
import path from "path";
|
||||||
|
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
|
||||||
|
import $RefParser from "@apidevtools/json-schema-ref-parser";
|
||||||
|
import { type JSONObject } from "superjson/dist/types";
|
||||||
|
import assert from "assert";
|
||||||
|
import { type JSONSchema4Object } from "json-schema";
|
||||||
|
import { isObject } from "lodash-es";
|
||||||
|
|
||||||
|
// @ts-expect-error for some reason missing from types
|
||||||
|
import parserEstree from "prettier/plugins/estree";
|
||||||
|
import parserBabel from "prettier/plugins/babel";
|
||||||
|
import prettier from "prettier/standalone";
|
||||||
|
|
||||||
|
const OPENAPI_URL =
|
||||||
|
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
|
||||||
|
|
||||||
|
// Fetch the openapi document
|
||||||
|
const response = await fetch(OPENAPI_URL);
|
||||||
|
const openApiYaml = await response.text();
|
||||||
|
|
||||||
|
// Parse the yaml document
|
||||||
|
let schema = YAML.parse(openApiYaml) as JSONObject;
|
||||||
|
schema = openapiSchemaToJsonSchema(schema);
|
||||||
|
|
||||||
|
const jsonSchema = await $RefParser.dereference(schema);
|
||||||
|
|
||||||
|
assert("components" in jsonSchema);
|
||||||
|
const completionRequestSchema = jsonSchema.components.schemas
|
||||||
|
.CreateChatCompletionRequest as JSONSchema4Object;
|
||||||
|
|
||||||
|
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
||||||
|
// the fact that the schema says `model` can be either a string or an enum,
|
||||||
|
// and displays a warning in the editor. Let's stick with just an enum for
|
||||||
|
// now and drop the string option.
|
||||||
|
assert(
|
||||||
|
"properties" in completionRequestSchema &&
|
||||||
|
isObject(completionRequestSchema.properties) &&
|
||||||
|
"model" in completionRequestSchema.properties &&
|
||||||
|
isObject(completionRequestSchema.properties.model),
|
||||||
|
);
|
||||||
|
|
||||||
|
const modelProperty = completionRequestSchema.properties.model;
|
||||||
|
assert(
|
||||||
|
"oneOf" in modelProperty &&
|
||||||
|
Array.isArray(modelProperty.oneOf) &&
|
||||||
|
modelProperty.oneOf.length === 2 &&
|
||||||
|
isObject(modelProperty.oneOf[1]) &&
|
||||||
|
"enum" in modelProperty.oneOf[1],
|
||||||
|
"Expected model to have oneOf length of 2",
|
||||||
|
);
|
||||||
|
modelProperty.type = "string";
|
||||||
|
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||||
|
delete modelProperty["oneOf"];
|
||||||
|
|
||||||
|
// The default of "inf" confuses the Typescript generator, so can just remove it
|
||||||
|
assert(
|
||||||
|
"max_tokens" in completionRequestSchema.properties &&
|
||||||
|
isObject(completionRequestSchema.properties.max_tokens) &&
|
||||||
|
"default" in completionRequestSchema.properties.max_tokens,
|
||||||
|
);
|
||||||
|
delete completionRequestSchema.properties.max_tokens["default"];
|
||||||
|
|
||||||
|
// Get the directory of the current script
|
||||||
|
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||||
|
|
||||||
|
// Write the JSON schema to a file in the current directory
|
||||||
|
fs.writeFileSync(
|
||||||
|
path.join(currentDirectory, "input.schema.json"),
|
||||||
|
await prettier.format(JSON.stringify(completionRequestSchema, null, 2), {
|
||||||
|
parser: "json",
|
||||||
|
plugins: [parserBabel, parserEstree],
|
||||||
|
}),
|
||||||
|
);
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"description": "ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.",
|
||||||
|
"example": "gpt-3.5-turbo",
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"gpt-4-32k-0613",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"messages": {
|
||||||
|
"description": "A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb).",
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"role": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["system", "user", "assistant", "function"],
|
||||||
|
"description": "The role of the messages author. One of `system`, `user`, `assistant`, or `function`."
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The contents of the message. `content` is required for all messages except assistant messages with function calls."
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters."
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The name and arguments of a function that should be called, as generated by the model.",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the function to call."
|
||||||
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["role"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"functions": {
|
||||||
|
"description": "A list of functions the model may generate JSON inputs for.",
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64."
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The description of what the function does."
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format.",
|
||||||
|
"additionalProperties": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
"description": "Controls how the model responds to function calls. \"none\" means the model does not call a function, and responds to the end-user. \"auto\" means the model can pick between an end-user or calling a function. Specifying a particular function via `{\"name\":\\ \"my_function\"}` forces the model to call that function. \"none\" is the default when no functions are present. \"auto\" is the default if functions are present.",
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["none", "auto"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the function to call."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 2,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n"
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n"
|
||||||
|
},
|
||||||
|
"n": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 128,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "How many chat completion choices to generate for each input message."
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"description": "If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).\n",
|
||||||
|
"type": "boolean",
|
||||||
|
"nullable": true,
|
||||||
|
"default": false
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"description": "Up to 4 sequences where the API will stop generating further tokens.\n",
|
||||||
|
"default": null,
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"maxItems": 4,
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"max_tokens": {
|
||||||
|
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"presence_penalty": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 0,
|
||||||
|
"minimum": -2,
|
||||||
|
"maximum": 2,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
|
||||||
|
},
|
||||||
|
"frequency_penalty": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 0,
|
||||||
|
"minimum": -2,
|
||||||
|
"maximum": 2,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
|
||||||
|
},
|
||||||
|
"logit_bias": {
|
||||||
|
"type": "object",
|
||||||
|
"x-oaiTypeLabel": "map",
|
||||||
|
"default": null,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "user-1234",
|
||||||
|
"description": "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["model", "messages"]
|
||||||
|
}
|
||||||
84
src/modelProviders/openai-ChatCompletion/frontend.ts
Normal file
84
src/modelProviders/openai-ChatCompletion/frontend.ts
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import { type JsonValue } from "type-fest";
|
||||||
|
import { type SupportedModel } from ".";
|
||||||
|
import { type FrontendModelProvider } from "../types";
|
||||||
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
|
||||||
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||||
|
name: "OpenAI ChatCompletion",
|
||||||
|
|
||||||
|
models: {
|
||||||
|
"gpt-4-0613": {
|
||||||
|
name: "GPT-4",
|
||||||
|
contextWindow: 8192,
|
||||||
|
promptTokenPrice: 0.00003,
|
||||||
|
completionTokenPrice: 0.00006,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-4-32k-0613": {
|
||||||
|
name: "GPT-4 32k",
|
||||||
|
contextWindow: 32768,
|
||||||
|
promptTokenPrice: 0.00006,
|
||||||
|
completionTokenPrice: 0.00012,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-0613": {
|
||||||
|
name: "GPT-3.5 Turbo",
|
||||||
|
contextWindow: 4096,
|
||||||
|
promptTokenPrice: 0.0000015,
|
||||||
|
completionTokenPrice: 0.000002,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-16k-0613": {
|
||||||
|
name: "GPT-3.5 Turbo 16k",
|
||||||
|
contextWindow: 16384,
|
||||||
|
promptTokenPrice: 0.000003,
|
||||||
|
completionTokenPrice: 0.000004,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
normalizeOutput: (output) => {
|
||||||
|
const message = output.choices[0]?.message;
|
||||||
|
if (!message)
|
||||||
|
return {
|
||||||
|
type: "json",
|
||||||
|
value: output as unknown as JsonValue,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (message.content) {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
value: message.content,
|
||||||
|
};
|
||||||
|
} else if (message.function_call) {
|
||||||
|
let args = message.function_call.arguments ?? "";
|
||||||
|
try {
|
||||||
|
args = JSON.parse(args);
|
||||||
|
} catch (e) {
|
||||||
|
// Ignore
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
type: "json",
|
||||||
|
value: {
|
||||||
|
...message.function_call,
|
||||||
|
arguments: args,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
type: "json",
|
||||||
|
value: message as unknown as JsonValue,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export default frontendModelProvider;
|
||||||
140
src/modelProviders/openai-ChatCompletion/getCompletion.ts
Normal file
140
src/modelProviders/openai-ChatCompletion/getCompletion.ts
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||||
|
import {
|
||||||
|
type ChatCompletionChunk,
|
||||||
|
type ChatCompletion,
|
||||||
|
type CompletionCreateParams,
|
||||||
|
} from "openai/resources/chat";
|
||||||
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
|
import { type CompletionResponse } from "../types";
|
||||||
|
import { omit } from "lodash-es";
|
||||||
|
import { openai } from "~/server/utils/openai";
|
||||||
|
import { truthyFilter } from "~/utils/utils";
|
||||||
|
import { APIError } from "openai";
|
||||||
|
import frontendModelProvider from "./frontend";
|
||||||
|
import modelProvider, { type SupportedModel } from ".";
|
||||||
|
|
||||||
|
const mergeStreamedChunks = (
|
||||||
|
base: ChatCompletion | null,
|
||||||
|
chunk: ChatCompletionChunk,
|
||||||
|
): ChatCompletion => {
|
||||||
|
if (base === null) {
|
||||||
|
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
const choices = [...base.choices];
|
||||||
|
for (const choice of chunk.choices) {
|
||||||
|
const baseChoice = choices.find((c) => c.index === choice.index);
|
||||||
|
if (baseChoice) {
|
||||||
|
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
|
||||||
|
baseChoice.message = baseChoice.message ?? { role: "assistant" };
|
||||||
|
|
||||||
|
if (choice.delta?.content)
|
||||||
|
baseChoice.message.content =
|
||||||
|
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
|
||||||
|
if (choice.delta?.function_call) {
|
||||||
|
const fnCall = baseChoice.message.function_call ?? {};
|
||||||
|
fnCall.name =
|
||||||
|
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
|
||||||
|
fnCall.arguments =
|
||||||
|
((fnCall.arguments as string) ?? "") +
|
||||||
|
((choice.delta.function_call.arguments as string) ?? "");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const merged: ChatCompletion = {
|
||||||
|
...base,
|
||||||
|
choices,
|
||||||
|
};
|
||||||
|
|
||||||
|
return merged;
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function getCompletion(
|
||||||
|
input: CompletionCreateParams,
|
||||||
|
onStream: ((partialOutput: ChatCompletion) => void) | null,
|
||||||
|
): Promise<CompletionResponse<ChatCompletion>> {
|
||||||
|
const start = Date.now();
|
||||||
|
let finalCompletion: ChatCompletion | null = null;
|
||||||
|
let promptTokens: number | undefined = undefined;
|
||||||
|
let completionTokens: number | undefined = undefined;
|
||||||
|
const modelName = modelProvider.getModel(input) as SupportedModel;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (onStream) {
|
||||||
|
const resp = await openai.chat.completions.create(
|
||||||
|
{ ...input, stream: true },
|
||||||
|
{
|
||||||
|
maxRetries: 0,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
for await (const part of resp) {
|
||||||
|
finalCompletion = mergeStreamedChunks(finalCompletion, part);
|
||||||
|
onStream(finalCompletion);
|
||||||
|
}
|
||||||
|
if (!finalCompletion) {
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: "Streaming failed to return a completion",
|
||||||
|
autoRetry: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
||||||
|
completionTokens = countOpenAIChatTokens(
|
||||||
|
modelName,
|
||||||
|
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
||||||
|
);
|
||||||
|
} catch (err) {
|
||||||
|
// TODO handle this, library seems like maybe it doesn't work with function calls?
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const resp = await openai.chat.completions.create(
|
||||||
|
{ ...input, stream: false },
|
||||||
|
{
|
||||||
|
maxRetries: 0,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
finalCompletion = resp;
|
||||||
|
promptTokens = resp.usage?.prompt_tokens ?? 0;
|
||||||
|
completionTokens = resp.usage?.completion_tokens ?? 0;
|
||||||
|
}
|
||||||
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
|
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
||||||
|
let cost = undefined;
|
||||||
|
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
||||||
|
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "success",
|
||||||
|
statusCode: 200,
|
||||||
|
value: finalCompletion,
|
||||||
|
timeToComplete,
|
||||||
|
promptTokens,
|
||||||
|
completionTokens,
|
||||||
|
cost,
|
||||||
|
};
|
||||||
|
} catch (error: unknown) {
|
||||||
|
console.error("ERROR IS", error);
|
||||||
|
if (error instanceof APIError) {
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: error.message,
|
||||||
|
autoRetry: error.status === 429 || error.status === 503,
|
||||||
|
statusCode: error.status,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
console.error(error);
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: (error as Error).message,
|
||||||
|
autoRetry: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
45
src/modelProviders/openai-ChatCompletion/index.ts
Normal file
45
src/modelProviders/openai-ChatCompletion/index.ts
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import { type JSONSchema4 } from "json-schema";
|
||||||
|
import { type ModelProvider } from "../types";
|
||||||
|
import inputSchema from "./codegen/input.schema.json";
|
||||||
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
|
import { getCompletion } from "./getCompletion";
|
||||||
|
import frontendModelProvider from "./frontend";
|
||||||
|
|
||||||
|
const supportedModels = [
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-4-32k-0613",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
export type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
|
export type OpenaiChatModelProvider = ModelProvider<
|
||||||
|
SupportedModel,
|
||||||
|
CompletionCreateParams,
|
||||||
|
ChatCompletion
|
||||||
|
>;
|
||||||
|
|
||||||
|
const modelProvider: OpenaiChatModelProvider = {
|
||||||
|
getModel: (input) => {
|
||||||
|
if (supportedModels.includes(input.model as SupportedModel))
|
||||||
|
return input.model as SupportedModel;
|
||||||
|
|
||||||
|
const modelMaps: Record<string, SupportedModel> = {
|
||||||
|
"gpt-4": "gpt-4-0613",
|
||||||
|
"gpt-4-32k": "gpt-4-32k-0613",
|
||||||
|
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
|
||||||
|
};
|
||||||
|
|
||||||
|
if (input.model in modelMaps) return modelMaps[input.model] as SupportedModel;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
|
shouldStream: (input) => input.stream ?? false,
|
||||||
|
getCompletion,
|
||||||
|
...frontendModelProvider,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProvider;
|
||||||
42
src/modelProviders/replicate-llama2/frontend.ts
Normal file
42
src/modelProviders/replicate-llama2/frontend.ts
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||||
|
import { type FrontendModelProvider } from "../types";
|
||||||
|
|
||||||
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||||
|
name: "Replicate Llama2",
|
||||||
|
|
||||||
|
models: {
|
||||||
|
"7b-chat": {
|
||||||
|
name: "LLama 2 7B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0023,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
|
||||||
|
},
|
||||||
|
"13b-chat": {
|
||||||
|
name: "LLama 2 13B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0023,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
|
||||||
|
},
|
||||||
|
"70b-chat": {
|
||||||
|
name: "LLama 2 70B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0032,
|
||||||
|
speed: "slow",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
normalizeOutput: (output) => {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
value: output.join(""),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export default frontendModelProvider;
|
||||||
60
src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
60
src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import { env } from "~/env.mjs";
|
||||||
|
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
|
||||||
|
import { type CompletionResponse } from "../types";
|
||||||
|
import Replicate from "replicate";
|
||||||
|
|
||||||
|
const replicate = new Replicate({
|
||||||
|
auth: env.REPLICATE_API_TOKEN || "",
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||||
|
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||||
|
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
||||||
|
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function getCompletion(
|
||||||
|
input: ReplicateLlama2Input,
|
||||||
|
onStream: ((partialOutput: string[]) => void) | null,
|
||||||
|
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||||
|
const start = Date.now();
|
||||||
|
|
||||||
|
const { model, stream, ...rest } = input;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const prediction = await replicate.predictions.create({
|
||||||
|
version: modelIds[model],
|
||||||
|
input: rest,
|
||||||
|
});
|
||||||
|
|
||||||
|
const interval = onStream
|
||||||
|
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||||
|
setInterval(async () => {
|
||||||
|
const partialPrediction = await replicate.predictions.get(prediction.id);
|
||||||
|
|
||||||
|
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
|
||||||
|
}, 500)
|
||||||
|
: null;
|
||||||
|
|
||||||
|
const resp = await replicate.wait(prediction, {});
|
||||||
|
if (interval) clearInterval(interval);
|
||||||
|
|
||||||
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
|
if (resp.error) throw new Error(resp.error as string);
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "success",
|
||||||
|
statusCode: 200,
|
||||||
|
value: resp.output as ReplicateLlama2Output,
|
||||||
|
timeToComplete,
|
||||||
|
};
|
||||||
|
} catch (error: unknown) {
|
||||||
|
console.error("ERROR IS", error);
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: (error as Error).message,
|
||||||
|
autoRetry: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
70
src/modelProviders/replicate-llama2/index.ts
Normal file
70
src/modelProviders/replicate-llama2/index.ts
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import { type ModelProvider } from "../types";
|
||||||
|
import frontendModelProvider from "./frontend";
|
||||||
|
import { getCompletion } from "./getCompletion";
|
||||||
|
|
||||||
|
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||||
|
|
||||||
|
export type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
|
export type ReplicateLlama2Input = {
|
||||||
|
model: SupportedModel;
|
||||||
|
prompt: string;
|
||||||
|
stream?: boolean;
|
||||||
|
max_length?: number;
|
||||||
|
temperature?: number;
|
||||||
|
top_p?: number;
|
||||||
|
repetition_penalty?: number;
|
||||||
|
debug?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ReplicateLlama2Output = string[];
|
||||||
|
|
||||||
|
export type ReplicateLlama2Provider = ModelProvider<
|
||||||
|
SupportedModel,
|
||||||
|
ReplicateLlama2Input,
|
||||||
|
ReplicateLlama2Output
|
||||||
|
>;
|
||||||
|
|
||||||
|
const modelProvider: ReplicateLlama2Provider = {
|
||||||
|
getModel: (input) => {
|
||||||
|
if (supportedModels.includes(input.model)) return input.model;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
inputSchema: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
model: {
|
||||||
|
type: "string",
|
||||||
|
enum: supportedModels as unknown as string[],
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
stream: {
|
||||||
|
type: "boolean",
|
||||||
|
},
|
||||||
|
max_length: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
temperature: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
top_p: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
repetition_penalty: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
debug: {
|
||||||
|
type: "boolean",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ["model", "prompt"],
|
||||||
|
},
|
||||||
|
shouldStream: (input) => input.stream ?? false,
|
||||||
|
getCompletion,
|
||||||
|
...frontendModelProvider,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProvider;
|
||||||
66
src/modelProviders/types.ts
Normal file
66
src/modelProviders/types.ts
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import { type JSONSchema4 } from "json-schema";
|
||||||
|
import { type JsonValue } from "type-fest";
|
||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
const ZodSupportedProvider = z.union([
|
||||||
|
z.literal("openai/ChatCompletion"),
|
||||||
|
z.literal("replicate/llama2"),
|
||||||
|
]);
|
||||||
|
|
||||||
|
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||||
|
|
||||||
|
export const ZodModel = z.object({
|
||||||
|
name: z.string(),
|
||||||
|
contextWindow: z.number(),
|
||||||
|
promptTokenPrice: z.number().optional(),
|
||||||
|
completionTokenPrice: z.number().optional(),
|
||||||
|
pricePerSecond: z.number().optional(),
|
||||||
|
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
|
||||||
|
provider: ZodSupportedProvider,
|
||||||
|
description: z.string().optional(),
|
||||||
|
learnMoreUrl: z.string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type Model = z.infer<typeof ZodModel>;
|
||||||
|
|
||||||
|
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||||
|
name: string;
|
||||||
|
models: Record<SupportedModels, Model>;
|
||||||
|
|
||||||
|
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type CompletionResponse<T> =
|
||||||
|
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
|
||||||
|
| {
|
||||||
|
type: "success";
|
||||||
|
value: T;
|
||||||
|
timeToComplete: number;
|
||||||
|
statusCode: number;
|
||||||
|
promptTokens?: number;
|
||||||
|
completionTokens?: number;
|
||||||
|
cost?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||||
|
getModel: (input: InputSchema) => SupportedModels | null;
|
||||||
|
shouldStream: (input: InputSchema) => boolean;
|
||||||
|
inputSchema: JSONSchema4;
|
||||||
|
getCompletion: (
|
||||||
|
input: InputSchema,
|
||||||
|
onStream: ((partialOutput: OutputSchema) => void) | null,
|
||||||
|
) => Promise<CompletionResponse<OutputSchema>>;
|
||||||
|
|
||||||
|
// This is just a convenience for type inference, don't use it at runtime
|
||||||
|
_outputSchema?: OutputSchema | null;
|
||||||
|
} & FrontendModelProvider<SupportedModels, OutputSchema>;
|
||||||
|
|
||||||
|
export type NormalizedOutput =
|
||||||
|
| {
|
||||||
|
type: "text";
|
||||||
|
value: string;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "json";
|
||||||
|
value: JsonValue;
|
||||||
|
};
|
||||||
@@ -2,22 +2,36 @@ import { type Session } from "next-auth";
|
|||||||
import { SessionProvider } from "next-auth/react";
|
import { SessionProvider } from "next-auth/react";
|
||||||
import { type AppType } from "next/app";
|
import { type AppType } from "next/app";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { ChakraProvider } from "@chakra-ui/react";
|
|
||||||
import theme from "~/utils/theme";
|
|
||||||
import Favicon from "~/components/Favicon";
|
import Favicon from "~/components/Favicon";
|
||||||
import "~/utils/analytics";
|
import "~/utils/analytics";
|
||||||
|
import Head from "next/head";
|
||||||
|
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
||||||
|
import { SyncAppStore } from "~/state/sync";
|
||||||
|
import NextAdapterApp from "next-query-params/app";
|
||||||
|
import { QueryParamProvider } from "use-query-params";
|
||||||
|
|
||||||
const MyApp: AppType<{ session: Session | null }> = ({
|
const MyApp: AppType<{ session: Session | null }> = ({
|
||||||
Component,
|
Component,
|
||||||
pageProps: { session, ...pageProps },
|
pageProps: { session, ...pageProps },
|
||||||
}) => {
|
}) => {
|
||||||
return (
|
return (
|
||||||
|
<>
|
||||||
|
<Head>
|
||||||
|
<meta
|
||||||
|
name="viewport"
|
||||||
|
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
|
||||||
|
/>
|
||||||
|
</Head>
|
||||||
<SessionProvider session={session}>
|
<SessionProvider session={session}>
|
||||||
|
<SyncAppStore />
|
||||||
<Favicon />
|
<Favicon />
|
||||||
<ChakraProvider theme={theme}>
|
<ChakraThemeProvider>
|
||||||
|
<QueryParamProvider adapter={NextAdapterApp}>
|
||||||
<Component {...pageProps} />
|
<Component {...pageProps} />
|
||||||
</ChakraProvider>
|
</QueryParamProvider>
|
||||||
|
</ChakraThemeProvider>
|
||||||
</SessionProvider>
|
</SessionProvider>
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
23
src/pages/account/signin.tsx
Normal file
23
src/pages/account/signin.tsx
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
import AppShell from "~/components/nav/AppShell";
|
||||||
|
|
||||||
|
export default function SignIn() {
|
||||||
|
const session = useSession().data;
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (session) {
|
||||||
|
router.push("/experiments").catch(console.error);
|
||||||
|
} else if (session === null) {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
}
|
||||||
|
}, [session, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AppShell>
|
||||||
|
<div />
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -49,6 +49,10 @@ const DeleteButton = () => {
|
|||||||
onClose();
|
onClose();
|
||||||
}, [mutation, experiment.data?.id, router]);
|
}, [mutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
||||||
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Button
|
<Button
|
||||||
@@ -124,6 +128,8 @@ export default function Experiment() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const canModify = experiment.data?.access.canModify ?? false;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AppShell title={experiment.data?.label}>
|
<AppShell title={experiment.data?.label}>
|
||||||
<VStack h="full">
|
<VStack h="full">
|
||||||
@@ -143,6 +149,7 @@ export default function Experiment() {
|
|||||||
</Link>
|
</Link>
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
<BreadcrumbItem isCurrentPage>
|
<BreadcrumbItem isCurrentPage>
|
||||||
|
{canModify ? (
|
||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={label}
|
value={label}
|
||||||
@@ -157,8 +164,14 @@ export default function Experiment() {
|
|||||||
_hover={{ borderColor: "gray.300" }}
|
_hover={{ borderColor: "gray.300" }}
|
||||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||||
/>
|
/>
|
||||||
|
) : (
|
||||||
|
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
|
||||||
|
{experiment.data?.label}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
|
{canModify && (
|
||||||
<HStack>
|
<HStack>
|
||||||
<Button
|
<Button
|
||||||
size="sm"
|
size="sm"
|
||||||
@@ -174,6 +187,7 @@ export default function Experiment() {
|
|||||||
</Button>
|
</Button>
|
||||||
<DeleteButton />
|
<DeleteButton />
|
||||||
</HStack>
|
</HStack>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<SettingsDrawer />
|
<SettingsDrawer />
|
||||||
<Box w="100%" overflowX="auto" flex={1}>
|
<Box w="100%" overflowX="auto" flex={1}>
|
||||||
|
|||||||
@@ -1,25 +1,53 @@
|
|||||||
import {
|
import {
|
||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
HStack,
|
|
||||||
Icon,
|
Icon,
|
||||||
VStack,
|
VStack,
|
||||||
Breadcrumb,
|
Breadcrumb,
|
||||||
BreadcrumbItem,
|
BreadcrumbItem,
|
||||||
Flex,
|
Flex,
|
||||||
|
Center,
|
||||||
|
Text,
|
||||||
|
Link,
|
||||||
|
HStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import AppShell from "~/components/nav/AppShell";
|
import AppShell from "~/components/nav/AppShell";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton";
|
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
|
||||||
import { ExperimentCard } from "~/components/experiments/ExperimentCard";
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
|
||||||
export default function ExperimentsPage() {
|
export default function ExperimentsPage() {
|
||||||
const experiments = api.experiments.list.useQuery();
|
const experiments = api.experiments.list.useQuery();
|
||||||
|
|
||||||
|
const user = useSession().data;
|
||||||
|
const authLoading = useSession().status === "loading";
|
||||||
|
|
||||||
|
if (user === null || authLoading) {
|
||||||
return (
|
return (
|
||||||
<AppShell>
|
<AppShell title="Experiments">
|
||||||
<VStack alignItems={"flex-start"} m={4} mt={1}>
|
<Center h="100%">
|
||||||
<HStack w="full" justifyContent="space-between" mb={4}>
|
{!authLoading && (
|
||||||
|
<Text>
|
||||||
|
<Link
|
||||||
|
onClick={() => {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
}}
|
||||||
|
textDecor="underline"
|
||||||
|
>
|
||||||
|
Sign in
|
||||||
|
</Link>{" "}
|
||||||
|
to view or create new experiments!
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Center>
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AppShell title="Experiments">
|
||||||
|
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||||
|
<HStack minH={8} align="center">
|
||||||
<Breadcrumb flex={1}>
|
<Breadcrumb flex={1}>
|
||||||
<BreadcrumbItem>
|
<BreadcrumbItem>
|
||||||
<Flex alignItems="center">
|
<Flex alignItems="center">
|
||||||
@@ -27,9 +55,9 @@ export default function ExperimentsPage() {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
<NewExperimentButton mr={4} borderRadius={8} />
|
|
||||||
</HStack>
|
</HStack>
|
||||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||||
|
<NewExperimentCard />
|
||||||
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</VStack>
|
</VStack>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { type GetServerSideProps } from "next";
|
import { type GetServerSideProps } from "next";
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/require-await
|
// eslint-disable-next-line @typescript-eslint/require-await
|
||||||
export const getServerSideProps: GetServerSideProps = async (context) => {
|
export const getServerSideProps: GetServerSideProps = async () => {
|
||||||
return {
|
return {
|
||||||
redirect: {
|
redirect: {
|
||||||
destination: "/experiments",
|
destination: "/experiments",
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { openai } from "../utils/openai";
|
import { openai } from "../utils/openai";
|
||||||
import { pick } from "lodash";
|
import { pick } from "lodash-es";
|
||||||
|
|
||||||
function promptHasVariable(prompt: string, variableName: string) {
|
|
||||||
return prompt.includes(`{{${variableName}}}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
type AxiosError = {
|
type AxiosError = {
|
||||||
response?: {
|
response?: {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
|
|||||||
import { createTRPCRouter } from "~/server/api/trpc";
|
import { createTRPCRouter } from "~/server/api/trpc";
|
||||||
import { experimentsRouter } from "./routers/experiments.router";
|
import { experimentsRouter } from "./routers/experiments.router";
|
||||||
import { scenariosRouter } from "./routers/scenarios.router";
|
import { scenariosRouter } from "./routers/scenarios.router";
|
||||||
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
|
|||||||
promptVariants: promptVariantsRouter,
|
promptVariants: promptVariantsRouter,
|
||||||
experiments: experimentsRouter,
|
experiments: experimentsRouter,
|
||||||
scenarios: scenariosRouter,
|
scenarios: scenariosRouter,
|
||||||
outputs: modelOutputsRouter,
|
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||||
templateVars: templateVarsRouter,
|
templateVars: templateVarsRouter,
|
||||||
evaluations: evaluationsRouter,
|
evaluations: evaluationsRouter,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
import { EvaluationMatchType } from "@prisma/client";
|
import { EvalType } from "@prisma/client";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { reevaluateEvaluation } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const evaluationsRouter = createTRPCRouter({
|
export const evaluationsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
|
.input(z.object({ experimentId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.evaluation.findMany({
|
return await prisma.evaluation.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -14,53 +19,76 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
name: z.string(),
|
label: z.string(),
|
||||||
matchString: z.string(),
|
value: z.string(),
|
||||||
matchType: z.nativeEnum(EvaluationMatchType),
|
evalType: z.nativeEnum(EvalType),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const evaluation = await prisma.evaluation.create({
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
|
await prisma.evaluation.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
name: input.name,
|
label: input.label,
|
||||||
matchString: input.matchString,
|
value: input.value,
|
||||||
matchType: input.matchType,
|
evalType: input.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(evaluation);
|
|
||||||
|
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
||||||
|
// to kick off a background job or something instead
|
||||||
|
await runAllEvals(input.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
updates: z.object({
|
updates: z.object({
|
||||||
name: z.string().optional(),
|
label: z.string().optional(),
|
||||||
matchString: z.string().optional(),
|
value: z.string().optional(),
|
||||||
matchType: z.nativeEnum(EvaluationMatchType).optional(),
|
evalType: z.nativeEnum(EvalType).optional(),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
await prisma.evaluation.update({
|
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
|
const evaluation = await prisma.evaluation.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: {
|
data: {
|
||||||
name: input.updates.name,
|
label: input.updates.label,
|
||||||
matchString: input.updates.matchString,
|
value: input.updates.value,
|
||||||
matchType: input.updates.matchType,
|
evalType: input.updates.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(
|
|
||||||
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
|
await prisma.outputEvaluation.deleteMany({
|
||||||
);
|
where: {
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// Re-run all evals. Other eval results will already be cached, so this
|
||||||
|
// should only re-run the updated one.
|
||||||
|
await runAllEvals(evaluation.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
await prisma.evaluation.delete({
|
await prisma.evaluation.delete({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,13 +1,32 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import {
|
||||||
|
canModifyExperiment,
|
||||||
|
requireCanModifyExperiment,
|
||||||
|
requireCanViewExperiment,
|
||||||
|
requireNothing,
|
||||||
|
} from "~/utils/accessControl";
|
||||||
|
import userOrg from "~/server/utils/userOrg";
|
||||||
|
import generateTypes from "~/modelProviders/generateTypes";
|
||||||
|
|
||||||
export const experimentsRouter = createTRPCRouter({
|
export const experimentsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.query(async () => {
|
list: protectedProcedure.query(async ({ ctx }) => {
|
||||||
|
// Anyone can list experiments
|
||||||
|
requireNothing(ctx);
|
||||||
|
|
||||||
const experiments = await prisma.experiment.findMany({
|
const experiments = await prisma.experiment.findMany({
|
||||||
|
where: {
|
||||||
|
organization: {
|
||||||
|
OrganizationUser: {
|
||||||
|
some: { userId: ctx.session.user.id },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
orderBy: {
|
orderBy: {
|
||||||
sortIndex: "asc",
|
sortIndex: "desc",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -39,15 +58,29 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
return experimentsWithCounts;
|
return experimentsWithCounts;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => {
|
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||||
return await prisma.experiment.findFirst({
|
await requireCanViewExperiment(input.id, ctx);
|
||||||
where: {
|
const experiment = await prisma.experiment.findFirstOrThrow({
|
||||||
id: input.id,
|
where: { id: input.id },
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const canModify = ctx.session?.user.id
|
||||||
|
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
|
||||||
|
: false;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...experiment,
|
||||||
|
access: {
|
||||||
|
canView: true,
|
||||||
|
canModify,
|
||||||
|
},
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure.input(z.object({})).mutation(async () => {
|
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||||
|
// Anyone can create an experiment
|
||||||
|
requireNothing(ctx);
|
||||||
|
|
||||||
const maxSortIndex =
|
const maxSortIndex =
|
||||||
(
|
(
|
||||||
await prisma.experiment.aggregate({
|
await prisma.experiment.aggregate({
|
||||||
@@ -61,36 +94,85 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
data: {
|
data: {
|
||||||
sortIndex: maxSortIndex + 1,
|
sortIndex: maxSortIndex + 1,
|
||||||
label: `Experiment ${maxSortIndex + 1}`,
|
label: `Experiment ${maxSortIndex + 1}`,
|
||||||
|
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||||
prisma.promptVariant.create({
|
prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
constructFn: dedent`prompt = {
|
// The interpolated $ is necessary until dedent incorporates
|
||||||
|
// https://github.com/dmnd/dedent/pull/46
|
||||||
|
constructFn: dedent`
|
||||||
|
/**
|
||||||
|
* Use Javascript to define an OpenAI chat completion
|
||||||
|
* (https://platform.openai.com/docs/api-reference/chat/create).
|
||||||
|
*
|
||||||
|
* You have access to the current scenario in the \`scenario\`
|
||||||
|
* variable.
|
||||||
|
*/
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
messages: [
|
||||||
}`,
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});`,
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
modelProvider: "openai/ChatCompletion",
|
||||||
|
constructFnVersion: 2,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.templateVariable.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
label: "language",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
prisma.testScenario.create({
|
prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
variableValues: {},
|
variableValues: {
|
||||||
|
language: "English",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
variableValues: {
|
||||||
|
language: "Spanish",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
variableValues: {
|
||||||
|
language: "German",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
await generateNewCell(variant.id, scenario1.id);
|
||||||
|
await generateNewCell(variant.id, scenario2.id);
|
||||||
|
await generateNewCell(variant.id, scenario3.id);
|
||||||
|
|
||||||
return exp;
|
return exp;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: protectedProcedure
|
||||||
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.id, ctx);
|
||||||
return await prisma.experiment.update({
|
return await prisma.experiment.update({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -101,11 +183,21 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.id, ctx);
|
||||||
|
|
||||||
await prisma.experiment.delete({
|
await prisma.experiment.delete({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
// Keeping these on `experiment` for now because we might want to limit the
|
||||||
|
// providers based on your account/experiment
|
||||||
|
promptTypes: publicProcedure.query(async () => {
|
||||||
|
return await generateTypes();
|
||||||
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,97 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
|
||||||
import { prisma } from "~/server/db";
|
|
||||||
import crypto from "crypto";
|
|
||||||
import type { Prisma } from "@prisma/client";
|
|
||||||
import { reevaluateVariant } from "~/server/utils/evaluations";
|
|
||||||
import { getCompletion } from "~/server/utils/getCompletion";
|
|
||||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
|
||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
|
||||||
get: publicProcedure
|
|
||||||
.input(
|
|
||||||
z.object({
|
|
||||||
scenarioId: z.string(),
|
|
||||||
variantId: z.string(),
|
|
||||||
channel: z.string().optional(),
|
|
||||||
forceRefetch: z.boolean().optional(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.mutation(async ({ input }) => {
|
|
||||||
const existing = await prisma.modelOutput.findUnique({
|
|
||||||
where: {
|
|
||||||
promptVariantId_testScenarioId: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (existing && !input.forceRefetch) return existing;
|
|
||||||
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
|
||||||
where: {
|
|
||||||
id: input.variantId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const scenario = await prisma.testScenario.findUnique({
|
|
||||||
where: {
|
|
||||||
id: input.scenarioId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!variant || !scenario) return null;
|
|
||||||
|
|
||||||
const prompt = await constructPrompt(variant, scenario);
|
|
||||||
|
|
||||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
|
||||||
|
|
||||||
// TODO: we should probably only use this if temperature=0
|
|
||||||
const existingResponse = await prisma.modelOutput.findFirst({
|
|
||||||
where: { inputHash, errorMessage: null },
|
|
||||||
});
|
|
||||||
|
|
||||||
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
|
|
||||||
|
|
||||||
if (existingResponse) {
|
|
||||||
modelResponse = {
|
|
||||||
output: existingResponse.output as Prisma.InputJsonValue,
|
|
||||||
statusCode: existingResponse.statusCode,
|
|
||||||
errorMessage: existingResponse.errorMessage,
|
|
||||||
timeToComplete: existingResponse.timeToComplete,
|
|
||||||
promptTokens: existingResponse.promptTokens ?? undefined,
|
|
||||||
completionTokens: existingResponse.completionTokens ?? undefined,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
modelResponse = await getCompletion(prompt, input.channel);
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.upsert({
|
|
||||||
where: {
|
|
||||||
promptVariantId_testScenarioId: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
create: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
inputHash,
|
|
||||||
...modelResponse,
|
|
||||||
},
|
|
||||||
update: {
|
|
||||||
...modelResponse,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
await reevaluateVariant(input.variantId);
|
|
||||||
|
|
||||||
return modelOutput;
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
@@ -1,11 +1,22 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||||
|
import { ZodModel } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
|
.input(z.object({ experimentId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.promptVariant.findMany({
|
return await prisma.promptVariant.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -15,7 +26,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
|
stats: publicProcedure
|
||||||
|
.input(z.object({ variantId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.variantId,
|
id: input.variantId,
|
||||||
@@ -26,11 +39,47 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const evalResults = await prisma.evaluationResult.findMany({
|
await requireCanViewExperiment(variant.experimentId, ctx);
|
||||||
where: {
|
|
||||||
promptVariantId: input.variantId,
|
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||||
|
by: ["evaluationId"],
|
||||||
|
_sum: {
|
||||||
|
result: true,
|
||||||
},
|
},
|
||||||
include: { evaluation: true },
|
_count: {
|
||||||
|
id: true,
|
||||||
|
},
|
||||||
|
where: {
|
||||||
|
modelOutput: {
|
||||||
|
scenarioVariantCell: {
|
||||||
|
promptVariant: {
|
||||||
|
id: input.variantId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const evals = await prisma.evaluation.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: variant.experimentId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const evalResults = evals.map((evalItem) => {
|
||||||
|
const evalResult = outputEvals.find(
|
||||||
|
(outputEval) => outputEval.evaluationId === evalItem.id,
|
||||||
|
);
|
||||||
|
return {
|
||||||
|
id: evalItem.id,
|
||||||
|
label: evalItem.label,
|
||||||
|
passCount: evalResult?._sum?.result ?? 0,
|
||||||
|
totalCount: evalResult?._count?.id ?? 1,
|
||||||
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
const scenarioCount = await prisma.testScenario.count({
|
const scenarioCount = await prisma.testScenario.count({
|
||||||
@@ -39,46 +88,77 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
const outputCount = await prisma.modelOutput.count({
|
const outputCount = await prisma.scenarioVariantCell.count({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenario: { visible: true },
|
testScenario: { visible: true },
|
||||||
|
modelOutput: {
|
||||||
|
is: {},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const overallTokens = await prisma.modelOutput.aggregate({
|
const overallTokens = await prisma.modelOutput.aggregate({
|
||||||
where: {
|
where: {
|
||||||
|
scenarioVariantCell: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenario: { visible: true },
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
_sum: {
|
_sum: {
|
||||||
|
cost: true,
|
||||||
promptTokens: true,
|
promptTokens: true,
|
||||||
completionTokens: true,
|
completionTokens: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: fix this
|
|
||||||
const model = "gpt-3.5-turbo-0613";
|
|
||||||
// const model = getModelName(variant.config);
|
|
||||||
|
|
||||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||||
const overallPromptCost = calculateTokenCost(model, promptTokens);
|
|
||||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||||
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
|
|
||||||
|
|
||||||
const overallCost = overallPromptCost + overallCompletionCost;
|
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
||||||
|
where: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenario: { visible: true },
|
||||||
|
// Check if is PENDING or IN_PROGRESS
|
||||||
|
retrievalStatus: {
|
||||||
|
in: ["PENDING", "IN_PROGRESS"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
|
return {
|
||||||
|
evalResults,
|
||||||
|
promptTokens,
|
||||||
|
completionTokens,
|
||||||
|
overallCost: overallTokens._sum?.cost ?? 0,
|
||||||
|
scenarioCount,
|
||||||
|
outputCount,
|
||||||
|
awaitingRetrievals,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
|
variantId: z.string().optional(),
|
||||||
|
newModel: ZodModel.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const lastVariant = await prisma.promptVariant.findFirst({
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
|
let originalVariant: PromptVariant | null = null;
|
||||||
|
if (input.variantId) {
|
||||||
|
originalVariant = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: input.variantId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
originalVariant = await prisma.promptVariant.findFirst({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
visible: true,
|
visible: true,
|
||||||
@@ -87,6 +167,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
sortIndex: "desc",
|
sortIndex: "desc",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const largestSortIndex =
|
const largestSortIndex =
|
||||||
(
|
(
|
||||||
@@ -100,12 +181,22 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
})
|
})
|
||||||
)._max?.sortIndex ?? 0;
|
)._max?.sortIndex ?? 0;
|
||||||
|
|
||||||
|
const newVariantLabel =
|
||||||
|
input.variantId && originalVariant
|
||||||
|
? `${originalVariant?.label} Copy`
|
||||||
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
|
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
label: newVariantLabel,
|
||||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||||
constructFn: lastVariant?.constructFn ?? "",
|
constructFn: newConstructFn,
|
||||||
|
constructFnVersion: 2,
|
||||||
|
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||||
|
modelProvider: originalVariant?.modelProvider ?? "openai/ChatCompletion",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -114,10 +205,26 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
if (originalVariant) {
|
||||||
|
// Insert new variant to right of original variant
|
||||||
|
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const scenario of scenarios) {
|
||||||
|
await generateNewCell(newVariant.id, scenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
return newVariant;
|
return newVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
@@ -126,7 +233,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const existing = await prisma.promptVariant.findUnique({
|
const existing = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -137,6 +244,8 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
const updatePromptVariantAction = prisma.promptVariant.update({
|
const updatePromptVariantAction = prisma.promptVariant.update({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -152,13 +261,18 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return updatedPromptVariant;
|
return updatedPromptVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
hide: publicProcedure
|
hide: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
const updatedPromptVariant = await prisma.promptVariant.update({
|
const updatedPromptVariant = await prisma.promptVariant.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||||
@@ -167,24 +281,65 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return updatedPromptVariant;
|
return updatedPromptVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
replaceVariant: publicProcedure
|
getModifiedPromptFn: protectedProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
id: z.string(),
|
||||||
|
instructions: z.string().optional(),
|
||||||
|
newModel: ZodModel.optional(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
|
const constructedPrompt = await parseConstructFn(existing.constructFn);
|
||||||
|
|
||||||
|
if ("error" in constructedPrompt) {
|
||||||
|
return userError(constructedPrompt.error);
|
||||||
|
}
|
||||||
|
|
||||||
|
const promptConstructionFn = await deriveNewConstructFn(
|
||||||
|
existing,
|
||||||
|
input.newModel,
|
||||||
|
input.instructions,
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: Validate promptConstructionFn
|
||||||
|
// TODO: Record in some sort of history
|
||||||
|
|
||||||
|
return promptConstructionFn;
|
||||||
|
}),
|
||||||
|
|
||||||
|
replaceVariant: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
constructFn: z.string(),
|
constructFn: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const existing = await prisma.promptVariant.findUnique({
|
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
if (!existing) {
|
if (!existing) {
|
||||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const parsedPrompt = await parseConstructFn(input.constructFn);
|
||||||
|
|
||||||
|
if ("error" in parsedPrompt) {
|
||||||
|
return userError(parsedPrompt.error);
|
||||||
|
}
|
||||||
|
|
||||||
// Create a duplicate with only the config changed
|
// Create a duplicate with only the config changed
|
||||||
const newVariant = await prisma.promptVariant.create({
|
const newVariant = await prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -193,11 +348,14 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
sortIndex: existing.sortIndex,
|
sortIndex: existing.sortIndex,
|
||||||
uiId: existing.uiId,
|
uiId: existing.uiId,
|
||||||
constructFn: input.constructFn,
|
constructFn: input.constructFn,
|
||||||
|
constructFnVersion: 2,
|
||||||
|
modelProvider: parsedPrompt.modelProvider,
|
||||||
|
model: parsedPrompt.model,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Hide anything with the same uiId besides the new one
|
// Hide anything with the same uiId besides the new one
|
||||||
const hideOldVariantsAction = prisma.promptVariant.updateMany({
|
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||||
where: {
|
where: {
|
||||||
uiId: existing.uiId,
|
uiId: existing.uiId,
|
||||||
id: {
|
id: {
|
||||||
@@ -209,80 +367,35 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||||
hideOldVariantsAction,
|
|
||||||
recordExperimentUpdated(existing.experimentId),
|
|
||||||
]);
|
|
||||||
|
|
||||||
return newVariant;
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: newVariant.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const scenario of scenarios) {
|
||||||
|
await generateNewCell(newVariant.id, scenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { status: "ok" } as const;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
reorder: publicProcedure
|
reorder: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
draggedId: z.string(),
|
draggedId: z.string(),
|
||||||
droppedId: z.string(),
|
droppedId: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const dragged = await prisma.promptVariant.findUnique({
|
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||||
where: {
|
where: { id: input.draggedId },
|
||||||
id: input.draggedId,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
const dropped = await prisma.promptVariant.findUnique({
|
await reorderPromptVariants(input.draggedId, input.droppedId);
|
||||||
where: {
|
|
||||||
id: input.droppedId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
|
|
||||||
throw new Error(
|
|
||||||
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const visibleItems = await prisma.promptVariant.findMany({
|
|
||||||
where: {
|
|
||||||
experimentId: dragged.experimentId,
|
|
||||||
visible: true,
|
|
||||||
},
|
|
||||||
orderBy: {
|
|
||||||
sortIndex: "asc",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remove the dragged item from its current position
|
|
||||||
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
|
|
||||||
|
|
||||||
// Find the index of the dragged item and the dropped item
|
|
||||||
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
|
|
||||||
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
|
|
||||||
|
|
||||||
// Determine the new index for the dragged item
|
|
||||||
let newIndex;
|
|
||||||
if (dragIndex < dropIndex) {
|
|
||||||
newIndex = dropIndex + 1; // Insert after the dropped item
|
|
||||||
} else {
|
|
||||||
newIndex = dropIndex; // Insert before the dropped item
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert the dragged item at the new position
|
|
||||||
orderedItems.splice(newIndex, 0, dragged);
|
|
||||||
|
|
||||||
// Now, we need to update all the items with their new sortIndex
|
|
||||||
await prisma.$transaction(
|
|
||||||
orderedItems.map((item, index) => {
|
|
||||||
return prisma.promptVariant.update({
|
|
||||||
where: {
|
|
||||||
id: item.id,
|
|
||||||
},
|
|
||||||
data: {
|
|
||||||
sortIndex: index,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
90
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
90
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
|
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||||
|
get: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
scenarioId: z.string(),
|
||||||
|
variantId: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||||
|
where: { id: input.scenarioId },
|
||||||
|
});
|
||||||
|
await requireCanViewExperiment(experimentId, ctx);
|
||||||
|
|
||||||
|
return await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: {
|
||||||
|
include: {
|
||||||
|
outputEvaluation: {
|
||||||
|
include: {
|
||||||
|
evaluation: {
|
||||||
|
select: { label: true },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
forceRefetch: protectedProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
scenarioId: z.string(),
|
||||||
|
variantId: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||||
|
where: { id: input.scenarioId },
|
||||||
|
});
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!cell) {
|
||||||
|
await generateNewCell(input.variantId, input.scenarioId);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cell.modelOutput) {
|
||||||
|
// TODO: Maybe keep these around to show previous generations?
|
||||||
|
await prisma.modelOutput.delete({
|
||||||
|
where: { id: cell.modelOutput.id },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cell.id },
|
||||||
|
data: { retrievalStatus: "PENDING" },
|
||||||
|
});
|
||||||
|
|
||||||
|
await queueLLMRetrievalTask(cell.id);
|
||||||
|
return true;
|
||||||
|
}),
|
||||||
|
});
|
||||||
@@ -1,79 +1,122 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { autogenerateScenarioValues } from "../autogen";
|
import { autogenerateScenarioValues } from "../autogen";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
|
const PAGE_SIZE = 10;
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
return await prisma.testScenario.findMany({
|
.input(z.object({ experimentId: z.string(), page: z.number() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
|
const { experimentId, page } = input;
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId,
|
||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
orderBy: {
|
orderBy: { sortIndex: "asc" },
|
||||||
sortIndex: "asc",
|
skip: (page - 1) * PAGE_SIZE,
|
||||||
|
take: PAGE_SIZE,
|
||||||
|
});
|
||||||
|
|
||||||
|
const count = await prisma.testScenario.count({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
visible: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
scenarios,
|
||||||
|
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||||
|
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||||
|
count,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
autogenerate: z.boolean().optional(),
|
autogenerate: z.boolean().optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const maxSortIndex =
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
(
|
|
||||||
await prisma.testScenario.aggregate({
|
await prisma.testScenario.updateMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
},
|
},
|
||||||
_max: {
|
data: {
|
||||||
sortIndex: true,
|
sortIndex: {
|
||||||
|
increment: 1,
|
||||||
},
|
},
|
||||||
})
|
},
|
||||||
)._max.sortIndex ?? 0;
|
});
|
||||||
|
|
||||||
const createNewScenarioAction = prisma.testScenario.create({
|
const createNewScenarioAction = prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
sortIndex: maxSortIndex + 1,
|
sortIndex: 0,
|
||||||
variableValues: input.autogenerate
|
variableValues: input.autogenerate
|
||||||
? await autogenerateScenarioValues(input.experimentId)
|
? await autogenerateScenarioValues(input.experimentId)
|
||||||
: {},
|
: {},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
const [scenario] = await prisma.$transaction([
|
||||||
createNewScenarioAction,
|
createNewScenarioAction,
|
||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const promptVariants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const variant of promptVariants) {
|
||||||
|
await generateNewCell(variant.id, scenario.id);
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||||
|
const experimentId = (
|
||||||
|
await prisma.testScenario.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
})
|
||||||
|
).experimentId;
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
const hiddenScenario = await prisma.testScenario.update({
|
const hiddenScenario = await prisma.testScenario.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||||
});
|
});
|
||||||
|
|
||||||
// Reevaluate all evaluations now that this scenario is hidden
|
// Reevaluate all evaluations now that this scenario is hidden
|
||||||
await reevaluateAll(hiddenScenario.experimentId);
|
await runAllEvals(hiddenScenario.experimentId);
|
||||||
|
|
||||||
return hiddenScenario;
|
return hiddenScenario;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
reorder: publicProcedure
|
reorder: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
draggedId: z.string(),
|
draggedId: z.string(),
|
||||||
droppedId: z.string(),
|
droppedId: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const dragged = await prisma.testScenario.findUnique({
|
const dragged = await prisma.testScenario.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.draggedId,
|
id: input.draggedId,
|
||||||
@@ -92,6 +135,8 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(dragged.experimentId, ctx);
|
||||||
|
|
||||||
const visibleItems = await prisma.testScenario.findMany({
|
const visibleItems = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: dragged.experimentId,
|
experimentId: dragged.experimentId,
|
||||||
@@ -135,14 +180,14 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
);
|
);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
replaceWithValues: publicProcedure
|
replaceWithValues: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
values: z.record(z.string()),
|
values: z.record(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
const existing = await prisma.testScenario.findUnique({
|
const existing = await prisma.testScenario.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: input.id,
|
id: input.id,
|
||||||
@@ -153,6 +198,8 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
throw new Error(`Scenario with id ${input.id} does not exist`);
|
throw new Error(`Scenario with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||||
|
|
||||||
const newScenario = await prisma.testScenario.create({
|
const newScenario = await prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: existing.experimentId,
|
experimentId: existing.experimentId,
|
||||||
@@ -175,6 +222,17 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const promptVariants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: newScenario.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const variant of promptVariants) {
|
||||||
|
await generateNewCell(variant.id, newScenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
return newScenario;
|
return newScenario;
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const templateVarsRouter = createTRPCRouter({
|
export const templateVarsRouter = createTRPCRouter({
|
||||||
create: publicProcedure
|
create: protectedProcedure
|
||||||
.input(z.object({ experimentId: z.string(), label: z.string() }))
|
.input(z.object({ experimentId: z.string(), label: z.string() }))
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
await prisma.templateVariable.create({
|
await prisma.templateVariable.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -14,11 +17,22 @@ export const templateVarsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
|
||||||
|
await requireCanModifyExperiment(experimentId, ctx);
|
||||||
|
|
||||||
await prisma.templateVariable.delete({ where: { id: input.id } });
|
await prisma.templateVariable.delete({ where: { id: input.id } });
|
||||||
}),
|
}),
|
||||||
|
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure
|
||||||
|
.input(z.object({ experimentId: z.string() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
return await prisma.templateVariable.findMany({
|
return await prisma.templateVariable.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ type CreateContextOptions = {
|
|||||||
session: Session | null;
|
session: Session | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
||||||
|
const noOp = () => {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
|
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
|
||||||
* it from here.
|
* it from here.
|
||||||
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
|
|||||||
return {
|
return {
|
||||||
session: opts.session,
|
session: opts.session,
|
||||||
prisma,
|
prisma,
|
||||||
|
markAccessControlRun: noOp,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
|
|||||||
* errors on the backend.
|
* errors on the backend.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
|
||||||
|
|
||||||
const t = initTRPC.context<typeof createTRPCContext>().create({
|
const t = initTRPC.context<typeof createTRPCContext>().create({
|
||||||
transformer: superjson,
|
transformer: superjson,
|
||||||
errorFormatter({ shape, error }) {
|
errorFormatter({ shape, error }) {
|
||||||
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
|
|||||||
export const publicProcedure = t.procedure;
|
export const publicProcedure = t.procedure;
|
||||||
|
|
||||||
/** Reusable middleware that enforces users are logged in before running the procedure. */
|
/** Reusable middleware that enforces users are logged in before running the procedure. */
|
||||||
const enforceUserIsAuthed = t.middleware(({ ctx, next }) => {
|
const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
|
||||||
if (!ctx.session || !ctx.session.user) {
|
if (!ctx.session || !ctx.session.user) {
|
||||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||||
}
|
}
|
||||||
return next({
|
|
||||||
|
let accessControlRun = false;
|
||||||
|
const resp = await next({
|
||||||
ctx: {
|
ctx: {
|
||||||
// infers the `session` as non-nullable
|
// infers the `session` as non-nullable
|
||||||
session: { ...ctx.session, user: ctx.session.user },
|
session: { ...ctx.session, user: ctx.session.user },
|
||||||
|
markAccessControlRun: () => {
|
||||||
|
accessControlRun = true;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
if (!accessControlRun)
|
||||||
|
throw new TRPCError({
|
||||||
|
code: "INTERNAL_SERVER_ERROR",
|
||||||
|
message:
|
||||||
|
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
|
||||||
|
});
|
||||||
|
|
||||||
|
return resp;
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
|
|||||||
import { type GetServerSidePropsContext } from "next";
|
import { type GetServerSidePropsContext } from "next";
|
||||||
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
|
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import GitHubProvider from "next-auth/providers/github";
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
|
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
|
||||||
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
|
|||||||
},
|
},
|
||||||
adapter: PrismaAdapter(prisma),
|
adapter: PrismaAdapter(prisma),
|
||||||
providers: [
|
providers: [
|
||||||
// DiscordProvider({
|
GitHubProvider({
|
||||||
// clientId: env.DISCORD_CLIENT_ID,
|
clientId: env.GITHUB_CLIENT_ID,
|
||||||
// clientSecret: env.DISCORD_CLIENT_SECRET,
|
clientSecret: env.GITHUB_CLIENT_SECRET,
|
||||||
// }),
|
}),
|
||||||
/**
|
|
||||||
* ...add more providers here.
|
|
||||||
*
|
|
||||||
* Most other providers require a bit more work than the Discord provider. For example, the
|
|
||||||
* GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
|
|
||||||
* model. Refer to the NextAuth.js docs for the provider you want to use. Example:
|
|
||||||
*
|
|
||||||
* @see https://next-auth.js.org/providers/github
|
|
||||||
*/
|
|
||||||
],
|
],
|
||||||
|
theme: {
|
||||||
|
logo: "/logo.svg",
|
||||||
|
brandColor: "#ff5733",
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
|
|||||||
export const prisma =
|
export const prisma =
|
||||||
globalForPrisma.prisma ??
|
globalForPrisma.prisma ??
|
||||||
new PrismaClient({
|
new PrismaClient({
|
||||||
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"],
|
log:
|
||||||
|
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
|
||||||
|
? ["query", "error", "warn"]
|
||||||
|
: ["error"],
|
||||||
});
|
});
|
||||||
|
|
||||||
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
||||||
|
|||||||
45
src/server/scripts/migrateConstructFns.test.ts
Normal file
45
src/server/scripts/migrateConstructFns.test.ts
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import "dotenv/config";
|
||||||
|
import dedent from "dedent";
|
||||||
|
import { expect, test } from "vitest";
|
||||||
|
import { migrate1to2 } from "./migrateConstructFns";
|
||||||
|
|
||||||
|
test("migrate1to2", () => {
|
||||||
|
const constructFn = dedent`
|
||||||
|
// Test comment
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "What is the capital of China?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
const migrated = migrate1to2(constructFn);
|
||||||
|
expect(migrated).toBe(dedent`
|
||||||
|
// Test comment
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "What is the capital of China?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
`);
|
||||||
|
|
||||||
|
// console.log(
|
||||||
|
// migrateConstructFn(dedent`definePrompt(
|
||||||
|
// "openai/ChatCompletion",
|
||||||
|
// {
|
||||||
|
// model: 'gpt-3.5-turbo-0613',
|
||||||
|
// messages: []
|
||||||
|
// }
|
||||||
|
// )`),
|
||||||
|
// );
|
||||||
|
});
|
||||||
58
src/server/scripts/migrateConstructFns.ts
Normal file
58
src/server/scripts/migrateConstructFns.ts
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import * as recast from "recast";
|
||||||
|
import { type ASTNode } from "ast-types";
|
||||||
|
import { prisma } from "../db";
|
||||||
|
import { fileURLToPath } from "url";
|
||||||
|
const { builders: b } = recast.types;
|
||||||
|
|
||||||
|
export const migrate1to2 = (fnBody: string): string => {
|
||||||
|
const ast: ASTNode = recast.parse(fnBody);
|
||||||
|
|
||||||
|
recast.visit(ast, {
|
||||||
|
visitAssignmentExpression(path) {
|
||||||
|
const node = path.node;
|
||||||
|
if ("name" in node.left && node.left.name === "prompt") {
|
||||||
|
const functionCall = b.callExpression(b.identifier("definePrompt"), [
|
||||||
|
b.literal("openai/ChatCompletion"),
|
||||||
|
node.right,
|
||||||
|
]);
|
||||||
|
path.replace(functionCall);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return recast.print(ast).code;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default async function migrateConstructFns() {
|
||||||
|
const v1Prompts = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
constructFnVersion: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
console.log(`Migrating ${v1Prompts.length} prompts 1->2`);
|
||||||
|
await Promise.all(
|
||||||
|
v1Prompts.map(async (variant) => {
|
||||||
|
try {
|
||||||
|
await prisma.promptVariant.update({
|
||||||
|
where: {
|
||||||
|
id: variant.id,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
constructFn: migrate1to2(variant.constructFn),
|
||||||
|
constructFnVersion: 2,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error migrating constructFn for variant", variant.id, e);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we're running this file directly, run the migration
|
||||||
|
if (process.argv.at(-1) === fileURLToPath(import.meta.url)) {
|
||||||
|
console.log("Running migration");
|
||||||
|
await migrateConstructFns();
|
||||||
|
console.log("Done");
|
||||||
|
}
|
||||||
19
src/server/scripts/openai-test.ts
Normal file
19
src/server/scripts/openai-test.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import "dotenv/config";
|
||||||
|
import { openai } from "../utils/openai";
|
||||||
|
|
||||||
|
const resp = await openai.chat.completions.create({
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "count to 20",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const part of resp) {
|
||||||
|
console.log("part", part);
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("final resp", resp);
|
||||||
26
src/server/scripts/replicate-test.ts
Normal file
26
src/server/scripts/replicate-test.ts
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import "dotenv/config";
|
||||||
|
import Replicate from "replicate";
|
||||||
|
|
||||||
|
const replicate = new Replicate({
|
||||||
|
auth: process.env.REPLICATE_API_TOKEN || "",
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log("going to run");
|
||||||
|
const prediction = await replicate.predictions.create({
|
||||||
|
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||||
|
input: {
|
||||||
|
prompt: "...",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log("waiting");
|
||||||
|
setInterval(() => {
|
||||||
|
replicate.predictions.get(prediction.id).then((prediction) => {
|
||||||
|
console.log(prediction);
|
||||||
|
});
|
||||||
|
}, 500);
|
||||||
|
// const output = await replicate.wait(prediction, {});
|
||||||
|
|
||||||
|
// console.log(output);
|
||||||
31
src/server/tasks/defineTask.ts
Normal file
31
src/server/tasks/defineTask.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Import necessary dependencies
|
||||||
|
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
|
// Define the defineTask function
|
||||||
|
function defineTask<TPayload>(
|
||||||
|
taskIdentifier: string,
|
||||||
|
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||||
|
) {
|
||||||
|
const enqueue = async (payload: TPayload) => {
|
||||||
|
console.log("Enqueuing task", taskIdentifier, payload);
|
||||||
|
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||||
|
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
|
||||||
|
return taskHandler(payload, helpers);
|
||||||
|
};
|
||||||
|
|
||||||
|
const task = {
|
||||||
|
identifier: taskIdentifier,
|
||||||
|
handler: handler as Task,
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
enqueue,
|
||||||
|
task,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export default defineTask;
|
||||||
165
src/server/tasks/queryLLM.task.ts
Normal file
165
src/server/tasks/queryLLM.task.ts
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
import { sleep } from "../utils/sleep";
|
||||||
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
import parseConstructFn from "../utils/parseConstructFn";
|
||||||
|
import hashPrompt from "../utils/hashPrompt";
|
||||||
|
import { type JsonObject } from "type-fest";
|
||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
|
|
||||||
|
export type queryLLMJob = {
|
||||||
|
scenarioVariantCellId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const MAX_AUTO_RETRIES = 10;
|
||||||
|
const MIN_DELAY = 500; // milliseconds
|
||||||
|
const MAX_DELAY = 15000; // milliseconds
|
||||||
|
|
||||||
|
function calculateDelay(numPreviousTries: number): number {
|
||||||
|
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||||
|
const jitter = Math.random() * baseDelay;
|
||||||
|
return baseDelay + jitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||||
|
const { scenarioVariantCellId } = task;
|
||||||
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
include: { modelOutput: true },
|
||||||
|
});
|
||||||
|
if (!cell) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 404,
|
||||||
|
errorMessage: "Cell not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cell is not pending, then some other job is already processing it
|
||||||
|
if (cell.retrievalStatus !== "PENDING") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "IN_PROGRESS",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: { id: cell.promptVariantId },
|
||||||
|
});
|
||||||
|
if (!variant) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 404,
|
||||||
|
errorMessage: "Prompt Variant not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scenario = await prisma.testScenario.findUnique({
|
||||||
|
where: { id: cell.testScenarioId },
|
||||||
|
});
|
||||||
|
if (!scenario) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 404,
|
||||||
|
errorMessage: "Scenario not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = await parseConstructFn(variant.constructFn, scenario.variableValues as JsonObject);
|
||||||
|
|
||||||
|
if ("error" in prompt) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: 400,
|
||||||
|
errorMessage: prompt.error,
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
|
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||||
|
|
||||||
|
if (streamingChannel) {
|
||||||
|
// Save streaming channel so that UI can connect to it
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: { streamingChannel },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
const onStream = streamingChannel
|
||||||
|
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
||||||
|
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
|
for (let i = 0; true; i++) {
|
||||||
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
|
if (response.type === "success") {
|
||||||
|
const inputHash = hashPrompt(prompt);
|
||||||
|
|
||||||
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
scenarioVariantCellId,
|
||||||
|
inputHash,
|
||||||
|
output: response.value as Prisma.InputJsonObject,
|
||||||
|
timeToComplete: response.timeToComplete,
|
||||||
|
promptTokens: response.promptTokens,
|
||||||
|
completionTokens: response.completionTokens,
|
||||||
|
cost: response.cost,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: response.statusCode,
|
||||||
|
retrievalStatus: "COMPLETE",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
||||||
|
const delay = calculateDelay(i);
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
errorMessage: response.message,
|
||||||
|
statusCode: response.statusCode,
|
||||||
|
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (shouldRetry) {
|
||||||
|
await sleep(delay);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user