Compare commits
19 Commits
fullscreen
...
scenario-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bffb03766 | ||
|
|
fa61c9c472 | ||
|
|
1309a6ec5d | ||
|
|
17a6fd31a5 | ||
|
|
e1cbeccb90 | ||
|
|
d6b97b29f7 | ||
|
|
09140f8b5f | ||
|
|
9952dd93d8 | ||
|
|
e0b457c6c5 | ||
|
|
0c37506975 | ||
|
|
2b2e0ab8ee | ||
|
|
3dbb06ec00 | ||
|
|
85d42a014b | ||
|
|
7d1ded3b18 | ||
|
|
b00f6dd04b | ||
|
|
2e395e4d39 | ||
|
|
4b06d05908 | ||
|
|
aabf355b81 | ||
|
|
61e5f0775d |
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@@ -1,6 +1,3 @@
|
|||||||
{
|
{
|
||||||
"eslint.format.enable": true,
|
"eslint.format.enable": true
|
||||||
"editor.codeActionsOnSave": {
|
|
||||||
"source.fixAll.eslint": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
"dev:next": "next dev",
|
"dev:next": "next dev",
|
||||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||||
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss' 'pnpm dev:worker'",
|
||||||
"postinstall": "prisma generate",
|
"postinstall": "prisma generate",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
@@ -59,6 +59,7 @@
|
|||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
|
"next-query-params": "^4.2.3",
|
||||||
"nextjs-routes": "^2.0.1",
|
"nextjs-routes": "^2.0.1",
|
||||||
"openai": "4.0.0-beta.2",
|
"openai": "4.0.0-beta.2",
|
||||||
"pluralize": "^8.0.0",
|
"pluralize": "^8.0.0",
|
||||||
@@ -79,6 +80,8 @@
|
|||||||
"superjson": "1.12.2",
|
"superjson": "1.12.2",
|
||||||
"tsx": "^3.12.7",
|
"tsx": "^3.12.7",
|
||||||
"type-fest": "^4.0.0",
|
"type-fest": "^4.0.0",
|
||||||
|
"use-query-params": "^2.2.1",
|
||||||
|
"uuid": "^9.0.0",
|
||||||
"vite-tsconfig-paths": "^4.2.0",
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
"zod": "^3.21.4",
|
"zod": "^3.21.4",
|
||||||
"zustand": "^4.3.9"
|
"zustand": "^4.3.9"
|
||||||
@@ -99,6 +102,7 @@
|
|||||||
"@types/react": "^18.2.6",
|
"@types/react": "^18.2.6",
|
||||||
"@types/react-dom": "^18.2.4",
|
"@types/react-dom": "^18.2.4",
|
||||||
"@types/react-syntax-highlighter": "^15.5.7",
|
"@types/react-syntax-highlighter": "^15.5.7",
|
||||||
|
"@types/uuid": "^9.0.2",
|
||||||
"@typescript-eslint/eslint-plugin": "^5.59.6",
|
"@typescript-eslint/eslint-plugin": "^5.59.6",
|
||||||
"@typescript-eslint/parser": "^5.59.6",
|
"@typescript-eslint/parser": "^5.59.6",
|
||||||
"eslint": "^8.40.0",
|
"eslint": "^8.40.0",
|
||||||
|
|||||||
58
pnpm-lock.yaml
generated
58
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
|||||||
lockfileVersion: '6.1'
|
lockfileVersion: '6.0'
|
||||||
|
|
||||||
settings:
|
settings:
|
||||||
autoInstallPeers: true
|
autoInstallPeers: true
|
||||||
@@ -119,6 +119,9 @@ dependencies:
|
|||||||
next-auth:
|
next-auth:
|
||||||
specifier: ^4.22.1
|
specifier: ^4.22.1
|
||||||
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
next-query-params:
|
||||||
|
specifier: ^4.2.3
|
||||||
|
version: 4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1)
|
||||||
nextjs-routes:
|
nextjs-routes:
|
||||||
specifier: ^2.0.1
|
specifier: ^2.0.1
|
||||||
version: 2.0.1(next@13.4.2)
|
version: 2.0.1(next@13.4.2)
|
||||||
@@ -179,6 +182,12 @@ dependencies:
|
|||||||
type-fest:
|
type-fest:
|
||||||
specifier: ^4.0.0
|
specifier: ^4.0.0
|
||||||
version: 4.0.0
|
version: 4.0.0
|
||||||
|
use-query-params:
|
||||||
|
specifier: ^2.2.1
|
||||||
|
version: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
uuid:
|
||||||
|
specifier: ^9.0.0
|
||||||
|
version: 9.0.0
|
||||||
vite-tsconfig-paths:
|
vite-tsconfig-paths:
|
||||||
specifier: ^4.2.0
|
specifier: ^4.2.0
|
||||||
version: 4.2.0(typescript@5.0.4)
|
version: 4.2.0(typescript@5.0.4)
|
||||||
@@ -235,6 +244,9 @@ devDependencies:
|
|||||||
'@types/react-syntax-highlighter':
|
'@types/react-syntax-highlighter':
|
||||||
specifier: ^15.5.7
|
specifier: ^15.5.7
|
||||||
version: 15.5.7
|
version: 15.5.7
|
||||||
|
'@types/uuid':
|
||||||
|
specifier: ^9.0.2
|
||||||
|
version: 9.0.2
|
||||||
'@typescript-eslint/eslint-plugin':
|
'@typescript-eslint/eslint-plugin':
|
||||||
specifier: ^5.59.6
|
specifier: ^5.59.6
|
||||||
version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4)
|
version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4)
|
||||||
@@ -3018,6 +3030,10 @@ packages:
|
|||||||
resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==}
|
resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==}
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/@types/uuid@9.0.2:
|
||||||
|
resolution: {integrity: sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4):
|
/@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4):
|
||||||
resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==}
|
resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==}
|
||||||
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
||||||
@@ -6037,6 +6053,19 @@ packages:
|
|||||||
uuid: 8.3.2
|
uuid: 8.3.2
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/next-query-params@4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1):
|
||||||
|
resolution: {integrity: sha512-hGNCYRH8YyA5ItiBGSKrtMl21b2MAqfPkdI1mvwloNVqSU142IaGzqHN+OTovyeLIpQfonY01y7BAHb/UH4POg==}
|
||||||
|
peerDependencies:
|
||||||
|
next: ^10.0.0 || ^11.0.0 || ^12.0.0 || ^13.0.0
|
||||||
|
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||||
|
use-query-params: ^2.0.0
|
||||||
|
dependencies:
|
||||||
|
next: 13.4.2(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
react: 18.2.0
|
||||||
|
tslib: 2.6.0
|
||||||
|
use-query-params: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
dev: false
|
||||||
|
|
||||||
/next-tick@1.1.0:
|
/next-tick@1.1.0:
|
||||||
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
|
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
|
||||||
dev: false
|
dev: false
|
||||||
@@ -7147,6 +7176,10 @@ packages:
|
|||||||
randombytes: 2.1.0
|
randombytes: 2.1.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/serialize-query-params@2.0.2:
|
||||||
|
resolution: {integrity: sha512-1chMo1dST4pFA9RDXAtF0Rbjaut4is7bzFbI1Z26IuMub68pNCILku85aYmeFhvnY//BXUPUhoRMjYcsT93J/Q==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/serve-static@1.15.0:
|
/serve-static@1.15.0:
|
||||||
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
|
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
|
||||||
engines: {node: '>= 0.8.0'}
|
engines: {node: '>= 0.8.0'}
|
||||||
@@ -7824,6 +7857,24 @@ packages:
|
|||||||
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/use-query-params@2.2.1(react-dom@18.2.0)(react@18.2.0):
|
||||||
|
resolution: {integrity: sha512-i6alcyLB8w9i3ZK3caNftdb+UnbfBRNPDnc89CNQWkGRmDrm/gfydHvMBfVsQJRq3NoHOM2dt/ceBWG2397v1Q==}
|
||||||
|
peerDependencies:
|
||||||
|
'@reach/router': ^1.2.1
|
||||||
|
react: '>=16.8.0'
|
||||||
|
react-dom: '>=16.8.0'
|
||||||
|
react-router-dom: '>=5'
|
||||||
|
peerDependenciesMeta:
|
||||||
|
'@reach/router':
|
||||||
|
optional: true
|
||||||
|
react-router-dom:
|
||||||
|
optional: true
|
||||||
|
dependencies:
|
||||||
|
react: 18.2.0
|
||||||
|
react-dom: 18.2.0(react@18.2.0)
|
||||||
|
serialize-query-params: 2.0.2
|
||||||
|
dev: false
|
||||||
|
|
||||||
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
|
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
|
||||||
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
|
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
|
||||||
engines: {node: '>=10'}
|
engines: {node: '>=10'}
|
||||||
@@ -7872,6 +7923,11 @@ packages:
|
|||||||
hasBin: true
|
hasBin: true
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/uuid@9.0.0:
|
||||||
|
resolution: {integrity: sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==}
|
||||||
|
hasBin: true
|
||||||
|
dev: false
|
||||||
|
|
||||||
/vary@1.1.2:
|
/vary@1.1.2:
|
||||||
resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==}
|
resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==}
|
||||||
engines: {node: '>= 0.8'}
|
engines: {node: '>= 0.8'}
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `streamingChannel` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "streamingChannel";
|
||||||
@@ -22,10 +22,10 @@ model Experiment {
|
|||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
TemplateVariable TemplateVariable[]
|
templateVariables TemplateVariable[]
|
||||||
PromptVariant PromptVariant[]
|
promptVariants PromptVariant[]
|
||||||
TestScenario TestScenario[]
|
testScenarios TestScenario[]
|
||||||
Evaluation Evaluation[]
|
evaluations Evaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model PromptVariant {
|
model PromptVariant {
|
||||||
@@ -90,11 +90,10 @@ enum CellRetrievalStatus {
|
|||||||
model ScenarioVariantCell {
|
model ScenarioVariantCell {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
statusCode Int?
|
statusCode Int?
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
retryTime DateTime?
|
retryTime DateTime?
|
||||||
streamingChannel String?
|
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
|
||||||
|
|
||||||
modelOutput ModelOutput?
|
modelOutput ModelOutput?
|
||||||
|
|
||||||
@@ -126,7 +125,7 @@ model ModelOutput {
|
|||||||
|
|
||||||
scenarioVariantCellId String @db.Uuid
|
scenarioVariantCellId String @db.Uuid
|
||||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||||
outputEvaluation OutputEvaluation[]
|
outputEvaluations OutputEvaluation[]
|
||||||
|
|
||||||
@@unique([scenarioVariantCellId])
|
@@unique([scenarioVariantCellId])
|
||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
@@ -150,7 +149,7 @@ model Evaluation {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
OutputEvaluation OutputEvaluation[]
|
outputEvaluations OutputEvaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model OutputEvaluation {
|
model OutputEvaluation {
|
||||||
@@ -179,8 +178,8 @@ model Organization {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
OrganizationUser OrganizationUser[]
|
organizationUsers OrganizationUser[]
|
||||||
Experiment Experiment[]
|
experiments Experiment[]
|
||||||
}
|
}
|
||||||
|
|
||||||
enum OrganizationUserRole {
|
enum OrganizationUserRole {
|
||||||
@@ -234,15 +233,15 @@ model Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model User {
|
model User {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
name String?
|
name String?
|
||||||
email String? @unique
|
email String? @unique
|
||||||
emailVerified DateTime?
|
emailVerified DateTime?
|
||||||
image String?
|
image String?
|
||||||
accounts Account[]
|
accounts Account[]
|
||||||
sessions Session[]
|
sessions Session[]
|
||||||
OrganizationUser OrganizationUser[]
|
organizationUsers OrganizationUser[]
|
||||||
Organization Organization[]
|
organizations Organization[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model VerificationToken {
|
model VerificationToken {
|
||||||
|
|||||||
@@ -7,9 +7,13 @@ const defaultId = "11111111-1111-1111-1111-111111111111";
|
|||||||
await prisma.organization.deleteMany({
|
await prisma.organization.deleteMany({
|
||||||
where: { id: defaultId },
|
where: { id: defaultId },
|
||||||
});
|
});
|
||||||
await prisma.organization.create({
|
|
||||||
data: { id: defaultId },
|
// If there's an existing org, just seed into it
|
||||||
});
|
const org =
|
||||||
|
(await prisma.organization.findFirst({})) ??
|
||||||
|
(await prisma.organization.create({
|
||||||
|
data: { id: defaultId },
|
||||||
|
}));
|
||||||
|
|
||||||
await prisma.experiment.deleteMany({
|
await prisma.experiment.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
@@ -21,7 +25,7 @@ await prisma.experiment.create({
|
|||||||
data: {
|
data: {
|
||||||
id: defaultId,
|
id: defaultId,
|
||||||
label: "Country Capitals Example",
|
label: "Country Capitals Example",
|
||||||
organizationId: defaultId,
|
organizationId: org.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -103,30 +107,41 @@ await prisma.testScenario.deleteMany({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const countries = [
|
||||||
|
"Afghanistan",
|
||||||
|
"Albania",
|
||||||
|
"Algeria",
|
||||||
|
"Andorra",
|
||||||
|
"Angola",
|
||||||
|
"Antigua and Barbuda",
|
||||||
|
"Argentina",
|
||||||
|
"Armenia",
|
||||||
|
"Australia",
|
||||||
|
"Austria",
|
||||||
|
"Austrian Empire",
|
||||||
|
"Azerbaijan",
|
||||||
|
"Baden",
|
||||||
|
"Bahamas, The",
|
||||||
|
"Bahrain",
|
||||||
|
"Bangladesh",
|
||||||
|
"Barbados",
|
||||||
|
"Bavaria",
|
||||||
|
"Belarus",
|
||||||
|
"Belgium",
|
||||||
|
"Belize",
|
||||||
|
"Benin (Dahomey)",
|
||||||
|
"Bolivia",
|
||||||
|
"Bosnia and Herzegovina",
|
||||||
|
"Botswana",
|
||||||
|
];
|
||||||
await prisma.testScenario.createMany({
|
await prisma.testScenario.createMany({
|
||||||
data: [
|
data: countries.map((country, i) => ({
|
||||||
{
|
experimentId: defaultId,
|
||||||
experimentId: defaultId,
|
sortIndex: i,
|
||||||
sortIndex: 0,
|
variableValues: {
|
||||||
variableValues: {
|
country: country,
|
||||||
country: "Spain",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
})),
|
||||||
experimentId: defaultId,
|
|
||||||
sortIndex: 1,
|
|
||||||
variableValues: {
|
|
||||||
country: "USA",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
experimentId: defaultId,
|
|
||||||
sortIndex: 2,
|
|
||||||
variableValues: {
|
|
||||||
country: "Chile",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const variants = await prisma.promptVariant.findMany({
|
const variants = await prisma.promptVariant.findMany({
|
||||||
@@ -149,5 +164,5 @@ await Promise.all(
|
|||||||
testScenarioId: scenario.id,
|
testScenarioId: scenario.id,
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId, { stream: false })),
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -6,4 +6,7 @@ echo "Migrating the database"
|
|||||||
pnpm prisma migrate deploy
|
pnpm prisma migrate deploy
|
||||||
|
|
||||||
echo "Starting the server"
|
echo "Starting the server"
|
||||||
pnpm start
|
|
||||||
|
pnpm concurrently --kill-others \
|
||||||
|
"pnpm start" \
|
||||||
|
"pnpm tsx src/server/tasks/worker.ts"
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import {
|
import {
|
||||||
Button,
|
Button,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
Modal,
|
Modal,
|
||||||
ModalBody,
|
ModalBody,
|
||||||
ModalCloseButton,
|
ModalCloseButton,
|
||||||
@@ -7,24 +9,21 @@ import {
|
|||||||
ModalFooter,
|
ModalFooter,
|
||||||
ModalHeader,
|
ModalHeader,
|
||||||
ModalOverlay,
|
ModalOverlay,
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
Spinner,
|
Spinner,
|
||||||
HStack,
|
Text,
|
||||||
Icon,
|
VStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
|
||||||
import { useState } from "react";
|
|
||||||
import { ModelStatsCard } from "./ModelStatsCard";
|
|
||||||
import { ModelSearch } from "./ModelSearch";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
import { type Model, type SupportedProvider } from "~/modelProviders/types";
|
import { useState } from "react";
|
||||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { keyForModel } from "~/utils/utils";
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
|
import { lookupModel, modelLabel } from "~/utils/utils";
|
||||||
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
|
import { ModelSearch } from "./ModelSearch";
|
||||||
|
import { ModelStatsCard } from "./ModelStatsCard";
|
||||||
|
|
||||||
export const ChangeModelModal = ({
|
export const ChangeModelModal = ({
|
||||||
variant,
|
variant,
|
||||||
@@ -33,11 +32,14 @@ export const ChangeModelModal = ({
|
|||||||
variant: PromptVariant;
|
variant: PromptVariant;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
const originalModelProviderName = variant.modelProvider as SupportedProvider;
|
const originalModel = lookupModel(variant.modelProvider, variant.model);
|
||||||
const originalModelProvider = frontendModelProviders[originalModelProviderName];
|
const [selectedModel, setSelectedModel] = useState({
|
||||||
const originalModel = originalModelProvider.models[variant.model] as Model;
|
provider: variant.modelProvider,
|
||||||
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
|
model: variant.model,
|
||||||
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
|
} as ProviderModel);
|
||||||
|
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
@@ -67,14 +69,16 @@ export const ChangeModelModal = ({
|
|||||||
await replaceVariantMutation.mutateAsync({
|
await replaceVariantMutation.mutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
constructFn: modifiedPromptFn,
|
constructFn: modifiedPromptFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
onClose();
|
onClose();
|
||||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||||
|
|
||||||
const originalModelLabel = keyForModel(originalModel);
|
const originalLabel = modelLabel(variant.modelProvider, variant.model);
|
||||||
const selectedModelLabel = keyForModel(selectedModel);
|
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
|
||||||
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
|
const convertedLabel =
|
||||||
|
convertedModel && modelLabel(convertedModel.provider, convertedModel.model);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
@@ -94,16 +98,19 @@ export const ChangeModelModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||||
{originalModelLabel !== selectedModelLabel && (
|
{originalLabel !== selectedLabel && (
|
||||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
<ModelStatsCard
|
||||||
|
label="New Model"
|
||||||
|
model={lookupModel(selectedModel.provider, selectedModel.model)}
|
||||||
|
/>
|
||||||
)}
|
)}
|
||||||
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
{isString(modifiedPromptFn) && (
|
{isString(modifiedPromptFn) && (
|
||||||
<CompareFunctions
|
<CompareFunctions
|
||||||
originalFunction={variant.constructFn}
|
originalFunction={variant.constructFn}
|
||||||
newFunction={modifiedPromptFn}
|
newFunction={modifiedPromptFn}
|
||||||
leftTitle={originalModelLabel}
|
leftTitle={originalLabel}
|
||||||
rightTitle={convertedModelLabel}
|
rightTitle={convertedLabel}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
@@ -115,7 +122,7 @@ export const ChangeModelModal = ({
|
|||||||
colorScheme="gray"
|
colorScheme="gray"
|
||||||
onClick={getModifiedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
minW={24}
|
minW={24}
|
||||||
isDisabled={originalModel === selectedModel || modificationInProgress}
|
isDisabled={originalLabel === selectedLabel || modificationInProgress}
|
||||||
>
|
>
|
||||||
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -1,49 +1,35 @@
|
|||||||
import { VStack, Text } from "@chakra-ui/react";
|
import { Text, VStack } from "@chakra-ui/react";
|
||||||
import { type LegacyRef, useCallback } from "react";
|
import { type LegacyRef } from "react";
|
||||||
import Select, { type SingleValue } from "react-select";
|
import Select from "react-select";
|
||||||
import { useElementDimensions } from "~/utils/hooks";
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
|
||||||
|
import { flatMap } from "lodash-es";
|
||||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
import { type Model } from "~/modelProviders/types";
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
import { keyForModel } from "~/utils/utils";
|
import { modelLabel } from "~/utils/utils";
|
||||||
|
|
||||||
const modelOptions: { label: string; value: Model }[] = [];
|
const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) =>
|
||||||
|
Object.entries(provider.models).map(([modelId]) => ({
|
||||||
|
provider: providerId,
|
||||||
|
model: modelId,
|
||||||
|
})),
|
||||||
|
) as ProviderModel[];
|
||||||
|
|
||||||
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
|
export const ModelSearch = (props: {
|
||||||
for (const [_, modelValue] of Object.entries(providerValue.models)) {
|
selectedModel: ProviderModel;
|
||||||
modelOptions.push({
|
setSelectedModel: (model: ProviderModel) => void;
|
||||||
label: keyForModel(modelValue),
|
|
||||||
value: modelValue,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const ModelSearch = ({
|
|
||||||
selectedModel,
|
|
||||||
setSelectedModel,
|
|
||||||
}: {
|
|
||||||
selectedModel: Model;
|
|
||||||
setSelectedModel: (model: Model) => void;
|
|
||||||
}) => {
|
}) => {
|
||||||
const handleSelection = useCallback(
|
|
||||||
(option: SingleValue<{ label: string; value: Model }>) => {
|
|
||||||
if (!option) return;
|
|
||||||
setSelectedModel(option.value);
|
|
||||||
},
|
|
||||||
[setSelectedModel],
|
|
||||||
);
|
|
||||||
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
|
|
||||||
|
|
||||||
const [containerRef, containerDimensions] = useElementDimensions();
|
const [containerRef, containerDimensions] = useElementDimensions();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||||
<Text>Browse Models</Text>
|
<Text>Browse Models</Text>
|
||||||
<Select
|
<Select<ProviderModel>
|
||||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||||
value={selectedOption}
|
getOptionLabel={(data) => modelLabel(data.provider, data.model)}
|
||||||
|
getOptionValue={(data) => modelLabel(data.provider, data.model)}
|
||||||
options={modelOptions}
|
options={modelOptions}
|
||||||
onChange={handleSelection}
|
onChange={(option) => option && props.setSelectedModel(option)}
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
import {
|
import {
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
HStack,
|
|
||||||
type StackProps,
|
|
||||||
GridItem,
|
GridItem,
|
||||||
SimpleGrid,
|
HStack,
|
||||||
Link,
|
Link,
|
||||||
|
SimpleGrid,
|
||||||
|
Text,
|
||||||
|
VStack,
|
||||||
|
type StackProps,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { type Model } from "~/modelProviders/types";
|
import { type lookupModel } from "~/utils/utils";
|
||||||
|
|
||||||
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
|
export const ModelStatsCard = ({
|
||||||
|
label,
|
||||||
|
model,
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
model: ReturnType<typeof lookupModel>;
|
||||||
|
}) => {
|
||||||
|
if (!model) return null;
|
||||||
return (
|
return (
|
||||||
<VStack w="full" align="start">
|
<VStack w="full" align="start">
|
||||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||||
|
|||||||
77
src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
77
src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Icon,
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
useDisclosure,
|
||||||
|
Text,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { useRef } from "react";
|
||||||
|
import { BsTrash } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
|
export const DeleteButton = () => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const mutation = api.experiments.delete.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data?.id) return;
|
||||||
|
await mutation.mutateAsync({ id: experiment.data.id });
|
||||||
|
await utils.experiments.list.invalidate();
|
||||||
|
await router.push({ pathname: "/experiments" });
|
||||||
|
onClose();
|
||||||
|
}, [mutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant={{ base: "outline", lg: "ghost" }}
|
||||||
|
colorScheme="red"
|
||||||
|
fontWeight="normal"
|
||||||
|
onClick={onOpen}
|
||||||
|
>
|
||||||
|
<Icon as={BsTrash} boxSize={4} />
|
||||||
|
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
||||||
|
Delete Experiment
|
||||||
|
</Text>
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
Delete Experiment
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||||
|
as well. Are you sure?
|
||||||
|
</AlertDialogBody>
|
||||||
|
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<Button ref={cancelRef} onClick={onClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -6,13 +6,14 @@ import {
|
|||||||
DrawerHeader,
|
DrawerHeader,
|
||||||
DrawerOverlay,
|
DrawerOverlay,
|
||||||
Heading,
|
Heading,
|
||||||
Stack,
|
VStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import EditScenarioVars from "./EditScenarioVars";
|
import EditScenarioVars from "../OutputsTable/EditScenarioVars";
|
||||||
import EditEvaluations from "./EditEvaluations";
|
import EditEvaluations from "../OutputsTable/EditEvaluations";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { DeleteButton } from "./DeleteButton";
|
||||||
|
|
||||||
export default function SettingsDrawer() {
|
export default function ExperimentSettingsDrawer() {
|
||||||
const isOpen = useAppStore((state) => state.drawerOpen);
|
const isOpen = useAppStore((state) => state.drawerOpen);
|
||||||
const closeDrawer = useAppStore((state) => state.closeDrawer);
|
const closeDrawer = useAppStore((state) => state.closeDrawer);
|
||||||
|
|
||||||
@@ -22,13 +23,16 @@ export default function SettingsDrawer() {
|
|||||||
<DrawerContent>
|
<DrawerContent>
|
||||||
<DrawerCloseButton />
|
<DrawerCloseButton />
|
||||||
<DrawerHeader>
|
<DrawerHeader>
|
||||||
<Heading size="md">Settings</Heading>
|
<Heading size="md">Experiment Settings</Heading>
|
||||||
</DrawerHeader>
|
</DrawerHeader>
|
||||||
<DrawerBody>
|
<DrawerBody h="full" pb={4}>
|
||||||
<Stack spacing={6}>
|
<VStack h="full" justifyContent="space-between">
|
||||||
<EditScenarioVars />
|
<VStack spacing={6}>
|
||||||
<EditEvaluations />
|
<EditScenarioVars />
|
||||||
</Stack>
|
<EditEvaluations />
|
||||||
|
</VStack>
|
||||||
|
<DeleteButton />
|
||||||
|
</VStack>
|
||||||
</DrawerBody>
|
</DrawerBody>
|
||||||
</DrawerContent>
|
</DrawerContent>
|
||||||
</Drawer>
|
</Drawer>
|
||||||
57
src/components/OutputsTable/AddVariantButton.tsx
Normal file
57
src/components/OutputsTable/AddVariantButton.tsx
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
||||||
|
import { BsPlus } from "react-icons/bs";
|
||||||
|
import { Text } from "@chakra-ui/react";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import {
|
||||||
|
useExperiment,
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useVisibleScenarioIds,
|
||||||
|
} from "~/utils/hooks";
|
||||||
|
import { cellPadding } from "../constants";
|
||||||
|
import { ActionButton } from "./ScenariosHeader";
|
||||||
|
|
||||||
|
export default function AddVariantButton() {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const mutation = api.promptVariants.create.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
|
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data) return;
|
||||||
|
await mutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [mutation]);
|
||||||
|
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
if (!canModify) return <Box w={cellPadding.x} />;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="100%" justifyContent="flex-end">
|
||||||
|
<ActionButton
|
||||||
|
onClick={onClick}
|
||||||
|
py={5}
|
||||||
|
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
||||||
|
>
|
||||||
|
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||||
|
</ActionButton>
|
||||||
|
{/* <Button
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
fontWeight="normal"
|
||||||
|
bgColor="transparent"
|
||||||
|
_hover={{ bgColor: "gray.100" }}
|
||||||
|
px={cellPadding.x}
|
||||||
|
onClick={onClick}
|
||||||
|
height="unset"
|
||||||
|
minH={headerMinHeight}
|
||||||
|
>
|
||||||
|
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||||
|
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||||
|
</Button> */}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
|
||||||
import { BsPlus } from "react-icons/bs";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
|
|
||||||
// Extracted Button styling into reusable component
|
|
||||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
|
||||||
<Button
|
|
||||||
fontWeight="normal"
|
|
||||||
bgColor="transparent"
|
|
||||||
_hover={{ bgColor: "gray.100" }}
|
|
||||||
px={2}
|
|
||||||
onClick={onClick}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
|
|
||||||
export default function NewScenarioButton() {
|
|
||||||
const { canModify } = useExperimentAccess();
|
|
||||||
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.scenarios.create.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
|
|
||||||
const [onClick] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
});
|
|
||||||
await utils.scenarios.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
const [onAutogenerate, autogenerating] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
autogenerate: true,
|
|
||||||
});
|
|
||||||
await utils.scenarios.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
if (!canModify) return null;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<HStack spacing={2}>
|
|
||||||
<StyledButton onClick={onClick}>
|
|
||||||
<Icon as={BsPlus} boxSize={6} />
|
|
||||||
Add Scenario
|
|
||||||
</StyledButton>
|
|
||||||
<StyledButton onClick={onAutogenerate}>
|
|
||||||
<Icon
|
|
||||||
as={autogenerating ? Spinner : BsPlus}
|
|
||||||
boxSize={autogenerating ? 4 : 6}
|
|
||||||
mr={autogenerating ? 2 : 0}
|
|
||||||
/>
|
|
||||||
Autogenerate Scenario
|
|
||||||
</StyledButton>
|
|
||||||
</HStack>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
|
|
||||||
import { BsPlus } from "react-icons/bs";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
|
||||||
|
|
||||||
export default function NewVariantButton() {
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.promptVariants.create.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
|
|
||||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
const { canModify } = useExperimentAccess();
|
|
||||||
if (!canModify) return <Box w={cellPadding.x} />;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
w="100%"
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
fontWeight="normal"
|
|
||||||
bgColor="transparent"
|
|
||||||
_hover={{ bgColor: "gray.100" }}
|
|
||||||
px={cellPadding.x}
|
|
||||||
onClick={onClick}
|
|
||||||
height="unset"
|
|
||||||
minH={headerMinHeight}
|
|
||||||
>
|
|
||||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
|
||||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -67,8 +67,8 @@ export default function OutputCell({
|
|||||||
|
|
||||||
const modelOutput = cell?.modelOutput;
|
const modelOutput = cell?.modelOutput;
|
||||||
|
|
||||||
// Disconnect from socket if we're not streaming anymore
|
// TODO: disconnect from socket if we're not streaming anymore
|
||||||
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
|
const streamedMessage = useSocket<OutputSchema>(cell?.id);
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|
||||||
@@ -81,10 +81,21 @@ export default function OutputCell({
|
|||||||
</Center>
|
</Center>
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
if (!cell && !fetchingOutput)
|
||||||
|
return (
|
||||||
|
<VStack>
|
||||||
|
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||||
|
<Text color="gray.500">Error retrieving output</Text>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
|
||||||
if (cell && cell.errorMessage) {
|
if (cell && cell.errorMessage) {
|
||||||
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
return (
|
||||||
|
<VStack>
|
||||||
|
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||||
|
<ErrorHandler cell={cell} refetchOutput={hardRefetch} />
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const normalizedOutput = modelOutput
|
const normalizedOutput = modelOutput
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ export const OutputStats = ({
|
|||||||
return (
|
return (
|
||||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
{modelOutput.outputEvaluation.map((evaluation) => {
|
{modelOutput.outputEvaluations.map((evaluation) => {
|
||||||
const passed = evaluation.result > 0.5;
|
const passed = evaluation.result > 0.5;
|
||||||
return (
|
return (
|
||||||
<Tooltip
|
<Tooltip
|
||||||
|
|||||||
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
||||||
|
import {
|
||||||
|
BsChevronDoubleLeft,
|
||||||
|
BsChevronDoubleRight,
|
||||||
|
BsChevronLeft,
|
||||||
|
BsChevronRight,
|
||||||
|
} from "react-icons/bs";
|
||||||
|
import { usePage, useScenarios } from "~/utils/hooks";
|
||||||
|
|
||||||
|
const ScenarioPaginator = () => {
|
||||||
|
const [page, setPage] = usePage();
|
||||||
|
const { data } = useScenarios();
|
||||||
|
|
||||||
|
if (!data) return null;
|
||||||
|
|
||||||
|
const { scenarios, startIndex, lastPage, count } = data;
|
||||||
|
|
||||||
|
const nextPage = () => {
|
||||||
|
if (page < lastPage) {
|
||||||
|
setPage(page + 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const prevPage = () => {
|
||||||
|
if (page > 1) {
|
||||||
|
setPage(page - 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const goToLastPage = () => setPage(lastPage, "replace");
|
||||||
|
const goToFirstPage = () => setPage(1, "replace");
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack pt={4}>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToFirstPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Go to first page"
|
||||||
|
icon={<BsChevronDoubleLeft />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={prevPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Previous page"
|
||||||
|
icon={<BsChevronLeft />}
|
||||||
|
/>
|
||||||
|
<Box>
|
||||||
|
{startIndex}-{startIndex + scenarios.length - 1} / {count}
|
||||||
|
</Box>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={nextPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Next page"
|
||||||
|
icon={<BsChevronRight />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToLastPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Go to last page"
|
||||||
|
icon={<BsChevronDoubleRight />}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ScenarioPaginator;
|
||||||
@@ -4,11 +4,13 @@ import { cellPadding } from "../constants";
|
|||||||
import OutputCell from "./OutputCell/OutputCell";
|
import OutputCell from "./OutputCell/OutputCell";
|
||||||
import ScenarioEditor from "./ScenarioEditor";
|
import ScenarioEditor from "./ScenarioEditor";
|
||||||
import type { PromptVariant, Scenario } from "./types";
|
import type { PromptVariant, Scenario } from "./types";
|
||||||
|
import { borders } from "./styles";
|
||||||
|
|
||||||
const ScenarioRow = (props: {
|
const ScenarioRow = (props: {
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
variants: PromptVariant[];
|
variants: PromptVariant[];
|
||||||
canHide: boolean;
|
canHide: boolean;
|
||||||
|
rowStart: number;
|
||||||
}) => {
|
}) => {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
@@ -21,15 +23,21 @@ const ScenarioRow = (props: {
|
|||||||
onMouseLeave={() => setIsHovered(false)}
|
onMouseLeave={() => setIsHovered(false)}
|
||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
borderLeftWidth={1}
|
borderLeftWidth={1}
|
||||||
|
{...borders}
|
||||||
|
rowStart={props.rowStart}
|
||||||
|
colStart={1}
|
||||||
>
|
>
|
||||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
{props.variants.map((variant) => (
|
{props.variants.map((variant, i) => (
|
||||||
<GridItem
|
<GridItem
|
||||||
key={variant.id}
|
key={variant.id}
|
||||||
onMouseEnter={() => setIsHovered(true)}
|
onMouseEnter={() => setIsHovered(true)}
|
||||||
onMouseLeave={() => setIsHovered(false)}
|
onMouseLeave={() => setIsHovered(false)}
|
||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
|
rowStart={props.rowStart}
|
||||||
|
colStart={i + 2}
|
||||||
|
{...borders}
|
||||||
>
|
>
|
||||||
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
||||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||||
|
|||||||
@@ -1,52 +1,82 @@
|
|||||||
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
import {
|
||||||
|
Button,
|
||||||
|
type ButtonProps,
|
||||||
|
HStack,
|
||||||
|
Text,
|
||||||
|
Icon,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuList,
|
||||||
|
MenuItem,
|
||||||
|
IconButton,
|
||||||
|
Spinner,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
|
import {
|
||||||
import { stickyHeaderStyle } from "./styles";
|
useExperiment,
|
||||||
import { BsPencil } from "react-icons/bs";
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useScenarios,
|
||||||
|
} from "~/utils/hooks";
|
||||||
|
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
|
||||||
export const ScenariosHeader = ({
|
export const ActionButton = (props: ButtonProps) => (
|
||||||
headerRows,
|
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
||||||
numScenarios,
|
);
|
||||||
}: {
|
|
||||||
headerRows: number;
|
export const ScenariosHeader = () => {
|
||||||
numScenarios: number;
|
|
||||||
}) => {
|
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
|
const scenarios = useScenarios();
|
||||||
|
|
||||||
const [ref, dimensions] = useElementDimensions();
|
const experiment = useExperiment();
|
||||||
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
const createScenarioMutation = api.scenarios.create.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const [onAddScenario, loading] = useHandledAsyncCallback(
|
||||||
|
async (autogenerate: boolean) => {
|
||||||
|
if (!experiment.data) return;
|
||||||
|
await createScenarioMutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
autogenerate,
|
||||||
|
});
|
||||||
|
await utils.scenarios.list.invalidate();
|
||||||
|
},
|
||||||
|
[createScenarioMutation],
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<GridItem
|
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
<Text fontSize={16} fontWeight="bold">
|
||||||
ref={ref as any}
|
Scenarios ({scenarios.data?.count})
|
||||||
display="flex"
|
</Text>
|
||||||
alignItems="flex-end"
|
{canModify && (
|
||||||
rowSpan={headerRows}
|
<Menu>
|
||||||
px={cellPadding.x}
|
<MenuButton
|
||||||
py={cellPadding.y}
|
as={IconButton}
|
||||||
// Only display the part of the grid item that has content
|
mt={1}
|
||||||
sx={{ ...stickyHeaderStyle, top: topValue }}
|
|
||||||
>
|
|
||||||
<HStack w="100%">
|
|
||||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
|
||||||
Scenarios ({numScenarios})
|
|
||||||
</Heading>
|
|
||||||
{canModify && (
|
|
||||||
<Button
|
|
||||||
size="xs"
|
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
color="gray.500"
|
aria-label="Edit Scenarios"
|
||||||
aria-label="Edit"
|
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||||
leftIcon={<BsPencil />}
|
/>
|
||||||
onClick={openDrawer}
|
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
|
||||||
>
|
<MenuItem
|
||||||
Edit Vars
|
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
|
||||||
</Button>
|
onClick={() => onAddScenario(false)}
|
||||||
)}
|
>
|
||||||
</HStack>
|
Add Scenario
|
||||||
</GridItem>
|
</MenuItem>
|
||||||
|
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
||||||
|
Autogenerate Scenario
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
|
||||||
|
Edit Vars
|
||||||
|
</MenuItem>
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -2,19 +2,24 @@ import {
|
|||||||
Box,
|
Box,
|
||||||
Button,
|
Button,
|
||||||
HStack,
|
HStack,
|
||||||
|
IconButton,
|
||||||
Spinner,
|
Spinner,
|
||||||
|
Text,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
useToast,
|
useToast,
|
||||||
Text,
|
|
||||||
IconButton,
|
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
|
||||||
import { type PromptVariant } from "./types";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useAppStore } from "~/state/store";
|
|
||||||
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||||
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import {
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useModifierKeyLabel,
|
||||||
|
useVisibleScenarioIds,
|
||||||
|
} from "~/utils/hooks";
|
||||||
|
import { type PromptVariant } from "./types";
|
||||||
|
|
||||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
@@ -63,6 +68,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||||
if (!editorRef.current) return;
|
if (!editorRef.current) return;
|
||||||
@@ -91,6 +97,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
const resp = await replaceVariant.mutateAsync({
|
const resp = await replaceVariant.mutateAsync({
|
||||||
id: props.variant.id,
|
id: props.variant.id,
|
||||||
constructFn: currentFn,
|
constructFn: currentFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
if (resp.status === "error") {
|
if (resp.status === "error") {
|
||||||
return toast({
|
return toast({
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
import { Grid, GridItem } from "@chakra-ui/react";
|
import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import NewScenarioButton from "./NewScenarioButton";
|
import AddVariantButton from "./AddVariantButton";
|
||||||
import NewVariantButton from "./NewVariantButton";
|
|
||||||
import ScenarioRow from "./ScenarioRow";
|
import ScenarioRow from "./ScenarioRow";
|
||||||
import VariantEditor from "./VariantEditor";
|
import VariantEditor from "./VariantEditor";
|
||||||
import VariantHeader from "../VariantHeader/VariantHeader";
|
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||||
import VariantStats from "./VariantStats";
|
import VariantStats from "./VariantStats";
|
||||||
import { ScenariosHeader } from "./ScenariosHeader";
|
import { ScenariosHeader } from "./ScenariosHeader";
|
||||||
import { stickyHeaderStyle } from "./styles";
|
import { borders } from "./styles";
|
||||||
|
import { useScenarios } from "~/utils/hooks";
|
||||||
|
import ScenarioPaginator from "./ScenarioPaginator";
|
||||||
|
import { Fragment } from "react";
|
||||||
|
|
||||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||||
const variants = api.promptVariants.list.useQuery(
|
const variants = api.promptVariants.list.useQuery(
|
||||||
@@ -15,68 +17,90 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
{ enabled: !!experimentId },
|
{ enabled: !!experimentId },
|
||||||
);
|
);
|
||||||
|
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
const scenarios = useScenarios();
|
||||||
{ experimentId: experimentId as string },
|
|
||||||
{ enabled: !!experimentId },
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!variants.data || !scenarios.data) return null;
|
if (!variants.data || !scenarios.data) return null;
|
||||||
|
|
||||||
const allCols = variants.data.length + 1;
|
const allCols = variants.data.length + 2;
|
||||||
const headerRows = 3;
|
const variantHeaderRows = 3;
|
||||||
|
const scenarioHeaderRows = 1;
|
||||||
|
const scenarioFooterRows = 1;
|
||||||
|
const visibleScenariosCount = scenarios.data.scenarios.length;
|
||||||
|
const allRows =
|
||||||
|
variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + scenarioFooterRows;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid
|
<Grid
|
||||||
p={4}
|
pt={4}
|
||||||
pb={24}
|
pb={24}
|
||||||
|
pl={4}
|
||||||
display="grid"
|
display="grid"
|
||||||
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
||||||
sx={{
|
sx={{
|
||||||
"> *": {
|
"> *": {
|
||||||
borderColor: "gray.300",
|
borderColor: "gray.300",
|
||||||
borderBottomWidth: 1,
|
|
||||||
borderRightWidth: 1,
|
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
fontSize="sm"
|
fontSize="sm"
|
||||||
>
|
>
|
||||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
<GridItem rowSpan={variantHeaderRows}>
|
||||||
|
<AddVariantButton />
|
||||||
{variants.data.map((variant) => (
|
|
||||||
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
|
|
||||||
))}
|
|
||||||
<GridItem
|
|
||||||
rowSpan={scenarios.data.length + headerRows}
|
|
||||||
padding={0}
|
|
||||||
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
|
|
||||||
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
|
|
||||||
h={8}
|
|
||||||
sx={stickyHeaderStyle}
|
|
||||||
>
|
|
||||||
<NewVariantButton />
|
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant, i) => {
|
||||||
<GridItem key={variant.uiId}>
|
const sharedProps: GridItemProps = {
|
||||||
<VariantEditor variant={variant} />
|
...borders,
|
||||||
</GridItem>
|
colStart: i + 2,
|
||||||
))}
|
borderLeftWidth: i === 0 ? 1 : 0,
|
||||||
{variants.data.map((variant) => (
|
marginLeft: i === 0 ? "-1px" : 0,
|
||||||
<GridItem key={variant.uiId}>
|
};
|
||||||
<VariantStats variant={variant} />
|
return (
|
||||||
</GridItem>
|
<Fragment key={variant.uiId}>
|
||||||
))}
|
<VariantHeader
|
||||||
{scenarios.data.map((scenario) => (
|
variant={variant}
|
||||||
|
canHide={variants.data.length > 1}
|
||||||
|
rowStart={1}
|
||||||
|
{...sharedProps}
|
||||||
|
/>
|
||||||
|
<GridItem rowStart={2} {...sharedProps}>
|
||||||
|
<VariantEditor variant={variant} />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem rowStart={3} {...sharedProps}>
|
||||||
|
<VariantStats variant={variant} />
|
||||||
|
</GridItem>
|
||||||
|
</Fragment>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
<GridItem
|
||||||
|
colSpan={allCols - 1}
|
||||||
|
rowStart={variantHeaderRows + 1}
|
||||||
|
colStart={1}
|
||||||
|
{...borders}
|
||||||
|
borderRightWidth={0}
|
||||||
|
>
|
||||||
|
<ScenariosHeader />
|
||||||
|
</GridItem>
|
||||||
|
|
||||||
|
{scenarios.data.scenarios.map((scenario, i) => (
|
||||||
<ScenarioRow
|
<ScenarioRow
|
||||||
|
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
||||||
key={scenario.uiId}
|
key={scenario.uiId}
|
||||||
scenario={scenario}
|
scenario={scenario}
|
||||||
variants={variants.data}
|
variants={variants.data}
|
||||||
canHide={scenarios.data.length > 1}
|
canHide={visibleScenariosCount > 1}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
<GridItem
|
||||||
<NewScenarioButton />
|
rowStart={variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + 2}
|
||||||
|
colStart={1}
|
||||||
|
colSpan={allCols}
|
||||||
|
>
|
||||||
|
<ScenarioPaginator />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
|
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
||||||
|
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
||||||
</Grid>
|
</Grid>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { type SystemStyleObject } from "@chakra-ui/react";
|
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
|
||||||
|
|
||||||
export const stickyHeaderStyle: SystemStyleObject = {
|
export const stickyHeaderStyle: SystemStyleObject = {
|
||||||
position: "sticky",
|
position: "sticky",
|
||||||
@@ -6,3 +6,8 @@ export const stickyHeaderStyle: SystemStyleObject = {
|
|||||||
backgroundColor: "#fff",
|
backgroundColor: "#fff",
|
||||||
zIndex: 10,
|
zIndex: 10,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const borders: GridItemProps = {
|
||||||
|
borderRightWidth: 1,
|
||||||
|
borderBottomWidth: 1,
|
||||||
|
};
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ import { type RouterOutputs } from "~/utils/api";
|
|||||||
|
|
||||||
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
||||||
|
|
||||||
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>["scenarios"][0];
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||||
import { type IconType } from "react-icons";
|
import { type IconType } from "react-icons";
|
||||||
|
import { BsStars } from "react-icons/bs";
|
||||||
|
|
||||||
export const RefineOption = ({
|
export const RefineAction = ({
|
||||||
label,
|
label,
|
||||||
icon,
|
icon,
|
||||||
desciption,
|
desciption,
|
||||||
@@ -10,7 +11,7 @@ export const RefineOption = ({
|
|||||||
loading,
|
loading,
|
||||||
}: {
|
}: {
|
||||||
label: string;
|
label: string;
|
||||||
icon: IconType;
|
icon?: IconType;
|
||||||
desciption: string;
|
desciption: string;
|
||||||
activeLabel: string | undefined;
|
activeLabel: string | undefined;
|
||||||
onClick: (label: string) => void;
|
onClick: (label: string) => void;
|
||||||
@@ -44,7 +45,7 @@ export const RefineOption = ({
|
|||||||
opacity={loading ? 0.5 : 1}
|
opacity={loading ? 0.5 : 1}
|
||||||
>
|
>
|
||||||
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
||||||
<Icon as={icon} boxSize={12} />
|
<Icon as={icon || BsStars} boxSize={12} />
|
||||||
<Heading size="md" fontFamily="inconsolata, monospace">
|
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||||
{label}
|
{label}
|
||||||
</Heading>
|
</Heading>
|
||||||
@@ -16,15 +16,15 @@ import {
|
|||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsStars } from "react-icons/bs";
|
import { BsStars } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import CompareFunctions from "./CompareFunctions";
|
import CompareFunctions from "./CompareFunctions";
|
||||||
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||||
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
|
import { RefineAction } from "./RefineAction";
|
||||||
import { RefineOption } from "./RefineOption";
|
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
import { type SupportedProvider } from "~/modelProviders/types";
|
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
|
||||||
export const RefinePromptModal = ({
|
export const RefinePromptModal = ({
|
||||||
variant,
|
variant,
|
||||||
@@ -34,14 +34,16 @@ export const RefinePromptModal = ({
|
|||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
|
const refinementActions =
|
||||||
|
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
|
||||||
|
|
||||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
const [instructions, setInstructions] = useState<string>("");
|
const [instructions, setInstructions] = useState<string>("");
|
||||||
|
|
||||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
|
const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>(
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -49,15 +51,15 @@ export const RefinePromptModal = ({
|
|||||||
async (label?: string) => {
|
async (label?: string) => {
|
||||||
if (!variant.experimentId) return;
|
if (!variant.experimentId) return;
|
||||||
const updatedInstructions = label
|
const updatedInstructions = label
|
||||||
? (providerRefineOptions[label] as RefineOptionInfo).instructions
|
? (refinementActions[label] as RefinementAction).instructions
|
||||||
: instructions;
|
: instructions;
|
||||||
setActiveRefineOptionLabel(label);
|
setActiveRefineActionLabel(label);
|
||||||
await getModifiedPromptMutateAsync({
|
await getModifiedPromptMutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
instructions: updatedInstructions,
|
instructions: updatedInstructions,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel],
|
||||||
);
|
);
|
||||||
|
|
||||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
@@ -72,6 +74,7 @@ export const RefinePromptModal = ({
|
|||||||
await replaceVariantMutation.mutateAsync({
|
await replaceVariantMutation.mutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
constructFn: refinedPromptFn,
|
constructFn: refinedPromptFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
onClose();
|
onClose();
|
||||||
@@ -95,18 +98,18 @@ export const RefinePromptModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<VStack spacing={4}>
|
<VStack spacing={4}>
|
||||||
{Object.keys(providerRefineOptions).length && (
|
{Object.keys(refinementActions).length && (
|
||||||
<>
|
<>
|
||||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||||
{Object.keys(providerRefineOptions).map((label) => (
|
{Object.keys(refinementActions).map((label) => (
|
||||||
<RefineOption
|
<RefineAction
|
||||||
key={label}
|
key={label}
|
||||||
label={label}
|
label={label}
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
icon={providerRefineOptions[label]!.icon}
|
icon={refinementActions[label]!.icon}
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
desciption={providerRefineOptions[label]!.description}
|
desciption={refinementActions[label]!.description}
|
||||||
activeLabel={activeRefineOptionLabel}
|
activeLabel={activeRefineActionLabel}
|
||||||
onClick={getModifiedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
loading={modificationInProgress}
|
loading={modificationInProgress}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -1,287 +0,0 @@
|
|||||||
// Super hacky, but we'll redo the organization when we have more models
|
|
||||||
|
|
||||||
import { type SupportedProvider } from "~/modelProviders/types";
|
|
||||||
import { VscJson } from "react-icons/vsc";
|
|
||||||
import { TfiThought } from "react-icons/tfi";
|
|
||||||
import { type IconType } from "react-icons";
|
|
||||||
|
|
||||||
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
|
|
||||||
|
|
||||||
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
|
|
||||||
"openai/ChatCompletion": {
|
|
||||||
"Add chain of thought": {
|
|
||||||
icon: VscJson,
|
|
||||||
description: "Asking the model to plan its answer can increase accuracy.",
|
|
||||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
|
||||||
|
|
||||||
This is what a prompt looks like before adding chain of thought:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Evaluate sentiment.\`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
This is what one looks like after adding chain of thought:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Evaluate sentiment.\`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
Here's another example:
|
|
||||||
|
|
||||||
Before:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Title: \${scenario.title}
|
|
||||||
Body: \${scenario.body}
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "score_post",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
score: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "score_post",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
After:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Title: \${scenario.title}
|
|
||||||
Body: \${scenario.body}
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "score_post",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
explanation: {
|
|
||||||
type: "string",
|
|
||||||
}
|
|
||||||
score: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "score_post",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Add chain of thought to the original prompt.`,
|
|
||||||
},
|
|
||||||
"Convert to function call": {
|
|
||||||
icon: TfiThought,
|
|
||||||
description: "Use function calls to get output from the model in a more structured way.",
|
|
||||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
|
||||||
|
|
||||||
This is what a prompt looks like before adding a function:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Evaluate sentiment.\`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
This is what one looks like after adding a function:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-4",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: "Evaluate sentiment.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: scenario.user_message,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "extract_sentiment",
|
|
||||||
parameters: {
|
|
||||||
type: "object", // parameters must always be an object with a properties key
|
|
||||||
properties: { // properties key is required
|
|
||||||
sentiment: {
|
|
||||||
type: "string",
|
|
||||||
description: "one of positive/negative/neutral",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "extract_sentiment",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Here's another example of adding a function:
|
|
||||||
|
|
||||||
Before:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Here is the title and body of a reddit post I am interested in:
|
|
||||||
|
|
||||||
title: \${scenario.title}
|
|
||||||
body: \${scenario.body}
|
|
||||||
|
|
||||||
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Answer one integer between 1 and 3.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
});
|
|
||||||
|
|
||||||
After:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: \`Title: \${scenario.title}
|
|
||||||
Body: \${scenario.body}
|
|
||||||
|
|
||||||
Need: \${scenario.need}
|
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale.\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature: 0,
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "score_post",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
score: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "score_post",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Another example
|
|
||||||
|
|
||||||
Before:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
stream: true,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
After:
|
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
|
||||||
model: "gpt-3.5-turbo",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
functions: [
|
|
||||||
{
|
|
||||||
name: "write_in_language",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
text: {
|
|
||||||
type: "string",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
function_call: {
|
|
||||||
name: "write_in_language",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"replicate/llama2": {},
|
|
||||||
};
|
|
||||||
@@ -3,28 +3,34 @@ import { type PromptVariant } from "../OutputsTable/types";
|
|||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { RiDraggable } from "react-icons/ri";
|
import { RiDraggable } from "react-icons/ri";
|
||||||
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
|
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||||
|
|
||||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
export default function VariantHeader(
|
||||||
|
allProps: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
canHide: boolean;
|
||||||
|
} & GridItemProps,
|
||||||
|
) {
|
||||||
|
const { variant, canHide, ...gridItemProps } = allProps;
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||||
const [label, setLabel] = useState(props.variant.label);
|
const [label, setLabel] = useState(variant.label);
|
||||||
|
|
||||||
const updateMutation = api.promptVariants.update.useMutation();
|
const updateMutation = api.promptVariants.update.useMutation();
|
||||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||||
if (label && label !== props.variant.label) {
|
if (label && label !== variant.label) {
|
||||||
await updateMutation.mutateAsync({
|
await updateMutation.mutateAsync({
|
||||||
id: props.variant.id,
|
id: variant.id,
|
||||||
updates: { label: label },
|
updates: { label: label },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
}, [updateMutation, variant.id, variant.label, label]);
|
||||||
|
|
||||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||||
const [onReorder] = useHandledAsyncCallback(
|
const [onReorder] = useHandledAsyncCallback(
|
||||||
@@ -32,7 +38,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
setIsDragTarget(false);
|
setIsDragTarget(false);
|
||||||
const draggedId = e.dataTransfer.getData("text/plain");
|
const draggedId = e.dataTransfer.getData("text/plain");
|
||||||
const droppedId = props.variant.id;
|
const droppedId = variant.id;
|
||||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||||
await reorderMutation.mutateAsync({
|
await reorderMutation.mutateAsync({
|
||||||
draggedId,
|
draggedId,
|
||||||
@@ -40,16 +46,16 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
},
|
},
|
||||||
[reorderMutation, props.variant.id],
|
[reorderMutation, variant.id],
|
||||||
);
|
);
|
||||||
|
|
||||||
const [menuOpen, setMenuOpen] = useState(false);
|
const [menuOpen, setMenuOpen] = useState(false);
|
||||||
|
|
||||||
if (!canModify) {
|
if (!canModify) {
|
||||||
return (
|
return (
|
||||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
|
||||||
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||||
{props.variant.label}
|
{variant.label}
|
||||||
</Text>
|
</Text>
|
||||||
</GridItem>
|
</GridItem>
|
||||||
);
|
);
|
||||||
@@ -64,6 +70,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||||
}}
|
}}
|
||||||
borderTopWidth={1}
|
borderTopWidth={1}
|
||||||
|
{...gridItemProps}
|
||||||
>
|
>
|
||||||
<HStack
|
<HStack
|
||||||
spacing={4}
|
spacing={4}
|
||||||
@@ -71,7 +78,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
minH={headerMinHeight}
|
minH={headerMinHeight}
|
||||||
draggable={!isInputHovered}
|
draggable={!isInputHovered}
|
||||||
onDragStart={(e) => {
|
onDragStart={(e) => {
|
||||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
e.dataTransfer.setData("text/plain", variant.id);
|
||||||
e.currentTarget.style.opacity = "0.4";
|
e.currentTarget.style.opacity = "0.4";
|
||||||
}}
|
}}
|
||||||
onDragEnd={(e) => {
|
onDragEnd={(e) => {
|
||||||
@@ -112,8 +119,8 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
onMouseLeave={() => setIsInputHovered(false)}
|
onMouseLeave={() => setIsInputHovered(false)}
|
||||||
/>
|
/>
|
||||||
<VariantHeaderMenuButton
|
<VariantHeaderMenuButton
|
||||||
variant={props.variant}
|
variant={variant}
|
||||||
canHide={props.canHide}
|
canHide={canHide}
|
||||||
menuOpen={menuOpen}
|
menuOpen={menuOpen}
|
||||||
setMenuOpen={setMenuOpen}
|
setMenuOpen={setMenuOpen}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import { type PromptVariant } from "../OutputsTable/types";
|
import { type PromptVariant } from "../OutputsTable/types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
import {
|
import {
|
||||||
Button,
|
|
||||||
Icon,
|
Icon,
|
||||||
Menu,
|
Menu,
|
||||||
MenuButton,
|
MenuButton,
|
||||||
@@ -11,6 +10,7 @@ import {
|
|||||||
MenuDivider,
|
MenuDivider,
|
||||||
Text,
|
Text,
|
||||||
Spinner,
|
Spinner,
|
||||||
|
IconButton,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
||||||
import { FaRegClone } from "react-icons/fa";
|
import { FaRegClone } from "react-icons/fa";
|
||||||
@@ -33,11 +33,13 @@ export default function VariantHeaderMenuButton({
|
|||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
await duplicateMutation.mutateAsync({
|
await duplicateMutation.mutateAsync({
|
||||||
experimentId: variant.experimentId,
|
experimentId: variant.experimentId,
|
||||||
variantId: variant.id,
|
variantId: variant.id,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [duplicateMutation, variant.experimentId, variant.id]);
|
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||||
@@ -56,15 +58,12 @@ export default function VariantHeaderMenuButton({
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||||
{duplicationInProgress ? (
|
<MenuButton
|
||||||
<Spinner boxSize={4} mx={3} my={3} />
|
as={IconButton}
|
||||||
) : (
|
variant="ghost"
|
||||||
<MenuButton>
|
aria-label="Edit Scenarios"
|
||||||
<Button variant="ghost">
|
icon={<Icon as={duplicationInProgress ? Spinner : BsGear} />}
|
||||||
<Icon as={BsGear} />
|
/>
|
||||||
</Button>
|
|
||||||
</MenuButton>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<MenuList mt={-3} fontSize="md">
|
<MenuList mt={-3} fontSize="md">
|
||||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||||
|
|||||||
@@ -1,4 +1,13 @@
|
|||||||
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
|
import {
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Divider,
|
||||||
|
Spinner,
|
||||||
|
AspectRatio,
|
||||||
|
SkeletonText,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import { formatTimePast } from "~/utils/dayjs";
|
import { formatTimePast } from "~/utils/dayjs";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
@@ -93,3 +102,13 @@ export const NewExperimentCard = () => {
|
|||||||
</AspectRatio>
|
</AspectRatio>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const ExperimentCardSkeleton = () => (
|
||||||
|
<AspectRatio ratio={1.2} w="full">
|
||||||
|
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
|
||||||
|
<SkeletonText noOfLines={1} w="80%" />
|
||||||
|
<SkeletonText noOfLines={2} w="60%" />
|
||||||
|
<SkeletonText noOfLines={1} w="80%" />
|
||||||
|
</VStack>
|
||||||
|
</AspectRatio>
|
||||||
|
);
|
||||||
|
|||||||
57
src/components/experiments/HeaderButtons/DeleteDialog.tsx
Normal file
57
src/components/experiments/HeaderButtons/DeleteDialog.tsx
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { useRef } from "react";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
|
export const DeleteDialog = ({ onClose }: { onClose: () => void }) => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const deleteMutation = api.experiments.delete.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data?.id) return;
|
||||||
|
await deleteMutation.mutateAsync({ id: experiment.data.id });
|
||||||
|
await utils.experiments.list.invalidate();
|
||||||
|
await router.push({ pathname: "/experiments" });
|
||||||
|
onClose();
|
||||||
|
}, [deleteMutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog isOpen leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
Delete Experiment
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||||
|
as well. Are you sure?
|
||||||
|
</AlertDialogBody>
|
||||||
|
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<Button ref={cancelRef} onClick={onClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
42
src/components/experiments/HeaderButtons/HeaderButtons.tsx
Normal file
42
src/components/experiments/HeaderButtons/HeaderButtons.tsx
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import { Button, HStack, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||||
|
import { useOnForkButtonPressed } from "./useOnForkButtonPressed";
|
||||||
|
import { useExperiment } from "~/utils/hooks";
|
||||||
|
import { BsGearFill } from "react-icons/bs";
|
||||||
|
import { TbGitFork } from "react-icons/tb";
|
||||||
|
import { useAppStore } from "~/state/store";
|
||||||
|
|
||||||
|
export const HeaderButtons = () => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
|
||||||
|
const canModify = experiment.data?.access.canModify ?? false;
|
||||||
|
|
||||||
|
const { onForkButtonPressed, isForking } = useOnForkButtonPressed();
|
||||||
|
|
||||||
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
|
|
||||||
|
if (experiment.isLoading) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack spacing={0} mt={{ base: 2, md: 0 }}>
|
||||||
|
<Button
|
||||||
|
onClick={onForkButtonPressed}
|
||||||
|
mr={4}
|
||||||
|
colorScheme={canModify ? undefined : "orange"}
|
||||||
|
bgColor={canModify ? undefined : "orange.400"}
|
||||||
|
minW={0}
|
||||||
|
variant={canModify ? "ghost" : "solid"}
|
||||||
|
>
|
||||||
|
{isForking ? <Spinner boxSize={5} /> : <Icon as={TbGitFork} boxSize={5} />}
|
||||||
|
<Text ml={2}>Fork</Text>
|
||||||
|
</Button>
|
||||||
|
{canModify && (
|
||||||
|
<Button variant={{ base: "solid", md: "ghost" }} onClick={openDrawer}>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={BsGearFill} />
|
||||||
|
<Text>Settings</Text>
|
||||||
|
</HStack>
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
import { useCallback } from "react";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
|
||||||
|
export const useOnForkButtonPressed = () => {
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const user = useSession().data;
|
||||||
|
const experiment = useExperiment();
|
||||||
|
|
||||||
|
const forkMutation = api.experiments.fork.useMutation();
|
||||||
|
|
||||||
|
const [onFork, isForking] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data?.id) return;
|
||||||
|
const forkedExperimentId = await forkMutation.mutateAsync({ id: experiment.data.id });
|
||||||
|
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
|
||||||
|
}, [forkMutation, experiment.data?.id, router]);
|
||||||
|
|
||||||
|
const onForkButtonPressed = useCallback(() => {
|
||||||
|
if (user === null) {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
} else {
|
||||||
|
onFork();
|
||||||
|
}
|
||||||
|
}, [onFork, user]);
|
||||||
|
|
||||||
|
return { onForkButtonPressed, isForking };
|
||||||
|
};
|
||||||
@@ -2,6 +2,7 @@ import { type JsonValue } from "type-fest";
|
|||||||
import { type SupportedModel } from ".";
|
import { type SupportedModel } from ".";
|
||||||
import { type FrontendModelProvider } from "../types";
|
import { type FrontendModelProvider } from "../types";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
import { refinementActions } from "./refinementActions";
|
||||||
|
|
||||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||||
name: "OpenAI ChatCompletion",
|
name: "OpenAI ChatCompletion",
|
||||||
@@ -45,6 +46,8 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletio
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
refinementActions,
|
||||||
|
|
||||||
normalizeOutput: (output) => {
|
normalizeOutput: (output) => {
|
||||||
const message = output.choices[0]?.message;
|
const message = output.choices[0]?.message;
|
||||||
if (!message)
|
if (!message)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ const modelProvider: OpenaiChatModelProvider = {
|
|||||||
return null;
|
return null;
|
||||||
},
|
},
|
||||||
inputSchema: inputSchema as JSONSchema4,
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
shouldStream: (input) => input.stream ?? false,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|||||||
279
src/modelProviders/openai-ChatCompletion/refinementActions.ts
Normal file
279
src/modelProviders/openai-ChatCompletion/refinementActions.ts
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
import { TfiThought } from "react-icons/tfi";
|
||||||
|
import { type RefinementAction } from "../types";
|
||||||
|
import { VscJson } from "react-icons/vsc";
|
||||||
|
|
||||||
|
export const refinementActions: Record<string, RefinementAction> = {
|
||||||
|
"Add chain of thought": {
|
||||||
|
icon: VscJson,
|
||||||
|
description: "Asking the model to plan its answer can increase accuracy.",
|
||||||
|
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding chain of thought:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
This is what one looks like after adding chain of thought:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
Here's another example:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
explanation: {
|
||||||
|
type: "string",
|
||||||
|
}
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Add chain of thought to the original prompt.`,
|
||||||
|
},
|
||||||
|
"Convert to function call": {
|
||||||
|
icon: TfiThought,
|
||||||
|
description: "Use function calls to get output from the model in a more structured way.",
|
||||||
|
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||||
|
|
||||||
|
This is what a prompt looks like before adding a function:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Evaluate sentiment.\`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
This is what one looks like after adding a function:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-4",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "Evaluate sentiment.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: scenario.user_message,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "extract_sentiment",
|
||||||
|
parameters: {
|
||||||
|
type: "object", // parameters must always be an object with a properties key
|
||||||
|
properties: { // properties key is required
|
||||||
|
sentiment: {
|
||||||
|
type: "string",
|
||||||
|
description: "one of positive/negative/neutral",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "extract_sentiment",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Here's another example of adding a function:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Here is the title and body of a reddit post I am interested in:
|
||||||
|
|
||||||
|
title: \${scenario.title}
|
||||||
|
body: \${scenario.body}
|
||||||
|
|
||||||
|
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Answer one integer between 1 and 3.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`Title: \${scenario.title}
|
||||||
|
Body: \${scenario.body}
|
||||||
|
|
||||||
|
Need: \${scenario.need}
|
||||||
|
|
||||||
|
Rate likelihood on 1-3 scale.\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "score_post",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
score: {
|
||||||
|
type: "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "score_post",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Another example
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "write_in_language",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
text: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "write_in_language",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||||
|
},
|
||||||
|
};
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||||
import { type FrontendModelProvider } from "../types";
|
import { type FrontendModelProvider } from "../types";
|
||||||
|
import { refinementActions } from "./refinementActions";
|
||||||
|
|
||||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||||
name: "Replicate Llama2",
|
name: "Replicate Llama2",
|
||||||
@@ -31,6 +32,8 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlam
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
refinementActions,
|
||||||
|
|
||||||
normalizeOutput: (output) => {
|
normalizeOutput: (output) => {
|
||||||
return {
|
return {
|
||||||
type: "text",
|
type: "text",
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ const replicate = new Replicate({
|
|||||||
});
|
});
|
||||||
|
|
||||||
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||||
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
"7b-chat": "5ec5fdadd80ace49f5a2b2178cceeb9f2f77c493b85b1131002c26e6b2b13184",
|
||||||
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
"13b-chat": "6b4da803a2382c08868c5af10a523892f38e2de1aafb2ee55b020d9efef2fdb8",
|
||||||
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
"70b-chat": "2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48",
|
||||||
};
|
};
|
||||||
|
|
||||||
export async function getCompletion(
|
export async function getCompletion(
|
||||||
@@ -19,7 +19,7 @@ export async function getCompletion(
|
|||||||
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||||
const start = Date.now();
|
const start = Date.now();
|
||||||
|
|
||||||
const { model, stream, ...rest } = input;
|
const { model, ...rest } = input;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const prediction = await replicate.predictions.create({
|
const prediction = await replicate.predictions.create({
|
||||||
@@ -27,8 +27,6 @@ export async function getCompletion(
|
|||||||
input: rest,
|
input: rest,
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log("stream?", onStream);
|
|
||||||
|
|
||||||
const interval = onStream
|
const interval = onStream
|
||||||
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||||
setInterval(async () => {
|
setInterval(async () => {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ export type SupportedModel = (typeof supportedModels)[number];
|
|||||||
export type ReplicateLlama2Input = {
|
export type ReplicateLlama2Input = {
|
||||||
model: SupportedModel;
|
model: SupportedModel;
|
||||||
prompt: string;
|
prompt: string;
|
||||||
stream?: boolean;
|
|
||||||
max_length?: number;
|
max_length?: number;
|
||||||
temperature?: number;
|
temperature?: number;
|
||||||
top_p?: number;
|
top_p?: number;
|
||||||
@@ -38,31 +37,43 @@ const modelProvider: ReplicateLlama2Provider = {
|
|||||||
type: "string",
|
type: "string",
|
||||||
enum: supportedModels as unknown as string[],
|
enum: supportedModels as unknown as string[],
|
||||||
},
|
},
|
||||||
|
system_prompt: {
|
||||||
|
type: "string",
|
||||||
|
description:
|
||||||
|
"System prompt to send to Llama v2. This is prepended to the prompt and helps guide system behavior.",
|
||||||
|
},
|
||||||
prompt: {
|
prompt: {
|
||||||
type: "string",
|
type: "string",
|
||||||
|
description: "Prompt to send to Llama v2.",
|
||||||
},
|
},
|
||||||
stream: {
|
max_new_tokens: {
|
||||||
type: "boolean",
|
|
||||||
},
|
|
||||||
max_length: {
|
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"Maximum number of tokens to generate. A word is generally 2-3 tokens (minimum: 1)",
|
||||||
},
|
},
|
||||||
temperature: {
|
temperature: {
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value. (minimum: 0.01; maximum: 5)",
|
||||||
},
|
},
|
||||||
top_p: {
|
top_p: {
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens (minimum: 0.01; maximum: 1)",
|
||||||
},
|
},
|
||||||
repetition_penalty: {
|
repetition_penalty: {
|
||||||
type: "number",
|
type: "number",
|
||||||
|
description:
|
||||||
|
"Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it. (minimum: 0.01; maximum: 5)",
|
||||||
},
|
},
|
||||||
debug: {
|
debug: {
|
||||||
type: "boolean",
|
type: "boolean",
|
||||||
|
description: "provide debugging output in logs",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
required: ["model", "prompt"],
|
required: ["model", "prompt"],
|
||||||
},
|
},
|
||||||
shouldStream: (input) => input.stream ?? false,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|||||||
3
src/modelProviders/replicate-llama2/refinementActions.ts
Normal file
3
src/modelProviders/replicate-llama2/refinementActions.ts
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
import { type RefinementAction } from "../types";
|
||||||
|
|
||||||
|
export const refinementActions: Record<string, RefinementAction> = {};
|
||||||
@@ -1,31 +1,35 @@
|
|||||||
import { type JSONSchema4 } from "json-schema";
|
import { type JSONSchema4 } from "json-schema";
|
||||||
|
import { type IconType } from "react-icons";
|
||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
const ZodSupportedProvider = z.union([
|
export const ZodSupportedProvider = z.union([
|
||||||
z.literal("openai/ChatCompletion"),
|
z.literal("openai/ChatCompletion"),
|
||||||
z.literal("replicate/llama2"),
|
z.literal("replicate/llama2"),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||||
|
|
||||||
export const ZodModel = z.object({
|
export type Model = {
|
||||||
name: z.string(),
|
name: string;
|
||||||
contextWindow: z.number(),
|
contextWindow: number;
|
||||||
promptTokenPrice: z.number().optional(),
|
promptTokenPrice?: number;
|
||||||
completionTokenPrice: z.number().optional(),
|
completionTokenPrice?: number;
|
||||||
pricePerSecond: z.number().optional(),
|
pricePerSecond?: number;
|
||||||
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
|
speed: "fast" | "medium" | "slow";
|
||||||
provider: ZodSupportedProvider,
|
provider: SupportedProvider;
|
||||||
description: z.string().optional(),
|
description?: string;
|
||||||
learnMoreUrl: z.string().optional(),
|
learnMoreUrl?: string;
|
||||||
});
|
};
|
||||||
|
|
||||||
export type Model = z.infer<typeof ZodModel>;
|
export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };
|
||||||
|
|
||||||
|
export type RefinementAction = { icon?: IconType; description: string; instructions: string };
|
||||||
|
|
||||||
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||||
name: string;
|
name: string;
|
||||||
models: Record<SupportedModels, Model>;
|
models: Record<SupportedModels, Model>;
|
||||||
|
refinementActions?: Record<string, RefinementAction>;
|
||||||
|
|
||||||
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||||
};
|
};
|
||||||
@@ -44,7 +48,7 @@ export type CompletionResponse<T> =
|
|||||||
|
|
||||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||||
getModel: (input: InputSchema) => SupportedModels | null;
|
getModel: (input: InputSchema) => SupportedModels | null;
|
||||||
shouldStream: (input: InputSchema) => boolean;
|
canStream: boolean;
|
||||||
inputSchema: JSONSchema4;
|
inputSchema: JSONSchema4;
|
||||||
getCompletion: (
|
getCompletion: (
|
||||||
input: InputSchema,
|
input: InputSchema,
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import "~/utils/analytics";
|
|||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
||||||
import { SyncAppStore } from "~/state/sync";
|
import { SyncAppStore } from "~/state/sync";
|
||||||
|
import NextAdapterApp from "next-query-params/app";
|
||||||
|
import { QueryParamProvider } from "use-query-params";
|
||||||
|
|
||||||
const MyApp: AppType<{ session: Session | null }> = ({
|
const MyApp: AppType<{ session: Session | null }> = ({
|
||||||
Component,
|
Component,
|
||||||
@@ -24,7 +26,9 @@ const MyApp: AppType<{ session: Session | null }> = ({
|
|||||||
<SyncAppStore />
|
<SyncAppStore />
|
||||||
<Favicon />
|
<Favicon />
|
||||||
<ChakraThemeProvider>
|
<ChakraThemeProvider>
|
||||||
<Component {...pageProps} />
|
<QueryParamProvider adapter={NextAdapterApp}>
|
||||||
|
<Component {...pageProps} />
|
||||||
|
</QueryParamProvider>
|
||||||
</ChakraThemeProvider>
|
</ChakraThemeProvider>
|
||||||
</SessionProvider>
|
</SessionProvider>
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -2,106 +2,37 @@ import {
|
|||||||
Box,
|
Box,
|
||||||
Breadcrumb,
|
Breadcrumb,
|
||||||
BreadcrumbItem,
|
BreadcrumbItem,
|
||||||
Button,
|
|
||||||
Center,
|
Center,
|
||||||
Flex,
|
Flex,
|
||||||
Icon,
|
Icon,
|
||||||
Input,
|
Input,
|
||||||
AlertDialog,
|
|
||||||
AlertDialogBody,
|
|
||||||
AlertDialogFooter,
|
|
||||||
AlertDialogHeader,
|
|
||||||
AlertDialogContent,
|
|
||||||
AlertDialogOverlay,
|
|
||||||
useDisclosure,
|
|
||||||
Text,
|
Text,
|
||||||
HStack,
|
|
||||||
VStack,
|
VStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
|
|
||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import { useState, useEffect, useRef } from "react";
|
import { useState, useEffect } from "react";
|
||||||
import { BsGearFill, BsTrash } from "react-icons/bs";
|
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import OutputsTable from "~/components/OutputsTable";
|
import OutputsTable from "~/components/OutputsTable";
|
||||||
import SettingsDrawer from "~/components/OutputsTable/SettingsDrawer";
|
import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
|
||||||
import AppShell from "~/components/nav/AppShell";
|
import AppShell from "~/components/nav/AppShell";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
import { useSyncVariantEditor } from "~/state/sync";
|
import { useSyncVariantEditor } from "~/state/sync";
|
||||||
|
import { HeaderButtons } from "~/components/experiments/HeaderButtons/HeaderButtons";
|
||||||
const DeleteButton = () => {
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.experiments.delete.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
const router = useRouter();
|
|
||||||
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
|
||||||
|
|
||||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data?.id) return;
|
|
||||||
await mutation.mutateAsync({ id: experiment.data.id });
|
|
||||||
await utils.experiments.list.invalidate();
|
|
||||||
await router.push({ pathname: "/experiments" });
|
|
||||||
onClose();
|
|
||||||
}, [mutation, experiment.data?.id, router]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
variant={{ base: "outline", lg: "ghost" }}
|
|
||||||
colorScheme="gray"
|
|
||||||
fontWeight="normal"
|
|
||||||
onClick={onOpen}
|
|
||||||
>
|
|
||||||
<Icon as={BsTrash} boxSize={4} color="gray.600" />
|
|
||||||
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
|
||||||
Delete Experiment
|
|
||||||
</Text>
|
|
||||||
</Button>
|
|
||||||
|
|
||||||
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
|
||||||
<AlertDialogOverlay>
|
|
||||||
<AlertDialogContent>
|
|
||||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
|
||||||
Delete Experiment
|
|
||||||
</AlertDialogHeader>
|
|
||||||
|
|
||||||
<AlertDialogBody>
|
|
||||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
|
||||||
as well. Are you sure?
|
|
||||||
</AlertDialogBody>
|
|
||||||
|
|
||||||
<AlertDialogFooter>
|
|
||||||
<Button ref={cancelRef} onClick={onClose}>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
|
||||||
Delete
|
|
||||||
</Button>
|
|
||||||
</AlertDialogFooter>
|
|
||||||
</AlertDialogContent>
|
|
||||||
</AlertDialogOverlay>
|
|
||||||
</AlertDialog>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function Experiment() {
|
export default function Experiment() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
|
||||||
useSyncVariantEditor();
|
useSyncVariantEditor();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
||||||
|
});
|
||||||
|
|
||||||
const [label, setLabel] = useState(experiment.data?.label || "");
|
const [label, setLabel] = useState(experiment.data?.label || "");
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setLabel(experiment.data?.label || "");
|
setLabel(experiment.data?.label || "");
|
||||||
@@ -138,7 +69,7 @@ export default function Experiment() {
|
|||||||
py={2}
|
py={2}
|
||||||
w="full"
|
w="full"
|
||||||
direction={{ base: "column", sm: "row" }}
|
direction={{ base: "column", sm: "row" }}
|
||||||
alignItems="flex-start"
|
alignItems={{ base: "flex-start", sm: "center" }}
|
||||||
>
|
>
|
||||||
<Breadcrumb flex={1}>
|
<Breadcrumb flex={1}>
|
||||||
<BreadcrumbItem>
|
<BreadcrumbItem>
|
||||||
@@ -171,25 +102,9 @@ export default function Experiment() {
|
|||||||
)}
|
)}
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
{canModify && (
|
<HeaderButtons />
|
||||||
<HStack>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
variant={{ base: "outline", lg: "ghost" }}
|
|
||||||
colorScheme="gray"
|
|
||||||
fontWeight="normal"
|
|
||||||
onClick={openDrawer}
|
|
||||||
>
|
|
||||||
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
|
|
||||||
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
|
||||||
Edit Vars & Evals
|
|
||||||
</Text>
|
|
||||||
</Button>
|
|
||||||
<DeleteButton />
|
|
||||||
</HStack>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
<SettingsDrawer />
|
<ExperimentSettingsDrawer />
|
||||||
<Box w="100%" overflowX="auto" flex={1}>
|
<Box w="100%" overflowX="auto" flex={1}>
|
||||||
<OutputsTable experimentId={router.query.id as string | undefined} />
|
<OutputsTable experimentId={router.query.id as string | undefined} />
|
||||||
</Box>
|
</Box>
|
||||||
|
|||||||
@@ -13,7 +13,11 @@ import {
|
|||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiFlaskLine } from "react-icons/ri";
|
||||||
import AppShell from "~/components/nav/AppShell";
|
import AppShell from "~/components/nav/AppShell";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
|
import {
|
||||||
|
ExperimentCard,
|
||||||
|
ExperimentCardSkeleton,
|
||||||
|
NewExperimentCard,
|
||||||
|
} from "~/components/experiments/ExperimentCard";
|
||||||
import { signIn, useSession } from "next-auth/react";
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
|
||||||
export default function ExperimentsPage() {
|
export default function ExperimentsPage() {
|
||||||
@@ -47,7 +51,7 @@ export default function ExperimentsPage() {
|
|||||||
return (
|
return (
|
||||||
<AppShell title="Experiments">
|
<AppShell title="Experiments">
|
||||||
<VStack alignItems={"flex-start"} px={4} py={2}>
|
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||||
<HStack minH={8} align="center">
|
<HStack minH={8} align="center" pt={2}>
|
||||||
<Breadcrumb flex={1}>
|
<Breadcrumb flex={1}>
|
||||||
<BreadcrumbItem>
|
<BreadcrumbItem>
|
||||||
<Flex alignItems="center">
|
<Flex alignItems="center">
|
||||||
@@ -58,7 +62,15 @@ export default function ExperimentsPage() {
|
|||||||
</HStack>
|
</HStack>
|
||||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||||
<NewExperimentCard />
|
<NewExperimentCard />
|
||||||
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
{experiments.data && !experiments.isLoading ? (
|
||||||
|
experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<ExperimentCardSkeleton />
|
||||||
|
<ExperimentCardSkeleton />
|
||||||
|
<ExperimentCardSkeleton />
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</VStack>
|
</VStack>
|
||||||
</AppShell>
|
</AppShell>
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
import { v4 as uuidv4 } from "uuid";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
@@ -20,7 +22,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
const experiments = await prisma.experiment.findMany({
|
const experiments = await prisma.experiment.findMany({
|
||||||
where: {
|
where: {
|
||||||
organization: {
|
organization: {
|
||||||
OrganizationUser: {
|
organizationUsers: {
|
||||||
some: { userId: ctx.session.user.id },
|
some: { userId: ctx.session.user.id },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -77,6 +79,189 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewExperiment(input.id, ctx);
|
||||||
|
|
||||||
|
const [
|
||||||
|
existingExp,
|
||||||
|
existingVariants,
|
||||||
|
existingScenarios,
|
||||||
|
existingCells,
|
||||||
|
evaluations,
|
||||||
|
templateVariables,
|
||||||
|
] = await prisma.$transaction([
|
||||||
|
prisma.experiment.findUniqueOrThrow({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.scenarioVariantCell.findMany({
|
||||||
|
where: {
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
promptVariant: {
|
||||||
|
experimentId: input.id,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: {
|
||||||
|
include: {
|
||||||
|
outputEvaluations: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.evaluation.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.templateVariable.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const newExperimentId = uuidv4();
|
||||||
|
|
||||||
|
const existingToNewVariantIds = new Map<string, string>();
|
||||||
|
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
|
||||||
|
for (const variant of existingVariants) {
|
||||||
|
const newVariantId = uuidv4();
|
||||||
|
existingToNewVariantIds.set(variant.id, newVariantId);
|
||||||
|
variantsToCreate.push({
|
||||||
|
...variant,
|
||||||
|
id: newVariantId,
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const existingToNewScenarioIds = new Map<string, string>();
|
||||||
|
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
|
||||||
|
for (const scenario of existingScenarios) {
|
||||||
|
const newScenarioId = uuidv4();
|
||||||
|
existingToNewScenarioIds.set(scenario.id, newScenarioId);
|
||||||
|
scenariosToCreate.push({
|
||||||
|
...scenario,
|
||||||
|
id: newScenarioId,
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
variableValues: scenario.variableValues as Prisma.InputJsonValue,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const existingToNewEvaluationIds = new Map<string, string>();
|
||||||
|
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
|
||||||
|
for (const evaluation of evaluations) {
|
||||||
|
const newEvaluationId = uuidv4();
|
||||||
|
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
|
||||||
|
evaluationsToCreate.push({
|
||||||
|
...evaluation,
|
||||||
|
id: newEvaluationId,
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
|
||||||
|
const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = [];
|
||||||
|
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
|
||||||
|
for (const cell of existingCells) {
|
||||||
|
const newCellId = uuidv4();
|
||||||
|
const { modelOutput, ...cellData } = cell;
|
||||||
|
cellsToCreate.push({
|
||||||
|
...cellData,
|
||||||
|
id: newCellId,
|
||||||
|
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
|
||||||
|
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
|
||||||
|
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
|
||||||
|
});
|
||||||
|
if (modelOutput) {
|
||||||
|
const newModelOutputId = uuidv4();
|
||||||
|
const { outputEvaluations, ...modelOutputData } = modelOutput;
|
||||||
|
modelOutputsToCreate.push({
|
||||||
|
...modelOutputData,
|
||||||
|
id: newModelOutputId,
|
||||||
|
scenarioVariantCellId: newCellId,
|
||||||
|
output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined,
|
||||||
|
});
|
||||||
|
for (const evaluation of outputEvaluations) {
|
||||||
|
outputEvaluationsToCreate.push({
|
||||||
|
...evaluation,
|
||||||
|
id: uuidv4(),
|
||||||
|
modelOutputId: newModelOutputId,
|
||||||
|
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
|
||||||
|
for (const templateVariable of templateVariables) {
|
||||||
|
templateVariablesToCreate.push({
|
||||||
|
...templateVariable,
|
||||||
|
id: uuidv4(),
|
||||||
|
experimentId: newExperimentId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxSortIndex =
|
||||||
|
(
|
||||||
|
await prisma.experiment.aggregate({
|
||||||
|
_max: {
|
||||||
|
sortIndex: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)._max?.sortIndex ?? 0;
|
||||||
|
|
||||||
|
await prisma.$transaction([
|
||||||
|
prisma.experiment.create({
|
||||||
|
data: {
|
||||||
|
id: newExperimentId,
|
||||||
|
sortIndex: maxSortIndex + 1,
|
||||||
|
label: `${existingExp.label} (forked)`,
|
||||||
|
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.promptVariant.createMany({
|
||||||
|
data: variantsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.testScenario.createMany({
|
||||||
|
data: scenariosToCreate,
|
||||||
|
}),
|
||||||
|
prisma.scenarioVariantCell.createMany({
|
||||||
|
data: cellsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.modelOutput.createMany({
|
||||||
|
data: modelOutputsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.evaluation.createMany({
|
||||||
|
data: evaluationsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.outputEvaluation.createMany({
|
||||||
|
data: outputEvaluationsToCreate,
|
||||||
|
}),
|
||||||
|
prisma.templateVariable.createMany({
|
||||||
|
data: templateVariablesToCreate,
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
return newExperimentId;
|
||||||
|
}),
|
||||||
|
|
||||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||||
// Anyone can create an experiment
|
// Anyone can create an experiment
|
||||||
requireNothing(ctx);
|
requireNothing(ctx);
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import { type PromptVariant } from "@prisma/client";
|
|||||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
import parseConstructFn from "~/server/utils/parseConstructFn";
|
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||||
import { ZodModel } from "~/modelProviders/types";
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { ZodSupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
@@ -144,7 +145,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
variantId: z.string().optional(),
|
variantId: z.string().optional(),
|
||||||
newModel: ZodModel.optional(),
|
streamScenarios: z.array(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -186,7 +187,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
? `${originalVariant?.label} Copy`
|
? `${originalVariant?.label} Copy`
|
||||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
|
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -218,7 +219,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const scenario of scenarios) {
|
for (const scenario of scenarios) {
|
||||||
await generateNewCell(newVariant.id, scenario.id);
|
await generateNewCell(newVariant.id, scenario.id, {
|
||||||
|
stream: input.streamScenarios.includes(scenario.id),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return newVariant;
|
return newVariant;
|
||||||
@@ -286,7 +289,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
instructions: z.string().optional(),
|
instructions: z.string().optional(),
|
||||||
newModel: ZodModel.optional(),
|
newModel: z
|
||||||
|
.object({
|
||||||
|
provider: ZodSupportedProvider,
|
||||||
|
model: z.string(),
|
||||||
|
})
|
||||||
|
.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -303,11 +311,11 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return userError(constructedPrompt.error);
|
return userError(constructedPrompt.error);
|
||||||
}
|
}
|
||||||
|
|
||||||
const promptConstructionFn = await deriveNewConstructFn(
|
const model = input.newModel
|
||||||
existing,
|
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||||
input.newModel,
|
: undefined;
|
||||||
input.instructions,
|
|
||||||
);
|
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||||
|
|
||||||
// TODO: Validate promptConstructionFn
|
// TODO: Validate promptConstructionFn
|
||||||
// TODO: Record in some sort of history
|
// TODO: Record in some sort of history
|
||||||
@@ -320,6 +328,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
constructFn: z.string(),
|
constructFn: z.string(),
|
||||||
|
streamScenarios: z.array(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -377,7 +386,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const scenario of scenarios) {
|
for (const scenario of scenarios) {
|
||||||
await generateNewCell(newVariant.id, scenario.id);
|
await generateNewCell(newVariant.id, scenario.id, {
|
||||||
|
stream: input.streamScenarios.includes(scenario.id),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return { status: "ok" } as const;
|
return { status: "ok" } as const;
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { queueQueryModel } from "~/server/tasks/queryModel.task";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||||
@@ -29,7 +29,7 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
include: {
|
include: {
|
||||||
modelOutput: {
|
modelOutput: {
|
||||||
include: {
|
include: {
|
||||||
outputEvaluation: {
|
outputEvaluations: {
|
||||||
include: {
|
include: {
|
||||||
evaluation: {
|
evaluation: {
|
||||||
select: { label: true },
|
select: { label: true },
|
||||||
@@ -62,14 +62,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
testScenarioId: input.scenarioId,
|
testScenarioId: input.scenarioId,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
include: {
|
include: { modelOutput: true },
|
||||||
modelOutput: true,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!cell) {
|
if (!cell) {
|
||||||
await generateNewCell(input.variantId, input.scenarioId);
|
await generateNewCell(input.variantId, input.scenarioId, { stream: true });
|
||||||
return true;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell.modelOutput) {
|
if (cell.modelOutput) {
|
||||||
@@ -79,12 +77,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await queueQueryModel(cell.id, true);
|
||||||
where: { id: cell.id },
|
|
||||||
data: { retrievalStatus: "PENDING" },
|
|
||||||
});
|
|
||||||
|
|
||||||
await queueLLMRetrievalTask(cell.id);
|
|
||||||
return true;
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -7,21 +7,39 @@ import { runAllEvals } from "~/server/utils/evaluations";
|
|||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
|
const PAGE_SIZE = 10;
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
.input(z.object({ experimentId: z.string() }))
|
.input(z.object({ experimentId: z.string(), page: z.number() }))
|
||||||
.query(async ({ input, ctx }) => {
|
.query(async ({ input, ctx }) => {
|
||||||
await requireCanViewExperiment(input.experimentId, ctx);
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.testScenario.findMany({
|
const { experimentId, page } = input;
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId,
|
||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
orderBy: {
|
orderBy: { sortIndex: "asc" },
|
||||||
sortIndex: "asc",
|
skip: (page - 1) * PAGE_SIZE,
|
||||||
|
take: PAGE_SIZE,
|
||||||
|
});
|
||||||
|
|
||||||
|
const count = await prisma.testScenario.count({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
visible: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
scenarios,
|
||||||
|
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||||
|
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||||
|
count,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: protectedProcedure
|
create: protectedProcedure
|
||||||
@@ -34,22 +52,21 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
const maxSortIndex =
|
await prisma.testScenario.updateMany({
|
||||||
(
|
where: {
|
||||||
await prisma.testScenario.aggregate({
|
experimentId: input.experimentId,
|
||||||
where: {
|
},
|
||||||
experimentId: input.experimentId,
|
data: {
|
||||||
},
|
sortIndex: {
|
||||||
_max: {
|
increment: 1,
|
||||||
sortIndex: true,
|
},
|
||||||
},
|
},
|
||||||
})
|
});
|
||||||
)._max.sortIndex ?? 0;
|
|
||||||
|
|
||||||
const createNewScenarioAction = prisma.testScenario.create({
|
const createNewScenarioAction = prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
sortIndex: maxSortIndex + 1,
|
sortIndex: 0,
|
||||||
variableValues: input.autogenerate
|
variableValues: input.autogenerate
|
||||||
? await autogenerateScenarioValues(input.experimentId)
|
? await autogenerateScenarioValues(input.experimentId)
|
||||||
: {},
|
: {},
|
||||||
@@ -69,7 +86,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const variant of promptVariants) {
|
for (const variant of promptVariants) {
|
||||||
await generateNewCell(variant.id, scenario.id);
|
await generateNewCell(variant.id, scenario.id, { stream: true });
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -213,7 +230,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const variant of promptVariants) {
|
for (const variant of promptVariants) {
|
||||||
await generateNewCell(variant.id, newScenario.id);
|
await generateNewCell(variant.id, newScenario.id, { stream: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
return newScenario;
|
return newScenario;
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
import { prisma } from "~/server/db";
|
|
||||||
import defineTask from "./defineTask";
|
|
||||||
import { sleep } from "../utils/sleep";
|
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
|
||||||
import { runEvalsForOutput } from "../utils/evaluations";
|
|
||||||
import { type Prisma } from "@prisma/client";
|
import { type Prisma } from "@prisma/client";
|
||||||
import parseConstructFn from "../utils/parseConstructFn";
|
|
||||||
import hashPrompt from "../utils/hashPrompt";
|
|
||||||
import { type JsonObject } from "type-fest";
|
import { type JsonObject } from "type-fest";
|
||||||
import modelProviders from "~/modelProviders/modelProviders";
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
|
import hashPrompt from "../utils/hashPrompt";
|
||||||
|
import parseConstructFn from "../utils/parseConstructFn";
|
||||||
|
import { sleep } from "../utils/sleep";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
|
||||||
export type queryLLMJob = {
|
export type QueryModelJob = {
|
||||||
scenarioVariantCellId: string;
|
cellId: string;
|
||||||
|
stream: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 10;
|
const MAX_AUTO_RETRIES = 10;
|
||||||
@@ -24,15 +24,16 @@ function calculateDelay(numPreviousTries: number): number {
|
|||||||
return baseDelay + jitter;
|
return baseDelay + jitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
||||||
const { scenarioVariantCellId } = task;
|
console.log("RUNNING TASK", task);
|
||||||
|
const { cellId, stream } = task;
|
||||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
include: { modelOutput: true },
|
include: { modelOutput: true },
|
||||||
});
|
});
|
||||||
if (!cell) {
|
if (!cell) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
errorMessage: "Cell not found",
|
errorMessage: "Cell not found",
|
||||||
@@ -47,7 +48,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
retrievalStatus: "IN_PROGRESS",
|
retrievalStatus: "IN_PROGRESS",
|
||||||
},
|
},
|
||||||
@@ -58,7 +59,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
});
|
});
|
||||||
if (!variant) {
|
if (!variant) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
errorMessage: "Prompt Variant not found",
|
errorMessage: "Prompt Variant not found",
|
||||||
@@ -73,7 +74,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
});
|
});
|
||||||
if (!scenario) {
|
if (!scenario) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
errorMessage: "Scenario not found",
|
errorMessage: "Scenario not found",
|
||||||
@@ -87,7 +88,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
if ("error" in prompt) {
|
if ("error" in prompt) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
errorMessage: prompt.error,
|
errorMessage: prompt.error,
|
||||||
@@ -99,18 +100,9 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const provider = modelProviders[prompt.modelProvider];
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
const onStream = stream
|
||||||
|
|
||||||
if (streamingChannel) {
|
|
||||||
// Save streaming channel so that UI can connect to it
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: { streamingChannel },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
const onStream = streamingChannel
|
|
||||||
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
||||||
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
|
wsConnection.emit("message", { channel: cell.id, payload: partialOutput });
|
||||||
}
|
}
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
@@ -121,7 +113,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.create({
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
data: {
|
data: {
|
||||||
scenarioVariantCellId,
|
scenarioVariantCellId: cellId,
|
||||||
inputHash,
|
inputHash,
|
||||||
output: response.value as Prisma.InputJsonObject,
|
output: response.value as Prisma.InputJsonObject,
|
||||||
timeToComplete: response.timeToComplete,
|
timeToComplete: response.timeToComplete,
|
||||||
@@ -132,7 +124,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: response.statusCode,
|
statusCode: response.statusCode,
|
||||||
retrievalStatus: "COMPLETE",
|
retrievalStatus: "COMPLETE",
|
||||||
@@ -146,7 +138,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
const delay = calculateDelay(i);
|
const delay = calculateDelay(i);
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
errorMessage: response.message,
|
errorMessage: response.message,
|
||||||
statusCode: response.statusCode,
|
statusCode: response.statusCode,
|
||||||
@@ -163,3 +155,21 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||||
|
console.log("queueQueryModel", cellId, stream);
|
||||||
|
await Promise.all([
|
||||||
|
prisma.scenarioVariantCell.update({
|
||||||
|
where: {
|
||||||
|
id: cellId,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "PENDING",
|
||||||
|
errorMessage: null,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
|
||||||
|
await queryModel.enqueue({ cellId, stream }),
|
||||||
|
console.log("queued"),
|
||||||
|
]);
|
||||||
|
};
|
||||||
@@ -2,39 +2,27 @@ import { type TaskList, run } from "graphile-worker";
|
|||||||
import "dotenv/config";
|
import "dotenv/config";
|
||||||
|
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { queryLLM } from "./queryLLM.task";
|
import { queryModel } from "./queryModel.task";
|
||||||
|
|
||||||
const registeredTasks = [queryLLM];
|
console.log("Starting worker");
|
||||||
|
|
||||||
|
const registeredTasks = [queryModel];
|
||||||
|
|
||||||
const taskList = registeredTasks.reduce((acc, task) => {
|
const taskList = registeredTasks.reduce((acc, task) => {
|
||||||
acc[task.task.identifier] = task.task.handler;
|
acc[task.task.identifier] = task.task.handler;
|
||||||
return acc;
|
return acc;
|
||||||
}, {} as TaskList);
|
}, {} as TaskList);
|
||||||
|
|
||||||
async function main() {
|
// Run a worker to execute jobs:
|
||||||
// Run a worker to execute jobs:
|
const runner = await run({
|
||||||
const runner = await run({
|
connectionString: env.DATABASE_URL,
|
||||||
connectionString: env.DATABASE_URL,
|
concurrency: 20,
|
||||||
concurrency: 20,
|
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
noHandleSignals: false,
|
||||||
noHandleSignals: false,
|
pollInterval: 1000,
|
||||||
pollInterval: 1000,
|
taskList,
|
||||||
// you can set the taskList or taskDirectory but not both
|
|
||||||
taskList,
|
|
||||||
// or:
|
|
||||||
// taskDirectory: `${__dirname}/tasks`,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
|
||||||
// "unhandled rejection" errors causing a process crash in the event of
|
|
||||||
// something going wrong.
|
|
||||||
await runner.promise;
|
|
||||||
|
|
||||||
// If the worker exits (whether through fatal error or otherwise), the above
|
|
||||||
// promise will resolve/reject.
|
|
||||||
}
|
|
||||||
|
|
||||||
main().catch((err) => {
|
|
||||||
console.error("Unhandled error occurred running worker: ", err);
|
|
||||||
process.exit(1);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
console.log("Worker successfully started");
|
||||||
|
|
||||||
|
await runner.promise;
|
||||||
|
|||||||
@@ -74,6 +74,11 @@ const requestUpdatedPromptFunction = async (
|
|||||||
2,
|
2,
|
||||||
)}`,
|
)}`,
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
messages.push({
|
||||||
|
role: "user",
|
||||||
|
content: `The provider is the same as the old provider: ${originalModel.provider}`,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (instructions) {
|
if (instructions) {
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ export const runAllEvals = async (experimentId: string) => {
|
|||||||
testScenario: true,
|
testScenario: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
outputEvaluation: true,
|
outputEvaluations: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
const evals = await prisma.evaluation.findMany({
|
const evals = await prisma.evaluation.findMany({
|
||||||
@@ -66,7 +66,7 @@ export const runAllEvals = async (experimentId: string) => {
|
|||||||
await Promise.all(
|
await Promise.all(
|
||||||
outputs.map(async (output) => {
|
outputs.map(async (output) => {
|
||||||
const unrunEvals = evals.filter(
|
const unrunEvals = evals.filter(
|
||||||
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
|
||||||
);
|
);
|
||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
import { type Prisma } from "@prisma/client";
|
import { type Prisma } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
|
||||||
import parseConstructFn from "./parseConstructFn";
|
import parseConstructFn from "./parseConstructFn";
|
||||||
import { type JsonObject } from "type-fest";
|
import { type JsonObject } from "type-fest";
|
||||||
import hashPrompt from "./hashPrompt";
|
import hashPrompt from "./hashPrompt";
|
||||||
import { omit } from "lodash-es";
|
import { omit } from "lodash-es";
|
||||||
|
import { queueQueryModel } from "../tasks/queryModel.task";
|
||||||
|
|
||||||
|
export const generateNewCell = async (
|
||||||
|
variantId: string,
|
||||||
|
scenarioId: string,
|
||||||
|
options?: { stream?: boolean },
|
||||||
|
): Promise<void> => {
|
||||||
|
const stream = options?.stream ?? false;
|
||||||
|
|
||||||
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: variantId,
|
id: variantId,
|
||||||
@@ -98,6 +104,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
cell = await queueLLMRetrievalTask(cell.id);
|
await queueQueryModel(cell.id, stream);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
import { prisma } from "../db";
|
|
||||||
import { queryLLM } from "../tasks/queryLLM.task";
|
|
||||||
|
|
||||||
export const queueLLMRetrievalTask = async (cellId: string) => {
|
|
||||||
const updatedCell = await prisma.scenarioVariantCell.update({
|
|
||||||
where: {
|
|
||||||
id: cellId,
|
|
||||||
},
|
|
||||||
data: {
|
|
||||||
retrievalStatus: "PENDING",
|
|
||||||
errorMessage: null,
|
|
||||||
},
|
|
||||||
include: {
|
|
||||||
modelOutput: true,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// @ts-expect-error we aren't passing the helpers but that's ok
|
|
||||||
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
|
||||||
|
|
||||||
return updatedCell;
|
|
||||||
};
|
|
||||||
@@ -8,7 +8,7 @@ export default async function userOrg(userId: string) {
|
|||||||
update: {},
|
update: {},
|
||||||
create: {
|
create: {
|
||||||
personalOrgUserId: userId,
|
personalOrgUserId: userId,
|
||||||
OrganizationUser: {
|
organizationUsers: {
|
||||||
create: {
|
create: {
|
||||||
userId: userId,
|
userId: userId,
|
||||||
role: "ADMIN",
|
role: "ADMIN",
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ export const editorBackground = "#fafafa";
|
|||||||
export type SharedVariantEditorSlice = {
|
export type SharedVariantEditorSlice = {
|
||||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||||
loadMonaco: () => Promise<void>;
|
loadMonaco: () => Promise<void>;
|
||||||
scenarios: RouterOutputs["scenarios"]["list"];
|
scenarios: RouterOutputs["scenarios"]["list"]["scenarios"];
|
||||||
updateScenariosModel: () => void;
|
updateScenariosModel: () => void;
|
||||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]["scenarios"]) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||||
|
|||||||
@@ -1,17 +1,14 @@
|
|||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { useScenarios } from "~/utils/hooks";
|
||||||
import { useAppStore } from "./store";
|
import { useAppStore } from "./store";
|
||||||
|
|
||||||
export function useSyncVariantEditor() {
|
export function useSyncVariantEditor() {
|
||||||
const experiment = useExperiment();
|
const scenarios = useScenarios();
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
|
||||||
{ experimentId: experiment.data?.id ?? "" },
|
|
||||||
{ enabled: !!experiment.data?.id },
|
|
||||||
);
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (scenarios.data) {
|
if (scenarios.data) {
|
||||||
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data);
|
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data.scenarios);
|
||||||
}
|
}
|
||||||
}, [scenarios.data]);
|
}, [scenarios.data]);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
|
|||||||
where: {
|
where: {
|
||||||
id: experimentId,
|
id: experimentId,
|
||||||
organization: {
|
organization: {
|
||||||
OrganizationUser: {
|
organizationUsers: {
|
||||||
some: {
|
some: {
|
||||||
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
||||||
userId,
|
userId,
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
// generate random channel id
|
|
||||||
|
|
||||||
export const generateChannel = () => {
|
|
||||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
|
||||||
};
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
|
import { NumberParam, useQueryParam, withDefault } from "use-query-params";
|
||||||
|
|
||||||
export const useExperiment = () => {
|
export const useExperiment = () => {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -93,3 +94,17 @@ export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | un
|
|||||||
|
|
||||||
return [ref, dimensions];
|
return [ref, dimensions];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const usePage = () => useQueryParam("page", withDefault(NumberParam, 1));
|
||||||
|
|
||||||
|
export const useScenarios = () => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const [page] = usePage();
|
||||||
|
|
||||||
|
return api.scenarios.list.useQuery(
|
||||||
|
{ experimentId: experiment.data?.id ?? "", page },
|
||||||
|
{ enabled: experiment.data?.id != null },
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? [];
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import { type Model } from "~/modelProviders/types";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
||||||
|
|
||||||
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;
|
export const lookupModel = (provider: string, model: string) => {
|
||||||
|
const modelObj = frontendModelProviders[provider as ProviderModel["provider"]]?.models[model];
|
||||||
|
return modelObj ? { ...modelObj, provider } : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelLabel = (provider: string, model: string) =>
|
||||||
|
`${provider}/${lookupModel(provider, model)?.name ?? model}`;
|
||||||
|
|||||||
Reference in New Issue
Block a user