Compare commits
28 Commits
space-out-
...
function-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
328cd4f5e6 | ||
|
|
26b6fa4f0c | ||
|
|
807665fdc1 | ||
|
|
d6597d2c8a | ||
|
|
566d67bf48 | ||
|
|
d4fb8b689a | ||
|
|
98b231c8bd | ||
|
|
45afb1f1f4 | ||
|
|
2bffb03766 | ||
|
|
223b990005 | ||
|
|
fa61c9c472 | ||
|
|
1309a6ec5d | ||
|
|
17a6fd31a5 | ||
|
|
e1cbeccb90 | ||
|
|
d6b97b29f7 | ||
|
|
09140f8b5f | ||
|
|
9952dd93d8 | ||
|
|
e0b457c6c5 | ||
|
|
0c37506975 | ||
|
|
2b2e0ab8ee | ||
|
|
3dbb06ec00 | ||
|
|
85d42a014b | ||
|
|
7d1ded3b18 | ||
|
|
b00f6dd04b | ||
|
|
2e395e4d39 | ||
|
|
4b06d05908 | ||
|
|
aabf355b81 | ||
|
|
cc1d1178da |
@@ -1,2 +1,2 @@
|
|||||||
src/codegen/openai.schema.json
|
*.schema.json
|
||||||
pnpm-lock.yaml
|
pnpm-lock.yaml
|
||||||
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@@ -1,6 +1,3 @@
|
|||||||
{
|
{
|
||||||
"eslint.format.enable": true,
|
"eslint.format.enable": true
|
||||||
"editor.codeActionsOnSave": {
|
|
||||||
"source.fixAll.eslint": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
|
|||||||
|
|
||||||
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
|
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
|
||||||
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
|
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
|
||||||
|
- Anthropic's [Claude 1 Instant](https://www.anthropic.com/index/introducing-claude) and [Claude 2](https://www.anthropic.com/index/claude-2)
|
||||||
|
|
||||||
## Running Locally
|
## Running Locally
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
"dev:next": "next dev",
|
"dev:next": "next dev",
|
||||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||||
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss' 'pnpm dev:worker'",
|
||||||
"postinstall": "prisma generate",
|
"postinstall": "prisma generate",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
@@ -21,6 +21,7 @@
|
|||||||
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
|
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@anthropic-ai/sdk": "^0.5.8",
|
||||||
"@apidevtools/json-schema-ref-parser": "^10.1.0",
|
"@apidevtools/json-schema-ref-parser": "^10.1.0",
|
||||||
"@babel/preset-typescript": "^7.22.5",
|
"@babel/preset-typescript": "^7.22.5",
|
||||||
"@babel/standalone": "^7.22.9",
|
"@babel/standalone": "^7.22.9",
|
||||||
@@ -59,6 +60,7 @@
|
|||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
|
"next-query-params": "^4.2.3",
|
||||||
"nextjs-routes": "^2.0.1",
|
"nextjs-routes": "^2.0.1",
|
||||||
"openai": "4.0.0-beta.2",
|
"openai": "4.0.0-beta.2",
|
||||||
"pluralize": "^8.0.0",
|
"pluralize": "^8.0.0",
|
||||||
@@ -79,6 +81,8 @@
|
|||||||
"superjson": "1.12.2",
|
"superjson": "1.12.2",
|
||||||
"tsx": "^3.12.7",
|
"tsx": "^3.12.7",
|
||||||
"type-fest": "^4.0.0",
|
"type-fest": "^4.0.0",
|
||||||
|
"use-query-params": "^2.2.1",
|
||||||
|
"uuid": "^9.0.0",
|
||||||
"vite-tsconfig-paths": "^4.2.0",
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
"zod": "^3.21.4",
|
"zod": "^3.21.4",
|
||||||
"zustand": "^4.3.9"
|
"zustand": "^4.3.9"
|
||||||
@@ -99,6 +103,7 @@
|
|||||||
"@types/react": "^18.2.6",
|
"@types/react": "^18.2.6",
|
||||||
"@types/react-dom": "^18.2.4",
|
"@types/react-dom": "^18.2.4",
|
||||||
"@types/react-syntax-highlighter": "^15.5.7",
|
"@types/react-syntax-highlighter": "^15.5.7",
|
||||||
|
"@types/uuid": "^9.0.2",
|
||||||
"@typescript-eslint/eslint-plugin": "^5.59.6",
|
"@typescript-eslint/eslint-plugin": "^5.59.6",
|
||||||
"@typescript-eslint/parser": "^5.59.6",
|
"@typescript-eslint/parser": "^5.59.6",
|
||||||
"eslint": "^8.40.0",
|
"eslint": "^8.40.0",
|
||||||
|
|||||||
75
pnpm-lock.yaml
generated
75
pnpm-lock.yaml
generated
@@ -5,6 +5,9 @@ settings:
|
|||||||
excludeLinksFromLockfile: false
|
excludeLinksFromLockfile: false
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
|
'@anthropic-ai/sdk':
|
||||||
|
specifier: ^0.5.8
|
||||||
|
version: 0.5.8
|
||||||
'@apidevtools/json-schema-ref-parser':
|
'@apidevtools/json-schema-ref-parser':
|
||||||
specifier: ^10.1.0
|
specifier: ^10.1.0
|
||||||
version: 10.1.0
|
version: 10.1.0
|
||||||
@@ -119,6 +122,9 @@ dependencies:
|
|||||||
next-auth:
|
next-auth:
|
||||||
specifier: ^4.22.1
|
specifier: ^4.22.1
|
||||||
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
next-query-params:
|
||||||
|
specifier: ^4.2.3
|
||||||
|
version: 4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1)
|
||||||
nextjs-routes:
|
nextjs-routes:
|
||||||
specifier: ^2.0.1
|
specifier: ^2.0.1
|
||||||
version: 2.0.1(next@13.4.2)
|
version: 2.0.1(next@13.4.2)
|
||||||
@@ -179,6 +185,12 @@ dependencies:
|
|||||||
type-fest:
|
type-fest:
|
||||||
specifier: ^4.0.0
|
specifier: ^4.0.0
|
||||||
version: 4.0.0
|
version: 4.0.0
|
||||||
|
use-query-params:
|
||||||
|
specifier: ^2.2.1
|
||||||
|
version: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
uuid:
|
||||||
|
specifier: ^9.0.0
|
||||||
|
version: 9.0.0
|
||||||
vite-tsconfig-paths:
|
vite-tsconfig-paths:
|
||||||
specifier: ^4.2.0
|
specifier: ^4.2.0
|
||||||
version: 4.2.0(typescript@5.0.4)
|
version: 4.2.0(typescript@5.0.4)
|
||||||
@@ -235,6 +247,9 @@ devDependencies:
|
|||||||
'@types/react-syntax-highlighter':
|
'@types/react-syntax-highlighter':
|
||||||
specifier: ^15.5.7
|
specifier: ^15.5.7
|
||||||
version: 15.5.7
|
version: 15.5.7
|
||||||
|
'@types/uuid':
|
||||||
|
specifier: ^9.0.2
|
||||||
|
version: 9.0.2
|
||||||
'@typescript-eslint/eslint-plugin':
|
'@typescript-eslint/eslint-plugin':
|
||||||
specifier: ^5.59.6
|
specifier: ^5.59.6
|
||||||
version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4)
|
version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4)
|
||||||
@@ -286,6 +301,22 @@ packages:
|
|||||||
'@jridgewell/gen-mapping': 0.3.3
|
'@jridgewell/gen-mapping': 0.3.3
|
||||||
'@jridgewell/trace-mapping': 0.3.18
|
'@jridgewell/trace-mapping': 0.3.18
|
||||||
|
|
||||||
|
/@anthropic-ai/sdk@0.5.8:
|
||||||
|
resolution: {integrity: sha512-iHenjcE2Q/az6VZiP1DueOSvKNRmxsly6Rx2yjJBoy7OBYVFGVjEdgs2mPQHtTX0ibKAR7tPq6F6MQbKDPWcKg==}
|
||||||
|
dependencies:
|
||||||
|
'@types/node': 18.16.0
|
||||||
|
'@types/node-fetch': 2.6.4
|
||||||
|
abort-controller: 3.0.0
|
||||||
|
agentkeepalive: 4.3.0
|
||||||
|
digest-fetch: 1.3.0
|
||||||
|
form-data-encoder: 1.7.2
|
||||||
|
formdata-node: 4.4.1
|
||||||
|
node-fetch: 2.6.12
|
||||||
|
transitivePeerDependencies:
|
||||||
|
- encoding
|
||||||
|
- supports-color
|
||||||
|
dev: false
|
||||||
|
|
||||||
/@apidevtools/json-schema-ref-parser@10.1.0:
|
/@apidevtools/json-schema-ref-parser@10.1.0:
|
||||||
resolution: {integrity: sha512-3e+viyMuXdrcK8v5pvP+SDoAQ77FH6OyRmuK48SZKmdHJRFm87RsSs8qm6kP39a/pOPURByJw+OXzQIqcfmKtA==}
|
resolution: {integrity: sha512-3e+viyMuXdrcK8v5pvP+SDoAQ77FH6OyRmuK48SZKmdHJRFm87RsSs8qm6kP39a/pOPURByJw+OXzQIqcfmKtA==}
|
||||||
engines: {node: '>= 16'}
|
engines: {node: '>= 16'}
|
||||||
@@ -3018,6 +3049,10 @@ packages:
|
|||||||
resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==}
|
resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==}
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/@types/uuid@9.0.2:
|
||||||
|
resolution: {integrity: sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4):
|
/@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4):
|
||||||
resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==}
|
resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==}
|
||||||
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
||||||
@@ -6037,6 +6072,19 @@ packages:
|
|||||||
uuid: 8.3.2
|
uuid: 8.3.2
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/next-query-params@4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1):
|
||||||
|
resolution: {integrity: sha512-hGNCYRH8YyA5ItiBGSKrtMl21b2MAqfPkdI1mvwloNVqSU142IaGzqHN+OTovyeLIpQfonY01y7BAHb/UH4POg==}
|
||||||
|
peerDependencies:
|
||||||
|
next: ^10.0.0 || ^11.0.0 || ^12.0.0 || ^13.0.0
|
||||||
|
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||||
|
use-query-params: ^2.0.0
|
||||||
|
dependencies:
|
||||||
|
next: 13.4.2(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
react: 18.2.0
|
||||||
|
tslib: 2.6.0
|
||||||
|
use-query-params: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
dev: false
|
||||||
|
|
||||||
/next-tick@1.1.0:
|
/next-tick@1.1.0:
|
||||||
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
|
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
|
||||||
dev: false
|
dev: false
|
||||||
@@ -7147,6 +7195,10 @@ packages:
|
|||||||
randombytes: 2.1.0
|
randombytes: 2.1.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/serialize-query-params@2.0.2:
|
||||||
|
resolution: {integrity: sha512-1chMo1dST4pFA9RDXAtF0Rbjaut4is7bzFbI1Z26IuMub68pNCILku85aYmeFhvnY//BXUPUhoRMjYcsT93J/Q==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/serve-static@1.15.0:
|
/serve-static@1.15.0:
|
||||||
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
|
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
|
||||||
engines: {node: '>= 0.8.0'}
|
engines: {node: '>= 0.8.0'}
|
||||||
@@ -7824,6 +7876,24 @@ packages:
|
|||||||
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/use-query-params@2.2.1(react-dom@18.2.0)(react@18.2.0):
|
||||||
|
resolution: {integrity: sha512-i6alcyLB8w9i3ZK3caNftdb+UnbfBRNPDnc89CNQWkGRmDrm/gfydHvMBfVsQJRq3NoHOM2dt/ceBWG2397v1Q==}
|
||||||
|
peerDependencies:
|
||||||
|
'@reach/router': ^1.2.1
|
||||||
|
react: '>=16.8.0'
|
||||||
|
react-dom: '>=16.8.0'
|
||||||
|
react-router-dom: '>=5'
|
||||||
|
peerDependenciesMeta:
|
||||||
|
'@reach/router':
|
||||||
|
optional: true
|
||||||
|
react-router-dom:
|
||||||
|
optional: true
|
||||||
|
dependencies:
|
||||||
|
react: 18.2.0
|
||||||
|
react-dom: 18.2.0(react@18.2.0)
|
||||||
|
serialize-query-params: 2.0.2
|
||||||
|
dev: false
|
||||||
|
|
||||||
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
|
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
|
||||||
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
|
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
|
||||||
engines: {node: '>=10'}
|
engines: {node: '>=10'}
|
||||||
@@ -7872,6 +7942,11 @@ packages:
|
|||||||
hasBin: true
|
hasBin: true
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/uuid@9.0.0:
|
||||||
|
resolution: {integrity: sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==}
|
||||||
|
hasBin: true
|
||||||
|
dev: false
|
||||||
|
|
||||||
/vary@1.1.2:
|
/vary@1.1.2:
|
||||||
resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==}
|
resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==}
|
||||||
engines: {node: '>= 0.8'}
|
engines: {node: '>= 0.8'}
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `streamingChannel` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "streamingChannel";
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey";
|
||||||
|
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "OutputEvaluation" DROP CONSTRAINT "OutputEvaluation_modelOutputId_fkey";
|
||||||
|
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "OutputEvaluation_modelOutputId_evaluationId_key";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "OutputEvaluation" RENAME COLUMN "modelOutputId" TO "modelResponseId";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "retryTime",
|
||||||
|
DROP COLUMN "statusCode",
|
||||||
|
ADD COLUMN "jobQueuedAt" TIMESTAMP(3),
|
||||||
|
ADD COLUMN "jobStartedAt" TIMESTAMP(3);
|
||||||
|
|
||||||
|
ALTER TABLE "ModelOutput" RENAME TO "ModelResponse";
|
||||||
|
|
||||||
|
ALTER TABLE "ModelResponse"
|
||||||
|
ADD COLUMN "requestedAt" TIMESTAMP(3),
|
||||||
|
ADD COLUMN "receivedAt" TIMESTAMP(3),
|
||||||
|
ADD COLUMN "statusCode" INTEGER,
|
||||||
|
ADD COLUMN "errorMessage" TEXT,
|
||||||
|
ADD COLUMN "retryTime" TIMESTAMP(3),
|
||||||
|
ADD COLUMN "outdated" BOOLEAN NOT NULL DEFAULT false;
|
||||||
|
|
||||||
|
-- 3. Remove the unnecessary column
|
||||||
|
ALTER TABLE "ModelResponse"
|
||||||
|
DROP COLUMN "timeToComplete";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ModelResponse" RENAME CONSTRAINT "ModelOutput_pkey" TO "ModelResponse_pkey";
|
||||||
|
ALTER TABLE "ModelResponse" ALTER COLUMN "output" DROP NOT NULL;
|
||||||
|
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "ModelOutput_scenarioVariantCellId_key";
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "ModelResponse" ADD CONSTRAINT "ModelResponse_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- RenameIndex
|
||||||
|
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ModelResponse_inputHash_idx";
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "OutputEvaluation_modelResponseId_evaluationId_key" ON "OutputEvaluation"("modelResponseId", "evaluationId");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelResponseId_fkey" FOREIGN KEY ("modelResponseId") REFERENCES "ModelResponse"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
|
||||||
@@ -22,10 +22,10 @@ model Experiment {
|
|||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
TemplateVariable TemplateVariable[]
|
templateVariables TemplateVariable[]
|
||||||
PromptVariant PromptVariant[]
|
promptVariants PromptVariant[]
|
||||||
TestScenario TestScenario[]
|
testScenarios TestScenario[]
|
||||||
Evaluation Evaluation[]
|
evaluations Evaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model PromptVariant {
|
model PromptVariant {
|
||||||
@@ -90,13 +90,11 @@ enum CellRetrievalStatus {
|
|||||||
model ScenarioVariantCell {
|
model ScenarioVariantCell {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
statusCode Int?
|
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||||
errorMessage String?
|
jobQueuedAt DateTime?
|
||||||
retryTime DateTime?
|
jobStartedAt DateTime?
|
||||||
streamingChannel String?
|
modelResponses ModelResponse[]
|
||||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
errorMessage String? // Contains errors that occurred independently of model responses
|
||||||
|
|
||||||
modelOutput ModelOutput?
|
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||||
@@ -111,24 +109,28 @@ model ScenarioVariantCell {
|
|||||||
@@unique([promptVariantId, testScenarioId])
|
@@unique([promptVariantId, testScenarioId])
|
||||||
}
|
}
|
||||||
|
|
||||||
model ModelOutput {
|
model ModelResponse {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
inputHash String
|
inputHash String
|
||||||
output Json
|
requestedAt DateTime?
|
||||||
timeToComplete Int @default(0)
|
receivedAt DateTime?
|
||||||
cost Float?
|
output Json?
|
||||||
promptTokens Int?
|
cost Float?
|
||||||
completionTokens Int?
|
promptTokens Int?
|
||||||
|
completionTokens Int?
|
||||||
|
statusCode Int?
|
||||||
|
errorMessage String?
|
||||||
|
retryTime DateTime?
|
||||||
|
outdated Boolean @default(false)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
scenarioVariantCellId String @db.Uuid
|
scenarioVariantCellId String @db.Uuid
|
||||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||||
outputEvaluation OutputEvaluation[]
|
outputEvaluations OutputEvaluation[]
|
||||||
|
|
||||||
@@unique([scenarioVariantCellId])
|
|
||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +152,7 @@ model Evaluation {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
OutputEvaluation OutputEvaluation[]
|
outputEvaluations OutputEvaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model OutputEvaluation {
|
model OutputEvaluation {
|
||||||
@@ -160,8 +162,8 @@ model OutputEvaluation {
|
|||||||
result Float
|
result Float
|
||||||
details String?
|
details String?
|
||||||
|
|
||||||
modelOutputId String @db.Uuid
|
modelResponseId String @db.Uuid
|
||||||
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
modelResponse ModelResponse @relation(fields: [modelResponseId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
evaluationId String @db.Uuid
|
evaluationId String @db.Uuid
|
||||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||||
@@ -169,7 +171,7 @@ model OutputEvaluation {
|
|||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([modelOutputId, evaluationId])
|
@@unique([modelResponseId, evaluationId])
|
||||||
}
|
}
|
||||||
|
|
||||||
model Organization {
|
model Organization {
|
||||||
@@ -179,8 +181,8 @@ model Organization {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
OrganizationUser OrganizationUser[]
|
organizationUsers OrganizationUser[]
|
||||||
Experiment Experiment[]
|
experiments Experiment[]
|
||||||
}
|
}
|
||||||
|
|
||||||
enum OrganizationUserRole {
|
enum OrganizationUserRole {
|
||||||
@@ -234,15 +236,15 @@ model Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model User {
|
model User {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
name String?
|
name String?
|
||||||
email String? @unique
|
email String? @unique
|
||||||
emailVerified DateTime?
|
emailVerified DateTime?
|
||||||
image String?
|
image String?
|
||||||
accounts Account[]
|
accounts Account[]
|
||||||
sessions Session[]
|
sessions Session[]
|
||||||
OrganizationUser OrganizationUser[]
|
organizationUsers OrganizationUser[]
|
||||||
Organization Organization[]
|
organizations Organization[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model VerificationToken {
|
model VerificationToken {
|
||||||
|
|||||||
@@ -7,9 +7,13 @@ const defaultId = "11111111-1111-1111-1111-111111111111";
|
|||||||
await prisma.organization.deleteMany({
|
await prisma.organization.deleteMany({
|
||||||
where: { id: defaultId },
|
where: { id: defaultId },
|
||||||
});
|
});
|
||||||
await prisma.organization.create({
|
|
||||||
data: { id: defaultId },
|
// If there's an existing org, just seed into it
|
||||||
});
|
const org =
|
||||||
|
(await prisma.organization.findFirst({})) ??
|
||||||
|
(await prisma.organization.create({
|
||||||
|
data: { id: defaultId },
|
||||||
|
}));
|
||||||
|
|
||||||
await prisma.experiment.deleteMany({
|
await prisma.experiment.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
@@ -21,7 +25,7 @@ await prisma.experiment.create({
|
|||||||
data: {
|
data: {
|
||||||
id: defaultId,
|
id: defaultId,
|
||||||
label: "Country Capitals Example",
|
label: "Country Capitals Example",
|
||||||
organizationId: defaultId,
|
organizationId: org.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -103,30 +107,41 @@ await prisma.testScenario.deleteMany({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const countries = [
|
||||||
|
"Afghanistan",
|
||||||
|
"Albania",
|
||||||
|
"Algeria",
|
||||||
|
"Andorra",
|
||||||
|
"Angola",
|
||||||
|
"Antigua and Barbuda",
|
||||||
|
"Argentina",
|
||||||
|
"Armenia",
|
||||||
|
"Australia",
|
||||||
|
"Austria",
|
||||||
|
"Austrian Empire",
|
||||||
|
"Azerbaijan",
|
||||||
|
"Baden",
|
||||||
|
"Bahamas, The",
|
||||||
|
"Bahrain",
|
||||||
|
"Bangladesh",
|
||||||
|
"Barbados",
|
||||||
|
"Bavaria",
|
||||||
|
"Belarus",
|
||||||
|
"Belgium",
|
||||||
|
"Belize",
|
||||||
|
"Benin (Dahomey)",
|
||||||
|
"Bolivia",
|
||||||
|
"Bosnia and Herzegovina",
|
||||||
|
"Botswana",
|
||||||
|
];
|
||||||
await prisma.testScenario.createMany({
|
await prisma.testScenario.createMany({
|
||||||
data: [
|
data: countries.map((country, i) => ({
|
||||||
{
|
experimentId: defaultId,
|
||||||
experimentId: defaultId,
|
sortIndex: i,
|
||||||
sortIndex: 0,
|
variableValues: {
|
||||||
variableValues: {
|
country: country,
|
||||||
country: "Spain",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
})),
|
||||||
experimentId: defaultId,
|
|
||||||
sortIndex: 1,
|
|
||||||
variableValues: {
|
|
||||||
country: "USA",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
experimentId: defaultId,
|
|
||||||
sortIndex: 2,
|
|
||||||
variableValues: {
|
|
||||||
country: "Chile",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const variants = await prisma.promptVariant.findMany({
|
const variants = await prisma.promptVariant.findMany({
|
||||||
@@ -149,5 +164,5 @@ await Promise.all(
|
|||||||
testScenarioId: scenario.id,
|
testScenarioId: scenario.id,
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId, { stream: false })),
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -6,4 +6,7 @@ echo "Migrating the database"
|
|||||||
pnpm prisma migrate deploy
|
pnpm prisma migrate deploy
|
||||||
|
|
||||||
echo "Starting the server"
|
echo "Starting the server"
|
||||||
pnpm start
|
|
||||||
|
pnpm concurrently --kill-others \
|
||||||
|
"pnpm start" \
|
||||||
|
"pnpm tsx src/server/tasks/worker.ts"
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import {
|
import {
|
||||||
Button,
|
Button,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
Modal,
|
Modal,
|
||||||
ModalBody,
|
ModalBody,
|
||||||
ModalCloseButton,
|
ModalCloseButton,
|
||||||
@@ -7,24 +9,21 @@ import {
|
|||||||
ModalFooter,
|
ModalFooter,
|
||||||
ModalHeader,
|
ModalHeader,
|
||||||
ModalOverlay,
|
ModalOverlay,
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
Spinner,
|
Spinner,
|
||||||
HStack,
|
Text,
|
||||||
Icon,
|
VStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
|
||||||
import { useState } from "react";
|
|
||||||
import { ModelStatsCard } from "./ModelStatsCard";
|
|
||||||
import { ModelSearch } from "./ModelSearch";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
import { type Model, type SupportedProvider } from "~/modelProviders/types";
|
import { useState } from "react";
|
||||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { keyForModel } from "~/utils/utils";
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
|
import { lookupModel, modelLabel } from "~/utils/utils";
|
||||||
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
|
import { ModelSearch } from "./ModelSearch";
|
||||||
|
import { ModelStatsCard } from "./ModelStatsCard";
|
||||||
|
|
||||||
export const ChangeModelModal = ({
|
export const ChangeModelModal = ({
|
||||||
variant,
|
variant,
|
||||||
@@ -33,11 +32,14 @@ export const ChangeModelModal = ({
|
|||||||
variant: PromptVariant;
|
variant: PromptVariant;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
const originalModelProviderName = variant.modelProvider as SupportedProvider;
|
const originalModel = lookupModel(variant.modelProvider, variant.model);
|
||||||
const originalModelProvider = frontendModelProviders[originalModelProviderName];
|
const [selectedModel, setSelectedModel] = useState({
|
||||||
const originalModel = originalModelProvider.models[variant.model] as Model;
|
provider: variant.modelProvider,
|
||||||
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
|
model: variant.model,
|
||||||
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
|
} as ProviderModel);
|
||||||
|
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
@@ -67,14 +69,16 @@ export const ChangeModelModal = ({
|
|||||||
await replaceVariantMutation.mutateAsync({
|
await replaceVariantMutation.mutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
constructFn: modifiedPromptFn,
|
constructFn: modifiedPromptFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
onClose();
|
onClose();
|
||||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||||
|
|
||||||
const originalModelLabel = keyForModel(originalModel);
|
const originalLabel = modelLabel(variant.modelProvider, variant.model);
|
||||||
const selectedModelLabel = keyForModel(selectedModel);
|
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
|
||||||
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
|
const convertedLabel =
|
||||||
|
convertedModel && modelLabel(convertedModel.provider, convertedModel.model);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
@@ -94,16 +98,19 @@ export const ChangeModelModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||||
{originalModelLabel !== selectedModelLabel && (
|
{originalLabel !== selectedLabel && (
|
||||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
<ModelStatsCard
|
||||||
|
label="New Model"
|
||||||
|
model={lookupModel(selectedModel.provider, selectedModel.model)}
|
||||||
|
/>
|
||||||
)}
|
)}
|
||||||
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
{isString(modifiedPromptFn) && (
|
{isString(modifiedPromptFn) && (
|
||||||
<CompareFunctions
|
<CompareFunctions
|
||||||
originalFunction={variant.constructFn}
|
originalFunction={variant.constructFn}
|
||||||
newFunction={modifiedPromptFn}
|
newFunction={modifiedPromptFn}
|
||||||
leftTitle={originalModelLabel}
|
leftTitle={originalLabel}
|
||||||
rightTitle={convertedModelLabel}
|
rightTitle={convertedLabel}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
@@ -115,7 +122,7 @@ export const ChangeModelModal = ({
|
|||||||
colorScheme="gray"
|
colorScheme="gray"
|
||||||
onClick={getModifiedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
minW={24}
|
minW={24}
|
||||||
isDisabled={originalModel === selectedModel || modificationInProgress}
|
isDisabled={originalLabel === selectedLabel || modificationInProgress}
|
||||||
>
|
>
|
||||||
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -1,49 +1,35 @@
|
|||||||
import { VStack, Text } from "@chakra-ui/react";
|
import { Text, VStack } from "@chakra-ui/react";
|
||||||
import { type LegacyRef, useCallback } from "react";
|
import { type LegacyRef } from "react";
|
||||||
import Select, { type SingleValue } from "react-select";
|
import Select from "react-select";
|
||||||
import { useElementDimensions } from "~/utils/hooks";
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
|
||||||
|
import { flatMap } from "lodash-es";
|
||||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
import { type Model } from "~/modelProviders/types";
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
import { keyForModel } from "~/utils/utils";
|
import { modelLabel } from "~/utils/utils";
|
||||||
|
|
||||||
const modelOptions: { label: string; value: Model }[] = [];
|
const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) =>
|
||||||
|
Object.entries(provider.models).map(([modelId]) => ({
|
||||||
|
provider: providerId,
|
||||||
|
model: modelId,
|
||||||
|
})),
|
||||||
|
) as ProviderModel[];
|
||||||
|
|
||||||
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
|
export const ModelSearch = (props: {
|
||||||
for (const [_, modelValue] of Object.entries(providerValue.models)) {
|
selectedModel: ProviderModel;
|
||||||
modelOptions.push({
|
setSelectedModel: (model: ProviderModel) => void;
|
||||||
label: keyForModel(modelValue),
|
|
||||||
value: modelValue,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const ModelSearch = ({
|
|
||||||
selectedModel,
|
|
||||||
setSelectedModel,
|
|
||||||
}: {
|
|
||||||
selectedModel: Model;
|
|
||||||
setSelectedModel: (model: Model) => void;
|
|
||||||
}) => {
|
}) => {
|
||||||
const handleSelection = useCallback(
|
|
||||||
(option: SingleValue<{ label: string; value: Model }>) => {
|
|
||||||
if (!option) return;
|
|
||||||
setSelectedModel(option.value);
|
|
||||||
},
|
|
||||||
[setSelectedModel],
|
|
||||||
);
|
|
||||||
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
|
|
||||||
|
|
||||||
const [containerRef, containerDimensions] = useElementDimensions();
|
const [containerRef, containerDimensions] = useElementDimensions();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||||
<Text>Browse Models</Text>
|
<Text>Browse Models</Text>
|
||||||
<Select
|
<Select<ProviderModel>
|
||||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||||
value={selectedOption}
|
getOptionLabel={(data) => modelLabel(data.provider, data.model)}
|
||||||
|
getOptionValue={(data) => modelLabel(data.provider, data.model)}
|
||||||
options={modelOptions}
|
options={modelOptions}
|
||||||
onChange={handleSelection}
|
onChange={(option) => option && props.setSelectedModel(option)}
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
import {
|
import {
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
HStack,
|
|
||||||
type StackProps,
|
|
||||||
GridItem,
|
GridItem,
|
||||||
SimpleGrid,
|
HStack,
|
||||||
Link,
|
Link,
|
||||||
|
SimpleGrid,
|
||||||
|
Text,
|
||||||
|
VStack,
|
||||||
|
type StackProps,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { type Model } from "~/modelProviders/types";
|
import { type lookupModel } from "~/utils/utils";
|
||||||
|
|
||||||
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
|
export const ModelStatsCard = ({
|
||||||
|
label,
|
||||||
|
model,
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
model: ReturnType<typeof lookupModel>;
|
||||||
|
}) => {
|
||||||
|
if (!model) return null;
|
||||||
return (
|
return (
|
||||||
<VStack w="full" align="start">
|
<VStack w="full" align="start">
|
||||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||||
|
|||||||
69
src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
69
src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Icon,
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
useDisclosure,
|
||||||
|
Text,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { useRef } from "react";
|
||||||
|
import { BsTrash } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
|
export const DeleteButton = () => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const mutation = api.experiments.delete.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data?.id) return;
|
||||||
|
await mutation.mutateAsync({ id: experiment.data.id });
|
||||||
|
await utils.experiments.list.invalidate();
|
||||||
|
await router.push({ pathname: "/experiments" });
|
||||||
|
onClose();
|
||||||
|
}, [mutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Button size="sm" variant="ghost" colorScheme="red" fontWeight="normal" onClick={onOpen}>
|
||||||
|
<Icon as={BsTrash} boxSize={4} />
|
||||||
|
<Text ml={2}>Delete Experiment</Text>
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
Delete Experiment
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||||
|
as well. Are you sure?
|
||||||
|
</AlertDialogBody>
|
||||||
|
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<Button ref={cancelRef} onClick={onClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -6,13 +6,14 @@ import {
|
|||||||
DrawerHeader,
|
DrawerHeader,
|
||||||
DrawerOverlay,
|
DrawerOverlay,
|
||||||
Heading,
|
Heading,
|
||||||
Stack,
|
VStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import EditScenarioVars from "./EditScenarioVars";
|
import EditScenarioVars from "../OutputsTable/EditScenarioVars";
|
||||||
import EditEvaluations from "./EditEvaluations";
|
import EditEvaluations from "../OutputsTable/EditEvaluations";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { DeleteButton } from "./DeleteButton";
|
||||||
|
|
||||||
export default function SettingsDrawer() {
|
export default function ExperimentSettingsDrawer() {
|
||||||
const isOpen = useAppStore((state) => state.drawerOpen);
|
const isOpen = useAppStore((state) => state.drawerOpen);
|
||||||
const closeDrawer = useAppStore((state) => state.closeDrawer);
|
const closeDrawer = useAppStore((state) => state.closeDrawer);
|
||||||
|
|
||||||
@@ -22,13 +23,16 @@ export default function SettingsDrawer() {
|
|||||||
<DrawerContent>
|
<DrawerContent>
|
||||||
<DrawerCloseButton />
|
<DrawerCloseButton />
|
||||||
<DrawerHeader>
|
<DrawerHeader>
|
||||||
<Heading size="md">Settings</Heading>
|
<Heading size="md">Experiment Settings</Heading>
|
||||||
</DrawerHeader>
|
</DrawerHeader>
|
||||||
<DrawerBody>
|
<DrawerBody h="full" pb={4}>
|
||||||
<Stack spacing={6}>
|
<VStack h="full" justifyContent="space-between">
|
||||||
<EditScenarioVars />
|
<VStack spacing={6}>
|
||||||
<EditEvaluations />
|
<EditScenarioVars />
|
||||||
</Stack>
|
<EditEvaluations />
|
||||||
|
</VStack>
|
||||||
|
<DeleteButton />
|
||||||
|
</VStack>
|
||||||
</DrawerBody>
|
</DrawerBody>
|
||||||
</DrawerContent>
|
</DrawerContent>
|
||||||
</Drawer>
|
</Drawer>
|
||||||
@@ -1,7 +1,13 @@
|
|||||||
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
||||||
import { BsPlus } from "react-icons/bs";
|
import { BsPlus } from "react-icons/bs";
|
||||||
|
import { Text } from "@chakra-ui/react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
import {
|
||||||
|
useExperiment,
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useVisibleScenarioIds,
|
||||||
|
} from "~/utils/hooks";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { ActionButton } from "./ScenariosHeader";
|
import { ActionButton } from "./ScenariosHeader";
|
||||||
|
|
||||||
@@ -9,11 +15,13 @@ export default function AddVariantButton() {
|
|||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
const mutation = api.promptVariants.create.useMutation();
|
const mutation = api.promptVariants.create.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||||
if (!experiment.data) return;
|
if (!experiment.data) return;
|
||||||
await mutation.mutateAsync({
|
await mutation.mutateAsync({
|
||||||
experimentId: experiment.data.id,
|
experimentId: experiment.data.id,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [mutation]);
|
}, [mutation]);
|
||||||
@@ -25,9 +33,10 @@ export default function AddVariantButton() {
|
|||||||
<Flex w="100%" justifyContent="flex-end">
|
<Flex w="100%" justifyContent="flex-end">
|
||||||
<ActionButton
|
<ActionButton
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
|
py={5}
|
||||||
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
||||||
>
|
>
|
||||||
Add Variant
|
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||||
</ActionButton>
|
</ActionButton>
|
||||||
{/* <Button
|
{/* <Button
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
|
|||||||
19
src/components/OutputsTable/OutputCell/CellContent.tsx
Normal file
19
src/components/OutputsTable/OutputCell/CellContent.tsx
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { type StackProps, VStack } from "@chakra-ui/react";
|
||||||
|
import { CellOptions } from "./CellOptions";
|
||||||
|
|
||||||
|
export const CellContent = ({
|
||||||
|
hardRefetch,
|
||||||
|
hardRefetching,
|
||||||
|
children,
|
||||||
|
...props
|
||||||
|
}: {
|
||||||
|
hardRefetch: () => void;
|
||||||
|
hardRefetching: boolean;
|
||||||
|
} & StackProps) => (
|
||||||
|
<VStack w="full" alignItems="flex-start" {...props}>
|
||||||
|
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||||
|
<VStack w="full" alignItems="flex-start" maxH={500} overflowY="auto">
|
||||||
|
{children}
|
||||||
|
</VStack>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
|
import { Button, HStack, Icon, Spinner, Tooltip } from "@chakra-ui/react";
|
||||||
import { BsArrowClockwise } from "react-icons/bs";
|
import { BsArrowClockwise } from "react-icons/bs";
|
||||||
import { useExperimentAccess } from "~/utils/hooks";
|
import { useExperimentAccess } from "~/utils/hooks";
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ export const CellOptions = ({
|
|||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
return (
|
return (
|
||||||
<HStack justifyContent="flex-end" w="full">
|
<HStack justifyContent="flex-end" w="full">
|
||||||
{!refetchingOutput && canModify && (
|
{canModify && (
|
||||||
<Tooltip label="Refetch output" aria-label="refetch output">
|
<Tooltip label="Refetch output" aria-label="refetch output">
|
||||||
<Button
|
<Button
|
||||||
size="xs"
|
size="xs"
|
||||||
@@ -28,7 +28,7 @@ export const CellOptions = ({
|
|||||||
onClick={refetchOutput}
|
onClick={refetchOutput}
|
||||||
aria-label="refetch output"
|
aria-label="refetch output"
|
||||||
>
|
>
|
||||||
<Icon as={BsArrowClockwise} boxSize={4} />
|
<Icon as={refetchingOutput ? Spinner : BsArrowClockwise} boxSize={4} />
|
||||||
</Button>
|
</Button>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { type PromptVariant, type Scenario } from "../types";
|
import { type PromptVariant, type Scenario } from "../types";
|
||||||
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
|
import { Text, VStack } from "@chakra-ui/react";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { type ReactElement, useState, useEffect } from "react";
|
import { type ReactElement, useState, useEffect, Fragment } from "react";
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
import { OutputStats } from "./OutputStats";
|
import { OutputStats } from "./OutputStats";
|
||||||
import { ErrorHandler } from "./ErrorHandler";
|
import { RetryCountdown } from "./RetryCountdown";
|
||||||
import { CellOptions } from "./CellOptions";
|
|
||||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { ResponseLog } from "./ResponseLog";
|
||||||
|
import { CellContent } from "./CellContent";
|
||||||
|
|
||||||
|
const WAITING_MESSAGE_INTERVAL = 20000;
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -60,40 +63,97 @@ export default function OutputCell({
|
|||||||
|
|
||||||
const awaitingOutput =
|
const awaitingOutput =
|
||||||
!cell ||
|
!cell ||
|
||||||
|
!cell.evalsComplete ||
|
||||||
cell.retrievalStatus === "PENDING" ||
|
cell.retrievalStatus === "PENDING" ||
|
||||||
cell.retrievalStatus === "IN_PROGRESS" ||
|
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||||
hardRefetching;
|
hardRefetching;
|
||||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||||
|
|
||||||
const modelOutput = cell?.modelOutput;
|
// TODO: disconnect from socket if we're not streaming anymore
|
||||||
|
const streamedMessage = useSocket<OutputSchema>(cell?.id);
|
||||||
// Disconnect from socket if we're not streaming anymore
|
|
||||||
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
|
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|
||||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
if (!cell && !fetchingOutput)
|
||||||
|
|
||||||
if (awaitingOutput && !streamedMessage)
|
|
||||||
return (
|
return (
|
||||||
<Center h="100%" w="100%">
|
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
|
||||||
<Spinner />
|
<Text color="gray.500">Error retrieving output</Text>
|
||||||
</Center>
|
</CellContent>
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
|
||||||
|
|
||||||
if (cell && cell.errorMessage) {
|
if (cell && cell.errorMessage) {
|
||||||
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
return (
|
||||||
|
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
|
||||||
|
<Text color="red.500">{cell.errorMessage}</Text>
|
||||||
|
</CellContent>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const normalizedOutput = modelOutput
|
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||||
? provider.normalizeOutput(modelOutput.output)
|
|
||||||
|
const mostRecentResponse = cell?.modelResponses[cell.modelResponses.length - 1];
|
||||||
|
const showLogs = !streamedMessage && !mostRecentResponse?.output;
|
||||||
|
|
||||||
|
if (showLogs)
|
||||||
|
return (
|
||||||
|
<CellContent
|
||||||
|
hardRefetching={hardRefetching}
|
||||||
|
hardRefetch={hardRefetch}
|
||||||
|
alignItems="flex-start"
|
||||||
|
fontFamily="inconsolata, monospace"
|
||||||
|
spacing={0}
|
||||||
|
>
|
||||||
|
{cell?.jobQueuedAt && <ResponseLog time={cell.jobQueuedAt} title="Job queued" />}
|
||||||
|
{cell?.jobStartedAt && <ResponseLog time={cell.jobStartedAt} title="Job started" />}
|
||||||
|
{cell?.modelResponses?.map((response) => {
|
||||||
|
let numWaitingMessages = 0;
|
||||||
|
const relativeWaitingTime = response.receivedAt
|
||||||
|
? response.receivedAt.getTime()
|
||||||
|
: Date.now();
|
||||||
|
if (response.requestedAt) {
|
||||||
|
numWaitingMessages = Math.floor(
|
||||||
|
(relativeWaitingTime - response.requestedAt.getTime()) / WAITING_MESSAGE_INTERVAL,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<Fragment key={response.id}>
|
||||||
|
{response.requestedAt && (
|
||||||
|
<ResponseLog time={response.requestedAt} title="Request sent to API" />
|
||||||
|
)}
|
||||||
|
{response.requestedAt &&
|
||||||
|
Array.from({ length: numWaitingMessages }, (_, i) => (
|
||||||
|
<ResponseLog
|
||||||
|
key={`waiting-${i}`}
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
|
time={new Date(response.requestedAt!.getTime() + i * WAITING_MESSAGE_INTERVAL)}
|
||||||
|
title="Waiting for response"
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
{response.receivedAt && (
|
||||||
|
<ResponseLog
|
||||||
|
time={response.receivedAt}
|
||||||
|
title="Response received from API"
|
||||||
|
message={`statusCode: ${response.statusCode ?? ""}\n ${
|
||||||
|
response.errorMessage ?? ""
|
||||||
|
}`}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Fragment>
|
||||||
|
);
|
||||||
|
}) ?? null}
|
||||||
|
{mostRecentResponse?.retryTime && (
|
||||||
|
<RetryCountdown retryTime={mostRecentResponse.retryTime} />
|
||||||
|
)}
|
||||||
|
</CellContent>
|
||||||
|
);
|
||||||
|
|
||||||
|
const normalizedOutput = mostRecentResponse?.output
|
||||||
|
? provider.normalizeOutput(mostRecentResponse?.output)
|
||||||
: streamedMessage
|
: streamedMessage
|
||||||
? provider.normalizeOutput(streamedMessage)
|
? provider.normalizeOutput(streamedMessage)
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
if (modelOutput && normalizedOutput?.type === "json") {
|
if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
|
||||||
return (
|
return (
|
||||||
<VStack
|
<VStack
|
||||||
w="100%"
|
w="100%"
|
||||||
@@ -103,8 +163,13 @@ export default function OutputCell({
|
|||||||
overflowX="hidden"
|
overflowX="hidden"
|
||||||
justifyContent="space-between"
|
justifyContent="space-between"
|
||||||
>
|
>
|
||||||
<VStack w="full" flex={1} spacing={0}>
|
<CellContent
|
||||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
hardRefetching={hardRefetching}
|
||||||
|
hardRefetch={hardRefetch}
|
||||||
|
w="full"
|
||||||
|
flex={1}
|
||||||
|
spacing={0}
|
||||||
|
>
|
||||||
<SyntaxHighlighter
|
<SyntaxHighlighter
|
||||||
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||||
language="json"
|
language="json"
|
||||||
@@ -116,8 +181,8 @@ export default function OutputCell({
|
|||||||
>
|
>
|
||||||
{stringify(normalizedOutput.value, { maxLength: 40 })}
|
{stringify(normalizedOutput.value, { maxLength: 40 })}
|
||||||
</SyntaxHighlighter>
|
</SyntaxHighlighter>
|
||||||
</VStack>
|
</CellContent>
|
||||||
<OutputStats modelOutput={modelOutput} scenario={scenario} />
|
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -127,10 +192,13 @@ export default function OutputCell({
|
|||||||
return (
|
return (
|
||||||
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
<VStack w="full" alignItems="flex-start" spacing={0}>
|
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
|
||||||
<Text>{contentToDisplay}</Text>
|
<Text>{contentToDisplay}</Text>
|
||||||
|
</CellContent>
|
||||||
</VStack>
|
</VStack>
|
||||||
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
|
{mostRecentResponse?.output && (
|
||||||
|
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
|
||||||
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,28 +7,32 @@ import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
|||||||
const SHOW_TIME = true;
|
const SHOW_TIME = true;
|
||||||
|
|
||||||
export const OutputStats = ({
|
export const OutputStats = ({
|
||||||
modelOutput,
|
modelResponse,
|
||||||
}: {
|
}: {
|
||||||
modelOutput: NonNullable<
|
modelResponse: NonNullable<
|
||||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelResponses"][0]
|
||||||
>;
|
>;
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
}) => {
|
}) => {
|
||||||
const timeToComplete = modelOutput.timeToComplete;
|
const timeToComplete =
|
||||||
|
modelResponse.receivedAt && modelResponse.requestedAt
|
||||||
|
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
|
||||||
|
: 0;
|
||||||
|
|
||||||
const promptTokens = modelOutput.promptTokens;
|
const promptTokens = modelResponse.promptTokens;
|
||||||
const completionTokens = modelOutput.completionTokens;
|
const completionTokens = modelResponse.completionTokens;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
{modelOutput.outputEvaluation.map((evaluation) => {
|
{modelResponse.outputEvaluations.map((evaluation) => {
|
||||||
const passed = evaluation.result > 0.5;
|
const passed = evaluation.result > 0.5;
|
||||||
return (
|
return (
|
||||||
<Tooltip
|
<Tooltip
|
||||||
isDisabled={!evaluation.details}
|
isDisabled={!evaluation.details}
|
||||||
label={evaluation.details}
|
label={evaluation.details}
|
||||||
key={evaluation.id}
|
key={evaluation.id}
|
||||||
|
shouldWrapChildren
|
||||||
>
|
>
|
||||||
<HStack spacing={0}>
|
<HStack spacing={0}>
|
||||||
<Text>{evaluation.evaluation.label}</Text>
|
<Text>{evaluation.evaluation.label}</Text>
|
||||||
@@ -42,15 +46,15 @@ export const OutputStats = ({
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{modelOutput.cost && (
|
{modelResponse.cost && (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={promptTokens}
|
promptTokens={promptTokens}
|
||||||
completionTokens={completionTokens}
|
completionTokens={completionTokens}
|
||||||
cost={modelOutput.cost}
|
cost={modelResponse.cost}
|
||||||
>
|
>
|
||||||
<HStack spacing={0}>
|
<HStack spacing={0}>
|
||||||
<Icon as={BsCurrencyDollar} />
|
<Icon as={BsCurrencyDollar} />
|
||||||
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
|
<Text mr={1}>{modelResponse.cost.toFixed(3)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</CostTooltip>
|
</CostTooltip>
|
||||||
)}
|
)}
|
||||||
|
|||||||
22
src/components/OutputsTable/OutputCell/ResponseLog.tsx
Normal file
22
src/components/OutputsTable/OutputCell/ResponseLog.tsx
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import { HStack, VStack, Text } from "@chakra-ui/react";
|
||||||
|
import dayjs from "dayjs";
|
||||||
|
|
||||||
|
export const ResponseLog = ({
|
||||||
|
time,
|
||||||
|
title,
|
||||||
|
message,
|
||||||
|
}: {
|
||||||
|
time: Date;
|
||||||
|
title: string;
|
||||||
|
message?: string;
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<VStack spacing={0} alignItems="flex-start">
|
||||||
|
<HStack>
|
||||||
|
<Text>{dayjs(time).format("HH:mm:ss")}</Text>
|
||||||
|
<Text>{title}</Text>
|
||||||
|
</HStack>
|
||||||
|
{message && <Text pl={4}>{message}</Text>}
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,21 +1,12 @@
|
|||||||
import { type ScenarioVariantCell } from "@prisma/client";
|
import { Text } from "@chakra-ui/react";
|
||||||
import { VStack, Text } from "@chakra-ui/react";
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import pluralize from "pluralize";
|
import pluralize from "pluralize";
|
||||||
|
|
||||||
export const ErrorHandler = ({
|
export const RetryCountdown = ({ retryTime }: { retryTime: Date }) => {
|
||||||
cell,
|
|
||||||
refetchOutput,
|
|
||||||
}: {
|
|
||||||
cell: ScenarioVariantCell;
|
|
||||||
refetchOutput: () => void;
|
|
||||||
}) => {
|
|
||||||
const [msToWait, setMsToWait] = useState(0);
|
const [msToWait, setMsToWait] = useState(0);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!cell.retryTime) return;
|
const initialWaitTime = retryTime.getTime() - Date.now();
|
||||||
|
|
||||||
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
|
||||||
const msModuloOneSecond = initialWaitTime % 1000;
|
const msModuloOneSecond = initialWaitTime % 1000;
|
||||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||||
setMsToWait(remainingTime);
|
setMsToWait(remainingTime);
|
||||||
@@ -36,18 +27,13 @@ export const ErrorHandler = ({
|
|||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
clearTimeout(timeout);
|
clearTimeout(timeout);
|
||||||
};
|
};
|
||||||
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
}, [retryTime]);
|
||||||
|
|
||||||
|
if (msToWait <= 0) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack w="full">
|
<Text color="red.600" fontSize="sm">
|
||||||
<Text color="red.600" wordBreak="break-word">
|
Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
|
||||||
{cell.errorMessage}
|
</Text>
|
||||||
</Text>
|
|
||||||
{msToWait > 0 && (
|
|
||||||
<Text color="red.600" fontSize="sm">
|
|
||||||
Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
||||||
|
import {
|
||||||
|
BsChevronDoubleLeft,
|
||||||
|
BsChevronDoubleRight,
|
||||||
|
BsChevronLeft,
|
||||||
|
BsChevronRight,
|
||||||
|
} from "react-icons/bs";
|
||||||
|
import { usePage, useScenarios } from "~/utils/hooks";
|
||||||
|
|
||||||
|
const ScenarioPaginator = () => {
|
||||||
|
const [page, setPage] = usePage();
|
||||||
|
const { data } = useScenarios();
|
||||||
|
|
||||||
|
if (!data) return null;
|
||||||
|
|
||||||
|
const { scenarios, startIndex, lastPage, count } = data;
|
||||||
|
|
||||||
|
const nextPage = () => {
|
||||||
|
if (page < lastPage) {
|
||||||
|
setPage(page + 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const prevPage = () => {
|
||||||
|
if (page > 1) {
|
||||||
|
setPage(page - 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const goToLastPage = () => setPage(lastPage, "replace");
|
||||||
|
const goToFirstPage = () => setPage(1, "replace");
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack pt={4}>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToFirstPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Go to first page"
|
||||||
|
icon={<BsChevronDoubleLeft />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={prevPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Previous page"
|
||||||
|
icon={<BsChevronLeft />}
|
||||||
|
/>
|
||||||
|
<Box>
|
||||||
|
{startIndex}-{startIndex + scenarios.length - 1} / {count}
|
||||||
|
</Box>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={nextPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Next page"
|
||||||
|
icon={<BsChevronRight />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToLastPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Go to last page"
|
||||||
|
icon={<BsChevronDoubleRight />}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ScenarioPaginator;
|
||||||
@@ -12,7 +12,12 @@ import {
|
|||||||
Spinner,
|
Spinner,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
import {
|
||||||
|
useExperiment,
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useScenarios,
|
||||||
|
} from "~/utils/hooks";
|
||||||
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
@@ -21,9 +26,10 @@ export const ActionButton = (props: ButtonProps) => (
|
|||||||
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
||||||
);
|
);
|
||||||
|
|
||||||
export const ScenariosHeader = (props: { numScenarios: number }) => {
|
export const ScenariosHeader = () => {
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
|
const scenarios = useScenarios();
|
||||||
|
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
const createScenarioMutation = api.scenarios.create.useMutation();
|
const createScenarioMutation = api.scenarios.create.useMutation();
|
||||||
@@ -44,19 +50,22 @@ export const ScenariosHeader = (props: { numScenarios: number }) => {
|
|||||||
return (
|
return (
|
||||||
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
||||||
<Text fontSize={16} fontWeight="bold">
|
<Text fontSize={16} fontWeight="bold">
|
||||||
Scenarios ({props.numScenarios})
|
Scenarios ({scenarios.data?.count})
|
||||||
</Text>
|
</Text>
|
||||||
{canModify && (
|
{canModify && (
|
||||||
<Menu>
|
<Menu>
|
||||||
<MenuButton mt={1}>
|
<MenuButton
|
||||||
<IconButton
|
as={IconButton}
|
||||||
variant="ghost"
|
mt={1}
|
||||||
aria-label="Edit Scenarios"
|
variant="ghost"
|
||||||
icon={<Icon as={loading ? Spinner : BsGear} />}
|
aria-label="Edit Scenarios"
|
||||||
/>
|
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||||
</MenuButton>
|
/>
|
||||||
<MenuList fontSize="md">
|
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
|
||||||
<MenuItem icon={<Icon as={BsPlus} boxSize={6} />} onClick={() => onAddScenario(false)}>
|
<MenuItem
|
||||||
|
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
|
||||||
|
onClick={() => onAddScenario(false)}
|
||||||
|
>
|
||||||
Add Scenario
|
Add Scenario
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
||||||
|
|||||||
@@ -1,17 +1,52 @@
|
|||||||
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
import {
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
Box,
|
||||||
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
Button,
|
||||||
import { type PromptVariant } from "./types";
|
HStack,
|
||||||
import { api } from "~/utils/api";
|
IconButton,
|
||||||
|
Spinner,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
useToast,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||||
|
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import {
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useModifierKeyLabel,
|
||||||
|
useVisibleScenarioIds,
|
||||||
|
} from "~/utils/hooks";
|
||||||
|
import { type PromptVariant } from "./types";
|
||||||
|
|
||||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||||
|
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
const [isChanged, setIsChanged] = useState(false);
|
const [isChanged, setIsChanged] = useState(false);
|
||||||
|
|
||||||
|
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||||
|
|
||||||
|
const toggleFullscreen = useCallback(() => {
|
||||||
|
setIsFullscreen((prev) => !prev);
|
||||||
|
editorRef.current?.focus();
|
||||||
|
}, [setIsFullscreen]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleEsc = (event: KeyboardEvent) => {
|
||||||
|
if (event.key === "Escape" && isFullscreen) {
|
||||||
|
toggleFullscreen();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener("keydown", handleEsc);
|
||||||
|
return () => window.removeEventListener("keydown", handleEsc);
|
||||||
|
}, [isFullscreen, toggleFullscreen]);
|
||||||
|
|
||||||
const lastSavedFn = props.variant.constructFn;
|
const lastSavedFn = props.variant.constructFn;
|
||||||
|
|
||||||
const modifierKey = useModifierKeyLabel();
|
const modifierKey = useModifierKeyLabel();
|
||||||
@@ -33,6 +68,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||||
if (!editorRef.current) return;
|
if (!editorRef.current) return;
|
||||||
@@ -61,6 +97,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
const resp = await replaceVariant.mutateAsync({
|
const resp = await replaceVariant.mutateAsync({
|
||||||
id: props.variant.id,
|
id: props.variant.id,
|
||||||
constructFn: currentFn,
|
constructFn: currentFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
if (resp.status === "error") {
|
if (resp.status === "error") {
|
||||||
return toast({
|
return toast({
|
||||||
@@ -99,11 +136,23 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
readOnly: !canModify,
|
readOnly: !canModify,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Workaround because otherwise the commands only work on whatever
|
||||||
|
// editor was loaded on the page last.
|
||||||
|
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
||||||
editorRef.current.onDidFocusEditorText(() => {
|
editorRef.current.onDidFocusEditorText(() => {
|
||||||
// Workaround because otherwise the command only works on whatever
|
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, onSave);
|
||||||
// editor was loaded on the page last.
|
|
||||||
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
editorRef.current?.addCommand(
|
||||||
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave);
|
monaco.KeyMod.CtrlCmd | monaco.KeyMod.Shift | monaco.KeyCode.KeyF,
|
||||||
|
toggleFullscreen,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Exit fullscreen with escape
|
||||||
|
editorRef.current?.addCommand(monaco.KeyCode.Escape, () => {
|
||||||
|
if (isFullscreen) {
|
||||||
|
toggleFullscreen();
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
editorRef.current.onDidChangeModelContent(checkForChanges);
|
editorRef.current.onDidChangeModelContent(checkForChanges);
|
||||||
@@ -132,8 +181,40 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
}, [canModify]);
|
}, [canModify]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="100%" pos="relative">
|
<Box
|
||||||
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
w="100%"
|
||||||
|
ref={containerRef}
|
||||||
|
sx={
|
||||||
|
isFullscreen
|
||||||
|
? {
|
||||||
|
position: "fixed",
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
right: 0,
|
||||||
|
bottom: 0,
|
||||||
|
}
|
||||||
|
: { h: "400px", w: "100%" }
|
||||||
|
}
|
||||||
|
bgColor={editorBackground}
|
||||||
|
zIndex={isFullscreen ? 1000 : "unset"}
|
||||||
|
pos="relative"
|
||||||
|
_hover={{ ".fullscreen-toggle": { opacity: 1 } }}
|
||||||
|
>
|
||||||
|
<Box id={editorId} w="100%" h="100%" />
|
||||||
|
<Tooltip label={`${modifierKey} + ⇧ + F`}>
|
||||||
|
<IconButton
|
||||||
|
className="fullscreen-toggle"
|
||||||
|
aria-label="Minimize"
|
||||||
|
icon={isFullscreen ? <FiMinimize /> : <FiMaximize />}
|
||||||
|
position="absolute"
|
||||||
|
top={2}
|
||||||
|
right={2}
|
||||||
|
onClick={toggleFullscreen}
|
||||||
|
opacity={0}
|
||||||
|
transition="opacity 0.2s"
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
|
||||||
{isChanged && (
|
{isChanged && (
|
||||||
<HStack pos="absolute" bottom={2} right={2}>
|
<HStack pos="absolute" bottom={2} right={2}>
|
||||||
<Button
|
<Button
|
||||||
@@ -146,7 +227,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
>
|
>
|
||||||
Reset
|
Reset
|
||||||
</Button>
|
</Button>
|
||||||
<Tooltip label={`${modifierKey} + Enter`}>
|
<Tooltip label={`${modifierKey} + S`}>
|
||||||
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||||
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -21,17 +21,14 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
completionTokens: 0,
|
completionTokens: 0,
|
||||||
scenarioCount: 0,
|
scenarioCount: 0,
|
||||||
outputCount: 0,
|
outputCount: 0,
|
||||||
awaitingRetrievals: false,
|
awaitingEvals: false,
|
||||||
},
|
},
|
||||||
refetchInterval,
|
refetchInterval,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||||
useEffect(
|
useEffect(() => setRefetchInterval(data.awaitingEvals ? 5000 : 0), [data.awaitingEvals]);
|
||||||
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
|
|
||||||
[data.awaitingRetrievals],
|
|
||||||
);
|
|
||||||
|
|
||||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||||
"green.500",
|
"green.500",
|
||||||
@@ -69,7 +66,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{data.overallCost && !data.awaitingRetrievals && (
|
{data.overallCost && (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={data.promptTokens}
|
promptTokens={data.promptTokens}
|
||||||
completionTokens={data.completionTokens}
|
completionTokens={data.completionTokens}
|
||||||
|
|||||||
@@ -7,30 +7,50 @@ import VariantHeader from "../VariantHeader/VariantHeader";
|
|||||||
import VariantStats from "./VariantStats";
|
import VariantStats from "./VariantStats";
|
||||||
import { ScenariosHeader } from "./ScenariosHeader";
|
import { ScenariosHeader } from "./ScenariosHeader";
|
||||||
import { borders } from "./styles";
|
import { borders } from "./styles";
|
||||||
|
import { useScenarios } from "~/utils/hooks";
|
||||||
|
import ScenarioPaginator from "./ScenarioPaginator";
|
||||||
|
import { Fragment, useEffect, useState } from "react";
|
||||||
|
|
||||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
export default function OutputsTable({
|
||||||
|
experimentId,
|
||||||
|
func,
|
||||||
|
}: {
|
||||||
|
experimentId: string | undefined;
|
||||||
|
func: () => void;
|
||||||
|
}) {
|
||||||
const variants = api.promptVariants.list.useQuery(
|
const variants = api.promptVariants.list.useQuery(
|
||||||
{ experimentId: experimentId as string },
|
{ experimentId: experimentId as string },
|
||||||
{ enabled: !!experimentId },
|
{ enabled: !!experimentId },
|
||||||
);
|
);
|
||||||
|
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
const scenarios = useScenarios();
|
||||||
{ experimentId: experimentId as string },
|
const [newFunc, setNewFunc] = useState<() => void | null>();
|
||||||
{ enabled: !!experimentId },
|
|
||||||
);
|
useEffect(() => {
|
||||||
|
console.log('func', func)
|
||||||
|
if (func) {
|
||||||
|
setNewFunc(prev => {
|
||||||
|
console.log('Setting newFunc from', prev, 'to', func);
|
||||||
|
return func;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [func]);
|
||||||
|
|
||||||
if (!variants.data || !scenarios.data) return null;
|
if (!variants.data || !scenarios.data) return null;
|
||||||
|
|
||||||
const allCols = variants.data.length + 2;
|
const allCols = variants.data.length + 2;
|
||||||
const variantHeaderRows = 3;
|
const variantHeaderRows = 3;
|
||||||
const scenarioHeaderRows = 1;
|
const scenarioHeaderRows = 1;
|
||||||
const allRows = variantHeaderRows + scenarioHeaderRows + scenarios.data.length;
|
const scenarioFooterRows = 1;
|
||||||
|
const visibleScenariosCount = scenarios.data.scenarios.length;
|
||||||
|
const allRows =
|
||||||
|
variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + scenarioFooterRows;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid
|
<Grid
|
||||||
pt={4}
|
pt={4}
|
||||||
pb={24}
|
pb={24}
|
||||||
pl={4}
|
pl={8}
|
||||||
display="grid"
|
display="grid"
|
||||||
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
||||||
sx={{
|
sx={{
|
||||||
@@ -43,17 +63,18 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
<GridItem rowSpan={variantHeaderRows}>
|
<GridItem rowSpan={variantHeaderRows}>
|
||||||
<AddVariantButton />
|
<AddVariantButton />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
{newFunc && newFunc.toString()}
|
||||||
{variants.data.map((variant, i) => {
|
{variants.data.map((variant, i) => {
|
||||||
const sharedProps: GridItemProps = {
|
const sharedProps: GridItemProps = {
|
||||||
...borders,
|
...borders,
|
||||||
colStart: i + 2,
|
colStart: i + 2,
|
||||||
borderLeftWidth: i === 0 ? 1 : 0,
|
borderLeftWidth: i === 0 ? 1 : 0,
|
||||||
|
marginLeft: i === 0 ? "-1px" : 0,
|
||||||
|
backgroundColor: "gray.100",
|
||||||
};
|
};
|
||||||
return (
|
return (
|
||||||
<>
|
<Fragment key={variant.uiId}>
|
||||||
<VariantHeader
|
<VariantHeader
|
||||||
key={variant.uiId}
|
|
||||||
variant={variant}
|
variant={variant}
|
||||||
canHide={variants.data.length > 1}
|
canHide={variants.data.length > 1}
|
||||||
rowStart={1}
|
rowStart={1}
|
||||||
@@ -65,7 +86,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
<GridItem rowStart={3} {...sharedProps}>
|
<GridItem rowStart={3} {...sharedProps}>
|
||||||
<VariantStats variant={variant} />
|
<VariantStats variant={variant} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
</>
|
</Fragment>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
|
|
||||||
@@ -76,18 +97,25 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
{...borders}
|
{...borders}
|
||||||
borderRightWidth={0}
|
borderRightWidth={0}
|
||||||
>
|
>
|
||||||
<ScenariosHeader numScenarios={scenarios.data.length} />
|
<ScenariosHeader />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
{scenarios.data.map((scenario, i) => (
|
{scenarios.data.scenarios.map((scenario, i) => (
|
||||||
<ScenarioRow
|
<ScenarioRow
|
||||||
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
||||||
key={scenario.uiId}
|
key={scenario.uiId}
|
||||||
scenario={scenario}
|
scenario={scenario}
|
||||||
variants={variants.data}
|
variants={variants.data}
|
||||||
canHide={scenarios.data.length > 1}
|
canHide={visibleScenariosCount > 1}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
|
<GridItem
|
||||||
|
rowStart={variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + 2}
|
||||||
|
colStart={1}
|
||||||
|
colSpan={allCols}
|
||||||
|
>
|
||||||
|
<ScenarioPaginator />
|
||||||
|
</GridItem>
|
||||||
|
|
||||||
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
||||||
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
||||||
|
|||||||
@@ -1,11 +1,4 @@
|
|||||||
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
|
import { type GridItemProps } from "@chakra-ui/react";
|
||||||
|
|
||||||
export const stickyHeaderStyle: SystemStyleObject = {
|
|
||||||
position: "sticky",
|
|
||||||
top: "0",
|
|
||||||
backgroundColor: "#fff",
|
|
||||||
zIndex: 10,
|
|
||||||
};
|
|
||||||
|
|
||||||
export const borders: GridItemProps = {
|
export const borders: GridItemProps = {
|
||||||
borderRightWidth: 1,
|
borderRightWidth: 1,
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ import { type RouterOutputs } from "~/utils/api";
|
|||||||
|
|
||||||
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
||||||
|
|
||||||
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>["scenarios"][0];
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||||
import { type IconType } from "react-icons";
|
import { type IconType } from "react-icons";
|
||||||
|
import { BsStars } from "react-icons/bs";
|
||||||
|
|
||||||
export const RefineOption = ({
|
export const RefineAction = ({
|
||||||
label,
|
label,
|
||||||
icon,
|
icon,
|
||||||
desciption,
|
desciption,
|
||||||
@@ -10,7 +11,7 @@ export const RefineOption = ({
|
|||||||
loading,
|
loading,
|
||||||
}: {
|
}: {
|
||||||
label: string;
|
label: string;
|
||||||
icon: IconType;
|
icon?: IconType;
|
||||||
desciption: string;
|
desciption: string;
|
||||||
activeLabel: string | undefined;
|
activeLabel: string | undefined;
|
||||||
onClick: (label: string) => void;
|
onClick: (label: string) => void;
|
||||||
@@ -44,7 +45,7 @@ export const RefineOption = ({
|
|||||||
opacity={loading ? 0.5 : 1}
|
opacity={loading ? 0.5 : 1}
|
||||||
>
|
>
|
||||||
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
||||||
<Icon as={icon} boxSize={12} />
|
<Icon as={icon || BsStars} boxSize={12} />
|
||||||
<Heading size="md" fontFamily="inconsolata, monospace">
|
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||||
{label}
|
{label}
|
||||||
</Heading>
|
</Heading>
|
||||||
@@ -16,15 +16,15 @@ import {
|
|||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsStars } from "react-icons/bs";
|
import { BsStars } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import CompareFunctions from "./CompareFunctions";
|
import CompareFunctions from "./CompareFunctions";
|
||||||
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||||
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
|
import { RefineAction } from "./RefineAction";
|
||||||
import { RefineOption } from "./RefineOption";
|
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
import { type SupportedProvider } from "~/modelProviders/types";
|
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
|
||||||
export const RefinePromptModal = ({
|
export const RefinePromptModal = ({
|
||||||
variant,
|
variant,
|
||||||
@@ -34,14 +34,16 @@ export const RefinePromptModal = ({
|
|||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
|
const refinementActions =
|
||||||
|
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
|
||||||
|
|
||||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
const [instructions, setInstructions] = useState<string>("");
|
const [instructions, setInstructions] = useState<string>("");
|
||||||
|
|
||||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
|
const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>(
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -49,15 +51,15 @@ export const RefinePromptModal = ({
|
|||||||
async (label?: string) => {
|
async (label?: string) => {
|
||||||
if (!variant.experimentId) return;
|
if (!variant.experimentId) return;
|
||||||
const updatedInstructions = label
|
const updatedInstructions = label
|
||||||
? (providerRefineOptions[label] as RefineOptionInfo).instructions
|
? (refinementActions[label] as RefinementAction).instructions
|
||||||
: instructions;
|
: instructions;
|
||||||
setActiveRefineOptionLabel(label);
|
setActiveRefineActionLabel(label);
|
||||||
await getModifiedPromptMutateAsync({
|
await getModifiedPromptMutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
instructions: updatedInstructions,
|
instructions: updatedInstructions,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel],
|
||||||
);
|
);
|
||||||
|
|
||||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
@@ -72,6 +74,7 @@ export const RefinePromptModal = ({
|
|||||||
await replaceVariantMutation.mutateAsync({
|
await replaceVariantMutation.mutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
constructFn: refinedPromptFn,
|
constructFn: refinedPromptFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
onClose();
|
onClose();
|
||||||
@@ -95,18 +98,18 @@ export const RefinePromptModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<VStack spacing={4}>
|
<VStack spacing={4}>
|
||||||
{Object.keys(providerRefineOptions).length && (
|
{Object.keys(refinementActions).length && (
|
||||||
<>
|
<>
|
||||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||||
{Object.keys(providerRefineOptions).map((label) => (
|
{Object.keys(refinementActions).map((label) => (
|
||||||
<RefineOption
|
<RefineAction
|
||||||
key={label}
|
key={label}
|
||||||
label={label}
|
label={label}
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
icon={providerRefineOptions[label]!.icon}
|
icon={refinementActions[label]!.icon}
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
desciption={providerRefineOptions[label]!.description}
|
desciption={refinementActions[label]!.description}
|
||||||
activeLabel={activeRefineOptionLabel}
|
activeLabel={activeRefineActionLabel}
|
||||||
onClick={getModifiedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
loading={modificationInProgress}
|
loading={modificationInProgress}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -1,287 +0,0 @@
|
|||||||
// Super hacky, but we'll redo the organization when we have more models
|
|
||||||
|
|
||||||
import { type SupportedProvider } from "~/modelProviders/types";
|
|
||||||
import { VscJson } from "react-icons/vsc";
|
|
||||||
import { TfiThought } from "react-icons/tfi";
|
|
||||||
import { type IconType } from "react-icons";
|
|
||||||
|
|
||||||
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
|
|
||||||
|
|
||||||
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
|
|
||||||
"openai/ChatCompletion": {
|
|
||||||
"Add chain of thought": {
|
|
||||||
icon: VscJson,
|
|
||||||
description: "Asking the model to plan its answer can increase accuracy.",
|
|
||||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
|
||||||
|
|
||||||
This is what a prompt looks like before adding chain of thought:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Evaluate sentiment.\`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
This is what one looks like after adding chain of thought:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Evaluate sentiment.\`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
Here's another example:
|
|
||||||
|
|
||||||
Before:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Title: \${scenario.title}
|
|
||||||
Body: \${scenario.body}
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "score_post",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
score: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "score_post",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
After:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Title: \${scenario.title}
|
|
||||||
Body: \${scenario.body}
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "score_post",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
explanation: {
|
|
||||||
type: "string",
|
|
||||||
}
|
|
||||||
score: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "score_post",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Add chain of thought to the original prompt.`,
|
|
||||||
},
|
|
||||||
"Convert to function call": {
|
|
||||||
icon: TfiThought,
|
|
||||||
description: "Use function calls to get output from the model in a more structured way.",
|
|
||||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
|
||||||
|
|
||||||
This is what a prompt looks like before adding a function:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Evaluate sentiment.\`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
This is what one looks like after adding a function:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: "Evaluate sentiment.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: scenario.user_message,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "extract_sentiment",
|
|
||||||
parameters: {
|
|
||||||
type: "object", // parameters must always be an object with a properties key
|
|
||||||
properties: { // properties key is required
|
|
||||||
sentiment: {
|
|
||||||
type: "string",
|
|
||||||
description: "one of positive/negative/neutral",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "extract_sentiment",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Here's another example of adding a function:
|
|
||||||
|
|
||||||
Before:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Here is the title and body of a reddit post I am interested in:
|
|
||||||
|
|
||||||
title: \${scenario.title}
|
|
||||||
body: \${scenario.body}
|
|
||||||
|
|
||||||
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Answer one integer between 1 and 3.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
});
|
|
||||||
|
|
||||||
After:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Title: \${scenario.title}
|
|
||||||
Body: \${scenario.body}
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "score_post",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
score: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "score_post",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Another example
|
|
||||||
|
|
||||||
Before:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
After:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "write_in_language",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
text: {
|
|
||||||
type: "string",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "write_in_language",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"replicate/llama2": {},
|
|
||||||
};
|
|
||||||
@@ -6,7 +6,6 @@ import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
|||||||
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
|
||||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||||
|
|
||||||
export default function VariantHeader(
|
export default function VariantHeader(
|
||||||
@@ -53,7 +52,17 @@ export default function VariantHeader(
|
|||||||
|
|
||||||
if (!canModify) {
|
if (!canModify) {
|
||||||
return (
|
return (
|
||||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
|
<GridItem
|
||||||
|
padding={0}
|
||||||
|
sx={{
|
||||||
|
position: "sticky",
|
||||||
|
top: "0",
|
||||||
|
// Ensure that the menu always appears above the sticky header of other variants
|
||||||
|
zIndex: menuOpen ? "dropdown" : 10,
|
||||||
|
}}
|
||||||
|
borderTopWidth={1}
|
||||||
|
{...gridItemProps}
|
||||||
|
>
|
||||||
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||||
{variant.label}
|
{variant.label}
|
||||||
</Text>
|
</Text>
|
||||||
@@ -65,15 +74,16 @@ export default function VariantHeader(
|
|||||||
<GridItem
|
<GridItem
|
||||||
padding={0}
|
padding={0}
|
||||||
sx={{
|
sx={{
|
||||||
...stickyHeaderStyle,
|
position: "sticky",
|
||||||
|
top: "0",
|
||||||
// Ensure that the menu always appears above the sticky header of other variants
|
// Ensure that the menu always appears above the sticky header of other variants
|
||||||
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
zIndex: menuOpen ? "dropdown" : 10,
|
||||||
}}
|
}}
|
||||||
borderTopWidth={1}
|
borderTopWidth={1}
|
||||||
{...gridItemProps}
|
{...gridItemProps}
|
||||||
>
|
>
|
||||||
<HStack
|
<HStack
|
||||||
spacing={4}
|
spacing={2}
|
||||||
alignItems="flex-start"
|
alignItems="flex-start"
|
||||||
minH={headerMinHeight}
|
minH={headerMinHeight}
|
||||||
draggable={!isInputHovered}
|
draggable={!isInputHovered}
|
||||||
@@ -92,7 +102,8 @@ export default function VariantHeader(
|
|||||||
setIsDragTarget(false);
|
setIsDragTarget(false);
|
||||||
}}
|
}}
|
||||||
onDrop={onReorder}
|
onDrop={onReorder}
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
backgroundColor={isDragTarget ? "gray.200" : "gray.100"}
|
||||||
|
h="full"
|
||||||
>
|
>
|
||||||
<Icon
|
<Icon
|
||||||
as={RiDraggable}
|
as={RiDraggable}
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import { type PromptVariant } from "../OutputsTable/types";
|
import { type PromptVariant } from "../OutputsTable/types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
import {
|
import {
|
||||||
Button,
|
|
||||||
Icon,
|
Icon,
|
||||||
Menu,
|
Menu,
|
||||||
MenuButton,
|
MenuButton,
|
||||||
@@ -11,6 +10,7 @@ import {
|
|||||||
MenuDivider,
|
MenuDivider,
|
||||||
Text,
|
Text,
|
||||||
Spinner,
|
Spinner,
|
||||||
|
IconButton,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
||||||
import { FaRegClone } from "react-icons/fa";
|
import { FaRegClone } from "react-icons/fa";
|
||||||
@@ -33,11 +33,13 @@ export default function VariantHeaderMenuButton({
|
|||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
await duplicateMutation.mutateAsync({
|
await duplicateMutation.mutateAsync({
|
||||||
experimentId: variant.experimentId,
|
experimentId: variant.experimentId,
|
||||||
variantId: variant.id,
|
variantId: variant.id,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [duplicateMutation, variant.experimentId, variant.id]);
|
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||||
@@ -56,15 +58,12 @@ export default function VariantHeaderMenuButton({
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||||
{duplicationInProgress ? (
|
<MenuButton
|
||||||
<Spinner boxSize={4} mx={3} my={3} />
|
as={IconButton}
|
||||||
) : (
|
variant="ghost"
|
||||||
<MenuButton>
|
aria-label="Edit Scenarios"
|
||||||
<Button variant="ghost">
|
icon={<Icon as={duplicationInProgress ? Spinner : BsGear} />}
|
||||||
<Icon as={BsGear} />
|
/>
|
||||||
</Button>
|
|
||||||
</MenuButton>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<MenuList mt={-3} fontSize="md">
|
<MenuList mt={-3} fontSize="md">
|
||||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||||
|
|||||||
@@ -1,4 +1,13 @@
|
|||||||
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
|
import {
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Divider,
|
||||||
|
Spinner,
|
||||||
|
AspectRatio,
|
||||||
|
SkeletonText,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import { formatTimePast } from "~/utils/dayjs";
|
import { formatTimePast } from "~/utils/dayjs";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
@@ -93,3 +102,13 @@ export const NewExperimentCard = () => {
|
|||||||
</AspectRatio>
|
</AspectRatio>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const ExperimentCardSkeleton = () => (
|
||||||
|
<AspectRatio ratio={1.2} w="full">
|
||||||
|
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
|
||||||
|
<SkeletonText noOfLines={1} w="80%" />
|
||||||
|
<SkeletonText noOfLines={2} w="60%" />
|
||||||
|
<SkeletonText noOfLines={1} w="80%" />
|
||||||
|
</VStack>
|
||||||
|
</AspectRatio>
|
||||||
|
);
|
||||||
|
|||||||
57
src/components/experiments/HeaderButtons/DeleteDialog.tsx
Normal file
57
src/components/experiments/HeaderButtons/DeleteDialog.tsx
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { useRef } from "react";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
|
export const DeleteDialog = ({ onClose }: { onClose: () => void }) => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const deleteMutation = api.experiments.delete.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data?.id) return;
|
||||||
|
await deleteMutation.mutateAsync({ id: experiment.data.id });
|
||||||
|
await utils.experiments.list.invalidate();
|
||||||
|
await router.push({ pathname: "/experiments" });
|
||||||
|
onClose();
|
||||||
|
}, [deleteMutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog isOpen leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
Delete Experiment
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||||
|
as well. Are you sure?
|
||||||
|
</AlertDialogBody>
|
||||||
|
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<Button ref={cancelRef} onClick={onClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
42
src/components/experiments/HeaderButtons/HeaderButtons.tsx
Normal file
42
src/components/experiments/HeaderButtons/HeaderButtons.tsx
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import { Button, HStack, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||||
|
import { useOnForkButtonPressed } from "./useOnForkButtonPressed";
|
||||||
|
import { useExperiment } from "~/utils/hooks";
|
||||||
|
import { BsGearFill } from "react-icons/bs";
|
||||||
|
import { TbGitFork } from "react-icons/tb";
|
||||||
|
import { useAppStore } from "~/state/store";
|
||||||
|
|
||||||
|
export const HeaderButtons = () => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
|
||||||
|
const canModify = experiment.data?.access.canModify ?? false;
|
||||||
|
|
||||||
|
const { onForkButtonPressed, isForking } = useOnForkButtonPressed();
|
||||||
|
|
||||||
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
|
|
||||||
|
if (experiment.isLoading) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack spacing={0} mt={{ base: 2, md: 0 }}>
|
||||||
|
<Button
|
||||||
|
onClick={onForkButtonPressed}
|
||||||
|
mr={4}
|
||||||
|
colorScheme={canModify ? undefined : "orange"}
|
||||||
|
bgColor={canModify ? undefined : "orange.400"}
|
||||||
|
minW={0}
|
||||||
|
variant={{ base: "solid", md: canModify ? "ghost" : "solid" }}
|
||||||
|
>
|
||||||
|
{isForking ? <Spinner boxSize={5} /> : <Icon as={TbGitFork} boxSize={5} />}
|
||||||
|
<Text ml={2}>Fork</Text>
|
||||||
|
</Button>
|
||||||
|
{canModify && (
|
||||||
|
<Button variant={{ base: "solid", md: "ghost" }} onClick={openDrawer}>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={BsGearFill} />
|
||||||
|
<Text>Settings</Text>
|
||||||
|
</HStack>
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
import { useCallback } from "react";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
|
||||||
|
export const useOnForkButtonPressed = () => {
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const user = useSession().data;
|
||||||
|
const experiment = useExperiment();
|
||||||
|
|
||||||
|
const forkMutation = api.experiments.fork.useMutation();
|
||||||
|
|
||||||
|
const [onFork, isForking] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data?.id) return;
|
||||||
|
const forkedExperimentId = await forkMutation.mutateAsync({ id: experiment.data.id });
|
||||||
|
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
|
||||||
|
}, [forkMutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
const onForkButtonPressed = useCallback(() => {
|
||||||
|
if (user === null) {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
} else {
|
||||||
|
onFork();
|
||||||
|
}
|
||||||
|
}, [onFork, user]);
|
||||||
|
|
||||||
|
return { onForkButtonPressed, isForking };
|
||||||
|
};
|
||||||
@@ -18,6 +18,7 @@ export const env = createEnv({
|
|||||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||||
OPENAI_API_KEY: z.string().min(1),
|
OPENAI_API_KEY: z.string().min(1),
|
||||||
REPLICATE_API_TOKEN: z.string().default("placeholder"),
|
REPLICATE_API_TOKEN: z.string().default("placeholder"),
|
||||||
|
ANTHROPIC_API_KEY: z.string().default("placeholder"),
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -44,6 +45,7 @@ export const env = createEnv({
|
|||||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||||
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||||
|
ANTHROPIC_API_KEY: process.env.ANTHROPIC_API_KEY,
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||||
|
|||||||
69
src/modelProviders/anthropic/codegen/codegen.ts
Normal file
69
src/modelProviders/anthropic/codegen/codegen.ts
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-var-requires */
|
||||||
|
|
||||||
|
import YAML from "yaml";
|
||||||
|
import fs from "fs";
|
||||||
|
import path from "path";
|
||||||
|
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
|
||||||
|
import $RefParser from "@apidevtools/json-schema-ref-parser";
|
||||||
|
import { type JSONObject } from "superjson/dist/types";
|
||||||
|
import assert from "assert";
|
||||||
|
import { type JSONSchema4Object } from "json-schema";
|
||||||
|
import { isObject } from "lodash-es";
|
||||||
|
|
||||||
|
// @ts-expect-error for some reason missing from types
|
||||||
|
import parserEstree from "prettier/plugins/estree";
|
||||||
|
import parserBabel from "prettier/plugins/babel";
|
||||||
|
import prettier from "prettier/standalone";
|
||||||
|
|
||||||
|
const OPENAPI_URL =
|
||||||
|
"https://raw.githubusercontent.com/tryAGI/Anthropic/1c0871e861de60a4c3a843cb90e17d63e86c234a/docs/openapi.yaml";
|
||||||
|
|
||||||
|
// Fetch the openapi document
|
||||||
|
const response = await fetch(OPENAPI_URL);
|
||||||
|
const openApiYaml = await response.text();
|
||||||
|
|
||||||
|
// Parse the yaml document
|
||||||
|
let schema = YAML.parse(openApiYaml) as JSONObject;
|
||||||
|
schema = openapiSchemaToJsonSchema(schema);
|
||||||
|
|
||||||
|
const jsonSchema = await $RefParser.dereference(schema);
|
||||||
|
|
||||||
|
assert("components" in jsonSchema);
|
||||||
|
const completionRequestSchema = jsonSchema.components.schemas
|
||||||
|
.CreateCompletionRequest as JSONSchema4Object;
|
||||||
|
|
||||||
|
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
||||||
|
// the fact that the schema says `model` can be either a string or an enum,
|
||||||
|
// and displays a warning in the editor. Let's stick with just an enum for
|
||||||
|
// now and drop the string option.
|
||||||
|
assert(
|
||||||
|
"properties" in completionRequestSchema &&
|
||||||
|
isObject(completionRequestSchema.properties) &&
|
||||||
|
"model" in completionRequestSchema.properties &&
|
||||||
|
isObject(completionRequestSchema.properties.model),
|
||||||
|
);
|
||||||
|
|
||||||
|
const modelProperty = completionRequestSchema.properties.model;
|
||||||
|
assert(
|
||||||
|
"oneOf" in modelProperty &&
|
||||||
|
Array.isArray(modelProperty.oneOf) &&
|
||||||
|
modelProperty.oneOf.length === 2 &&
|
||||||
|
isObject(modelProperty.oneOf[1]) &&
|
||||||
|
"enum" in modelProperty.oneOf[1],
|
||||||
|
"Expected model to have oneOf length of 2",
|
||||||
|
);
|
||||||
|
modelProperty.type = "string";
|
||||||
|
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||||
|
delete modelProperty["oneOf"];
|
||||||
|
|
||||||
|
// Get the directory of the current script
|
||||||
|
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||||
|
|
||||||
|
// Write the JSON schema to a file in the current directory
|
||||||
|
fs.writeFileSync(
|
||||||
|
path.join(currentDirectory, "input.schema.json"),
|
||||||
|
await prettier.format(JSON.stringify(completionRequestSchema, null, 2), {
|
||||||
|
parser: "json",
|
||||||
|
plugins: [parserBabel, parserEstree],
|
||||||
|
}),
|
||||||
|
);
|
||||||
129
src/modelProviders/anthropic/codegen/input.schema.json
Normal file
129
src/modelProviders/anthropic/codegen/input.schema.json
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"description": "The model that will complete your prompt.\nAs we improve Claude, we develop new versions of it that you can query.\nThis parameter controls which version of Claude answers your request.\nRight now we are offering two model families: Claude, and Claude Instant.\nYou can use them by setting model to \"claude-2\" or \"claude-instant-1\", respectively.\nSee models for additional details.\n",
|
||||||
|
"x-oaiTypeLabel": "string",
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"claude-2",
|
||||||
|
"claude-2.0",
|
||||||
|
"claude-instant-1",
|
||||||
|
"claude-instant-1.1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"description": "The prompt that you want Claude to complete.\n\nFor proper response generation you will need to format your prompt as follows:\n\\n\\nHuman: ${userQuestion}\\n\\nAssistant:\nSee our comments on prompts for more context.\n",
|
||||||
|
"default": "<|endoftext|>",
|
||||||
|
"nullable": true,
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"example": "This is a test."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"example": "This is a test."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"example": "[1212, 318, 257, 1332, 13]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": "[[1212, 318, 257, 1332, 13]]"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"max_tokens_to_sample": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"default": 256,
|
||||||
|
"example": 256,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "The maximum number of tokens to generate before stopping.\n\nNote that our models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate.\n"
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Amount of randomness injected into the response.\n\nDefaults to 1. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and closer to 1 for creative and generative tasks.\n"
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1,
|
||||||
|
"default": 1,
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Use nucleus sampling.\n\nIn nucleus sampling, we compute the cumulative distribution over all the options \nfor each subsequent token in decreasing probability order and cut it off once \nit reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both.\n"
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 5,
|
||||||
|
"example": 5,
|
||||||
|
"nullable": true,
|
||||||
|
"description": "Only sample from the top K options for each subsequent token.\n\nUsed to remove \"long tail\" low probability responses. Learn more technical details here.\n"
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"description": "Whether to incrementally stream the response using server-sent events.\nSee this guide to SSE events for details.type: boolean\n",
|
||||||
|
"nullable": true,
|
||||||
|
"default": false
|
||||||
|
},
|
||||||
|
"stop_sequences": {
|
||||||
|
"description": "Sequences that will cause the model to stop generating completion text.\nOur models stop on \"\\n\\nHuman:\", and may include additional built-in stop sequences in the future. By providing the stop_sequences parameter, you may include additional strings that will cause the model to stop generating.\n",
|
||||||
|
"default": null,
|
||||||
|
"nullable": true,
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"default": "<|endoftext|>",
|
||||||
|
"example": "\n",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"maxItems": 4,
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "[\"\\n\"]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"user_id": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "13803d75-b4b5-4c3e-b2a2-6f21399b021b",
|
||||||
|
"description": "An external identifier for the user who is associated with the request.\n\nThis should be a uuid, hash value, or other opaque identifier. Anthropic may use this id to help detect abuse. \nDo not include any identifying information such as name, email address, or phone number.\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "An object describing metadata about the request.\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["model", "prompt", "max_tokens_to_sample"]
|
||||||
|
}
|
||||||
40
src/modelProviders/anthropic/frontend.ts
Normal file
40
src/modelProviders/anthropic/frontend.ts
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import { type Completion } from "@anthropic-ai/sdk/resources";
|
||||||
|
import { type SupportedModel } from ".";
|
||||||
|
import { type FrontendModelProvider } from "../types";
|
||||||
|
import { refinementActions } from "./refinementActions";
|
||||||
|
|
||||||
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, Completion> = {
|
||||||
|
name: "Replicate Llama2",
|
||||||
|
|
||||||
|
models: {
|
||||||
|
"claude-2.0": {
|
||||||
|
name: "Claude 2.0",
|
||||||
|
contextWindow: 100000,
|
||||||
|
promptTokenPrice: 11.02 / 1000000,
|
||||||
|
completionTokenPrice: 32.68 / 1000000,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "anthropic",
|
||||||
|
learnMoreUrl: "https://www.anthropic.com/product",
|
||||||
|
},
|
||||||
|
"claude-instant-1.1": {
|
||||||
|
name: "Claude Instant 1.1",
|
||||||
|
contextWindow: 100000,
|
||||||
|
promptTokenPrice: 1.63 / 1000000,
|
||||||
|
completionTokenPrice: 5.51 / 1000000,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "anthropic",
|
||||||
|
learnMoreUrl: "https://www.anthropic.com/product",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
refinementActions,
|
||||||
|
|
||||||
|
normalizeOutput: (output) => {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
value: output.completion,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export default frontendModelProvider;
|
||||||
86
src/modelProviders/anthropic/getCompletion.ts
Normal file
86
src/modelProviders/anthropic/getCompletion.ts
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import { env } from "~/env.mjs";
|
||||||
|
import { type CompletionResponse } from "../types";
|
||||||
|
|
||||||
|
import Anthropic, { APIError } from "@anthropic-ai/sdk";
|
||||||
|
import { type Completion, type CompletionCreateParams } from "@anthropic-ai/sdk/resources";
|
||||||
|
import { isObject, isString } from "lodash-es";
|
||||||
|
|
||||||
|
const anthropic = new Anthropic({
|
||||||
|
apiKey: env.ANTHROPIC_API_KEY,
|
||||||
|
});
|
||||||
|
|
||||||
|
export async function getCompletion(
|
||||||
|
input: CompletionCreateParams,
|
||||||
|
onStream: ((partialOutput: Completion) => void) | null,
|
||||||
|
): Promise<CompletionResponse<Completion>> {
|
||||||
|
const start = Date.now();
|
||||||
|
let finalCompletion: Completion | null = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (onStream) {
|
||||||
|
const resp = await anthropic.completions.create(
|
||||||
|
{ ...input, stream: true },
|
||||||
|
{
|
||||||
|
maxRetries: 0,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
for await (const part of resp) {
|
||||||
|
if (finalCompletion === null) {
|
||||||
|
finalCompletion = part;
|
||||||
|
} else {
|
||||||
|
finalCompletion = { ...part, completion: finalCompletion.completion + part.completion };
|
||||||
|
}
|
||||||
|
onStream(finalCompletion);
|
||||||
|
}
|
||||||
|
if (!finalCompletion) {
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: "Streaming failed to return a completion",
|
||||||
|
autoRetry: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const resp = await anthropic.completions.create(
|
||||||
|
{ ...input, stream: false },
|
||||||
|
{
|
||||||
|
maxRetries: 0,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
finalCompletion = resp;
|
||||||
|
}
|
||||||
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "success",
|
||||||
|
statusCode: 200,
|
||||||
|
value: finalCompletion,
|
||||||
|
timeToComplete,
|
||||||
|
};
|
||||||
|
} catch (error: unknown) {
|
||||||
|
console.log("CAUGHT ERROR", error);
|
||||||
|
if (error instanceof APIError) {
|
||||||
|
const message =
|
||||||
|
isObject(error.error) &&
|
||||||
|
"error" in error.error &&
|
||||||
|
isObject(error.error.error) &&
|
||||||
|
"message" in error.error.error &&
|
||||||
|
isString(error.error.error.message)
|
||||||
|
? error.error.error.message
|
||||||
|
: error.message;
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message,
|
||||||
|
autoRetry: error.status === 429 || error.status === 503,
|
||||||
|
statusCode: error.status,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
type: "error",
|
||||||
|
message: (error as Error).message,
|
||||||
|
autoRetry: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
34
src/modelProviders/anthropic/index.ts
Normal file
34
src/modelProviders/anthropic/index.ts
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import { type JSONSchema4 } from "json-schema";
|
||||||
|
import { type ModelProvider } from "../types";
|
||||||
|
import inputSchema from "./codegen/input.schema.json";
|
||||||
|
import { getCompletion } from "./getCompletion";
|
||||||
|
import frontendModelProvider from "./frontend";
|
||||||
|
import type { Completion, CompletionCreateParams } from "@anthropic-ai/sdk/resources";
|
||||||
|
|
||||||
|
const supportedModels = ["claude-2.0", "claude-instant-1.1"] as const;
|
||||||
|
|
||||||
|
export type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
|
export type AnthropicProvider = ModelProvider<SupportedModel, CompletionCreateParams, Completion>;
|
||||||
|
|
||||||
|
const modelProvider: AnthropicProvider = {
|
||||||
|
getModel: (input) => {
|
||||||
|
if (supportedModels.includes(input.model as SupportedModel))
|
||||||
|
return input.model as SupportedModel;
|
||||||
|
|
||||||
|
const modelMaps: Record<string, SupportedModel> = {
|
||||||
|
"claude-2": "claude-2.0",
|
||||||
|
"claude-instant-1": "claude-instant-1.1",
|
||||||
|
};
|
||||||
|
|
||||||
|
if (input.model in modelMaps) return modelMaps[input.model] as SupportedModel;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
|
canStream: true,
|
||||||
|
getCompletion,
|
||||||
|
...frontendModelProvider,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default modelProvider;
|
||||||
3
src/modelProviders/anthropic/refinementActions.ts
Normal file
3
src/modelProviders/anthropic/refinementActions.ts
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
import { type RefinementAction } from "../types";
|
||||||
|
|
||||||
|
export const refinementActions: Record<string, RefinementAction> = {};
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||||
|
import anthropicFrontend from "./anthropic/frontend";
|
||||||
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
||||||
|
|
||||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
|
||||||
|
|
||||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||||
// just include them in the default `modelProviders` object because it has some
|
// just include them in the default `modelProviders` object because it has some
|
||||||
// transient dependencies that can only be imported on the server.
|
// transient dependencies that can only be imported on the server.
|
||||||
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
||||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||||
"replicate/llama2": replicateLlama2Frontend,
|
"replicate/llama2": replicateLlama2Frontend,
|
||||||
|
anthropic: anthropicFrontend,
|
||||||
};
|
};
|
||||||
|
|
||||||
export default frontendModelProviders;
|
export default frontendModelProviders;
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||||
import replicateLlama2 from "./replicate-llama2";
|
import replicateLlama2 from "./replicate-llama2";
|
||||||
|
import anthropic from "./anthropic";
|
||||||
import { type SupportedProvider, type ModelProvider } from "./types";
|
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||||
|
|
||||||
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||||
"openai/ChatCompletion": openaiChatCompletion,
|
"openai/ChatCompletion": openaiChatCompletion,
|
||||||
"replicate/llama2": replicateLlama2,
|
"replicate/llama2": replicateLlama2,
|
||||||
|
anthropic,
|
||||||
};
|
};
|
||||||
|
|
||||||
export default modelProviders;
|
export default modelProviders;
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { type JsonValue } from "type-fest";
|
|||||||
import { type SupportedModel } from ".";
|
import { type SupportedModel } from ".";
|
||||||
import { type FrontendModelProvider } from "../types";
|
import { type FrontendModelProvider } from "../types";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
import { refinementActions } from "./refinementActions";
|
||||||
|
|
||||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||||
name: "OpenAI ChatCompletion",
|
name: "OpenAI ChatCompletion",
|
||||||
@@ -45,6 +46,8 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletio
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
refinementActions,
|
||||||
|
|
||||||
normalizeOutput: (output) => {
|
normalizeOutput: (output) => {
|
||||||
const message = output.choices[0]?.message;
|
const message = output.choices[0]?.message;
|
||||||
if (!message)
|
if (!message)
|
||||||
|
|||||||
@@ -120,7 +120,6 @@ export async function getCompletion(
|
|||||||
cost,
|
cost,
|
||||||
};
|
};
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
console.error("ERROR IS", error);
|
|
||||||
if (error instanceof APIError) {
|
if (error instanceof APIError) {
|
||||||
return {
|
return {
|
||||||
type: "error",
|
type: "error",
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ const modelProvider: OpenaiChatModelProvider = {
|
|||||||
return null;
|
return null;
|
||||||
},
|
},
|
||||||
inputSchema: inputSchema as JSONSchema4,
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
shouldStream: (input) => input.stream ?? false,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|||||||
279
src/modelProviders/openai-ChatCompletion/refinementActions.ts
Normal file
279
src/modelProviders/openai-ChatCompletion/refinementActions.ts
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
import { TfiThought } from "react-icons/tfi";
|
||||||
|
import { type RefinementAction } from "../types";
|
||||||
|
import { VscJson } from "react-icons/vsc";
|
||||||
|
|
||||||
|
export const refinementActions: Record<string, RefinementAction> = {
|
||||||
|
"Add chain of thought": {
|
||||||
|
icon: VscJson,
|
||||||
|
description: "Asking the model to plan its answer can increase accuracy.",
|
||||||
|
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding chain of thought:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
This is what one looks like after adding chain of thought:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
Here's another example:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
explanation: {
|
||||||
|
type: "string",
|
||||||
|
}
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Add chain of thought to the original prompt.`,
|
||||||
|
},
|
||||||
|
"Convert to function call": {
|
||||||
|
icon: TfiThought,
|
||||||
|
description: "Use function calls to get output from the model in a more structured way.",
|
||||||
|
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding a function:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
This is what one looks like after adding a function:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "Evaluate sentiment.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: scenario.user_message,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "extract_sentiment",
|
||||||
|
parameters: {
|
||||||
|
type: "object", // parameters must always be an object with a properties key
|
||||||
|
properties: { // properties key is required
|
||||||
|
sentiment: {
|
||||||
|
type: "string",
|
||||||
|
description: "one of positive/negative/neutral",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "extract_sentiment",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Here's another example of adding a function:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Here is the title and body of a reddit post I am interested in:
|
||||||
|
|
||||||
|
title: \${scenario.title}
|
||||||
|
body: \${scenario.body}
|
||||||
|
|
||||||
|
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Answer one integer between 1 and 3.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Another example
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "write_in_language",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
text: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "write_in_language",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||||
|
},
|
||||||
|
};
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||||
import { type FrontendModelProvider } from "../types";
|
import { type FrontendModelProvider } from "../types";
|
||||||
|
import { refinementActions } from "./refinementActions";
|
||||||
|
|
||||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||||
name: "Replicate Llama2",
|
name: "Replicate Llama2",
|
||||||
@@ -31,6 +32,8 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlam
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
refinementActions,
|
||||||
|
|
||||||
normalizeOutput: (output) => {
|
normalizeOutput: (output) => {
|
||||||
return {
|
return {
|
||||||
type: "text",
|
type: "text",
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ const replicate = new Replicate({
|
|||||||
});
|
});
|
||||||
|
|
||||||
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||||
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
"7b-chat": "5ec5fdadd80ace49f5a2b2178cceeb9f2f77c493b85b1131002c26e6b2b13184",
|
||||||
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
"13b-chat": "6b4da803a2382c08868c5af10a523892f38e2de1aafb2ee55b020d9efef2fdb8",
|
||||||
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
"70b-chat": "2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48",
|
||||||
};
|
};
|
||||||
|
|
||||||
export async function getCompletion(
|
export async function getCompletion(
|
||||||
@@ -19,7 +19,7 @@ export async function getCompletion(
|
|||||||
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||||
const start = Date.now();
|
const start = Date.now();
|
||||||
|
|
||||||
const { model, stream, ...rest } = input;
|
const { model, ...rest } = input;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const prediction = await replicate.predictions.create({
|
const prediction = await replicate.predictions.create({
|
||||||
@@ -27,8 +27,6 @@ export async function getCompletion(
|
|||||||
input: rest,
|
input: rest,
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log("stream?", onStream);
|
|
||||||
|
|
||||||
const interval = onStream
|
const interval = onStream
|
||||||
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||||
setInterval(async () => {
|
setInterval(async () => {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ export type SupportedModel = (typeof supportedModels)[number];
|
|||||||
export type ReplicateLlama2Input = {
|
export type ReplicateLlama2Input = {
|
||||||
model: SupportedModel;
|
model: SupportedModel;
|
||||||
prompt: string;
|
prompt: string;
|
||||||
stream?: boolean;
|
|
||||||
max_length?: number;
|
max_length?: number;
|
||||||
temperature?: number;
|
temperature?: number;
|
||||||
top_p?: number;
|
top_p?: number;
|
||||||
@@ -38,31 +37,43 @@ const modelProvider: ReplicateLlama2Provider = {
|
|||||||
type: "string",
|
type: "string",
|
||||||
enum: supportedModels as unknown as string[],
|
enum: supportedModels as unknown as string[],
|
||||||
},
|
},
|
||||||
|
system_prompt: {
|
||||||
|
type: "string",
|
||||||
|
description:
|
||||||
|
"System prompt to send to Llama v2. This is prepended to the prompt and helps guide system behavior.",
|
||||||
|
},
|
||||||
prompt: {
|
prompt: {
|
||||||
type: "string",
|
type: "string",
|
||||||
|
description: "Prompt to send to Llama v2.",
|
||||||
},
|
},
|
||||||
stream: {
|
max_new_tokens: {
|
||||||
type: "boolean",
|
|
||||||
},
|
|
||||||
max_length: {
|
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"Maximum number of tokens to generate. A word is generally 2-3 tokens (minimum: 1)",
|
||||||
},
|
},
|
||||||
temperature: {
|
temperature: {
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value. (minimum: 0.01; maximum: 5)",
|
||||||
},
|
},
|
||||||
top_p: {
|
top_p: {
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens (minimum: 0.01; maximum: 1)",
|
||||||
},
|
},
|
||||||
repetition_penalty: {
|
repetition_penalty: {
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it. (minimum: 0.01; maximum: 5)",
|
||||||
},
|
},
|
||||||
debug: {
|
debug: {
|
||||||
type: "boolean",
|
type: "boolean",
|
||||||
|
description: "provide debugging output in logs",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
required: ["model", "prompt"],
|
required: ["model", "prompt"],
|
||||||
},
|
},
|
||||||
shouldStream: (input) => input.stream ?? false,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|||||||
3
src/modelProviders/replicate-llama2/refinementActions.ts
Normal file
3
src/modelProviders/replicate-llama2/refinementActions.ts
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
import { type RefinementAction } from "../types";
|
||||||
|
|
||||||
|
export const refinementActions: Record<string, RefinementAction> = {};
|
||||||
@@ -1,31 +1,36 @@
|
|||||||
import { type JSONSchema4 } from "json-schema";
|
import { type JSONSchema4 } from "json-schema";
|
||||||
|
import { type IconType } from "react-icons";
|
||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
const ZodSupportedProvider = z.union([
|
export const ZodSupportedProvider = z.union([
|
||||||
z.literal("openai/ChatCompletion"),
|
z.literal("openai/ChatCompletion"),
|
||||||
z.literal("replicate/llama2"),
|
z.literal("replicate/llama2"),
|
||||||
|
z.literal("anthropic"),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||||
|
|
||||||
export const ZodModel = z.object({
|
export type Model = {
|
||||||
name: z.string(),
|
name: string;
|
||||||
contextWindow: z.number(),
|
contextWindow: number;
|
||||||
promptTokenPrice: z.number().optional(),
|
promptTokenPrice?: number;
|
||||||
completionTokenPrice: z.number().optional(),
|
completionTokenPrice?: number;
|
||||||
pricePerSecond: z.number().optional(),
|
pricePerSecond?: number;
|
||||||
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
|
speed: "fast" | "medium" | "slow";
|
||||||
provider: ZodSupportedProvider,
|
provider: SupportedProvider;
|
||||||
description: z.string().optional(),
|
description?: string;
|
||||||
learnMoreUrl: z.string().optional(),
|
learnMoreUrl?: string;
|
||||||
});
|
};
|
||||||
|
|
||||||
export type Model = z.infer<typeof ZodModel>;
|
export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };
|
||||||
|
|
||||||
|
export type RefinementAction = { icon?: IconType; description: string; instructions: string };
|
||||||
|
|
||||||
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||||
name: string;
|
name: string;
|
||||||
models: Record<SupportedModels, Model>;
|
models: Record<SupportedModels, Model>;
|
||||||
|
refinementActions?: Record<string, RefinementAction>;
|
||||||
|
|
||||||
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||||
};
|
};
|
||||||
@@ -44,7 +49,7 @@ export type CompletionResponse<T> =
|
|||||||
|
|
||||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||||
getModel: (input: InputSchema) => SupportedModels | null;
|
getModel: (input: InputSchema) => SupportedModels | null;
|
||||||
shouldStream: (input: InputSchema) => boolean;
|
canStream: boolean;
|
||||||
inputSchema: JSONSchema4;
|
inputSchema: JSONSchema4;
|
||||||
getCompletion: (
|
getCompletion: (
|
||||||
input: InputSchema,
|
input: InputSchema,
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import "~/utils/analytics";
|
|||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
||||||
import { SyncAppStore } from "~/state/sync";
|
import { SyncAppStore } from "~/state/sync";
|
||||||
|
import NextAdapterApp from "next-query-params/app";
|
||||||
|
import { QueryParamProvider } from "use-query-params";
|
||||||
|
|
||||||
const MyApp: AppType<{ session: Session | null }> = ({
|
const MyApp: AppType<{ session: Session | null }> = ({
|
||||||
Component,
|
Component,
|
||||||
@@ -24,7 +26,9 @@ const MyApp: AppType<{ session: Session | null }> = ({
|
|||||||
<SyncAppStore />
|
<SyncAppStore />
|
||||||
<Favicon />
|
<Favicon />
|
||||||
<ChakraThemeProvider>
|
<ChakraThemeProvider>
|
||||||
<Component {...pageProps} />
|
<QueryParamProvider adapter={NextAdapterApp}>
|
||||||
|
<Component {...pageProps} />
|
||||||
|
</QueryParamProvider>
|
||||||
</ChakraThemeProvider>
|
</ChakraThemeProvider>
|
||||||
</SessionProvider>
|
</SessionProvider>
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -2,106 +2,37 @@ import {
|
|||||||
Box,
|
Box,
|
||||||
Breadcrumb,
|
Breadcrumb,
|
||||||
BreadcrumbItem,
|
BreadcrumbItem,
|
||||||
Button,
|
|
||||||
Center,
|
Center,
|
||||||
Flex,
|
Flex,
|
||||||
Icon,
|
Icon,
|
||||||
Input,
|
Input,
|
||||||
AlertDialog,
|
|
||||||
AlertDialogBody,
|
|
||||||
AlertDialogFooter,
|
|
||||||
AlertDialogHeader,
|
|
||||||
AlertDialogContent,
|
|
||||||
AlertDialogOverlay,
|
|
||||||
useDisclosure,
|
|
||||||
Text,
|
Text,
|
||||||
HStack,
|
|
||||||
VStack,
|
VStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
|
|
||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import { useState, useEffect, useRef } from "react";
|
import { useState, useEffect } from "react";
|
||||||
import { BsGearFill, BsTrash } from "react-icons/bs";
|
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import OutputsTable from "~/components/OutputsTable";
|
import OutputsTable from "~/components/OutputsTable";
|
||||||
import SettingsDrawer from "~/components/OutputsTable/SettingsDrawer";
|
import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
|
||||||
import AppShell from "~/components/nav/AppShell";
|
import AppShell from "~/components/nav/AppShell";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
import { useSyncVariantEditor } from "~/state/sync";
|
import { useSyncVariantEditor } from "~/state/sync";
|
||||||
|
import { HeaderButtons } from "~/components/experiments/HeaderButtons/HeaderButtons";
|
||||||
const DeleteButton = () => {
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.experiments.delete.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
const router = useRouter();
|
|
||||||
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
|
||||||
|
|
||||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data?.id) return;
|
|
||||||
await mutation.mutateAsync({ id: experiment.data.id });
|
|
||||||
await utils.experiments.list.invalidate();
|
|
||||||
await router.push({ pathname: "/experiments" });
|
|
||||||
onClose();
|
|
||||||
}, [mutation, experiment.data?.id, router]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
variant={{ base: "outline", lg: "ghost" }}
|
|
||||||
colorScheme="gray"
|
|
||||||
fontWeight="normal"
|
|
||||||
onClick={onOpen}
|
|
||||||
>
|
|
||||||
<Icon as={BsTrash} boxSize={4} color="gray.600" />
|
|
||||||
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
|
||||||
Delete Experiment
|
|
||||||
</Text>
|
|
||||||
</Button>
|
|
||||||
|
|
||||||
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
|
||||||
<AlertDialogOverlay>
|
|
||||||
<AlertDialogContent>
|
|
||||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
|
||||||
Delete Experiment
|
|
||||||
</AlertDialogHeader>
|
|
||||||
|
|
||||||
<AlertDialogBody>
|
|
||||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
|
||||||
as well. Are you sure?
|
|
||||||
</AlertDialogBody>
|
|
||||||
|
|
||||||
<AlertDialogFooter>
|
|
||||||
<Button ref={cancelRef} onClick={onClose}>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
|
||||||
Delete
|
|
||||||
</Button>
|
|
||||||
</AlertDialogFooter>
|
|
||||||
</AlertDialogContent>
|
|
||||||
</AlertDialogOverlay>
|
|
||||||
</AlertDialog>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function Experiment() {
|
export default function Experiment() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
|
||||||
useSyncVariantEditor();
|
useSyncVariantEditor();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
||||||
|
});
|
||||||
|
|
||||||
const [label, setLabel] = useState(experiment.data?.label || "");
|
const [label, setLabel] = useState(experiment.data?.label || "");
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setLabel(experiment.data?.label || "");
|
setLabel(experiment.data?.label || "");
|
||||||
@@ -130,6 +61,14 @@ export default function Experiment() {
|
|||||||
|
|
||||||
const canModify = experiment.data?.access.canModify ?? false;
|
const canModify = experiment.data?.access.canModify ?? false;
|
||||||
|
|
||||||
|
const y = "5"
|
||||||
|
const z = {abc: "123"}
|
||||||
|
|
||||||
|
const func = () => {
|
||||||
|
const u = 12;
|
||||||
|
const m = `hello ${y} ${z.abc} ${u} world`;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AppShell title={experiment.data?.label}>
|
<AppShell title={experiment.data?.label}>
|
||||||
<VStack h="full">
|
<VStack h="full">
|
||||||
@@ -138,7 +77,7 @@ export default function Experiment() {
|
|||||||
py={2}
|
py={2}
|
||||||
w="full"
|
w="full"
|
||||||
direction={{ base: "column", sm: "row" }}
|
direction={{ base: "column", sm: "row" }}
|
||||||
alignItems="flex-start"
|
alignItems={{ base: "flex-start", sm: "center" }}
|
||||||
>
|
>
|
||||||
<Breadcrumb flex={1}>
|
<Breadcrumb flex={1}>
|
||||||
<BreadcrumbItem>
|
<BreadcrumbItem>
|
||||||
@@ -171,27 +110,11 @@ export default function Experiment() {
|
|||||||
)}
|
)}
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
{canModify && (
|
<HeaderButtons />
|
||||||
<HStack>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
variant={{ base: "outline", lg: "ghost" }}
|
|
||||||
colorScheme="gray"
|
|
||||||
fontWeight="normal"
|
|
||||||
onClick={openDrawer}
|
|
||||||
>
|
|
||||||
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
|
|
||||||
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
|
||||||
Edit Vars & Evals
|
|
||||||
</Text>
|
|
||||||
</Button>
|
|
||||||
<DeleteButton />
|
|
||||||
</HStack>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
<SettingsDrawer />
|
<ExperimentSettingsDrawer />
|
||||||
<Box w="100%" overflowX="auto" flex={1}>
|
<Box w="100%" overflowX="auto" flex={1}>
|
||||||
<OutputsTable experimentId={router.query.id as string | undefined} />
|
<OutputsTable experimentId={router.query.id as string | undefined} func={func} />
|
||||||
</Box>
|
</Box>
|
||||||
</VStack>
|
</VStack>
|
||||||
</AppShell>
|
</AppShell>
|
||||||
|
|||||||
@@ -13,7 +13,11 @@ import {
|
|||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import AppShell from "~/components/nav/AppShell";
|
import AppShell from "~/components/nav/AppShell";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
|
import {
|
||||||
|
ExperimentCard,
|
||||||
|
ExperimentCardSkeleton,
|
||||||
|
NewExperimentCard,
|
||||||
|
} from "~/components/experiments/ExperimentCard";
|
||||||
import { signIn, useSession } from "next-auth/react";
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
|
||||||
export default function ExperimentsPage() {
|
export default function ExperimentsPage() {
|
||||||
@@ -47,7 +51,7 @@ export default function ExperimentsPage() {
|
|||||||
return (
|
return (
|
||||||
<AppShell title="Experiments">
|
<AppShell title="Experiments">
|
||||||
<VStack alignItems={"flex-start"} px={4} py={2}>
|
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||||
<HStack minH={8} align="center">
|
<HStack minH={8} align="center" pt={2}>
|
||||||
<Breadcrumb flex={1}>
|
<Breadcrumb flex={1}>
|
||||||
<BreadcrumbItem>
|
<BreadcrumbItem>
|
||||||
<Flex alignItems="center">
|
<Flex alignItems="center">
|
||||||
@@ -58,7 +62,15 @@ export default function ExperimentsPage() {
|
|||||||
</HStack>
|
</HStack>
|
||||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||||
<NewExperimentCard />
|
<NewExperimentCard />
|
||||||
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
{experiments.data && !experiments.isLoading ? (
|
||||||
|
experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<ExperimentCardSkeleton />
|
||||||
|
<ExperimentCardSkeleton />
|
||||||
|
<ExperimentCardSkeleton />
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</VStack>
|
</VStack>
|
||||||
</AppShell>
|
</AppShell>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client";
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { runAllEvals } from "~/server/utils/evaluations";
|
import { queueRunNewEval } from "~/server/tasks/runNewEval.task";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const evaluationsRouter = createTRPCRouter({
|
export const evaluationsRouter = createTRPCRouter({
|
||||||
@@ -40,9 +40,7 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
await queueRunNewEval(input.experimentId);
|
||||||
// to kick off a background job or something instead
|
|
||||||
await runAllEvals(input.experimentId);
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: protectedProcedure
|
update: protectedProcedure
|
||||||
@@ -78,7 +76,7 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
// Re-run all evals. Other eval results will already be cached, so this
|
// Re-run all evals. Other eval results will already be cached, so this
|
||||||
// should only re-run the updated one.
|
// should only re-run the updated one.
|
||||||
await runAllEvals(evaluation.experimentId);
|
await queueRunNewEval(experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: protectedProcedure
|
delete: protectedProcedure
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
import { v4 as uuidv4 } from "uuid";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
@@ -20,7 +22,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
const experiments = await prisma.experiment.findMany({
|
const experiments = await prisma.experiment.findMany({
|
||||||
where: {
|
where: {
|
||||||
organization: {
|
organization: {
|
||||||
OrganizationUser: {
|
organizationUsers: {
|
||||||
some: { userId: ctx.session.user.id },
|
some: { userId: ctx.session.user.id },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -77,6 +79,189 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.id, ctx);
|
||||||
|
|
||||||
|
const [
|
||||||
|
existingExp,
|
||||||
|
existingVariants,
|
||||||
|
existingScenarios,
|
||||||
|
existingCells,
|
||||||
|
evaluations,
|
||||||
|
templateVariables,
|
||||||
|
] = await prisma.$transaction([
|
||||||
|
prisma.experiment.findUniqueOrThrow({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.scenarioVariantCell.findMany({
|
||||||
|
where: {
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
promptVariant: {
|
||||||
|
experimentId: input.id,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelResponses: {
|
||||||
|
include: {
|
||||||
|
outputEvaluations: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.evaluation.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.templateVariable.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const newExperimentId = uuidv4();
|
||||||
|
|
||||||
|
const existingToNewVariantIds = new Map<string, string>();
|
||||||
|
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
|
||||||
|
for (const variant of existingVariants) {
|
||||||
|
const newVariantId = uuidv4();
|
||||||
|
existingToNewVariantIds.set(variant.id, newVariantId);
|
||||||
|
variantsToCreate.push({
|
||||||
|
...variant,
|
||||||
|
id: newVariantId,
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const existingToNewScenarioIds = new Map<string, string>();
|
||||||
|
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
|
||||||
|
for (const scenario of existingScenarios) {
|
||||||
|
const newScenarioId = uuidv4();
|
||||||
|
existingToNewScenarioIds.set(scenario.id, newScenarioId);
|
||||||
|
scenariosToCreate.push({
|
||||||
|
...scenario,
|
||||||
|
id: newScenarioId,
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
variableValues: scenario.variableValues as Prisma.InputJsonValue,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const existingToNewEvaluationIds = new Map<string, string>();
|
||||||
|
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
|
||||||
|
for (const evaluation of evaluations) {
|
||||||
|
const newEvaluationId = uuidv4();
|
||||||
|
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
|
||||||
|
evaluationsToCreate.push({
|
||||||
|
...evaluation,
|
||||||
|
id: newEvaluationId,
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
|
||||||
|
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
|
||||||
|
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
|
||||||
|
for (const cell of existingCells) {
|
||||||
|
const newCellId = uuidv4();
|
||||||
|
const { modelResponses, ...cellData } = cell;
|
||||||
|
cellsToCreate.push({
|
||||||
|
...cellData,
|
||||||
|
id: newCellId,
|
||||||
|
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
|
||||||
|
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
|
||||||
|
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
|
||||||
|
});
|
||||||
|
for (const modelResponse of modelResponses) {
|
||||||
|
const newModelResponseId = uuidv4();
|
||||||
|
const { outputEvaluations, ...modelResponseData } = modelResponse;
|
||||||
|
modelResponsesToCreate.push({
|
||||||
|
...modelResponseData,
|
||||||
|
id: newModelResponseId,
|
||||||
|
scenarioVariantCellId: newCellId,
|
||||||
|
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
|
||||||
|
});
|
||||||
|
for (const evaluation of outputEvaluations) {
|
||||||
|
outputEvaluationsToCreate.push({
|
||||||
|
...evaluation,
|
||||||
|
id: uuidv4(),
|
||||||
|
modelResponseId: newModelResponseId,
|
||||||
|
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
|
||||||
|
for (const templateVariable of templateVariables) {
|
||||||
|
templateVariablesToCreate.push({
|
||||||
|
...templateVariable,
|
||||||
|
id: uuidv4(),
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxSortIndex =
|
||||||
|
(
|
||||||
|
await prisma.experiment.aggregate({
|
||||||
|
_max: {
|
||||||
|
sortIndex: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)._max?.sortIndex ?? 0;
|
||||||
|
|
||||||
|
await prisma.$transaction([
|
||||||
|
prisma.experiment.create({
|
||||||
|
data: {
|
||||||
|
id: newExperimentId,
|
||||||
|
sortIndex: maxSortIndex + 1,
|
||||||
|
label: `${existingExp.label} (forked)`,
|
||||||
|
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.promptVariant.createMany({
|
||||||
|
data: variantsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.testScenario.createMany({
|
||||||
|
data: scenariosToCreate,
|
||||||
|
}),
|
||||||
|
prisma.scenarioVariantCell.createMany({
|
||||||
|
data: cellsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.modelResponse.createMany({
|
||||||
|
data: modelResponsesToCreate,
|
||||||
|
}),
|
||||||
|
prisma.evaluation.createMany({
|
||||||
|
data: evaluationsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.outputEvaluation.createMany({
|
||||||
|
data: outputEvaluationsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.templateVariable.createMany({
|
||||||
|
data: templateVariablesToCreate,
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
return newExperimentId;
|
||||||
|
}),
|
||||||
|
|
||||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||||
// Anyone can create an experiment
|
// Anyone can create an experiment
|
||||||
requireNothing(ctx);
|
requireNothing(ctx);
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { Prisma } from "@prisma/client";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import userError from "~/server/utils/error";
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
@@ -9,7 +10,8 @@ import { type PromptVariant } from "@prisma/client";
|
|||||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
import parseConstructFn from "~/server/utils/parseConstructFn";
|
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||||
import { ZodModel } from "~/modelProviders/types";
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { ZodSupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
@@ -50,7 +52,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
id: true,
|
id: true,
|
||||||
},
|
},
|
||||||
where: {
|
where: {
|
||||||
modelOutput: {
|
modelResponse: {
|
||||||
|
outdated: false,
|
||||||
|
output: { not: Prisma.AnyNull },
|
||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
promptVariant: {
|
promptVariant: {
|
||||||
id: input.variantId,
|
id: input.variantId,
|
||||||
@@ -92,14 +96,23 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
where: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenario: { visible: true },
|
testScenario: { visible: true },
|
||||||
modelOutput: {
|
modelResponses: {
|
||||||
is: {},
|
some: {
|
||||||
|
outdated: false,
|
||||||
|
output: {
|
||||||
|
not: Prisma.AnyNull,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const overallTokens = await prisma.modelOutput.aggregate({
|
const overallTokens = await prisma.modelResponse.aggregate({
|
||||||
where: {
|
where: {
|
||||||
|
outdated: false,
|
||||||
|
output: {
|
||||||
|
not: Prisma.AnyNull,
|
||||||
|
},
|
||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenario: {
|
testScenario: {
|
||||||
@@ -117,16 +130,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||||
|
|
||||||
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
const awaitingEvals = !!evalResults.find(
|
||||||
where: {
|
(result) => result.totalCount < scenarioCount * evals.length,
|
||||||
promptVariantId: input.variantId,
|
);
|
||||||
testScenario: { visible: true },
|
|
||||||
// Check if is PENDING or IN_PROGRESS
|
|
||||||
retrievalStatus: {
|
|
||||||
in: ["PENDING", "IN_PROGRESS"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
evalResults,
|
evalResults,
|
||||||
@@ -135,7 +141,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
overallCost: overallTokens._sum?.cost ?? 0,
|
overallCost: overallTokens._sum?.cost ?? 0,
|
||||||
scenarioCount,
|
scenarioCount,
|
||||||
outputCount,
|
outputCount,
|
||||||
awaitingRetrievals,
|
awaitingEvals,
|
||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -144,7 +150,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
variantId: z.string().optional(),
|
variantId: z.string().optional(),
|
||||||
newModel: ZodModel.optional(),
|
streamScenarios: z.array(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -186,7 +192,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
? `${originalVariant?.label} Copy`
|
? `${originalVariant?.label} Copy`
|
||||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
|
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -218,7 +224,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const scenario of scenarios) {
|
for (const scenario of scenarios) {
|
||||||
await generateNewCell(newVariant.id, scenario.id);
|
await generateNewCell(newVariant.id, scenario.id, {
|
||||||
|
stream: input.streamScenarios.includes(scenario.id),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return newVariant;
|
return newVariant;
|
||||||
@@ -286,7 +294,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
instructions: z.string().optional(),
|
instructions: z.string().optional(),
|
||||||
newModel: ZodModel.optional(),
|
newModel: z
|
||||||
|
.object({
|
||||||
|
provider: ZodSupportedProvider,
|
||||||
|
model: z.string(),
|
||||||
|
})
|
||||||
|
.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -303,11 +316,11 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return userError(constructedPrompt.error);
|
return userError(constructedPrompt.error);
|
||||||
}
|
}
|
||||||
|
|
||||||
const promptConstructionFn = await deriveNewConstructFn(
|
const model = input.newModel
|
||||||
existing,
|
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||||
input.newModel,
|
: undefined;
|
||||||
input.instructions,
|
|
||||||
);
|
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||||
|
|
||||||
// TODO: Validate promptConstructionFn
|
// TODO: Validate promptConstructionFn
|
||||||
// TODO: Record in some sort of history
|
// TODO: Record in some sort of history
|
||||||
@@ -320,6 +333,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
constructFn: z.string(),
|
constructFn: z.string(),
|
||||||
|
streamScenarios: z.array(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -377,7 +391,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const scenario of scenarios) {
|
for (const scenario of scenarios) {
|
||||||
await generateNewCell(newVariant.id, scenario.id);
|
await generateNewCell(newVariant.id, scenario.id, {
|
||||||
|
stream: input.streamScenarios.includes(scenario.id),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return { status: "ok" } as const;
|
return { status: "ok" } as const;
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { queueQueryModel } from "~/server/tasks/queryModel.task";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||||
@@ -19,27 +19,45 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
await requireCanViewExperiment(experimentId, ctx);
|
await requireCanViewExperiment(experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.scenarioVariantCell.findUnique({
|
const [cell, numTotalEvals] = await prisma.$transaction([
|
||||||
where: {
|
prisma.scenarioVariantCell.findUnique({
|
||||||
promptVariantId_testScenarioId: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId_testScenarioId: {
|
||||||
testScenarioId: input.scenarioId,
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
include: {
|
||||||
include: {
|
modelResponses: {
|
||||||
modelOutput: {
|
where: {
|
||||||
include: {
|
outdated: false,
|
||||||
outputEvaluation: {
|
},
|
||||||
include: {
|
include: {
|
||||||
evaluation: {
|
outputEvaluations: {
|
||||||
select: { label: true },
|
include: {
|
||||||
|
evaluation: {
|
||||||
|
select: { label: true },
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
});
|
prisma.evaluation.count({
|
||||||
|
where: { experimentId },
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
if (!cell) return null;
|
||||||
|
|
||||||
|
const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1];
|
||||||
|
const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...cell,
|
||||||
|
evalsComplete,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
forceRefetch: protectedProcedure
|
forceRefetch: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
@@ -62,29 +80,20 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
testScenarioId: input.scenarioId,
|
testScenarioId: input.scenarioId,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
include: {
|
|
||||||
modelOutput: true,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!cell) {
|
if (!cell) {
|
||||||
await generateNewCell(input.variantId, input.scenarioId);
|
await generateNewCell(input.variantId, input.scenarioId, { stream: true });
|
||||||
return true;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell.modelOutput) {
|
await prisma.modelResponse.updateMany({
|
||||||
// TODO: Maybe keep these around to show previous generations?
|
where: { scenarioVariantCellId: cell.id },
|
||||||
await prisma.modelOutput.delete({
|
data: {
|
||||||
where: { id: cell.modelOutput.id },
|
outdated: true,
|
||||||
});
|
},
|
||||||
}
|
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: cell.id },
|
|
||||||
data: { retrievalStatus: "PENDING" },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
await queueLLMRetrievalTask(cell.id);
|
await queueQueryModel(cell.id, true);
|
||||||
return true;
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -7,21 +7,39 @@ import { runAllEvals } from "~/server/utils/evaluations";
|
|||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
|
const PAGE_SIZE = 10;
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
.input(z.object({ experimentId: z.string() }))
|
.input(z.object({ experimentId: z.string(), page: z.number() }))
|
||||||
.query(async ({ input, ctx }) => {
|
.query(async ({ input, ctx }) => {
|
||||||
await requireCanViewExperiment(input.experimentId, ctx);
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.testScenario.findMany({
|
const { experimentId, page } = input;
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId,
|
||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
orderBy: {
|
orderBy: { sortIndex: "asc" },
|
||||||
sortIndex: "asc",
|
skip: (page - 1) * PAGE_SIZE,
|
||||||
|
take: PAGE_SIZE,
|
||||||
|
});
|
||||||
|
|
||||||
|
const count = await prisma.testScenario.count({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
visible: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
scenarios,
|
||||||
|
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||||
|
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||||
|
count,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: protectedProcedure
|
create: protectedProcedure
|
||||||
@@ -68,7 +86,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const variant of promptVariants) {
|
for (const variant of promptVariants) {
|
||||||
await generateNewCell(variant.id, scenario.id);
|
await generateNewCell(variant.id, scenario.id, { stream: true });
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -212,7 +230,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const variant of promptVariants) {
|
for (const variant of promptVariants) {
|
||||||
await generateNewCell(variant.id, newScenario.id);
|
await generateNewCell(variant.id, newScenario.id, { stream: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
return newScenario;
|
return newScenario;
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ function defineTask<TPayload>(
|
|||||||
taskIdentifier: string,
|
taskIdentifier: string,
|
||||||
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||||
) {
|
) {
|
||||||
const enqueue = async (payload: TPayload) => {
|
const enqueue = async (payload: TPayload, runAt?: Date) => {
|
||||||
console.log("Enqueuing task", taskIdentifier, payload);
|
console.log("Enqueuing task", taskIdentifier, payload);
|
||||||
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload, { runAt });
|
||||||
};
|
};
|
||||||
|
|
||||||
const handler = (payload: TPayload, helpers: Helpers) => {
|
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||||
|
|||||||
@@ -1,165 +0,0 @@
|
|||||||
import { prisma } from "~/server/db";
|
|
||||||
import defineTask from "./defineTask";
|
|
||||||
import { sleep } from "../utils/sleep";
|
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
|
||||||
import { runEvalsForOutput } from "../utils/evaluations";
|
|
||||||
import { type Prisma } from "@prisma/client";
|
|
||||||
import parseConstructFn from "../utils/parseConstructFn";
|
|
||||||
import hashPrompt from "../utils/hashPrompt";
|
|
||||||
import { type JsonObject } from "type-fest";
|
|
||||||
import modelProviders from "~/modelProviders/modelProviders";
|
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
|
||||||
|
|
||||||
export type queryLLMJob = {
|
|
||||||
scenarioVariantCellId: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 10;
|
|
||||||
const MIN_DELAY = 500; // milliseconds
|
|
||||||
const MAX_DELAY = 15000; // milliseconds
|
|
||||||
|
|
||||||
function calculateDelay(numPreviousTries: number): number {
|
|
||||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
|
||||||
const jitter = Math.random() * baseDelay;
|
|
||||||
return baseDelay + jitter;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|
||||||
const { scenarioVariantCellId } = task;
|
|
||||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
include: { modelOutput: true },
|
|
||||||
});
|
|
||||||
if (!cell) {
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: {
|
|
||||||
statusCode: 404,
|
|
||||||
errorMessage: "Cell not found",
|
|
||||||
retrievalStatus: "ERROR",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If cell is not pending, then some other job is already processing it
|
|
||||||
if (cell.retrievalStatus !== "PENDING") {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: {
|
|
||||||
retrievalStatus: "IN_PROGRESS",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
|
||||||
where: { id: cell.promptVariantId },
|
|
||||||
});
|
|
||||||
if (!variant) {
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: {
|
|
||||||
statusCode: 404,
|
|
||||||
errorMessage: "Prompt Variant not found",
|
|
||||||
retrievalStatus: "ERROR",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const scenario = await prisma.testScenario.findUnique({
|
|
||||||
where: { id: cell.testScenarioId },
|
|
||||||
});
|
|
||||||
if (!scenario) {
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: {
|
|
||||||
statusCode: 404,
|
|
||||||
errorMessage: "Scenario not found",
|
|
||||||
retrievalStatus: "ERROR",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const prompt = await parseConstructFn(variant.constructFn, scenario.variableValues as JsonObject);
|
|
||||||
|
|
||||||
if ("error" in prompt) {
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: {
|
|
||||||
statusCode: 400,
|
|
||||||
errorMessage: prompt.error,
|
|
||||||
retrievalStatus: "ERROR",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const provider = modelProviders[prompt.modelProvider];
|
|
||||||
|
|
||||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
|
||||||
|
|
||||||
if (streamingChannel) {
|
|
||||||
// Save streaming channel so that UI can connect to it
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: { streamingChannel },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
const onStream = streamingChannel
|
|
||||||
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
|
||||||
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
|
|
||||||
}
|
|
||||||
: null;
|
|
||||||
|
|
||||||
for (let i = 0; true; i++) {
|
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
|
||||||
if (response.type === "success") {
|
|
||||||
const inputHash = hashPrompt(prompt);
|
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.create({
|
|
||||||
data: {
|
|
||||||
scenarioVariantCellId,
|
|
||||||
inputHash,
|
|
||||||
output: response.value as Prisma.InputJsonObject,
|
|
||||||
timeToComplete: response.timeToComplete,
|
|
||||||
promptTokens: response.promptTokens,
|
|
||||||
completionTokens: response.completionTokens,
|
|
||||||
cost: response.cost,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: {
|
|
||||||
statusCode: response.statusCode,
|
|
||||||
retrievalStatus: "COMPLETE",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
|
||||||
const delay = calculateDelay(i);
|
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: {
|
|
||||||
errorMessage: response.message,
|
|
||||||
statusCode: response.statusCode,
|
|
||||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
|
||||||
retrievalStatus: "ERROR",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (shouldRetry) {
|
|
||||||
await sleep(delay);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
185
src/server/tasks/queryModel.task.ts
Normal file
185
src/server/tasks/queryModel.task.ts
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
import { type JsonObject } from "type-fest";
|
||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
|
import hashPrompt from "../utils/hashPrompt";
|
||||||
|
import parseConstructFn from "../utils/parseConstructFn";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
|
||||||
|
export type QueryModelJob = {
|
||||||
|
cellId: string;
|
||||||
|
stream: boolean;
|
||||||
|
numPreviousTries: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
const MAX_AUTO_RETRIES = 50;
|
||||||
|
const MIN_DELAY = 500; // milliseconds
|
||||||
|
const MAX_DELAY = 15000; // milliseconds
|
||||||
|
|
||||||
|
function calculateDelay(numPreviousTries: number): number {
|
||||||
|
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||||
|
const jitter = Math.random() * baseDelay;
|
||||||
|
return baseDelay + jitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
||||||
|
console.log("RUNNING TASK", task);
|
||||||
|
const { cellId, stream, numPreviousTries } = task;
|
||||||
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: { id: cellId },
|
||||||
|
include: { modelResponses: true },
|
||||||
|
});
|
||||||
|
if (!cell) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cell is not pending, then some other job is already processing it
|
||||||
|
if (cell.retrievalStatus !== "PENDING") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "IN_PROGRESS",
|
||||||
|
jobStartedAt: new Date(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: { id: cell.promptVariantId },
|
||||||
|
});
|
||||||
|
if (!variant) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
errorMessage: "Prompt Variant not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scenario = await prisma.testScenario.findUnique({
|
||||||
|
where: { id: cell.testScenarioId },
|
||||||
|
});
|
||||||
|
if (!scenario) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
errorMessage: "Scenario not found",
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = await parseConstructFn(variant.constructFn, scenario.variableValues as JsonObject);
|
||||||
|
|
||||||
|
if ("error" in prompt) {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
errorMessage: prompt.error,
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
|
const onStream = stream
|
||||||
|
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
||||||
|
wsConnection.emit("message", { channel: cell.id, payload: partialOutput });
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
|
const inputHash = hashPrompt(prompt);
|
||||||
|
|
||||||
|
let modelResponse = await prisma.modelResponse.create({
|
||||||
|
data: {
|
||||||
|
inputHash,
|
||||||
|
scenarioVariantCellId: cellId,
|
||||||
|
requestedAt: new Date(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
|
if (response.type === "success") {
|
||||||
|
modelResponse = await prisma.modelResponse.update({
|
||||||
|
where: { id: modelResponse.id },
|
||||||
|
data: {
|
||||||
|
output: response.value as Prisma.InputJsonObject,
|
||||||
|
statusCode: response.statusCode,
|
||||||
|
receivedAt: new Date(),
|
||||||
|
promptTokens: response.promptTokens,
|
||||||
|
completionTokens: response.completionTokens,
|
||||||
|
cost: response.cost,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "COMPLETE",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
||||||
|
} else {
|
||||||
|
const shouldRetry = response.autoRetry && numPreviousTries < MAX_AUTO_RETRIES;
|
||||||
|
const delay = calculateDelay(numPreviousTries);
|
||||||
|
const retryTime = new Date(Date.now() + delay);
|
||||||
|
|
||||||
|
await prisma.modelResponse.update({
|
||||||
|
where: { id: modelResponse.id },
|
||||||
|
data: {
|
||||||
|
statusCode: response.statusCode,
|
||||||
|
errorMessage: response.message,
|
||||||
|
receivedAt: new Date(),
|
||||||
|
retryTime: shouldRetry ? retryTime : null,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (shouldRetry) {
|
||||||
|
await queryModel.enqueue(
|
||||||
|
{
|
||||||
|
cellId,
|
||||||
|
stream,
|
||||||
|
numPreviousTries: numPreviousTries + 1,
|
||||||
|
},
|
||||||
|
retryTime,
|
||||||
|
);
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "PENDING",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||||
|
await Promise.all([
|
||||||
|
prisma.scenarioVariantCell.update({
|
||||||
|
where: {
|
||||||
|
id: cellId,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "PENDING",
|
||||||
|
errorMessage: null,
|
||||||
|
jobQueuedAt: new Date(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
queryModel.enqueue({ cellId, stream, numPreviousTries: 0 }),
|
||||||
|
]);
|
||||||
|
};
|
||||||
17
src/server/tasks/runNewEval.task.ts
Normal file
17
src/server/tasks/runNewEval.task.ts
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import { runAllEvals } from "../utils/evaluations";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
|
||||||
|
export type RunNewEvalJob = {
|
||||||
|
experimentId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// When a new eval is created, we want to run it on all existing outputs, but return the new eval first
|
||||||
|
export const runNewEval = defineTask<RunNewEvalJob>("runNewEval", async (task) => {
|
||||||
|
console.log("RUNNING TASK", task);
|
||||||
|
const { experimentId } = task;
|
||||||
|
await runAllEvals(experimentId);
|
||||||
|
});
|
||||||
|
|
||||||
|
export const queueRunNewEval = async (experimentId: string) => {
|
||||||
|
await runNewEval.enqueue({ experimentId });
|
||||||
|
};
|
||||||
@@ -2,39 +2,28 @@ import { type TaskList, run } from "graphile-worker";
|
|||||||
import "dotenv/config";
|
import "dotenv/config";
|
||||||
|
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { queryLLM } from "./queryLLM.task";
|
import { queryModel } from "./queryModel.task";
|
||||||
|
import { runNewEval } from "./runNewEval.task";
|
||||||
|
|
||||||
const registeredTasks = [queryLLM];
|
console.log("Starting worker");
|
||||||
|
|
||||||
|
const registeredTasks = [queryModel, runNewEval];
|
||||||
|
|
||||||
const taskList = registeredTasks.reduce((acc, task) => {
|
const taskList = registeredTasks.reduce((acc, task) => {
|
||||||
acc[task.task.identifier] = task.task.handler;
|
acc[task.task.identifier] = task.task.handler;
|
||||||
return acc;
|
return acc;
|
||||||
}, {} as TaskList);
|
}, {} as TaskList);
|
||||||
|
|
||||||
async function main() {
|
// Run a worker to execute jobs:
|
||||||
// Run a worker to execute jobs:
|
const runner = await run({
|
||||||
const runner = await run({
|
connectionString: env.DATABASE_URL,
|
||||||
connectionString: env.DATABASE_URL,
|
concurrency: 50,
|
||||||
concurrency: 20,
|
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
noHandleSignals: false,
|
||||||
noHandleSignals: false,
|
pollInterval: 1000,
|
||||||
pollInterval: 1000,
|
taskList,
|
||||||
// you can set the taskList or taskDirectory but not both
|
|
||||||
taskList,
|
|
||||||
// or:
|
|
||||||
// taskDirectory: `${__dirname}/tasks`,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
|
||||||
// "unhandled rejection" errors causing a process crash in the event of
|
|
||||||
// something going wrong.
|
|
||||||
await runner.promise;
|
|
||||||
|
|
||||||
// If the worker exits (whether through fatal error or otherwise), the above
|
|
||||||
// promise will resolve/reject.
|
|
||||||
}
|
|
||||||
|
|
||||||
main().catch((err) => {
|
|
||||||
console.error("Unhandled error occurred running worker: ", err);
|
|
||||||
process.exit(1);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
console.log("Worker successfully started");
|
||||||
|
|
||||||
|
await runner.promise;
|
||||||
|
|||||||
@@ -74,6 +74,11 @@ const requestUpdatedPromptFunction = async (
|
|||||||
2,
|
2,
|
||||||
)}`,
|
)}`,
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
messages.push({
|
||||||
|
role: "user",
|
||||||
|
content: `The provider is the same as the old provider: ${originalModel.provider}`,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (instructions) {
|
if (instructions) {
|
||||||
|
|||||||
@@ -1,19 +1,25 @@
|
|||||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { runOneEval } from "./runOneEval";
|
import { runOneEval } from "./runOneEval";
|
||||||
import { type Scenario } from "~/components/OutputsTable/types";
|
import { type Scenario } from "~/components/OutputsTable/types";
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
const runAndSaveEval = async (
|
||||||
const result = await runOneEval(evaluation, scenario, modelOutput);
|
evaluation: Evaluation,
|
||||||
|
scenario: Scenario,
|
||||||
|
modelResponse: ModelResponse,
|
||||||
|
provider: SupportedProvider,
|
||||||
|
) => {
|
||||||
|
const result = await runOneEval(evaluation, scenario, modelResponse, provider);
|
||||||
return await prisma.outputEvaluation.upsert({
|
return await prisma.outputEvaluation.upsert({
|
||||||
where: {
|
where: {
|
||||||
modelOutputId_evaluationId: {
|
modelResponseId_evaluationId: {
|
||||||
modelOutputId: modelOutput.id,
|
modelResponseId: modelResponse.id,
|
||||||
evaluationId: evaluation.id,
|
evaluationId: evaluation.id,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
create: {
|
create: {
|
||||||
modelOutputId: modelOutput.id,
|
modelResponseId: modelResponse.id,
|
||||||
evaluationId: evaluation.id,
|
evaluationId: evaluation.id,
|
||||||
...result,
|
...result,
|
||||||
},
|
},
|
||||||
@@ -26,20 +32,28 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu
|
|||||||
export const runEvalsForOutput = async (
|
export const runEvalsForOutput = async (
|
||||||
experimentId: string,
|
experimentId: string,
|
||||||
scenario: Scenario,
|
scenario: Scenario,
|
||||||
modelOutput: ModelOutput,
|
modelResponse: ModelResponse,
|
||||||
|
provider: SupportedProvider,
|
||||||
) => {
|
) => {
|
||||||
const evaluations = await prisma.evaluation.findMany({
|
const evaluations = await prisma.evaluation.findMany({
|
||||||
where: { experimentId },
|
where: { experimentId },
|
||||||
});
|
});
|
||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
evaluations.map(
|
||||||
|
async (evaluation) => await runAndSaveEval(evaluation, scenario, modelResponse, provider),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Will not run eval-output pairs that already exist in the database
|
||||||
export const runAllEvals = async (experimentId: string) => {
|
export const runAllEvals = async (experimentId: string) => {
|
||||||
const outputs = await prisma.modelOutput.findMany({
|
const outputs = await prisma.modelResponse.findMany({
|
||||||
where: {
|
where: {
|
||||||
|
outdated: false,
|
||||||
|
output: {
|
||||||
|
not: Prisma.AnyNull,
|
||||||
|
},
|
||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
promptVariant: {
|
promptVariant: {
|
||||||
experimentId,
|
experimentId,
|
||||||
@@ -54,9 +68,10 @@ export const runAllEvals = async (experimentId: string) => {
|
|||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
include: {
|
include: {
|
||||||
testScenario: true,
|
testScenario: true,
|
||||||
|
promptVariant: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
outputEvaluation: true,
|
outputEvaluations: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
const evals = await prisma.evaluation.findMany({
|
const evals = await prisma.evaluation.findMany({
|
||||||
@@ -65,13 +80,18 @@ export const runAllEvals = async (experimentId: string) => {
|
|||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
outputs.map(async (output) => {
|
outputs.map(async (output) => {
|
||||||
const unrunEvals = evals.filter(
|
const evalsToBeRun = evals.filter(
|
||||||
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
|
||||||
);
|
);
|
||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
unrunEvals.map(async (evaluation) => {
|
evalsToBeRun.map(async (evaluation) => {
|
||||||
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
await runAndSaveEval(
|
||||||
|
evaluation,
|
||||||
|
output.scenarioVariantCell.testScenario,
|
||||||
|
output,
|
||||||
|
output.scenarioVariantCell.promptVariant.modelProvider as SupportedProvider,
|
||||||
|
);
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
import { type Prisma } from "@prisma/client";
|
import { Prisma } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
|
||||||
import parseConstructFn from "./parseConstructFn";
|
import parseConstructFn from "./parseConstructFn";
|
||||||
import { type JsonObject } from "type-fest";
|
import { type JsonObject } from "type-fest";
|
||||||
import hashPrompt from "./hashPrompt";
|
import hashPrompt from "./hashPrompt";
|
||||||
import { omit } from "lodash-es";
|
import { omit } from "lodash-es";
|
||||||
|
import { queueQueryModel } from "../tasks/queryModel.task";
|
||||||
|
|
||||||
|
export const generateNewCell = async (
|
||||||
|
variantId: string,
|
||||||
|
scenarioId: string,
|
||||||
|
options?: { stream?: boolean },
|
||||||
|
): Promise<void> => {
|
||||||
|
const stream = options?.stream ?? false;
|
||||||
|
|
||||||
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: variantId,
|
id: variantId,
|
||||||
@@ -29,7 +35,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
include: {
|
include: {
|
||||||
modelOutput: true,
|
modelResponses: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -45,8 +51,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
|||||||
data: {
|
data: {
|
||||||
promptVariantId: variantId,
|
promptVariantId: variantId,
|
||||||
testScenarioId: scenarioId,
|
testScenarioId: scenarioId,
|
||||||
statusCode: 400,
|
|
||||||
errorMessage: parsedConstructFn.error,
|
|
||||||
retrievalStatus: "ERROR",
|
retrievalStatus: "ERROR",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@@ -63,41 +67,60 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
|||||||
retrievalStatus: "PENDING",
|
retrievalStatus: "PENDING",
|
||||||
},
|
},
|
||||||
include: {
|
include: {
|
||||||
modelOutput: true,
|
modelResponses: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
const matchingModelResponse = await prisma.modelResponse.findFirst({
|
||||||
where: { inputHash },
|
where: {
|
||||||
|
inputHash,
|
||||||
|
output: {
|
||||||
|
not: Prisma.AnyNull,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
orderBy: {
|
||||||
|
receivedAt: "desc",
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
scenarioVariantCell: true,
|
||||||
|
},
|
||||||
|
take: 1,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (matchingModelOutput) {
|
if (matchingModelResponse) {
|
||||||
const newModelOutput = await prisma.modelOutput.create({
|
const newModelResponse = await prisma.modelResponse.create({
|
||||||
data: {
|
data: {
|
||||||
...omit(matchingModelOutput, ["id"]),
|
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
|
||||||
scenarioVariantCellId: cell.id,
|
scenarioVariantCellId: cell.id,
|
||||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
output: matchingModelResponse.output as Prisma.InputJsonValue,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: cell.id },
|
where: { id: cell.id },
|
||||||
data: { retrievalStatus: "COMPLETE" },
|
data: {
|
||||||
|
retrievalStatus: "COMPLETE",
|
||||||
|
jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt,
|
||||||
|
jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt,
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Copy over all eval results as well
|
// Copy over all eval results as well
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
(
|
(
|
||||||
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
|
await prisma.outputEvaluation.findMany({
|
||||||
|
where: { modelResponseId: matchingModelResponse.id },
|
||||||
|
})
|
||||||
).map(async (evaluation) => {
|
).map(async (evaluation) => {
|
||||||
await prisma.outputEvaluation.create({
|
await prisma.outputEvaluation.create({
|
||||||
data: {
|
data: {
|
||||||
...omit(evaluation, ["id"]),
|
...omit(evaluation, ["id"]),
|
||||||
modelOutputId: newModelOutput.id,
|
modelResponseId: newModelResponse.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
cell = await queueLLMRetrievalTask(cell.id);
|
await queueQueryModel(cell.id, stream);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
import { prisma } from "../db";
|
|
||||||
import { queryLLM } from "../tasks/queryLLM.task";
|
|
||||||
|
|
||||||
export const queueLLMRetrievalTask = async (cellId: string) => {
|
|
||||||
const updatedCell = await prisma.scenarioVariantCell.update({
|
|
||||||
where: {
|
|
||||||
id: cellId,
|
|
||||||
},
|
|
||||||
data: {
|
|
||||||
retrievalStatus: "PENDING",
|
|
||||||
errorMessage: null,
|
|
||||||
},
|
|
||||||
include: {
|
|
||||||
modelOutput: true,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// @ts-expect-error we aren't passing the helpers but that's ok
|
|
||||||
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
|
||||||
|
|
||||||
return updatedCell;
|
|
||||||
};
|
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
|
||||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||||
import { openai } from "./openai";
|
import { openai } from "./openai";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const runGpt4Eval = async (
|
export const runGpt4Eval = async (
|
||||||
evaluation: Evaluation,
|
evaluation: Evaluation,
|
||||||
scenario: TestScenario,
|
scenario: TestScenario,
|
||||||
message: ChatCompletion.Choice.Message,
|
stringifiedOutput: string,
|
||||||
): Promise<{ result: number; details: string }> => {
|
): Promise<{ result: number; details: string }> => {
|
||||||
const output = await openai.chat.completions.create({
|
const output = await openai.chat.completions.create({
|
||||||
model: "gpt-4-0613",
|
model: "gpt-4-0613",
|
||||||
@@ -26,11 +27,7 @@ export const runGpt4Eval = async (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: `The full output of the simpler message:\n---\n${JSON.stringify(
|
content: `The full output of the simpler message:\n---\n${stringifiedOutput}`,
|
||||||
message.content ?? message.function_call,
|
|
||||||
null,
|
|
||||||
2,
|
|
||||||
)}`,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
function_call: {
|
function_call: {
|
||||||
@@ -70,15 +67,16 @@ export const runGpt4Eval = async (
|
|||||||
export const runOneEval = async (
|
export const runOneEval = async (
|
||||||
evaluation: Evaluation,
|
evaluation: Evaluation,
|
||||||
scenario: TestScenario,
|
scenario: TestScenario,
|
||||||
modelOutput: ModelOutput,
|
modelResponse: ModelResponse,
|
||||||
|
provider: SupportedProvider,
|
||||||
): Promise<{ result: number; details?: string }> => {
|
): Promise<{ result: number; details?: string }> => {
|
||||||
const output = modelOutput.output as unknown as ChatCompletion;
|
const modelProvider = modelProviders[provider];
|
||||||
|
const message = modelProvider.normalizeOutput(modelResponse.output);
|
||||||
const message = output?.choices?.[0]?.message;
|
|
||||||
|
|
||||||
if (!message) return { result: 0 };
|
if (!message) return { result: 0 };
|
||||||
|
|
||||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
const stringifiedOutput =
|
||||||
|
message.type === "json" ? JSON.stringify(message.value, null, 2) : message.value;
|
||||||
|
|
||||||
const matchRegex = escapeRegExp(
|
const matchRegex = escapeRegExp(
|
||||||
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
||||||
@@ -86,10 +84,10 @@ export const runOneEval = async (
|
|||||||
|
|
||||||
switch (evaluation.evalType) {
|
switch (evaluation.evalType) {
|
||||||
case "CONTAINS":
|
case "CONTAINS":
|
||||||
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
|
return { result: stringifiedOutput.match(matchRegex) !== null ? 1 : 0 };
|
||||||
case "DOES_NOT_CONTAIN":
|
case "DOES_NOT_CONTAIN":
|
||||||
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
|
return { result: stringifiedOutput.match(matchRegex) === null ? 1 : 0 };
|
||||||
case "GPT4_EVAL":
|
case "GPT4_EVAL":
|
||||||
return await runGpt4Eval(evaluation, scenario, message);
|
return await runGpt4Eval(evaluation, scenario, stringifiedOutput);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ export default async function userOrg(userId: string) {
|
|||||||
update: {},
|
update: {},
|
||||||
create: {
|
create: {
|
||||||
personalOrgUserId: userId,
|
personalOrgUserId: userId,
|
||||||
OrganizationUser: {
|
organizationUsers: {
|
||||||
create: {
|
create: {
|
||||||
userId: userId,
|
userId: userId,
|
||||||
role: "ADMIN",
|
role: "ADMIN",
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ export const editorBackground = "#fafafa";
|
|||||||
export type SharedVariantEditorSlice = {
|
export type SharedVariantEditorSlice = {
|
||||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||||
loadMonaco: () => Promise<void>;
|
loadMonaco: () => Promise<void>;
|
||||||
scenarios: RouterOutputs["scenarios"]["list"];
|
scenarios: RouterOutputs["scenarios"]["list"]["scenarios"];
|
||||||
updateScenariosModel: () => void;
|
updateScenariosModel: () => void;
|
||||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]["scenarios"]) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||||
|
|||||||
@@ -1,17 +1,14 @@
|
|||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { useScenarios } from "~/utils/hooks";
|
||||||
import { useAppStore } from "./store";
|
import { useAppStore } from "./store";
|
||||||
|
|
||||||
export function useSyncVariantEditor() {
|
export function useSyncVariantEditor() {
|
||||||
const experiment = useExperiment();
|
const scenarios = useScenarios();
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
|
||||||
{ experimentId: experiment.data?.id ?? "" },
|
|
||||||
{ enabled: !!experiment.data?.id },
|
|
||||||
);
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (scenarios.data) {
|
if (scenarios.data) {
|
||||||
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data);
|
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data.scenarios);
|
||||||
}
|
}
|
||||||
}, [scenarios.data]);
|
}, [scenarios.data]);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
|
|||||||
where: {
|
where: {
|
||||||
id: experimentId,
|
id: experimentId,
|
||||||
organization: {
|
organization: {
|
||||||
OrganizationUser: {
|
organizationUsers: {
|
||||||
some: {
|
some: {
|
||||||
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
||||||
userId,
|
userId,
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
// generate random channel id
|
|
||||||
|
|
||||||
export const generateChannel = () => {
|
|
||||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
|
||||||
};
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
|
import { NumberParam, useQueryParam, withDefault } from "use-query-params";
|
||||||
|
|
||||||
export const useExperiment = () => {
|
export const useExperiment = () => {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -93,3 +94,17 @@ export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | un
|
|||||||
|
|
||||||
return [ref, dimensions];
|
return [ref, dimensions];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const usePage = () => useQueryParam("page", withDefault(NumberParam, 1));
|
||||||
|
|
||||||
|
export const useScenarios = () => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const [page] = usePage();
|
||||||
|
|
||||||
|
return api.scenarios.list.useQuery(
|
||||||
|
{ experimentId: experiment.data?.id ?? "", page },
|
||||||
|
{ enabled: experiment.data?.id != null },
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? [];
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import { type Model } from "~/modelProviders/types";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
||||||
|
|
||||||
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;
|
export const lookupModel = (provider: string, model: string) => {
|
||||||
|
const modelObj = frontendModelProviders[provider as ProviderModel["provider"]]?.models[model];
|
||||||
|
return modelObj ? { ...modelObj, provider } : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelLabel = (provider: string, model: string) =>
|
||||||
|
`${provider}/${lookupModel(provider, model)?.name ?? model}`;
|
||||||
|
|||||||
Reference in New Issue
Block a user