Compare commits
42 Commits
model-prov
...
scenario-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bffb03766 | ||
|
|
fa61c9c472 | ||
|
|
1309a6ec5d | ||
|
|
17a6fd31a5 | ||
|
|
e1cbeccb90 | ||
|
|
d6b97b29f7 | ||
|
|
09140f8b5f | ||
|
|
9952dd93d8 | ||
|
|
e0b457c6c5 | ||
|
|
0c37506975 | ||
|
|
2b2e0ab8ee | ||
|
|
3dbb06ec00 | ||
|
|
85d42a014b | ||
|
|
7d1ded3b18 | ||
|
|
b00f6dd04b | ||
|
|
2e395e4d39 | ||
|
|
4b06d05908 | ||
|
|
aabf355b81 | ||
|
|
61e5f0775d | ||
|
|
cc1d1178da | ||
|
|
7466db63df | ||
|
|
79a0b03bf8 | ||
|
|
6fb7a82d72 | ||
|
|
4ea30a3ba3 | ||
|
|
52d1d5c7ee | ||
|
|
46036a44d2 | ||
|
|
3753fe5c16 | ||
|
|
213a00a8e6 | ||
|
|
af9943eefc | ||
|
|
741128e0f4 | ||
|
|
aff14539d8 | ||
|
|
1af81a50a9 | ||
|
|
7e1fbb3767 | ||
|
|
a5d972005e | ||
|
|
a180b5bef2 | ||
|
|
55c697223e | ||
|
|
9978075867 | ||
|
|
372c2512c9 | ||
|
|
1822fe198e | ||
|
|
f06e1db3db | ||
|
|
9314a86857 | ||
|
|
54dcb4a567 |
@@ -17,6 +17,9 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
|
||||
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
||||
OPENAI_API_KEY=""
|
||||
|
||||
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
|
||||
REPLICATE_API_TOKEN=""
|
||||
|
||||
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||
|
||||
# Next Auth
|
||||
|
||||
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@@ -1,6 +1,3 @@
|
||||
{
|
||||
"eslint.format.enable": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.fixAll.eslint": true
|
||||
}
|
||||
"eslint.format.enable": true
|
||||
}
|
||||
|
||||
@@ -43,7 +43,8 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
|
||||
|
||||
## Supported Models
|
||||
|
||||
OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
|
||||
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
|
||||
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
|
||||
|
||||
## Running Locally
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
"dev:next": "next dev",
|
||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss' 'pnpm dev:worker'",
|
||||
"postinstall": "prisma generate",
|
||||
"lint": "next lint",
|
||||
"start": "next start",
|
||||
@@ -59,6 +59,7 @@
|
||||
"lodash-es": "^4.17.21",
|
||||
"next": "^13.4.2",
|
||||
"next-auth": "^4.22.1",
|
||||
"next-query-params": "^4.2.3",
|
||||
"nextjs-routes": "^2.0.1",
|
||||
"openai": "4.0.0-beta.2",
|
||||
"pluralize": "^8.0.0",
|
||||
@@ -79,6 +80,8 @@
|
||||
"superjson": "1.12.2",
|
||||
"tsx": "^3.12.7",
|
||||
"type-fest": "^4.0.0",
|
||||
"use-query-params": "^2.2.1",
|
||||
"uuid": "^9.0.0",
|
||||
"vite-tsconfig-paths": "^4.2.0",
|
||||
"zod": "^3.21.4",
|
||||
"zustand": "^4.3.9"
|
||||
@@ -99,6 +102,7 @@
|
||||
"@types/react": "^18.2.6",
|
||||
"@types/react-dom": "^18.2.4",
|
||||
"@types/react-syntax-highlighter": "^15.5.7",
|
||||
"@types/uuid": "^9.0.2",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.6",
|
||||
"@typescript-eslint/parser": "^5.59.6",
|
||||
"eslint": "^8.40.0",
|
||||
|
||||
58
pnpm-lock.yaml
generated
58
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
||||
lockfileVersion: '6.1'
|
||||
lockfileVersion: '6.0'
|
||||
|
||||
settings:
|
||||
autoInstallPeers: true
|
||||
@@ -119,6 +119,9 @@ dependencies:
|
||||
next-auth:
|
||||
specifier: ^4.22.1
|
||||
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
||||
next-query-params:
|
||||
specifier: ^4.2.3
|
||||
version: 4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1)
|
||||
nextjs-routes:
|
||||
specifier: ^2.0.1
|
||||
version: 2.0.1(next@13.4.2)
|
||||
@@ -179,6 +182,12 @@ dependencies:
|
||||
type-fest:
|
||||
specifier: ^4.0.0
|
||||
version: 4.0.0
|
||||
use-query-params:
|
||||
specifier: ^2.2.1
|
||||
version: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||
uuid:
|
||||
specifier: ^9.0.0
|
||||
version: 9.0.0
|
||||
vite-tsconfig-paths:
|
||||
specifier: ^4.2.0
|
||||
version: 4.2.0(typescript@5.0.4)
|
||||
@@ -235,6 +244,9 @@ devDependencies:
|
||||
'@types/react-syntax-highlighter':
|
||||
specifier: ^15.5.7
|
||||
version: 15.5.7
|
||||
'@types/uuid':
|
||||
specifier: ^9.0.2
|
||||
version: 9.0.2
|
||||
'@typescript-eslint/eslint-plugin':
|
||||
specifier: ^5.59.6
|
||||
version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4)
|
||||
@@ -3018,6 +3030,10 @@ packages:
|
||||
resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==}
|
||||
dev: false
|
||||
|
||||
/@types/uuid@9.0.2:
|
||||
resolution: {integrity: sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ==}
|
||||
dev: true
|
||||
|
||||
/@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4):
|
||||
resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==}
|
||||
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
||||
@@ -6037,6 +6053,19 @@ packages:
|
||||
uuid: 8.3.2
|
||||
dev: false
|
||||
|
||||
/next-query-params@4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1):
|
||||
resolution: {integrity: sha512-hGNCYRH8YyA5ItiBGSKrtMl21b2MAqfPkdI1mvwloNVqSU142IaGzqHN+OTovyeLIpQfonY01y7BAHb/UH4POg==}
|
||||
peerDependencies:
|
||||
next: ^10.0.0 || ^11.0.0 || ^12.0.0 || ^13.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
use-query-params: ^2.0.0
|
||||
dependencies:
|
||||
next: 13.4.2(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0)
|
||||
react: 18.2.0
|
||||
tslib: 2.6.0
|
||||
use-query-params: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||
dev: false
|
||||
|
||||
/next-tick@1.1.0:
|
||||
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
|
||||
dev: false
|
||||
@@ -7147,6 +7176,10 @@ packages:
|
||||
randombytes: 2.1.0
|
||||
dev: true
|
||||
|
||||
/serialize-query-params@2.0.2:
|
||||
resolution: {integrity: sha512-1chMo1dST4pFA9RDXAtF0Rbjaut4is7bzFbI1Z26IuMub68pNCILku85aYmeFhvnY//BXUPUhoRMjYcsT93J/Q==}
|
||||
dev: false
|
||||
|
||||
/serve-static@1.15.0:
|
||||
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
|
||||
engines: {node: '>= 0.8.0'}
|
||||
@@ -7824,6 +7857,24 @@ packages:
|
||||
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
||||
dev: false
|
||||
|
||||
/use-query-params@2.2.1(react-dom@18.2.0)(react@18.2.0):
|
||||
resolution: {integrity: sha512-i6alcyLB8w9i3ZK3caNftdb+UnbfBRNPDnc89CNQWkGRmDrm/gfydHvMBfVsQJRq3NoHOM2dt/ceBWG2397v1Q==}
|
||||
peerDependencies:
|
||||
'@reach/router': ^1.2.1
|
||||
react: '>=16.8.0'
|
||||
react-dom: '>=16.8.0'
|
||||
react-router-dom: '>=5'
|
||||
peerDependenciesMeta:
|
||||
'@reach/router':
|
||||
optional: true
|
||||
react-router-dom:
|
||||
optional: true
|
||||
dependencies:
|
||||
react: 18.2.0
|
||||
react-dom: 18.2.0(react@18.2.0)
|
||||
serialize-query-params: 2.0.2
|
||||
dev: false
|
||||
|
||||
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
|
||||
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
|
||||
engines: {node: '>=10'}
|
||||
@@ -7872,6 +7923,11 @@ packages:
|
||||
hasBin: true
|
||||
dev: false
|
||||
|
||||
/uuid@9.0.0:
|
||||
resolution: {integrity: sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==}
|
||||
hasBin: true
|
||||
dev: false
|
||||
|
||||
/vary@1.1.2:
|
||||
resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==}
|
||||
engines: {node: '>= 0.8'}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `streamingChannel` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "streamingChannel";
|
||||
@@ -22,10 +22,10 @@ model Experiment {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
TemplateVariable TemplateVariable[]
|
||||
PromptVariant PromptVariant[]
|
||||
TestScenario TestScenario[]
|
||||
Evaluation Evaluation[]
|
||||
templateVariables TemplateVariable[]
|
||||
promptVariants PromptVariant[]
|
||||
testScenarios TestScenario[]
|
||||
evaluations Evaluation[]
|
||||
}
|
||||
|
||||
model PromptVariant {
|
||||
@@ -90,11 +90,10 @@ enum CellRetrievalStatus {
|
||||
model ScenarioVariantCell {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
statusCode Int?
|
||||
errorMessage String?
|
||||
retryTime DateTime?
|
||||
streamingChannel String?
|
||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||
statusCode Int?
|
||||
errorMessage String?
|
||||
retryTime DateTime?
|
||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||
|
||||
modelOutput ModelOutput?
|
||||
|
||||
@@ -126,7 +125,7 @@ model ModelOutput {
|
||||
|
||||
scenarioVariantCellId String @db.Uuid
|
||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||
outputEvaluation OutputEvaluation[]
|
||||
outputEvaluations OutputEvaluation[]
|
||||
|
||||
@@unique([scenarioVariantCellId])
|
||||
@@index([inputHash])
|
||||
@@ -150,7 +149,7 @@ model Evaluation {
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
OutputEvaluation OutputEvaluation[]
|
||||
outputEvaluations OutputEvaluation[]
|
||||
}
|
||||
|
||||
model OutputEvaluation {
|
||||
@@ -179,8 +178,8 @@ model Organization {
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
OrganizationUser OrganizationUser[]
|
||||
Experiment Experiment[]
|
||||
organizationUsers OrganizationUser[]
|
||||
experiments Experiment[]
|
||||
}
|
||||
|
||||
enum OrganizationUserRole {
|
||||
@@ -234,15 +233,15 @@ model Session {
|
||||
}
|
||||
|
||||
model User {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
name String?
|
||||
email String? @unique
|
||||
emailVerified DateTime?
|
||||
image String?
|
||||
accounts Account[]
|
||||
sessions Session[]
|
||||
OrganizationUser OrganizationUser[]
|
||||
Organization Organization[]
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
name String?
|
||||
email String? @unique
|
||||
emailVerified DateTime?
|
||||
image String?
|
||||
accounts Account[]
|
||||
sessions Session[]
|
||||
organizationUsers OrganizationUser[]
|
||||
organizations Organization[]
|
||||
}
|
||||
|
||||
model VerificationToken {
|
||||
|
||||
@@ -7,9 +7,13 @@ const defaultId = "11111111-1111-1111-1111-111111111111";
|
||||
await prisma.organization.deleteMany({
|
||||
where: { id: defaultId },
|
||||
});
|
||||
await prisma.organization.create({
|
||||
data: { id: defaultId },
|
||||
});
|
||||
|
||||
// If there's an existing org, just seed into it
|
||||
const org =
|
||||
(await prisma.organization.findFirst({})) ??
|
||||
(await prisma.organization.create({
|
||||
data: { id: defaultId },
|
||||
}));
|
||||
|
||||
await prisma.experiment.deleteMany({
|
||||
where: {
|
||||
@@ -21,7 +25,7 @@ await prisma.experiment.create({
|
||||
data: {
|
||||
id: defaultId,
|
||||
label: "Country Capitals Example",
|
||||
organizationId: defaultId,
|
||||
organizationId: org.id,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -103,30 +107,41 @@ await prisma.testScenario.deleteMany({
|
||||
},
|
||||
});
|
||||
|
||||
const countries = [
|
||||
"Afghanistan",
|
||||
"Albania",
|
||||
"Algeria",
|
||||
"Andorra",
|
||||
"Angola",
|
||||
"Antigua and Barbuda",
|
||||
"Argentina",
|
||||
"Armenia",
|
||||
"Australia",
|
||||
"Austria",
|
||||
"Austrian Empire",
|
||||
"Azerbaijan",
|
||||
"Baden",
|
||||
"Bahamas, The",
|
||||
"Bahrain",
|
||||
"Bangladesh",
|
||||
"Barbados",
|
||||
"Bavaria",
|
||||
"Belarus",
|
||||
"Belgium",
|
||||
"Belize",
|
||||
"Benin (Dahomey)",
|
||||
"Bolivia",
|
||||
"Bosnia and Herzegovina",
|
||||
"Botswana",
|
||||
];
|
||||
await prisma.testScenario.createMany({
|
||||
data: [
|
||||
{
|
||||
experimentId: defaultId,
|
||||
sortIndex: 0,
|
||||
variableValues: {
|
||||
country: "Spain",
|
||||
},
|
||||
data: countries.map((country, i) => ({
|
||||
experimentId: defaultId,
|
||||
sortIndex: i,
|
||||
variableValues: {
|
||||
country: country,
|
||||
},
|
||||
{
|
||||
experimentId: defaultId,
|
||||
sortIndex: 1,
|
||||
variableValues: {
|
||||
country: "USA",
|
||||
},
|
||||
},
|
||||
{
|
||||
experimentId: defaultId,
|
||||
sortIndex: 2,
|
||||
variableValues: {
|
||||
country: "Chile",
|
||||
},
|
||||
},
|
||||
],
|
||||
})),
|
||||
});
|
||||
|
||||
const variants = await prisma.promptVariant.findMany({
|
||||
@@ -149,5 +164,5 @@ await Promise.all(
|
||||
testScenarioId: scenario.id,
|
||||
})),
|
||||
)
|
||||
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
||||
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId, { stream: false })),
|
||||
);
|
||||
|
||||
@@ -6,4 +6,7 @@ echo "Migrating the database"
|
||||
pnpm prisma migrate deploy
|
||||
|
||||
echo "Starting the server"
|
||||
pnpm start
|
||||
|
||||
pnpm concurrently --kill-others \
|
||||
"pnpm start" \
|
||||
"pnpm tsx src/server/tasks/worker.ts"
|
||||
@@ -1,19 +1,22 @@
|
||||
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
||||
import ResizeTextarea from "react-textarea-autosize";
|
||||
import React from "react";
|
||||
import React, { useLayoutEffect, useState } from "react";
|
||||
|
||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||
HTMLTextAreaElement,
|
||||
TextareaProps & { minRows?: number }
|
||||
> = (props, ref) => {
|
||||
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
|
||||
const [isRerendered, setIsRerendered] = useState(false);
|
||||
useLayoutEffect(() => setIsRerendered(true), []);
|
||||
|
||||
return (
|
||||
<Textarea
|
||||
minH="unset"
|
||||
overflow="hidden"
|
||||
minRows={minRows}
|
||||
overflowY={isRerendered ? overflowY : "hidden"}
|
||||
w="100%"
|
||||
resize="none"
|
||||
ref={ref}
|
||||
minRows={1}
|
||||
transition="height none"
|
||||
as={ResizeTextarea}
|
||||
{...props}
|
||||
|
||||
142
src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
142
src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
@@ -0,0 +1,142 @@
|
||||
import {
|
||||
Button,
|
||||
HStack,
|
||||
Icon,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Spinner,
|
||||
Text,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { useState } from "react";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { type ProviderModel } from "~/modelProviders/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import { lookupModel, modelLabel } from "~/utils/utils";
|
||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||
import { ModelSearch } from "./ModelSearch";
|
||||
import { ModelStatsCard } from "./ModelStatsCard";
|
||||
|
||||
export const ChangeModelModal = ({
|
||||
variant,
|
||||
onClose,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const originalModel = lookupModel(variant.modelProvider, variant.model);
|
||||
const [selectedModel, setSelectedModel] = useState({
|
||||
provider: variant.modelProvider,
|
||||
model: variant.model,
|
||||
} as ProviderModel);
|
||||
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment) return;
|
||||
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
newModel: selectedModel,
|
||||
});
|
||||
setConvertedModel(selectedModel);
|
||||
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
|
||||
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
|
||||
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (
|
||||
!variant.experimentId ||
|
||||
!modifiedPromptFn ||
|
||||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
|
||||
)
|
||||
return;
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
constructFn: modifiedPromptFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||
|
||||
const originalLabel = modelLabel(variant.modelProvider, variant.model);
|
||||
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
|
||||
const convertedLabel =
|
||||
convertedModel && modelLabel(convertedModel.provider, convertedModel.model);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={onClose}
|
||||
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={RiExchangeFundsFill} />
|
||||
<Text>Change Model</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||
{originalLabel !== selectedLabel && (
|
||||
<ModelStatsCard
|
||||
label="New Model"
|
||||
model={lookupModel(selectedModel.provider, selectedModel.model)}
|
||||
/>
|
||||
)}
|
||||
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
{isString(modifiedPromptFn) && (
|
||||
<CompareFunctions
|
||||
originalFunction={variant.constructFn}
|
||||
newFunction={modifiedPromptFn}
|
||||
leftTitle={originalLabel}
|
||||
rightTitle={convertedLabel}
|
||||
/>
|
||||
)}
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<HStack>
|
||||
<Button
|
||||
colorScheme="gray"
|
||||
onClick={getModifiedPromptFn}
|
||||
minW={24}
|
||||
isDisabled={originalLabel === selectedLabel || modificationInProgress}
|
||||
>
|
||||
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
|
||||
>
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
36
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
36
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
@@ -0,0 +1,36 @@
|
||||
import { Text, VStack } from "@chakra-ui/react";
|
||||
import { type LegacyRef } from "react";
|
||||
import Select from "react-select";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
|
||||
import { flatMap } from "lodash-es";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { type ProviderModel } from "~/modelProviders/types";
|
||||
import { modelLabel } from "~/utils/utils";
|
||||
|
||||
const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) =>
|
||||
Object.entries(provider.models).map(([modelId]) => ({
|
||||
provider: providerId,
|
||||
model: modelId,
|
||||
})),
|
||||
) as ProviderModel[];
|
||||
|
||||
export const ModelSearch = (props: {
|
||||
selectedModel: ProviderModel;
|
||||
setSelectedModel: (model: ProviderModel) => void;
|
||||
}) => {
|
||||
const [containerRef, containerDimensions] = useElementDimensions();
|
||||
|
||||
return (
|
||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||
<Text>Browse Models</Text>
|
||||
<Select<ProviderModel>
|
||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||
getOptionLabel={(data) => modelLabel(data.provider, data.model)}
|
||||
getOptionValue={(data) => modelLabel(data.provider, data.model)}
|
||||
options={modelOptions}
|
||||
onChange={(option) => option && props.setSelectedModel(option)}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
109
src/components/ChangeModelModal/ModelStatsCard.tsx
Normal file
109
src/components/ChangeModelModal/ModelStatsCard.tsx
Normal file
@@ -0,0 +1,109 @@
|
||||
import {
|
||||
GridItem,
|
||||
HStack,
|
||||
Link,
|
||||
SimpleGrid,
|
||||
Text,
|
||||
VStack,
|
||||
type StackProps,
|
||||
} from "@chakra-ui/react";
|
||||
import { type lookupModel } from "~/utils/utils";
|
||||
|
||||
export const ModelStatsCard = ({
|
||||
label,
|
||||
model,
|
||||
}: {
|
||||
label: string;
|
||||
model: ReturnType<typeof lookupModel>;
|
||||
}) => {
|
||||
if (!model) return null;
|
||||
return (
|
||||
<VStack w="full" align="start">
|
||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||
{label}
|
||||
</Text>
|
||||
|
||||
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
|
||||
<HStack w="full" align="flex-start">
|
||||
<Text flex={1} fontSize="lg">
|
||||
<Text as="span" color="gray.600">
|
||||
{model.provider} /{" "}
|
||||
</Text>
|
||||
<Text as="span" fontWeight="bold" color="gray.900">
|
||||
{model.name}
|
||||
</Text>
|
||||
</Text>
|
||||
<Link
|
||||
href={model.learnMoreUrl}
|
||||
isExternal
|
||||
color="blue.500"
|
||||
fontWeight="bold"
|
||||
fontSize="sm"
|
||||
ml={2}
|
||||
>
|
||||
Learn More
|
||||
</Link>
|
||||
</HStack>
|
||||
<SimpleGrid
|
||||
w="full"
|
||||
justifyContent="space-between"
|
||||
alignItems="flex-start"
|
||||
fontSize="sm"
|
||||
columns={{ base: 2, md: 4 }}
|
||||
>
|
||||
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
|
||||
{model.promptTokenPrice && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Input"
|
||||
info={
|
||||
<Text>
|
||||
${(model.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{model.completionTokenPrice && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Output"
|
||||
info={
|
||||
<Text>
|
||||
${(model.completionTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{model.pricePerSecond && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Price"
|
||||
info={
|
||||
<Text>
|
||||
${model.pricePerSecond.toFixed(3)}
|
||||
<Text color="gray.500"> / second</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const SelectedModelLabeledInfo = ({
|
||||
label,
|
||||
info,
|
||||
...props
|
||||
}: {
|
||||
label: string;
|
||||
info: string | number | React.ReactElement;
|
||||
} & StackProps) => (
|
||||
<GridItem>
|
||||
<VStack alignItems="flex-start" {...props}>
|
||||
<Text fontWeight="bold">{label}</Text>
|
||||
<Text>{info}</Text>
|
||||
</VStack>
|
||||
</GridItem>
|
||||
);
|
||||
77
src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
77
src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
@@ -0,0 +1,77 @@
|
||||
import {
|
||||
Button,
|
||||
Icon,
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogContent,
|
||||
AlertDialogOverlay,
|
||||
useDisclosure,
|
||||
Text,
|
||||
} from "@chakra-ui/react";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useRef } from "react";
|
||||
import { BsTrash } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const DeleteButton = () => {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.experiments.delete.useMutation();
|
||||
const utils = api.useContext();
|
||||
const router = useRouter();
|
||||
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
await mutation.mutateAsync({ id: experiment.data.id });
|
||||
await utils.experiments.list.invalidate();
|
||||
await router.push({ pathname: "/experiments" });
|
||||
onClose();
|
||||
}, [mutation, experiment.data?.id, router]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={{ base: "outline", lg: "ghost" }}
|
||||
colorScheme="red"
|
||||
fontWeight="normal"
|
||||
onClick={onOpen}
|
||||
>
|
||||
<Icon as={BsTrash} boxSize={4} />
|
||||
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
||||
Delete Experiment
|
||||
</Text>
|
||||
</Button>
|
||||
|
||||
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete Experiment
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||
as well. Are you sure?
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -6,13 +6,14 @@ import {
|
||||
DrawerHeader,
|
||||
DrawerOverlay,
|
||||
Heading,
|
||||
Stack,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import EditScenarioVars from "./EditScenarioVars";
|
||||
import EditEvaluations from "./EditEvaluations";
|
||||
import EditScenarioVars from "../OutputsTable/EditScenarioVars";
|
||||
import EditEvaluations from "../OutputsTable/EditEvaluations";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { DeleteButton } from "./DeleteButton";
|
||||
|
||||
export default function SettingsDrawer() {
|
||||
export default function ExperimentSettingsDrawer() {
|
||||
const isOpen = useAppStore((state) => state.drawerOpen);
|
||||
const closeDrawer = useAppStore((state) => state.closeDrawer);
|
||||
|
||||
@@ -22,13 +23,16 @@ export default function SettingsDrawer() {
|
||||
<DrawerContent>
|
||||
<DrawerCloseButton />
|
||||
<DrawerHeader>
|
||||
<Heading size="md">Settings</Heading>
|
||||
<Heading size="md">Experiment Settings</Heading>
|
||||
</DrawerHeader>
|
||||
<DrawerBody>
|
||||
<Stack spacing={6}>
|
||||
<EditScenarioVars />
|
||||
<EditEvaluations />
|
||||
</Stack>
|
||||
<DrawerBody h="full" pb={4}>
|
||||
<VStack h="full" justifyContent="space-between">
|
||||
<VStack spacing={6}>
|
||||
<EditScenarioVars />
|
||||
<EditEvaluations />
|
||||
</VStack>
|
||||
<DeleteButton />
|
||||
</VStack>
|
||||
</DrawerBody>
|
||||
</DrawerContent>
|
||||
</Drawer>
|
||||
57
src/components/OutputsTable/AddVariantButton.tsx
Normal file
57
src/components/OutputsTable/AddVariantButton.tsx
Normal file
@@ -0,0 +1,57 @@
|
||||
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { Text } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
useExperiment,
|
||||
useExperimentAccess,
|
||||
useHandledAsyncCallback,
|
||||
useVisibleScenarioIds,
|
||||
} from "~/utils/hooks";
|
||||
import { cellPadding } from "../constants";
|
||||
import { ActionButton } from "./ScenariosHeader";
|
||||
|
||||
export default function AddVariantButton() {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.promptVariants.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const { canModify } = useExperimentAccess();
|
||||
if (!canModify) return <Box w={cellPadding.x} />;
|
||||
|
||||
return (
|
||||
<Flex w="100%" justifyContent="flex-end">
|
||||
<ActionButton
|
||||
onClick={onClick}
|
||||
py={5}
|
||||
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
||||
>
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</ActionButton>
|
||||
{/* <Button
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={cellPadding.x}
|
||||
onClick={onClick}
|
||||
height="unset"
|
||||
minH={headerMinHeight}
|
||||
>
|
||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</Button> */}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
@@ -18,11 +18,9 @@ export const FloatingLabelInput = ({
|
||||
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
||||
fontSize={isFocused || !!value ? "12px" : "16px"}
|
||||
transition="all 0.15s"
|
||||
zIndex="100"
|
||||
zIndex="5"
|
||||
bg="white"
|
||||
px={1}
|
||||
mt={0}
|
||||
mb={2}
|
||||
lineHeight="1"
|
||||
pointerEvents="none"
|
||||
color={isFocused ? "blue.500" : "gray.500"}
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
// Extracted Button styling into reusable component
|
||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
||||
<Button
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={2}
|
||||
onClick={onClick}
|
||||
>
|
||||
{children}
|
||||
</Button>
|
||||
);
|
||||
|
||||
export default function NewScenarioButton() {
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.scenarios.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onClick] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const [onAutogenerate, autogenerating] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
autogenerate: true,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
if (!canModify) return null;
|
||||
|
||||
return (
|
||||
<HStack spacing={2}>
|
||||
<StyledButton onClick={onClick}>
|
||||
<Icon as={BsPlus} boxSize={6} />
|
||||
Add Scenario
|
||||
</StyledButton>
|
||||
<StyledButton onClick={onAutogenerate}>
|
||||
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
|
||||
Autogenerate Scenario
|
||||
</StyledButton>
|
||||
</HStack>
|
||||
);
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
|
||||
export default function NewVariantButton() {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.promptVariants.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const { canModify } = useExperimentAccess();
|
||||
if (!canModify) return <Box w={cellPadding.x} />;
|
||||
|
||||
return (
|
||||
<Button
|
||||
w="100%"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={cellPadding.x}
|
||||
onClick={onClick}
|
||||
height="unset"
|
||||
minH={headerMinHeight}
|
||||
>
|
||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -40,7 +40,7 @@ export default function OutputCell({
|
||||
);
|
||||
|
||||
const provider =
|
||||
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
|
||||
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
|
||||
|
||||
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
||||
|
||||
@@ -67,8 +67,8 @@ export default function OutputCell({
|
||||
|
||||
const modelOutput = cell?.modelOutput;
|
||||
|
||||
// Disconnect from socket if we're not streaming anymore
|
||||
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
|
||||
// TODO: disconnect from socket if we're not streaming anymore
|
||||
const streamedMessage = useSocket<OutputSchema>(cell?.id);
|
||||
|
||||
if (!vars) return null;
|
||||
|
||||
@@ -81,18 +81,27 @@ export default function OutputCell({
|
||||
</Center>
|
||||
);
|
||||
|
||||
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
if (!cell && !fetchingOutput)
|
||||
return (
|
||||
<VStack>
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||
<Text color="gray.500">Error retrieving output</Text>
|
||||
</VStack>
|
||||
);
|
||||
|
||||
if (cell && cell.errorMessage) {
|
||||
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||
return (
|
||||
<VStack>
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||
<ErrorHandler cell={cell} refetchOutput={hardRefetch} />
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
const normalizedOutput = modelOutput
|
||||
? // @ts-expect-error TODO FIX ASAP
|
||||
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
||||
? provider.normalizeOutput(modelOutput.output)
|
||||
: streamedMessage
|
||||
? // @ts-expect-error TODO FIX ASAP
|
||||
provider.normalizeOutput(streamedMessage)
|
||||
? provider.normalizeOutput(streamedMessage)
|
||||
: null;
|
||||
|
||||
if (modelOutput && normalizedOutput?.type === "json") {
|
||||
|
||||
@@ -22,7 +22,7 @@ export const OutputStats = ({
|
||||
return (
|
||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
{modelOutput.outputEvaluation.map((evaluation) => {
|
||||
{modelOutput.outputEvaluations.map((evaluation) => {
|
||||
const passed = evaluation.result > 0.5;
|
||||
return (
|
||||
<Tooltip
|
||||
|
||||
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
@@ -0,0 +1,74 @@
|
||||
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
||||
import {
|
||||
BsChevronDoubleLeft,
|
||||
BsChevronDoubleRight,
|
||||
BsChevronLeft,
|
||||
BsChevronRight,
|
||||
} from "react-icons/bs";
|
||||
import { usePage, useScenarios } from "~/utils/hooks";
|
||||
|
||||
const ScenarioPaginator = () => {
|
||||
const [page, setPage] = usePage();
|
||||
const { data } = useScenarios();
|
||||
|
||||
if (!data) return null;
|
||||
|
||||
const { scenarios, startIndex, lastPage, count } = data;
|
||||
|
||||
const nextPage = () => {
|
||||
if (page < lastPage) {
|
||||
setPage(page + 1, "replace");
|
||||
}
|
||||
};
|
||||
|
||||
const prevPage = () => {
|
||||
if (page > 1) {
|
||||
setPage(page - 1, "replace");
|
||||
}
|
||||
};
|
||||
|
||||
const goToLastPage = () => setPage(lastPage, "replace");
|
||||
const goToFirstPage = () => setPage(1, "replace");
|
||||
|
||||
return (
|
||||
<HStack pt={4}>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={goToFirstPage}
|
||||
isDisabled={page === 1}
|
||||
aria-label="Go to first page"
|
||||
icon={<BsChevronDoubleLeft />}
|
||||
/>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={prevPage}
|
||||
isDisabled={page === 1}
|
||||
aria-label="Previous page"
|
||||
icon={<BsChevronLeft />}
|
||||
/>
|
||||
<Box>
|
||||
{startIndex}-{startIndex + scenarios.length - 1} / {count}
|
||||
</Box>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={nextPage}
|
||||
isDisabled={page === lastPage}
|
||||
aria-label="Next page"
|
||||
icon={<BsChevronRight />}
|
||||
/>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={goToLastPage}
|
||||
isDisabled={page === lastPage}
|
||||
aria-label="Go to last page"
|
||||
icon={<BsChevronDoubleRight />}
|
||||
/>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default ScenarioPaginator;
|
||||
@@ -4,11 +4,13 @@ import { cellPadding } from "../constants";
|
||||
import OutputCell from "./OutputCell/OutputCell";
|
||||
import ScenarioEditor from "./ScenarioEditor";
|
||||
import type { PromptVariant, Scenario } from "./types";
|
||||
import { borders } from "./styles";
|
||||
|
||||
const ScenarioRow = (props: {
|
||||
scenario: Scenario;
|
||||
variants: PromptVariant[];
|
||||
canHide: boolean;
|
||||
rowStart: number;
|
||||
}) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
@@ -21,15 +23,21 @@ const ScenarioRow = (props: {
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
borderLeftWidth={1}
|
||||
{...borders}
|
||||
rowStart={props.rowStart}
|
||||
colStart={1}
|
||||
>
|
||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||
</GridItem>
|
||||
{props.variants.map((variant) => (
|
||||
{props.variants.map((variant, i) => (
|
||||
<GridItem
|
||||
key={variant.id}
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
rowStart={props.rowStart}
|
||||
colStart={i + 2}
|
||||
{...borders}
|
||||
>
|
||||
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||
|
||||
@@ -1,52 +1,82 @@
|
||||
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
||||
import {
|
||||
Button,
|
||||
type ButtonProps,
|
||||
HStack,
|
||||
Text,
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react";
|
||||
import { cellPadding } from "../constants";
|
||||
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
import { BsPencil } from "react-icons/bs";
|
||||
import {
|
||||
useExperiment,
|
||||
useExperimentAccess,
|
||||
useHandledAsyncCallback,
|
||||
useScenarios,
|
||||
} from "~/utils/hooks";
|
||||
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { api } from "~/utils/api";
|
||||
|
||||
export const ScenariosHeader = ({
|
||||
headerRows,
|
||||
numScenarios,
|
||||
}: {
|
||||
headerRows: number;
|
||||
numScenarios: number;
|
||||
}) => {
|
||||
export const ActionButton = (props: ButtonProps) => (
|
||||
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
||||
);
|
||||
|
||||
export const ScenariosHeader = () => {
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
const { canModify } = useExperimentAccess();
|
||||
const scenarios = useScenarios();
|
||||
|
||||
const [ref, dimensions] = useElementDimensions();
|
||||
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
||||
const experiment = useExperiment();
|
||||
const createScenarioMutation = api.scenarios.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onAddScenario, loading] = useHandledAsyncCallback(
|
||||
async (autogenerate: boolean) => {
|
||||
if (!experiment.data) return;
|
||||
await createScenarioMutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
autogenerate,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
},
|
||||
[createScenarioMutation],
|
||||
);
|
||||
|
||||
return (
|
||||
<GridItem
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
ref={ref as any}
|
||||
display="flex"
|
||||
alignItems="flex-end"
|
||||
rowSpan={headerRows}
|
||||
px={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
// Only display the part of the grid item that has content
|
||||
sx={{ ...stickyHeaderStyle, top: topValue }}
|
||||
>
|
||||
<HStack w="100%">
|
||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||
Scenarios ({numScenarios})
|
||||
</Heading>
|
||||
{canModify && (
|
||||
<Button
|
||||
size="xs"
|
||||
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
||||
<Text fontSize={16} fontWeight="bold">
|
||||
Scenarios ({scenarios.data?.count})
|
||||
</Text>
|
||||
{canModify && (
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
mt={1}
|
||||
variant="ghost"
|
||||
color="gray.500"
|
||||
aria-label="Edit"
|
||||
leftIcon={<BsPencil />}
|
||||
onClick={openDrawer}
|
||||
>
|
||||
Edit Vars
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
</GridItem>
|
||||
aria-label="Edit Scenarios"
|
||||
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||
/>
|
||||
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
|
||||
<MenuItem
|
||||
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
|
||||
onClick={() => onAddScenario(false)}
|
||||
>
|
||||
Add Scenario
|
||||
</MenuItem>
|
||||
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
||||
Autogenerate Scenario
|
||||
</MenuItem>
|
||||
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
|
||||
Edit Vars
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,17 +1,52 @@
|
||||
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
||||
import { useRef, useEffect, useState, useCallback } from "react";
|
||||
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
HStack,
|
||||
IconButton,
|
||||
Spinner,
|
||||
Text,
|
||||
Tooltip,
|
||||
useToast,
|
||||
} from "@chakra-ui/react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
useExperimentAccess,
|
||||
useHandledAsyncCallback,
|
||||
useModifierKeyLabel,
|
||||
useVisibleScenarioIds,
|
||||
} from "~/utils/hooks";
|
||||
import { type PromptVariant } from "./types";
|
||||
|
||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||
const [isChanged, setIsChanged] = useState(false);
|
||||
|
||||
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||
|
||||
const toggleFullscreen = useCallback(() => {
|
||||
setIsFullscreen((prev) => !prev);
|
||||
editorRef.current?.focus();
|
||||
}, [setIsFullscreen]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleEsc = (event: KeyboardEvent) => {
|
||||
if (event.key === "Escape" && isFullscreen) {
|
||||
toggleFullscreen();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", handleEsc);
|
||||
return () => window.removeEventListener("keydown", handleEsc);
|
||||
}, [isFullscreen, toggleFullscreen]);
|
||||
|
||||
const lastSavedFn = props.variant.constructFn;
|
||||
|
||||
const modifierKey = useModifierKeyLabel();
|
||||
@@ -33,6 +68,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||
const utils = api.useContext();
|
||||
const toast = useToast();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!editorRef.current) return;
|
||||
@@ -61,6 +97,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const resp = await replaceVariant.mutateAsync({
|
||||
id: props.variant.id,
|
||||
constructFn: currentFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
if (resp.status === "error") {
|
||||
return toast({
|
||||
@@ -99,11 +136,23 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
readOnly: !canModify,
|
||||
});
|
||||
|
||||
// Workaround because otherwise the commands only work on whatever
|
||||
// editor was loaded on the page last.
|
||||
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
||||
editorRef.current.onDidFocusEditorText(() => {
|
||||
// Workaround because otherwise the command only works on whatever
|
||||
// editor was loaded on the page last.
|
||||
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
||||
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave);
|
||||
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, onSave);
|
||||
|
||||
editorRef.current?.addCommand(
|
||||
monaco.KeyMod.CtrlCmd | monaco.KeyMod.Shift | monaco.KeyCode.KeyF,
|
||||
toggleFullscreen,
|
||||
);
|
||||
|
||||
// Exit fullscreen with escape
|
||||
editorRef.current?.addCommand(monaco.KeyCode.Escape, () => {
|
||||
if (isFullscreen) {
|
||||
toggleFullscreen();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
editorRef.current.onDidChangeModelContent(checkForChanges);
|
||||
@@ -132,8 +181,40 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
}, [canModify]);
|
||||
|
||||
return (
|
||||
<Box w="100%" pos="relative">
|
||||
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
||||
<Box
|
||||
w="100%"
|
||||
ref={containerRef}
|
||||
sx={
|
||||
isFullscreen
|
||||
? {
|
||||
position: "fixed",
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
}
|
||||
: { h: "400px", w: "100%" }
|
||||
}
|
||||
bgColor={editorBackground}
|
||||
zIndex={isFullscreen ? 1000 : "unset"}
|
||||
pos="relative"
|
||||
_hover={{ ".fullscreen-toggle": { opacity: 1 } }}
|
||||
>
|
||||
<Box id={editorId} w="100%" h="100%" />
|
||||
<Tooltip label={`${modifierKey} + ⇧ + F`}>
|
||||
<IconButton
|
||||
className="fullscreen-toggle"
|
||||
aria-label="Minimize"
|
||||
icon={isFullscreen ? <FiMinimize /> : <FiMaximize />}
|
||||
position="absolute"
|
||||
top={2}
|
||||
right={2}
|
||||
onClick={toggleFullscreen}
|
||||
opacity={0}
|
||||
transition="opacity 0.2s"
|
||||
/>
|
||||
</Tooltip>
|
||||
|
||||
{isChanged && (
|
||||
<HStack pos="absolute" bottom={2} right={2}>
|
||||
<Button
|
||||
@@ -146,7 +227,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
>
|
||||
Reset
|
||||
</Button>
|
||||
<Tooltip label={`${modifierKey} + Enter`}>
|
||||
<Tooltip label={`${modifierKey} + S`}>
|
||||
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||
</Button>
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import { Grid, GridItem } from "@chakra-ui/react";
|
||||
import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import NewScenarioButton from "./NewScenarioButton";
|
||||
import NewVariantButton from "./NewVariantButton";
|
||||
import AddVariantButton from "./AddVariantButton";
|
||||
import ScenarioRow from "./ScenarioRow";
|
||||
import VariantEditor from "./VariantEditor";
|
||||
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||
import VariantStats from "./VariantStats";
|
||||
import { ScenariosHeader } from "./ScenariosHeader";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
import { borders } from "./styles";
|
||||
import { useScenarios } from "~/utils/hooks";
|
||||
import ScenarioPaginator from "./ScenarioPaginator";
|
||||
import { Fragment } from "react";
|
||||
|
||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||
const variants = api.promptVariants.list.useQuery(
|
||||
@@ -15,68 +17,90 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
{ enabled: !!experimentId },
|
||||
);
|
||||
|
||||
const scenarios = api.scenarios.list.useQuery(
|
||||
{ experimentId: experimentId as string },
|
||||
{ enabled: !!experimentId },
|
||||
);
|
||||
const scenarios = useScenarios();
|
||||
|
||||
if (!variants.data || !scenarios.data) return null;
|
||||
|
||||
const allCols = variants.data.length + 1;
|
||||
const headerRows = 3;
|
||||
const allCols = variants.data.length + 2;
|
||||
const variantHeaderRows = 3;
|
||||
const scenarioHeaderRows = 1;
|
||||
const scenarioFooterRows = 1;
|
||||
const visibleScenariosCount = scenarios.data.scenarios.length;
|
||||
const allRows =
|
||||
variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + scenarioFooterRows;
|
||||
|
||||
return (
|
||||
<Grid
|
||||
p={4}
|
||||
pt={4}
|
||||
pb={24}
|
||||
pl={4}
|
||||
display="grid"
|
||||
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
||||
sx={{
|
||||
"> *": {
|
||||
borderColor: "gray.300",
|
||||
borderBottomWidth: 1,
|
||||
borderRightWidth: 1,
|
||||
},
|
||||
}}
|
||||
fontSize="sm"
|
||||
>
|
||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
|
||||
))}
|
||||
<GridItem
|
||||
rowSpan={scenarios.data.length + headerRows}
|
||||
padding={0}
|
||||
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
|
||||
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
|
||||
h={8}
|
||||
sx={stickyHeaderStyle}
|
||||
>
|
||||
<NewVariantButton />
|
||||
<GridItem rowSpan={variantHeaderRows}>
|
||||
<AddVariantButton />
|
||||
</GridItem>
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId}>
|
||||
<VariantEditor variant={variant} />
|
||||
</GridItem>
|
||||
))}
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId}>
|
||||
<VariantStats variant={variant} />
|
||||
</GridItem>
|
||||
))}
|
||||
{scenarios.data.map((scenario) => (
|
||||
{variants.data.map((variant, i) => {
|
||||
const sharedProps: GridItemProps = {
|
||||
...borders,
|
||||
colStart: i + 2,
|
||||
borderLeftWidth: i === 0 ? 1 : 0,
|
||||
marginLeft: i === 0 ? "-1px" : 0,
|
||||
};
|
||||
return (
|
||||
<Fragment key={variant.uiId}>
|
||||
<VariantHeader
|
||||
variant={variant}
|
||||
canHide={variants.data.length > 1}
|
||||
rowStart={1}
|
||||
{...sharedProps}
|
||||
/>
|
||||
<GridItem rowStart={2} {...sharedProps}>
|
||||
<VariantEditor variant={variant} />
|
||||
</GridItem>
|
||||
<GridItem rowStart={3} {...sharedProps}>
|
||||
<VariantStats variant={variant} />
|
||||
</GridItem>
|
||||
</Fragment>
|
||||
);
|
||||
})}
|
||||
|
||||
<GridItem
|
||||
colSpan={allCols - 1}
|
||||
rowStart={variantHeaderRows + 1}
|
||||
colStart={1}
|
||||
{...borders}
|
||||
borderRightWidth={0}
|
||||
>
|
||||
<ScenariosHeader />
|
||||
</GridItem>
|
||||
|
||||
{scenarios.data.scenarios.map((scenario, i) => (
|
||||
<ScenarioRow
|
||||
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
||||
key={scenario.uiId}
|
||||
scenario={scenario}
|
||||
variants={variants.data}
|
||||
canHide={scenarios.data.length > 1}
|
||||
canHide={visibleScenariosCount > 1}
|
||||
/>
|
||||
))}
|
||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
||||
<NewScenarioButton />
|
||||
<GridItem
|
||||
rowStart={variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + 2}
|
||||
colStart={1}
|
||||
colSpan={allCols}
|
||||
>
|
||||
<ScenarioPaginator />
|
||||
</GridItem>
|
||||
|
||||
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
||||
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
||||
</Grid>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import { type SystemStyleObject } from "@chakra-ui/react";
|
||||
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
|
||||
|
||||
export const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
top: "0",
|
||||
backgroundColor: "#fff",
|
||||
zIndex: 1,
|
||||
zIndex: 10,
|
||||
};
|
||||
|
||||
export const borders: GridItemProps = {
|
||||
borderRightWidth: 1,
|
||||
borderBottomWidth: 1,
|
||||
};
|
||||
|
||||
@@ -2,4 +2,4 @@ import { type RouterOutputs } from "~/utils/api";
|
||||
|
||||
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
||||
|
||||
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
||||
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>["scenarios"][0];
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||
import Prism from "prismjs";
|
||||
@@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => {
|
||||
const CompareFunctions = ({
|
||||
originalFunction,
|
||||
newFunction = "",
|
||||
leftTitle = "Original",
|
||||
rightTitle = "Modified",
|
||||
...props
|
||||
}: {
|
||||
originalFunction: string;
|
||||
newFunction?: string;
|
||||
}) => {
|
||||
leftTitle?: string;
|
||||
rightTitle?: string;
|
||||
} & StackProps) => {
|
||||
const showSplitView = useBreakpointValue(
|
||||
{
|
||||
base: false,
|
||||
@@ -34,22 +39,20 @@ const CompareFunctions = ({
|
||||
);
|
||||
|
||||
return (
|
||||
<HStack w="full" spacing={5}>
|
||||
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
|
||||
<DiffViewer
|
||||
oldValue={originalFunction}
|
||||
newValue={newFunction || originalFunction}
|
||||
splitView={showSplitView}
|
||||
hideLineNumbers={!showSplitView}
|
||||
leftTitle="Original"
|
||||
rightTitle={newFunction ? "Modified" : "Unmodified"}
|
||||
disableWordDiff={true}
|
||||
compareMethod={DiffMethod.CHARS}
|
||||
renderContent={highlightSyntax}
|
||||
showDiffOnly={false}
|
||||
/>
|
||||
</VStack>
|
||||
</HStack>
|
||||
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
|
||||
<DiffViewer
|
||||
oldValue={originalFunction}
|
||||
newValue={newFunction || originalFunction}
|
||||
splitView={showSplitView}
|
||||
hideLineNumbers={!showSplitView}
|
||||
leftTitle={leftTitle}
|
||||
rightTitle={rightTitle}
|
||||
disableWordDiff={true}
|
||||
compareMethod={DiffMethod.CHARS}
|
||||
renderContent={highlightSyntax}
|
||||
showDiffOnly={false}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({
|
||||
minW="unset"
|
||||
size="sm"
|
||||
onClick={() => onSubmit()}
|
||||
disabled={!instructions}
|
||||
variant={instructions ? "solid" : "ghost"}
|
||||
mr={4}
|
||||
borderRadius="8"
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||
import { type IconType } from "react-icons";
|
||||
import { refineOptions, type RefineOptionLabel } from "./refineOptions";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
|
||||
export const RefineOption = ({
|
||||
export const RefineAction = ({
|
||||
label,
|
||||
activeLabel,
|
||||
icon,
|
||||
desciption,
|
||||
activeLabel,
|
||||
onClick,
|
||||
loading,
|
||||
}: {
|
||||
label: RefineOptionLabel;
|
||||
activeLabel: RefineOptionLabel | undefined;
|
||||
icon: IconType;
|
||||
onClick: (label: RefineOptionLabel) => void;
|
||||
label: string;
|
||||
icon?: IconType;
|
||||
desciption: string;
|
||||
activeLabel: string | undefined;
|
||||
onClick: (label: string) => void;
|
||||
loading: boolean;
|
||||
}) => {
|
||||
const isActive = activeLabel === label;
|
||||
const desciption = refineOptions[label].description;
|
||||
|
||||
return (
|
||||
<GridItem w="80" h="44">
|
||||
@@ -44,7 +45,7 @@ export const RefineOption = ({
|
||||
opacity={loading ? 0.5 : 1}
|
||||
>
|
||||
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
||||
<Icon as={icon} boxSize={12} />
|
||||
<Icon as={icon || BsStars} boxSize={12} />
|
||||
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||
{label}
|
||||
</Heading>
|
||||
@@ -15,17 +15,16 @@ import {
|
||||
SimpleGrid,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { VscJson } from "react-icons/vsc";
|
||||
import { TfiThought } from "react-icons/tfi";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { useState } from "react";
|
||||
import CompareFunctions from "./CompareFunctions";
|
||||
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||
import { type RefineOptionLabel, refineOptions } from "./refineOptions";
|
||||
import { RefineOption } from "./RefineOption";
|
||||
import { RefineAction } from "./RefineAction";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
|
||||
export const RefinePromptModal = ({
|
||||
variant,
|
||||
@@ -35,26 +34,32 @@ export const RefinePromptModal = ({
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getRefinedPromptFn.useMutation();
|
||||
const refinementActions =
|
||||
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
|
||||
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
const [instructions, setInstructions] = useState<string>("");
|
||||
|
||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
||||
RefineOptionLabel | undefined
|
||||
>(undefined);
|
||||
const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>(
|
||||
undefined,
|
||||
);
|
||||
|
||||
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(
|
||||
async (label?: RefineOptionLabel) => {
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||
async (label?: string) => {
|
||||
if (!variant.experimentId) return;
|
||||
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
||||
setActiveRefineOptionLabel(label);
|
||||
await getRefinedPromptMutateAsync({
|
||||
const updatedInstructions = label
|
||||
? (refinementActions[label] as RefinementAction).instructions
|
||||
: instructions;
|
||||
setActiveRefineActionLabel(label);
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
instructions: updatedInstructions,
|
||||
});
|
||||
},
|
||||
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
||||
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel],
|
||||
);
|
||||
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
@@ -69,13 +74,18 @@ export const RefinePromptModal = ({
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
constructFn: refinedPromptFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||
|
||||
return (
|
||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={onClose}
|
||||
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
@@ -88,35 +98,37 @@ export const RefinePromptModal = ({
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<VStack spacing={4}>
|
||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||
<RefineOption
|
||||
label="Convert to function call"
|
||||
activeLabel={activeRefineOptionLabel}
|
||||
icon={VscJson}
|
||||
onClick={getRefinedPromptFn}
|
||||
loading={refiningInProgress}
|
||||
/>
|
||||
<RefineOption
|
||||
label="Add chain of thought"
|
||||
activeLabel={activeRefineOptionLabel}
|
||||
icon={TfiThought}
|
||||
onClick={getRefinedPromptFn}
|
||||
loading={refiningInProgress}
|
||||
/>
|
||||
</SimpleGrid>
|
||||
<HStack>
|
||||
<Text color="gray.500">or</Text>
|
||||
</HStack>
|
||||
{Object.keys(refinementActions).length && (
|
||||
<>
|
||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||
{Object.keys(refinementActions).map((label) => (
|
||||
<RefineAction
|
||||
key={label}
|
||||
label={label}
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
icon={refinementActions[label]!.icon}
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
desciption={refinementActions[label]!.description}
|
||||
activeLabel={activeRefineActionLabel}
|
||||
onClick={getModifiedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
/>
|
||||
))}
|
||||
</SimpleGrid>
|
||||
<Text color="gray.500">or</Text>
|
||||
</>
|
||||
)}
|
||||
<CustomInstructionsInput
|
||||
instructions={instructions}
|
||||
setInstructions={setInstructions}
|
||||
loading={refiningInProgress}
|
||||
onSubmit={getRefinedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
onSubmit={getModifiedPromptFn}
|
||||
/>
|
||||
</VStack>
|
||||
<CompareFunctions
|
||||
originalFunction={variant.constructFn}
|
||||
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||
maxH="40vh"
|
||||
/>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
@@ -124,12 +136,10 @@ export const RefinePromptModal = ({
|
||||
<ModalFooter>
|
||||
<HStack spacing={4}>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
disabled={replacementInProgress || !refinedPromptFn}
|
||||
_disabled={{
|
||||
bgColor: "blue.500",
|
||||
}}
|
||||
isDisabled={replacementInProgress || !refinedPromptFn}
|
||||
>
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
// Super hacky, but we'll redo the organization when we have more models
|
||||
|
||||
export type RefineOptionLabel = "Add chain of thought" | "Convert to function call";
|
||||
|
||||
export const refineOptions: Record<
|
||||
RefineOptionLabel,
|
||||
{ description: string; instructions: string }
|
||||
> = {
|
||||
"Add chain of thought": {
|
||||
description: "Asking the model to plan its answer can increase accuracy.",
|
||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||
|
||||
This is what a prompt looks like before adding chain of thought:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
This is what one looks like after adding chain of thought:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
Here's another example:
|
||||
|
||||
Before:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
};
|
||||
|
||||
After:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
explanation: {
|
||||
type: "string",
|
||||
}
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
};
|
||||
|
||||
Add chain of thought to the original prompt.`,
|
||||
},
|
||||
"Convert to function call": {
|
||||
description: "Use function calls to get output from the model in a more structured way.",
|
||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||
|
||||
This is what a prompt looks like before adding a function:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
This is what one looks like after adding a function:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Evaluate sentiment.",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: scenario.user_message,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "extract_sentiment",
|
||||
parameters: {
|
||||
type: "object", // parameters must always be an object with a properties key
|
||||
properties: { // properties key is required
|
||||
sentiment: {
|
||||
type: "string",
|
||||
description: "one of positive/negative/neutral",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "extract_sentiment",
|
||||
},
|
||||
};
|
||||
|
||||
Here's another example of adding a function:
|
||||
|
||||
Before:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Here is the title and body of a reddit post I am interested in:
|
||||
|
||||
title: \${scenario.title}
|
||||
body: \${scenario.body}
|
||||
|
||||
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Answer one integer between 1 and 3.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
};
|
||||
|
||||
After:
|
||||
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
};
|
||||
|
||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||
},
|
||||
};
|
||||
@@ -1,89 +0,0 @@
|
||||
import {
|
||||
VStack,
|
||||
Text,
|
||||
HStack,
|
||||
type StackProps,
|
||||
GridItem,
|
||||
SimpleGrid,
|
||||
Link,
|
||||
} from "@chakra-ui/react";
|
||||
import { modelStats } from "~/modelProviders/modelStats";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
|
||||
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
|
||||
const stats = modelStats[model];
|
||||
return (
|
||||
<VStack w="full" align="start">
|
||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||
{label}
|
||||
</Text>
|
||||
|
||||
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
|
||||
<HStack w="full" align="flex-start">
|
||||
<Text flex={1} fontSize="lg">
|
||||
<Text as="span" color="gray.600">
|
||||
{stats.provider} /{" "}
|
||||
</Text>
|
||||
<Text as="span" fontWeight="bold" color="gray.900">
|
||||
{model}
|
||||
</Text>
|
||||
</Text>
|
||||
<Link
|
||||
href={stats.learnMoreUrl}
|
||||
isExternal
|
||||
color="blue.500"
|
||||
fontWeight="bold"
|
||||
fontSize="sm"
|
||||
ml={2}
|
||||
>
|
||||
Learn More
|
||||
</Link>
|
||||
</HStack>
|
||||
<SimpleGrid
|
||||
w="full"
|
||||
justifyContent="space-between"
|
||||
alignItems="flex-start"
|
||||
fontSize="sm"
|
||||
columns={{ base: 2, md: 4 }}
|
||||
>
|
||||
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
|
||||
<SelectedModelLabeledInfo
|
||||
label="Input"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo
|
||||
label="Output"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const SelectedModelLabeledInfo = ({
|
||||
label,
|
||||
info,
|
||||
...props
|
||||
}: {
|
||||
label: string;
|
||||
info: string | number | React.ReactElement;
|
||||
} & StackProps) => (
|
||||
<GridItem>
|
||||
<VStack alignItems="flex-start" {...props}>
|
||||
<Text fontWeight="bold">{label}</Text>
|
||||
<Text>{info}</Text>
|
||||
</VStack>
|
||||
</GridItem>
|
||||
);
|
||||
@@ -1,85 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
VStack,
|
||||
Text,
|
||||
Spinner,
|
||||
HStack,
|
||||
Icon,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { useState } from "react";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { ModelStatsCard } from "./ModelStatsCard";
|
||||
import { SelectModelSearch } from "./SelectModelSearch";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const SelectModelModal = ({
|
||||
originalModel,
|
||||
variantId,
|
||||
onClose,
|
||||
}: {
|
||||
originalModel: SupportedModel;
|
||||
variantId: string;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
|
||||
const createMutation = api.promptVariants.create.useMutation();
|
||||
|
||||
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment?.data?.id) return;
|
||||
await createMutation.mutateAsync({
|
||||
experimentId: experiment?.data?.id,
|
||||
variantId,
|
||||
newModel: selectedModel,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [createMutation, experiment?.data?.id, variantId, onClose]);
|
||||
|
||||
return (
|
||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={RiExchangeFundsFill} />
|
||||
<Text>Change Model</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||
{originalModel !== selectedModel && (
|
||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||
)}
|
||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={createNewVariant}
|
||||
minW={24}
|
||||
disabled={originalModel === selectedModel}
|
||||
>
|
||||
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
@@ -1,47 +0,0 @@
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { type LegacyRef, useCallback } from "react";
|
||||
import Select, { type SingleValue } from "react-select";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
|
||||
const modelOptions: { value: SupportedModel; label: string }[] = [
|
||||
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
|
||||
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
|
||||
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
|
||||
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
|
||||
{ value: "gpt-4", label: "gpt-4" },
|
||||
{ value: "gpt-4-0613", label: "gpt-4-0613" },
|
||||
{ value: "gpt-4-32k", label: "gpt-4-32k" },
|
||||
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
|
||||
];
|
||||
|
||||
export const SelectModelSearch = ({
|
||||
selectedModel,
|
||||
setSelectedModel,
|
||||
}: {
|
||||
selectedModel: SupportedModel;
|
||||
setSelectedModel: (model: SupportedModel) => void;
|
||||
}) => {
|
||||
const handleSelection = useCallback(
|
||||
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
|
||||
if (!option) return;
|
||||
setSelectedModel(option.value);
|
||||
},
|
||||
[setSelectedModel],
|
||||
);
|
||||
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
|
||||
|
||||
const [containerRef, containerDimensions] = useElementDimensions();
|
||||
|
||||
return (
|
||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||
<Text>Browse Models</Text>
|
||||
<Select
|
||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||
value={selectedOption}
|
||||
options={modelOptions}
|
||||
onChange={handleSelection}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -3,28 +3,34 @@ import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { RiDraggable } from "react-icons/ri";
|
||||
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
|
||||
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||
|
||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||
export default function VariantHeader(
|
||||
allProps: {
|
||||
variant: PromptVariant;
|
||||
canHide: boolean;
|
||||
} & GridItemProps,
|
||||
) {
|
||||
const { variant, canHide, ...gridItemProps } = allProps;
|
||||
const { canModify } = useExperimentAccess();
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||
const [label, setLabel] = useState(props.variant.label);
|
||||
const [label, setLabel] = useState(variant.label);
|
||||
|
||||
const updateMutation = api.promptVariants.update.useMutation();
|
||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||
if (label && label !== props.variant.label) {
|
||||
if (label && label !== variant.label) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: props.variant.id,
|
||||
id: variant.id,
|
||||
updates: { label: label },
|
||||
});
|
||||
}
|
||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
||||
}, [updateMutation, variant.id, variant.label, label]);
|
||||
|
||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||
const [onReorder] = useHandledAsyncCallback(
|
||||
@@ -32,7 +38,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
e.preventDefault();
|
||||
setIsDragTarget(false);
|
||||
const draggedId = e.dataTransfer.getData("text/plain");
|
||||
const droppedId = props.variant.id;
|
||||
const droppedId = variant.id;
|
||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||
await reorderMutation.mutateAsync({
|
||||
draggedId,
|
||||
@@ -40,16 +46,16 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
},
|
||||
[reorderMutation, props.variant.id],
|
||||
[reorderMutation, variant.id],
|
||||
);
|
||||
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
|
||||
if (!canModify) {
|
||||
return (
|
||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
|
||||
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||
{props.variant.label}
|
||||
{variant.label}
|
||||
</Text>
|
||||
</GridItem>
|
||||
);
|
||||
@@ -64,6 +70,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||
}}
|
||||
borderTopWidth={1}
|
||||
{...gridItemProps}
|
||||
>
|
||||
<HStack
|
||||
spacing={4}
|
||||
@@ -71,7 +78,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
minH={headerMinHeight}
|
||||
draggable={!isInputHovered}
|
||||
onDragStart={(e) => {
|
||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
||||
e.dataTransfer.setData("text/plain", variant.id);
|
||||
e.currentTarget.style.opacity = "0.4";
|
||||
}}
|
||||
onDragEnd={(e) => {
|
||||
@@ -112,8 +119,8 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
onMouseLeave={() => setIsInputHovered(false)}
|
||||
/>
|
||||
<VariantHeaderMenuButton
|
||||
variant={props.variant}
|
||||
canHide={props.canHide}
|
||||
variant={variant}
|
||||
canHide={canHide}
|
||||
menuOpen={menuOpen}
|
||||
setMenuOpen={setMenuOpen}
|
||||
/>
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import {
|
||||
Button,
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
@@ -11,14 +10,14 @@ import {
|
||||
MenuDivider,
|
||||
Text,
|
||||
Spinner,
|
||||
IconButton,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
||||
import { FaRegClone } from "react-icons/fa";
|
||||
import { useState } from "react";
|
||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
|
||||
|
||||
export default function VariantHeaderMenuButton({
|
||||
variant,
|
||||
@@ -34,11 +33,13 @@ export default function VariantHeaderMenuButton({
|
||||
const utils = api.useContext();
|
||||
|
||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||
await duplicateMutation.mutateAsync({
|
||||
experimentId: variant.experimentId,
|
||||
variantId: variant.id,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||
@@ -51,21 +52,18 @@ export default function VariantHeaderMenuButton({
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [hideMutation, variant.id]);
|
||||
|
||||
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
|
||||
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
|
||||
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||
{duplicationInProgress ? (
|
||||
<Spinner boxSize={4} mx={3} my={3} />
|
||||
) : (
|
||||
<MenuButton>
|
||||
<Button variant="ghost">
|
||||
<Icon as={BsGear} />
|
||||
</Button>
|
||||
</MenuButton>
|
||||
)}
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
variant="ghost"
|
||||
aria-label="Edit Scenarios"
|
||||
icon={<Icon as={duplicationInProgress ? Spinner : BsGear} />}
|
||||
/>
|
||||
|
||||
<MenuList mt={-3} fontSize="md">
|
||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||
@@ -73,7 +71,7 @@ export default function VariantHeaderMenuButton({
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||
onClick={() => setSelectModelModalOpen(true)}
|
||||
onClick={() => setChangeModelModalOpen(true)}
|
||||
>
|
||||
Change Model
|
||||
</MenuItem>
|
||||
@@ -98,12 +96,8 @@ export default function VariantHeaderMenuButton({
|
||||
)}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{selectModelModalOpen && (
|
||||
<SelectModelModal
|
||||
originalModel={variant.model as SupportedModel}
|
||||
variantId={variant.id}
|
||||
onClose={() => setSelectModelModalOpen(false)}
|
||||
/>
|
||||
{changeModelModalOpen && (
|
||||
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
|
||||
)}
|
||||
{refinePromptModalOpen && (
|
||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
|
||||
import {
|
||||
HStack,
|
||||
Icon,
|
||||
VStack,
|
||||
Text,
|
||||
Divider,
|
||||
Spinner,
|
||||
AspectRatio,
|
||||
SkeletonText,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import { formatTimePast } from "~/utils/dayjs";
|
||||
import Link from "next/link";
|
||||
@@ -93,3 +102,13 @@ export const NewExperimentCard = () => {
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
export const ExperimentCardSkeleton = () => (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
<SkeletonText noOfLines={2} w="60%" />
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
|
||||
57
src/components/experiments/HeaderButtons/DeleteDialog.tsx
Normal file
57
src/components/experiments/HeaderButtons/DeleteDialog.tsx
Normal file
@@ -0,0 +1,57 @@
|
||||
import {
|
||||
Button,
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogContent,
|
||||
AlertDialogOverlay,
|
||||
} from "@chakra-ui/react";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useRef } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const DeleteDialog = ({ onClose }: { onClose: () => void }) => {
|
||||
const experiment = useExperiment();
|
||||
const deleteMutation = api.experiments.delete.useMutation();
|
||||
const utils = api.useContext();
|
||||
const router = useRouter();
|
||||
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
await deleteMutation.mutateAsync({ id: experiment.data.id });
|
||||
await utils.experiments.list.invalidate();
|
||||
await router.push({ pathname: "/experiments" });
|
||||
onClose();
|
||||
}, [deleteMutation, experiment.data?.id, router]);
|
||||
|
||||
return (
|
||||
<AlertDialog isOpen leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete Experiment
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||
as well. Are you sure?
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
);
|
||||
};
|
||||
42
src/components/experiments/HeaderButtons/HeaderButtons.tsx
Normal file
42
src/components/experiments/HeaderButtons/HeaderButtons.tsx
Normal file
@@ -0,0 +1,42 @@
|
||||
import { Button, HStack, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||
import { useOnForkButtonPressed } from "./useOnForkButtonPressed";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { BsGearFill } from "react-icons/bs";
|
||||
import { TbGitFork } from "react-icons/tb";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
export const HeaderButtons = () => {
|
||||
const experiment = useExperiment();
|
||||
|
||||
const canModify = experiment.data?.access.canModify ?? false;
|
||||
|
||||
const { onForkButtonPressed, isForking } = useOnForkButtonPressed();
|
||||
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
|
||||
if (experiment.isLoading) return null;
|
||||
|
||||
return (
|
||||
<HStack spacing={0} mt={{ base: 2, md: 0 }}>
|
||||
<Button
|
||||
onClick={onForkButtonPressed}
|
||||
mr={4}
|
||||
colorScheme={canModify ? undefined : "orange"}
|
||||
bgColor={canModify ? undefined : "orange.400"}
|
||||
minW={0}
|
||||
variant={canModify ? "ghost" : "solid"}
|
||||
>
|
||||
{isForking ? <Spinner boxSize={5} /> : <Icon as={TbGitFork} boxSize={5} />}
|
||||
<Text ml={2}>Fork</Text>
|
||||
</Button>
|
||||
{canModify && (
|
||||
<Button variant={{ base: "solid", md: "ghost" }} onClick={openDrawer}>
|
||||
<HStack>
|
||||
<Icon as={BsGearFill} />
|
||||
<Text>Settings</Text>
|
||||
</HStack>
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,30 @@
|
||||
import { useCallback } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import { useRouter } from "next/router";
|
||||
|
||||
export const useOnForkButtonPressed = () => {
|
||||
const router = useRouter();
|
||||
|
||||
const user = useSession().data;
|
||||
const experiment = useExperiment();
|
||||
|
||||
const forkMutation = api.experiments.fork.useMutation();
|
||||
|
||||
const [onFork, isForking] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
const forkedExperimentId = await forkMutation.mutateAsync({ id: experiment.data.id });
|
||||
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
|
||||
}, [forkMutation, experiment.data?.id, router]);
|
||||
|
||||
const onForkButtonPressed = useCallback(() => {
|
||||
if (user === null) {
|
||||
signIn("github").catch(console.error);
|
||||
} else {
|
||||
onFork();
|
||||
}
|
||||
}, [onFork, user]);
|
||||
|
||||
return { onForkButtonPressed, isForking };
|
||||
};
|
||||
@@ -9,7 +9,6 @@ export const env = createEnv({
|
||||
server: {
|
||||
DATABASE_URL: z.string().url(),
|
||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||
OPENAI_API_KEY: z.string().min(1),
|
||||
RESTRICT_PRISMA_LOGS: z
|
||||
.string()
|
||||
.optional()
|
||||
@@ -17,7 +16,8 @@ export const env = createEnv({
|
||||
.transform((val) => val.toLowerCase() === "true"),
|
||||
GITHUB_CLIENT_ID: z.string().min(1),
|
||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||
REPLICATE_API_TOKEN: z.string().min(1),
|
||||
OPENAI_API_KEY: z.string().min(1),
|
||||
REPLICATE_API_TOKEN: z.string().default("placeholder"),
|
||||
},
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
||||
|
||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||
|
||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||
// just include them in the default `modelProviders` object because it has some
|
||||
// transient dependencies that can only be imported on the server.
|
||||
const modelProvidersFrontend = {
|
||||
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||
"replicate/llama2": replicateLlama2Frontend,
|
||||
} as const;
|
||||
};
|
||||
|
||||
export default modelProvidersFrontend;
|
||||
export default frontendModelProviders;
|
||||
@@ -1,9 +1,10 @@
|
||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||
import replicateLlama2 from "./replicate-llama2";
|
||||
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||
|
||||
const modelProviders = {
|
||||
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletion,
|
||||
"replicate/llama2": replicateLlama2,
|
||||
} as const;
|
||||
};
|
||||
|
||||
export default modelProviders;
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import { type SupportedModel } from "../server/types";
|
||||
|
||||
interface ModelStats {
|
||||
contextLength: number;
|
||||
promptTokenPrice: number;
|
||||
completionTokenPrice: number;
|
||||
speed: "fast" | "medium" | "slow";
|
||||
provider: "OpenAI";
|
||||
learnMoreUrl: string;
|
||||
}
|
||||
|
||||
export const modelStats: Record<SupportedModel, ModelStats> = {
|
||||
"gpt-4": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-0613": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
};
|
||||
@@ -56,6 +56,14 @@ modelProperty.type = "string";
|
||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||
delete modelProperty["oneOf"];
|
||||
|
||||
// The default of "inf" confuses the Typescript generator, so can just remove it
|
||||
assert(
|
||||
"max_tokens" in completionRequestSchema.properties &&
|
||||
isObject(completionRequestSchema.properties.max_tokens) &&
|
||||
"default" in completionRequestSchema.properties.max_tokens,
|
||||
);
|
||||
delete completionRequestSchema.properties.max_tokens["default"];
|
||||
|
||||
// Get the directory of the current script
|
||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||
|
||||
|
||||
@@ -150,7 +150,6 @@
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
||||
"default": "inf",
|
||||
"type": "integer"
|
||||
},
|
||||
"presence_penalty": {
|
||||
|
||||
@@ -1,8 +1,53 @@
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { type OpenaiChatModelProvider } from ".";
|
||||
import { type ModelProviderFrontend } from "../types";
|
||||
import { type SupportedModel } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { refinementActions } from "./refinementActions";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
|
||||
models: {
|
||||
"gpt-4-0613": {
|
||||
name: "GPT-4",
|
||||
contextWindow: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
name: "GPT-4 32k",
|
||||
contextWindow: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
name: "GPT-3.5 Turbo",
|
||||
contextWindow: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
name: "GPT-3.5 Turbo 16k",
|
||||
contextWindow: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
},
|
||||
|
||||
refinementActions,
|
||||
|
||||
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
||||
normalizeOutput: (output) => {
|
||||
const message = output.choices[0]?.message;
|
||||
if (!message)
|
||||
@@ -39,4 +84,4 @@ const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
||||
},
|
||||
};
|
||||
|
||||
export default modelProviderFrontend;
|
||||
export default frontendModelProvider;
|
||||
|
||||
@@ -8,10 +8,10 @@ import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { type CompletionResponse } from "../types";
|
||||
import { omit } from "lodash-es";
|
||||
import { openai } from "~/server/utils/openai";
|
||||
import { type OpenAIChatModel } from "~/server/types";
|
||||
import { truthyFilter } from "~/utils/utils";
|
||||
import { APIError } from "openai";
|
||||
import { modelStats } from "../modelStats";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import modelProvider, { type SupportedModel } from ".";
|
||||
|
||||
const mergeStreamedChunks = (
|
||||
base: ChatCompletion | null,
|
||||
@@ -60,6 +60,7 @@ export async function getCompletion(
|
||||
let finalCompletion: ChatCompletion | null = null;
|
||||
let promptTokens: number | undefined = undefined;
|
||||
let completionTokens: number | undefined = undefined;
|
||||
const modelName = modelProvider.getModel(input) as SupportedModel;
|
||||
|
||||
try {
|
||||
if (onStream) {
|
||||
@@ -81,12 +82,9 @@ export async function getCompletion(
|
||||
};
|
||||
}
|
||||
try {
|
||||
promptTokens = countOpenAIChatTokens(
|
||||
input.model as keyof typeof OpenAIChatModel,
|
||||
input.messages,
|
||||
);
|
||||
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
||||
completionTokens = countOpenAIChatTokens(
|
||||
input.model as keyof typeof OpenAIChatModel,
|
||||
modelName,
|
||||
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
||||
);
|
||||
} catch (err) {
|
||||
@@ -106,10 +104,10 @@ export async function getCompletion(
|
||||
}
|
||||
const timeToComplete = Date.now() - start;
|
||||
|
||||
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
|
||||
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
||||
let cost = undefined;
|
||||
if (stats && promptTokens && completionTokens) {
|
||||
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
|
||||
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
||||
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -3,6 +3,7 @@ import { type ModelProvider } from "../types";
|
||||
import inputSchema from "./codegen/input.schema.json";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
import frontendModelProvider from "./frontend";
|
||||
|
||||
const supportedModels = [
|
||||
"gpt-4-0613",
|
||||
@@ -11,7 +12,7 @@ const supportedModels = [
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
] as const;
|
||||
|
||||
type SupportedModel = (typeof supportedModels)[number];
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type OpenaiChatModelProvider = ModelProvider<
|
||||
SupportedModel,
|
||||
@@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider<
|
||||
>;
|
||||
|
||||
const modelProvider: OpenaiChatModelProvider = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
models: {
|
||||
"gpt-4-0613": {
|
||||
name: "GPT-4",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
name: "GPT-4 32k",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
name: "GPT-3.5 Turbo",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
name: "GPT-3.5 Turbo 16k",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
},
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model as SupportedModel))
|
||||
return input.model as SupportedModel;
|
||||
@@ -55,8 +37,9 @@ const modelProvider: OpenaiChatModelProvider = {
|
||||
return null;
|
||||
},
|
||||
inputSchema: inputSchema as JSONSchema4,
|
||||
shouldStream: (input) => input.stream ?? false,
|
||||
canStream: true,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
|
||||
279
src/modelProviders/openai-ChatCompletion/refinementActions.ts
Normal file
279
src/modelProviders/openai-ChatCompletion/refinementActions.ts
Normal file
@@ -0,0 +1,279 @@
|
||||
import { TfiThought } from "react-icons/tfi";
|
||||
import { type RefinementAction } from "../types";
|
||||
import { VscJson } from "react-icons/vsc";
|
||||
|
||||
export const refinementActions: Record<string, RefinementAction> = {
|
||||
"Add chain of thought": {
|
||||
icon: VscJson,
|
||||
description: "Asking the model to plan its answer can increase accuracy.",
|
||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||
|
||||
This is what a prompt looks like before adding chain of thought:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
This is what one looks like after adding chain of thought:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
Here's another example:
|
||||
|
||||
Before:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
explanation: {
|
||||
type: "string",
|
||||
}
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
});
|
||||
|
||||
Add chain of thought to the original prompt.`,
|
||||
},
|
||||
"Convert to function call": {
|
||||
icon: TfiThought,
|
||||
description: "Use function calls to get output from the model in a more structured way.",
|
||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||
|
||||
This is what a prompt looks like before adding a function:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
This is what one looks like after adding a function:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Evaluate sentiment.",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: scenario.user_message,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "extract_sentiment",
|
||||
parameters: {
|
||||
type: "object", // parameters must always be an object with a properties key
|
||||
properties: { // properties key is required
|
||||
sentiment: {
|
||||
type: "string",
|
||||
description: "one of positive/negative/neutral",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "extract_sentiment",
|
||||
},
|
||||
});
|
||||
|
||||
Here's another example of adding a function:
|
||||
|
||||
Before:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Here is the title and body of a reddit post I am interested in:
|
||||
|
||||
title: \${scenario.title}
|
||||
body: \${scenario.body}
|
||||
|
||||
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Answer one integer between 1 and 3.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
});
|
||||
|
||||
Another example
|
||||
|
||||
Before:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "write_in_language",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
text: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "write_in_language",
|
||||
},
|
||||
});
|
||||
|
||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||
},
|
||||
};
|
||||
@@ -1,7 +1,39 @@
|
||||
import { type ReplicateLlama2Provider } from ".";
|
||||
import { type ModelProviderFrontend } from "../types";
|
||||
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
import { refinementActions } from "./refinementActions";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||
name: "Replicate Llama2",
|
||||
|
||||
models: {
|
||||
"7b-chat": {
|
||||
name: "LLama 2 7B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0023,
|
||||
speed: "fast",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
|
||||
},
|
||||
"13b-chat": {
|
||||
name: "LLama 2 13B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0023,
|
||||
speed: "medium",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
|
||||
},
|
||||
"70b-chat": {
|
||||
name: "LLama 2 70B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0032,
|
||||
speed: "slow",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
|
||||
},
|
||||
},
|
||||
|
||||
refinementActions,
|
||||
|
||||
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
||||
normalizeOutput: (output) => {
|
||||
return {
|
||||
type: "text",
|
||||
@@ -10,4 +42,4 @@ const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
||||
},
|
||||
};
|
||||
|
||||
export default modelProviderFrontend;
|
||||
export default frontendModelProvider;
|
||||
|
||||
@@ -8,9 +8,9 @@ const replicate = new Replicate({
|
||||
});
|
||||
|
||||
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
||||
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||
"7b-chat": "5ec5fdadd80ace49f5a2b2178cceeb9f2f77c493b85b1131002c26e6b2b13184",
|
||||
"13b-chat": "6b4da803a2382c08868c5af10a523892f38e2de1aafb2ee55b020d9efef2fdb8",
|
||||
"70b-chat": "2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48",
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
@@ -19,7 +19,7 @@ export async function getCompletion(
|
||||
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||
const start = Date.now();
|
||||
|
||||
const { model, stream, ...rest } = input;
|
||||
const { model, ...rest } = input;
|
||||
|
||||
try {
|
||||
const prediction = await replicate.predictions.create({
|
||||
@@ -27,8 +27,6 @@ export async function getCompletion(
|
||||
input: rest,
|
||||
});
|
||||
|
||||
console.log("stream?", onStream);
|
||||
|
||||
const interval = onStream
|
||||
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||
setInterval(async () => {
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { type ModelProvider } from "../types";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
|
||||
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||
|
||||
type SupportedModel = (typeof supportedModels)[number];
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type ReplicateLlama2Input = {
|
||||
model: SupportedModel;
|
||||
prompt: string;
|
||||
stream?: boolean;
|
||||
max_length?: number;
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
@@ -25,12 +25,6 @@ export type ReplicateLlama2Provider = ModelProvider<
|
||||
>;
|
||||
|
||||
const modelProvider: ReplicateLlama2Provider = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
models: {
|
||||
"7b-chat": {},
|
||||
"13b-chat": {},
|
||||
"70b-chat": {},
|
||||
},
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model)) return input.model;
|
||||
|
||||
@@ -43,32 +37,45 @@ const modelProvider: ReplicateLlama2Provider = {
|
||||
type: "string",
|
||||
enum: supportedModels as unknown as string[],
|
||||
},
|
||||
system_prompt: {
|
||||
type: "string",
|
||||
description:
|
||||
"System prompt to send to Llama v2. This is prepended to the prompt and helps guide system behavior.",
|
||||
},
|
||||
prompt: {
|
||||
type: "string",
|
||||
description: "Prompt to send to Llama v2.",
|
||||
},
|
||||
stream: {
|
||||
type: "boolean",
|
||||
},
|
||||
max_length: {
|
||||
max_new_tokens: {
|
||||
type: "number",
|
||||
description:
|
||||
"Maximum number of tokens to generate. A word is generally 2-3 tokens (minimum: 1)",
|
||||
},
|
||||
temperature: {
|
||||
type: "number",
|
||||
description:
|
||||
"Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value. (minimum: 0.01; maximum: 5)",
|
||||
},
|
||||
top_p: {
|
||||
type: "number",
|
||||
description:
|
||||
"When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens (minimum: 0.01; maximum: 1)",
|
||||
},
|
||||
repetition_penalty: {
|
||||
type: "number",
|
||||
description:
|
||||
"Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it. (minimum: 0.01; maximum: 5)",
|
||||
},
|
||||
debug: {
|
||||
type: "boolean",
|
||||
description: "provide debugging output in logs",
|
||||
},
|
||||
},
|
||||
required: ["model", "prompt"],
|
||||
},
|
||||
shouldStream: (input) => input.stream ?? false,
|
||||
canStream: true,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
|
||||
3
src/modelProviders/replicate-llama2/refinementActions.ts
Normal file
3
src/modelProviders/replicate-llama2/refinementActions.ts
Normal file
@@ -0,0 +1,3 @@
|
||||
import { type RefinementAction } from "../types";
|
||||
|
||||
export const refinementActions: Record<string, RefinementAction> = {};
|
||||
@@ -1,9 +1,37 @@
|
||||
import { type JSONSchema4 } from "json-schema";
|
||||
import { type IconType } from "react-icons";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { z } from "zod";
|
||||
|
||||
type ModelProviderModel = {
|
||||
name?: string;
|
||||
learnMore?: string;
|
||||
export const ZodSupportedProvider = z.union([
|
||||
z.literal("openai/ChatCompletion"),
|
||||
z.literal("replicate/llama2"),
|
||||
]);
|
||||
|
||||
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||
|
||||
export type Model = {
|
||||
name: string;
|
||||
contextWindow: number;
|
||||
promptTokenPrice?: number;
|
||||
completionTokenPrice?: number;
|
||||
pricePerSecond?: number;
|
||||
speed: "fast" | "medium" | "slow";
|
||||
provider: SupportedProvider;
|
||||
description?: string;
|
||||
learnMoreUrl?: string;
|
||||
};
|
||||
|
||||
export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };
|
||||
|
||||
export type RefinementAction = { icon?: IconType; description: string; instructions: string };
|
||||
|
||||
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||
name: string;
|
||||
models: Record<SupportedModels, Model>;
|
||||
refinementActions?: Record<string, RefinementAction>;
|
||||
|
||||
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||
};
|
||||
|
||||
export type CompletionResponse<T> =
|
||||
@@ -19,10 +47,8 @@ export type CompletionResponse<T> =
|
||||
};
|
||||
|
||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||
name: string;
|
||||
models: Record<SupportedModels, ModelProviderModel>;
|
||||
getModel: (input: InputSchema) => SupportedModels | null;
|
||||
shouldStream: (input: InputSchema) => boolean;
|
||||
canStream: boolean;
|
||||
inputSchema: JSONSchema4;
|
||||
getCompletion: (
|
||||
input: InputSchema,
|
||||
@@ -31,7 +57,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
|
||||
|
||||
// This is just a convenience for type inference, don't use it at runtime
|
||||
_outputSchema?: OutputSchema | null;
|
||||
};
|
||||
} & FrontendModelProvider<SupportedModels, OutputSchema>;
|
||||
|
||||
export type NormalizedOutput =
|
||||
| {
|
||||
@@ -42,7 +68,3 @@ export type NormalizedOutput =
|
||||
type: "json";
|
||||
value: JsonValue;
|
||||
};
|
||||
|
||||
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
|
||||
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
|
||||
};
|
||||
|
||||
@@ -7,6 +7,8 @@ import "~/utils/analytics";
|
||||
import Head from "next/head";
|
||||
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
||||
import { SyncAppStore } from "~/state/sync";
|
||||
import NextAdapterApp from "next-query-params/app";
|
||||
import { QueryParamProvider } from "use-query-params";
|
||||
|
||||
const MyApp: AppType<{ session: Session | null }> = ({
|
||||
Component,
|
||||
@@ -24,7 +26,9 @@ const MyApp: AppType<{ session: Session | null }> = ({
|
||||
<SyncAppStore />
|
||||
<Favicon />
|
||||
<ChakraThemeProvider>
|
||||
<Component {...pageProps} />
|
||||
<QueryParamProvider adapter={NextAdapterApp}>
|
||||
<Component {...pageProps} />
|
||||
</QueryParamProvider>
|
||||
</ChakraThemeProvider>
|
||||
</SessionProvider>
|
||||
</>
|
||||
|
||||
@@ -2,106 +2,37 @@ import {
|
||||
Box,
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
Button,
|
||||
Center,
|
||||
Flex,
|
||||
Icon,
|
||||
Input,
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogContent,
|
||||
AlertDialogOverlay,
|
||||
useDisclosure,
|
||||
Text,
|
||||
HStack,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { BsGearFill, BsTrash } from "react-icons/bs";
|
||||
import { useState, useEffect } from "react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import OutputsTable from "~/components/OutputsTable";
|
||||
import SettingsDrawer from "~/components/OutputsTable/SettingsDrawer";
|
||||
import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { useSyncVariantEditor } from "~/state/sync";
|
||||
|
||||
const DeleteButton = () => {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.experiments.delete.useMutation();
|
||||
const utils = api.useContext();
|
||||
const router = useRouter();
|
||||
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
await mutation.mutateAsync({ id: experiment.data.id });
|
||||
await utils.experiments.list.invalidate();
|
||||
await router.push({ pathname: "/experiments" });
|
||||
onClose();
|
||||
}, [mutation, experiment.data?.id, router]);
|
||||
|
||||
useEffect(() => {
|
||||
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={{ base: "outline", lg: "ghost" }}
|
||||
colorScheme="gray"
|
||||
fontWeight="normal"
|
||||
onClick={onOpen}
|
||||
>
|
||||
<Icon as={BsTrash} boxSize={4} color="gray.600" />
|
||||
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
||||
Delete Experiment
|
||||
</Text>
|
||||
</Button>
|
||||
|
||||
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete Experiment
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||
as well. Are you sure?
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
import { HeaderButtons } from "~/components/experiments/HeaderButtons/HeaderButtons";
|
||||
|
||||
export default function Experiment() {
|
||||
const router = useRouter();
|
||||
const experiment = useExperiment();
|
||||
const utils = api.useContext();
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
useSyncVariantEditor();
|
||||
|
||||
useEffect(() => {
|
||||
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
||||
});
|
||||
|
||||
const [label, setLabel] = useState(experiment.data?.label || "");
|
||||
useEffect(() => {
|
||||
setLabel(experiment.data?.label || "");
|
||||
@@ -138,7 +69,7 @@ export default function Experiment() {
|
||||
py={2}
|
||||
w="full"
|
||||
direction={{ base: "column", sm: "row" }}
|
||||
alignItems="flex-start"
|
||||
alignItems={{ base: "flex-start", sm: "center" }}
|
||||
>
|
||||
<Breadcrumb flex={1}>
|
||||
<BreadcrumbItem>
|
||||
@@ -171,25 +102,9 @@ export default function Experiment() {
|
||||
)}
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
{canModify && (
|
||||
<HStack>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={{ base: "outline", lg: "ghost" }}
|
||||
colorScheme="gray"
|
||||
fontWeight="normal"
|
||||
onClick={openDrawer}
|
||||
>
|
||||
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
|
||||
<Text display={{ base: "none", lg: "block" }} ml={2}>
|
||||
Edit Vars & Evals
|
||||
</Text>
|
||||
</Button>
|
||||
<DeleteButton />
|
||||
</HStack>
|
||||
)}
|
||||
<HeaderButtons />
|
||||
</Flex>
|
||||
<SettingsDrawer />
|
||||
<ExperimentSettingsDrawer />
|
||||
<Box w="100%" overflowX="auto" flex={1}>
|
||||
<OutputsTable experimentId={router.query.id as string | undefined} />
|
||||
</Box>
|
||||
|
||||
@@ -13,29 +13,36 @@ import {
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
|
||||
import {
|
||||
ExperimentCard,
|
||||
ExperimentCardSkeleton,
|
||||
NewExperimentCard,
|
||||
} from "~/components/experiments/ExperimentCard";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
|
||||
export default function ExperimentsPage() {
|
||||
const experiments = api.experiments.list.useQuery();
|
||||
|
||||
const user = useSession().data;
|
||||
const authLoading = useSession().status === "loading";
|
||||
|
||||
if (user === null) {
|
||||
if (user === null || authLoading) {
|
||||
return (
|
||||
<AppShell title="Experiments">
|
||||
<Center h="100%">
|
||||
<Text>
|
||||
<Link
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
textDecor="underline"
|
||||
>
|
||||
Sign in
|
||||
</Link>{" "}
|
||||
to view or create new experiments!
|
||||
</Text>
|
||||
{!authLoading && (
|
||||
<Text>
|
||||
<Link
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
textDecor="underline"
|
||||
>
|
||||
Sign in
|
||||
</Link>{" "}
|
||||
to view or create new experiments!
|
||||
</Text>
|
||||
)}
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
@@ -44,7 +51,7 @@ export default function ExperimentsPage() {
|
||||
return (
|
||||
<AppShell title="Experiments">
|
||||
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||
<HStack minH={8} align="center">
|
||||
<HStack minH={8} align="center" pt={2}>
|
||||
<Breadcrumb flex={1}>
|
||||
<BreadcrumbItem>
|
||||
<Flex alignItems="center">
|
||||
@@ -55,7 +62,15 @@ export default function ExperimentsPage() {
|
||||
</HStack>
|
||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||
<NewExperimentCard />
|
||||
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
|
||||
{experiments.data && !experiments.isLoading ? (
|
||||
experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)
|
||||
) : (
|
||||
<>
|
||||
<ExperimentCardSkeleton />
|
||||
<ExperimentCardSkeleton />
|
||||
<ExperimentCardSkeleton />
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</AppShell>
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { z } from "zod";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
@@ -20,7 +22,7 @@ export const experimentsRouter = createTRPCRouter({
|
||||
const experiments = await prisma.experiment.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
OrganizationUser: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
@@ -77,6 +79,189 @@ export const experimentsRouter = createTRPCRouter({
|
||||
};
|
||||
}),
|
||||
|
||||
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
|
||||
const [
|
||||
existingExp,
|
||||
existingVariants,
|
||||
existingScenarios,
|
||||
existingCells,
|
||||
evaluations,
|
||||
templateVariables,
|
||||
] = await prisma.$transaction([
|
||||
prisma.experiment.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
promptVariant: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: {
|
||||
include: {
|
||||
outputEvaluations: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
const newExperimentId = uuidv4();
|
||||
|
||||
const existingToNewVariantIds = new Map<string, string>();
|
||||
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
|
||||
for (const variant of existingVariants) {
|
||||
const newVariantId = uuidv4();
|
||||
existingToNewVariantIds.set(variant.id, newVariantId);
|
||||
variantsToCreate.push({
|
||||
...variant,
|
||||
id: newVariantId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewScenarioIds = new Map<string, string>();
|
||||
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
|
||||
for (const scenario of existingScenarios) {
|
||||
const newScenarioId = uuidv4();
|
||||
existingToNewScenarioIds.set(scenario.id, newScenarioId);
|
||||
scenariosToCreate.push({
|
||||
...scenario,
|
||||
id: newScenarioId,
|
||||
experimentId: newExperimentId,
|
||||
variableValues: scenario.variableValues as Prisma.InputJsonValue,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewEvaluationIds = new Map<string, string>();
|
||||
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
|
||||
for (const evaluation of evaluations) {
|
||||
const newEvaluationId = uuidv4();
|
||||
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
|
||||
evaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: newEvaluationId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
|
||||
const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = [];
|
||||
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
|
||||
for (const cell of existingCells) {
|
||||
const newCellId = uuidv4();
|
||||
const { modelOutput, ...cellData } = cell;
|
||||
cellsToCreate.push({
|
||||
...cellData,
|
||||
id: newCellId,
|
||||
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
|
||||
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
|
||||
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
if (modelOutput) {
|
||||
const newModelOutputId = uuidv4();
|
||||
const { outputEvaluations, ...modelOutputData } = modelOutput;
|
||||
modelOutputsToCreate.push({
|
||||
...modelOutputData,
|
||||
id: newModelOutputId,
|
||||
scenarioVariantCellId: newCellId,
|
||||
output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const evaluation of outputEvaluations) {
|
||||
outputEvaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: uuidv4(),
|
||||
modelOutputId: newModelOutputId,
|
||||
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
|
||||
for (const templateVariable of templateVariables) {
|
||||
templateVariablesToCreate.push({
|
||||
...templateVariable,
|
||||
id: uuidv4(),
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
await prisma.$transaction([
|
||||
prisma.experiment.create({
|
||||
data: {
|
||||
id: newExperimentId,
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `${existingExp.label} (forked)`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.createMany({
|
||||
data: variantsToCreate,
|
||||
}),
|
||||
prisma.testScenario.createMany({
|
||||
data: scenariosToCreate,
|
||||
}),
|
||||
prisma.scenarioVariantCell.createMany({
|
||||
data: cellsToCreate,
|
||||
}),
|
||||
prisma.modelOutput.createMany({
|
||||
data: modelOutputsToCreate,
|
||||
}),
|
||||
prisma.evaluation.createMany({
|
||||
data: evaluationsToCreate,
|
||||
}),
|
||||
prisma.outputEvaluation.createMany({
|
||||
data: outputEvaluationsToCreate,
|
||||
}),
|
||||
prisma.templateVariable.createMany({
|
||||
data: templateVariablesToCreate,
|
||||
}),
|
||||
]);
|
||||
|
||||
return newExperimentId;
|
||||
}),
|
||||
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
@@ -98,7 +283,7 @@ export const experimentsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
const [variant, _, scenario] = await prisma.$transaction([
|
||||
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||
prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
@@ -121,7 +306,7 @@ export const experimentsRouter = createTRPCRouter({
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
||||
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});`,
|
||||
@@ -133,20 +318,38 @@ export const experimentsRouter = createTRPCRouter({
|
||||
prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "text",
|
||||
label: "language",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
text: "This is a test scenario.",
|
||||
language: "English",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "Spanish",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "German",
|
||||
},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
await generateNewCell(variant.id, scenario.id);
|
||||
await generateNewCell(variant.id, scenario1.id);
|
||||
await generateNewCell(variant.id, scenario2.id);
|
||||
await generateNewCell(variant.id, scenario3.id);
|
||||
|
||||
return exp;
|
||||
}),
|
||||
|
||||
@@ -2,7 +2,6 @@ import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
@@ -10,6 +9,8 @@ import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { ZodSupportedProvider } from "~/modelProviders/types";
|
||||
|
||||
export const promptVariantsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
@@ -144,7 +145,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
variantId: z.string().optional(),
|
||||
newModel: z.string().optional(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
@@ -186,10 +187,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(
|
||||
originalVariant,
|
||||
input.newModel as SupportedModel,
|
||||
);
|
||||
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
@@ -221,7 +219,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id);
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return newVariant;
|
||||
@@ -284,11 +284,17 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
getRefinedPromptFn: protectedProcedure
|
||||
getModifiedPromptFn: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
instructions: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
newModel: z
|
||||
.object({
|
||||
provider: ZodSupportedProvider,
|
||||
model: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
@@ -305,11 +311,11 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
return userError(constructedPrompt.error);
|
||||
}
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(
|
||||
existing,
|
||||
constructedPrompt.model as SupportedModel,
|
||||
input.instructions,
|
||||
);
|
||||
const model = input.newModel
|
||||
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||
: undefined;
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||
|
||||
// TODO: Validate promptConstructionFn
|
||||
// TODO: Record in some sort of history
|
||||
@@ -322,6 +328,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
z.object({
|
||||
id: z.string(),
|
||||
constructFn: z.string(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
@@ -379,7 +386,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id);
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return { status: "ok" } as const;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { queueQueryModel } from "~/server/tasks/queryModel.task";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
@@ -29,7 +29,7 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
include: {
|
||||
modelOutput: {
|
||||
include: {
|
||||
outputEvaluation: {
|
||||
outputEvaluations: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
@@ -62,14 +62,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
include: { modelOutput: true },
|
||||
});
|
||||
|
||||
if (!cell) {
|
||||
await generateNewCell(input.variantId, input.scenarioId);
|
||||
return true;
|
||||
await generateNewCell(input.variantId, input.scenarioId, { stream: true });
|
||||
return;
|
||||
}
|
||||
|
||||
if (cell.modelOutput) {
|
||||
@@ -79,12 +77,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
});
|
||||
}
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "PENDING" },
|
||||
});
|
||||
|
||||
await queueLLMRetrievalTask(cell.id);
|
||||
return true;
|
||||
await queueQueryModel(cell.id, true);
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -7,21 +7,39 @@ import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
const PAGE_SIZE = 10;
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.input(z.object({ experimentId: z.string(), page: z.number() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.testScenario.findMany({
|
||||
const { experimentId, page } = input;
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
orderBy: { sortIndex: "asc" },
|
||||
skip: (page - 1) * PAGE_SIZE,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
const count = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
scenarios,
|
||||
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||
count,
|
||||
};
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
@@ -34,22 +52,21 @@ export const scenariosRouter = createTRPCRouter({
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.testScenario.aggregate({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max.sortIndex ?? 0;
|
||||
await prisma.testScenario.updateMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
data: {
|
||||
sortIndex: {
|
||||
increment: 1,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const createNewScenarioAction = prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
sortIndex: maxSortIndex + 1,
|
||||
sortIndex: 0,
|
||||
variableValues: input.autogenerate
|
||||
? await autogenerateScenarioValues(input.experimentId)
|
||||
: {},
|
||||
@@ -69,7 +86,7 @@ export const scenariosRouter = createTRPCRouter({
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, scenario.id);
|
||||
await generateNewCell(variant.id, scenario.id, { stream: true });
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -213,7 +230,7 @@ export const scenariosRouter = createTRPCRouter({
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, newScenario.id);
|
||||
await generateNewCell(variant.id, newScenario.id, { stream: true });
|
||||
}
|
||||
|
||||
return newScenario;
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import { prisma } from "~/server/db";
|
||||
import defineTask from "./defineTask";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import parseConstructFn from "../utils/parseConstructFn";
|
||||
import hashPrompt from "../utils/hashPrompt";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { prisma } from "~/server/db";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import hashPrompt from "../utils/hashPrompt";
|
||||
import parseConstructFn from "../utils/parseConstructFn";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import defineTask from "./defineTask";
|
||||
|
||||
export type queryLLMJob = {
|
||||
scenarioVariantCellId: string;
|
||||
export type QueryModelJob = {
|
||||
cellId: string;
|
||||
stream: boolean;
|
||||
};
|
||||
|
||||
const MAX_AUTO_RETRIES = 10;
|
||||
@@ -24,15 +24,16 @@ function calculateDelay(numPreviousTries: number): number {
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
const { scenarioVariantCellId } = task;
|
||||
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
||||
console.log("RUNNING TASK", task);
|
||||
const { cellId, stream } = task;
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
include: { modelOutput: true },
|
||||
});
|
||||
if (!cell) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Cell not found",
|
||||
@@ -47,7 +48,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
return;
|
||||
}
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "IN_PROGRESS",
|
||||
},
|
||||
@@ -58,7 +59,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
});
|
||||
if (!variant) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Prompt Variant not found",
|
||||
@@ -73,7 +74,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
});
|
||||
if (!scenario) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Scenario not found",
|
||||
@@ -87,7 +88,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
|
||||
if ("error" in prompt) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 400,
|
||||
errorMessage: prompt.error,
|
||||
@@ -99,34 +100,22 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
|
||||
const provider = modelProviders[prompt.modelProvider];
|
||||
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||
|
||||
if (streamingChannel) {
|
||||
// Save streaming channel so that UI can connect to it
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: { streamingChannel },
|
||||
});
|
||||
}
|
||||
const onStream = streamingChannel
|
||||
const onStream = stream
|
||||
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
||||
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
|
||||
wsConnection.emit("message", { channel: cell.id, payload: partialOutput });
|
||||
}
|
||||
: null;
|
||||
|
||||
for (let i = 0; true; i++) {
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
const modelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
scenarioVariantCellId,
|
||||
scenarioVariantCellId: cellId,
|
||||
inputHash,
|
||||
output: response.value as unknown as Prisma.InputJsonObject,
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
timeToComplete: response.timeToComplete,
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
@@ -135,7 +124,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
});
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: response.statusCode,
|
||||
retrievalStatus: "COMPLETE",
|
||||
@@ -149,12 +138,12 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
const delay = calculateDelay(i);
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
errorMessage: response.message,
|
||||
statusCode: response.statusCode,
|
||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -166,3 +155,21 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||
console.log("queueQueryModel", cellId, stream);
|
||||
await Promise.all([
|
||||
prisma.scenarioVariantCell.update({
|
||||
where: {
|
||||
id: cellId,
|
||||
},
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
},
|
||||
}),
|
||||
|
||||
await queryModel.enqueue({ cellId, stream }),
|
||||
console.log("queued"),
|
||||
]);
|
||||
};
|
||||
@@ -2,39 +2,27 @@ import { type TaskList, run } from "graphile-worker";
|
||||
import "dotenv/config";
|
||||
|
||||
import { env } from "~/env.mjs";
|
||||
import { queryLLM } from "./queryLLM.task";
|
||||
import { queryModel } from "./queryModel.task";
|
||||
|
||||
const registeredTasks = [queryLLM];
|
||||
console.log("Starting worker");
|
||||
|
||||
const registeredTasks = [queryModel];
|
||||
|
||||
const taskList = registeredTasks.reduce((acc, task) => {
|
||||
acc[task.task.identifier] = task.task.handler;
|
||||
return acc;
|
||||
}, {} as TaskList);
|
||||
|
||||
async function main() {
|
||||
// Run a worker to execute jobs:
|
||||
const runner = await run({
|
||||
connectionString: env.DATABASE_URL,
|
||||
concurrency: 20,
|
||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||
noHandleSignals: false,
|
||||
pollInterval: 1000,
|
||||
// you can set the taskList or taskDirectory but not both
|
||||
taskList,
|
||||
// or:
|
||||
// taskDirectory: `${__dirname}/tasks`,
|
||||
});
|
||||
|
||||
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
||||
// "unhandled rejection" errors causing a process crash in the event of
|
||||
// something going wrong.
|
||||
await runner.promise;
|
||||
|
||||
// If the worker exits (whether through fatal error or otherwise), the above
|
||||
// promise will resolve/reject.
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Unhandled error occurred running worker: ", err);
|
||||
process.exit(1);
|
||||
// Run a worker to execute jobs:
|
||||
const runner = await run({
|
||||
connectionString: env.DATABASE_URL,
|
||||
concurrency: 20,
|
||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||
noHandleSignals: false,
|
||||
pollInterval: 1000,
|
||||
taskList,
|
||||
});
|
||||
|
||||
console.log("Worker successfully started");
|
||||
|
||||
await runner.promise;
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
export enum OpenAIChatModel {
|
||||
"gpt-4" = "gpt-4",
|
||||
"gpt-4-0613" = "gpt-4-0613",
|
||||
"gpt-4-32k" = "gpt-4-32k",
|
||||
"gpt-4-32k-0613" = "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
||||
}
|
||||
|
||||
export type SupportedModel = keyof typeof OpenAIChatModel;
|
||||
@@ -1,18 +1,18 @@
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { type SupportedModel } from "../types";
|
||||
import ivm from "isolated-vm";
|
||||
import dedent from "dedent";
|
||||
import { openai } from "./openai";
|
||||
import { getApiShapeForModel } from "./getTypesForModel";
|
||||
import { isObject } from "lodash-es";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
||||
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||
import { type SupportedProvider, type Model } from "~/modelProviders/types";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function deriveNewConstructFn(
|
||||
originalVariant: PromptVariant | null,
|
||||
newModel?: SupportedModel,
|
||||
newModel?: Model,
|
||||
instructions?: string,
|
||||
) {
|
||||
if (originalVariant && !newModel && !instructions) {
|
||||
@@ -36,10 +36,11 @@ export async function deriveNewConstructFn(
|
||||
const NUM_RETRIES = 5;
|
||||
const requestUpdatedPromptFunction = async (
|
||||
originalVariant: PromptVariant,
|
||||
newModel?: SupportedModel,
|
||||
newModel?: Model,
|
||||
instructions?: string,
|
||||
) => {
|
||||
const originalModel = originalVariant.model as SupportedModel;
|
||||
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
|
||||
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
|
||||
let newContructionFn = "";
|
||||
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||
try {
|
||||
@@ -47,17 +48,38 @@ const requestUpdatedPromptFunction = async (
|
||||
{
|
||||
role: "system",
|
||||
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
||||
getApiShapeForModel(originalModel),
|
||||
originalModelProvider.inputSchema,
|
||||
null,
|
||||
2,
|
||||
)}\n\nDo not add any assistant messages.`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
|
||||
},
|
||||
];
|
||||
if (newModel) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
|
||||
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
|
||||
});
|
||||
if (newModel.provider !== originalModel.provider) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `The old provider was ${originalModel.provider}. The new provider is ${
|
||||
newModel.provider
|
||||
}. Here is the schema for the new model:\n---\n${JSON.stringify(
|
||||
modelProviders[newModel.provider].inputSchema,
|
||||
null,
|
||||
2,
|
||||
)}`,
|
||||
});
|
||||
} else {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `The provider is the same as the old provider: ${originalModel.provider}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (instructions) {
|
||||
messages.push({
|
||||
@@ -65,10 +87,6 @@ const requestUpdatedPromptFunction = async (
|
||||
content: instructions,
|
||||
});
|
||||
}
|
||||
messages.push({
|
||||
role: "system",
|
||||
content: "The prompt variable has already been declared, so do not declare it again.",
|
||||
});
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages,
|
||||
|
||||
@@ -56,7 +56,7 @@ export const runAllEvals = async (experimentId: string) => {
|
||||
testScenario: true,
|
||||
},
|
||||
},
|
||||
outputEvaluation: true,
|
||||
outputEvaluations: true,
|
||||
},
|
||||
});
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
@@ -66,7 +66,7 @@ export const runAllEvals = async (experimentId: string) => {
|
||||
await Promise.all(
|
||||
outputs.map(async (output) => {
|
||||
const unrunEvals = evals.filter(
|
||||
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
||||
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||
import parseConstructFn from "./parseConstructFn";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import hashPrompt from "./hashPrompt";
|
||||
import { omit } from "lodash-es";
|
||||
import { queueQueryModel } from "../tasks/queryModel.task";
|
||||
|
||||
export const generateNewCell = async (
|
||||
variantId: string,
|
||||
scenarioId: string,
|
||||
options?: { stream?: boolean },
|
||||
): Promise<void> => {
|
||||
const stream = options?.stream ?? false;
|
||||
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: variantId,
|
||||
@@ -18,7 +25,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
if (!variant || !scenario) return;
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
@@ -32,7 +39,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
},
|
||||
});
|
||||
|
||||
if (cell) return cell;
|
||||
if (cell) return;
|
||||
|
||||
const parsedConstructFn = await parseConstructFn(
|
||||
variant.constructFn,
|
||||
@@ -40,7 +47,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
);
|
||||
|
||||
if ("error" in parsedConstructFn) {
|
||||
return await prisma.scenarioVariantCell.create({
|
||||
await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
@@ -49,6 +56,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const inputHash = hashPrompt(parsedConstructFn);
|
||||
@@ -69,29 +77,33 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
where: { inputHash },
|
||||
});
|
||||
|
||||
let newModelOutput;
|
||||
|
||||
if (matchingModelOutput) {
|
||||
newModelOutput = await prisma.modelOutput.create({
|
||||
const newModelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
...omit(matchingModelOutput, ["id"]),
|
||||
scenarioVariantCellId: cell.id,
|
||||
inputHash,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
timeToComplete: matchingModelOutput.timeToComplete,
|
||||
cost: matchingModelOutput.cost,
|
||||
promptTokens: matchingModelOutput.promptTokens,
|
||||
completionTokens: matchingModelOutput.completionTokens,
|
||||
createdAt: matchingModelOutput.createdAt,
|
||||
updatedAt: matchingModelOutput.updatedAt,
|
||||
},
|
||||
});
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "COMPLETE" },
|
||||
});
|
||||
} else {
|
||||
cell = await queueLLMRetrievalTask(cell.id);
|
||||
}
|
||||
|
||||
return { ...cell, modelOutput: newModelOutput };
|
||||
// Copy over all eval results as well
|
||||
await Promise.all(
|
||||
(
|
||||
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
|
||||
).map(async (evaluation) => {
|
||||
await prisma.outputEvaluation.create({
|
||||
data: {
|
||||
...omit(evaluation, ["id"]),
|
||||
modelOutputId: newModelOutput.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
await queueQueryModel(cell.id, stream);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
import { type SupportedModel } from "../types";
|
||||
|
||||
export const getApiShapeForModel = (model: SupportedModel) => {
|
||||
// if (model in OpenAIChatModel) return openAIChatApiShape;
|
||||
return "";
|
||||
};
|
||||
@@ -70,7 +70,6 @@ export default async function parseConstructFn(
|
||||
// We've validated the JSON schema so this should be safe
|
||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
const model = provider.getModel(input);
|
||||
if (!model) {
|
||||
return {
|
||||
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
|
||||
|
||||
return {
|
||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
|
||||
model,
|
||||
modelInput: input,
|
||||
};
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { prisma } from "../db";
|
||||
import { queryLLM } from "../tasks/queryLLM.task";
|
||||
|
||||
export const queueLLMRetrievalTask = async (cellId: string) => {
|
||||
const updatedCell = await prisma.scenarioVariantCell.update({
|
||||
where: {
|
||||
id: cellId,
|
||||
},
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
// @ts-expect-error we aren't passing the helpers but that's ok
|
||||
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
||||
|
||||
return updatedCell;
|
||||
};
|
||||
@@ -8,7 +8,7 @@ export default async function userOrg(userId: string) {
|
||||
update: {},
|
||||
create: {
|
||||
personalOrgUserId: userId,
|
||||
OrganizationUser: {
|
||||
organizationUsers: {
|
||||
create: {
|
||||
userId: userId,
|
||||
role: "ADMIN",
|
||||
|
||||
@@ -8,9 +8,9 @@ export const editorBackground = "#fafafa";
|
||||
export type SharedVariantEditorSlice = {
|
||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||
loadMonaco: () => Promise<void>;
|
||||
scenarios: RouterOutputs["scenarios"]["list"];
|
||||
scenarios: RouterOutputs["scenarios"]["list"]["scenarios"];
|
||||
updateScenariosModel: () => void;
|
||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]["scenarios"]) => void;
|
||||
};
|
||||
|
||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import { useEffect } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { useScenarios } from "~/utils/hooks";
|
||||
import { useAppStore } from "./store";
|
||||
|
||||
export function useSyncVariantEditor() {
|
||||
const experiment = useExperiment();
|
||||
const scenarios = api.scenarios.list.useQuery(
|
||||
{ experimentId: experiment.data?.id ?? "" },
|
||||
{ enabled: !!experiment.data?.id },
|
||||
);
|
||||
const scenarios = useScenarios();
|
||||
|
||||
useEffect(() => {
|
||||
if (scenarios.data) {
|
||||
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data);
|
||||
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data.scenarios);
|
||||
}
|
||||
}, [scenarios.data]);
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
|
||||
where: {
|
||||
id: experimentId,
|
||||
organization: {
|
||||
OrganizationUser: {
|
||||
organizationUsers: {
|
||||
some: {
|
||||
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
||||
userId,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { GPTTokens } from "gpt-tokens";
|
||||
import { type OpenAIChatModel } from "~/server/types";
|
||||
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
|
||||
|
||||
interface GPTTokensMessageItem {
|
||||
name?: string;
|
||||
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
|
||||
}
|
||||
|
||||
export const countOpenAIChatTokens = (
|
||||
model: keyof typeof OpenAIChatModel,
|
||||
model: SupportedModel,
|
||||
messages: ChatCompletion.Choice.Message[],
|
||||
) => {
|
||||
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
// generate random channel id
|
||||
|
||||
export const generateChannel = () => {
|
||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
||||
};
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useRouter } from "next/router";
|
||||
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { NumberParam, useQueryParam, withDefault } from "use-query-params";
|
||||
|
||||
export const useExperiment = () => {
|
||||
const router = useRouter();
|
||||
@@ -93,3 +94,17 @@ export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | un
|
||||
|
||||
return [ref, dimensions];
|
||||
};
|
||||
|
||||
export const usePage = () => useQueryParam("page", withDefault(NumberParam, 1));
|
||||
|
||||
export const useScenarios = () => {
|
||||
const experiment = useExperiment();
|
||||
const [page] = usePage();
|
||||
|
||||
return api.scenarios.list.useQuery(
|
||||
{ experimentId: experiment.data?.id ?? "", page },
|
||||
{ enabled: experiment.data?.id != null },
|
||||
);
|
||||
};
|
||||
|
||||
export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? [];
|
||||
|
||||
@@ -1 +1,12 @@
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { type ProviderModel } from "~/modelProviders/types";
|
||||
|
||||
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
||||
|
||||
export const lookupModel = (provider: string, model: string) => {
|
||||
const modelObj = frontendModelProviders[provider as ProviderModel["provider"]]?.models[model];
|
||||
return modelObj ? { ...modelObj, provider } : null;
|
||||
};
|
||||
|
||||
export const modelLabel = (provider: string, model: string) =>
|
||||
`${provider}/${lookupModel(provider, model)?.name ?? model}`;
|
||||
|
||||
Reference in New Issue
Block a user