Compare commits
31 Commits
model-prov
...
variant-le
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05a932ea74 | ||
|
|
85d42a014b | ||
|
|
7d1ded3b18 | ||
|
|
b00f6dd04b | ||
|
|
2e395e4d39 | ||
|
|
4b06d05908 | ||
|
|
aabf355b81 | ||
|
|
61e5f0775d | ||
|
|
cc1d1178da | ||
|
|
7466db63df | ||
|
|
79a0b03bf8 | ||
|
|
6fb7a82d72 | ||
|
|
4ea30a3ba3 | ||
|
|
52d1d5c7ee | ||
|
|
46036a44d2 | ||
|
|
3753fe5c16 | ||
|
|
213a00a8e6 | ||
|
|
af9943eefc | ||
|
|
741128e0f4 | ||
|
|
aff14539d8 | ||
|
|
1af81a50a9 | ||
|
|
7e1fbb3767 | ||
|
|
a5d972005e | ||
|
|
a180b5bef2 | ||
|
|
55c697223e | ||
|
|
9978075867 | ||
|
|
372c2512c9 | ||
|
|
1822fe198e | ||
|
|
f06e1db3db | ||
|
|
9314a86857 | ||
|
|
54dcb4a567 |
@@ -17,6 +17,9 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
|
|||||||
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
||||||
OPENAI_API_KEY=""
|
OPENAI_API_KEY=""
|
||||||
|
|
||||||
|
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
|
||||||
|
REPLICATE_API_TOKEN=""
|
||||||
|
|
||||||
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||||
|
|
||||||
# Next Auth
|
# Next Auth
|
||||||
|
|||||||
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@@ -1,6 +1,3 @@
|
|||||||
{
|
{
|
||||||
"eslint.format.enable": true,
|
"eslint.format.enable": true
|
||||||
"editor.codeActionsOnSave": {
|
|
||||||
"source.fixAll.eslint": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
|
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
|
||||||
|
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
|
||||||
|
|
||||||
## Running Locally
|
## Running Locally
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,7 @@
|
|||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
|
"next-query-params": "^4.2.3",
|
||||||
"nextjs-routes": "^2.0.1",
|
"nextjs-routes": "^2.0.1",
|
||||||
"openai": "4.0.0-beta.2",
|
"openai": "4.0.0-beta.2",
|
||||||
"pluralize": "^8.0.0",
|
"pluralize": "^8.0.0",
|
||||||
@@ -79,6 +80,7 @@
|
|||||||
"superjson": "1.12.2",
|
"superjson": "1.12.2",
|
||||||
"tsx": "^3.12.7",
|
"tsx": "^3.12.7",
|
||||||
"type-fest": "^4.0.0",
|
"type-fest": "^4.0.0",
|
||||||
|
"use-query-params": "^2.2.1",
|
||||||
"vite-tsconfig-paths": "^4.2.0",
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
"zod": "^3.21.4",
|
"zod": "^3.21.4",
|
||||||
"zustand": "^4.3.9"
|
"zustand": "^4.3.9"
|
||||||
|
|||||||
41
pnpm-lock.yaml
generated
41
pnpm-lock.yaml
generated
@@ -119,6 +119,9 @@ dependencies:
|
|||||||
next-auth:
|
next-auth:
|
||||||
specifier: ^4.22.1
|
specifier: ^4.22.1
|
||||||
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
next-query-params:
|
||||||
|
specifier: ^4.2.3
|
||||||
|
version: 4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1)
|
||||||
nextjs-routes:
|
nextjs-routes:
|
||||||
specifier: ^2.0.1
|
specifier: ^2.0.1
|
||||||
version: 2.0.1(next@13.4.2)
|
version: 2.0.1(next@13.4.2)
|
||||||
@@ -179,6 +182,9 @@ dependencies:
|
|||||||
type-fest:
|
type-fest:
|
||||||
specifier: ^4.0.0
|
specifier: ^4.0.0
|
||||||
version: 4.0.0
|
version: 4.0.0
|
||||||
|
use-query-params:
|
||||||
|
specifier: ^2.2.1
|
||||||
|
version: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||||
vite-tsconfig-paths:
|
vite-tsconfig-paths:
|
||||||
specifier: ^4.2.0
|
specifier: ^4.2.0
|
||||||
version: 4.2.0(typescript@5.0.4)
|
version: 4.2.0(typescript@5.0.4)
|
||||||
@@ -6037,6 +6043,19 @@ packages:
|
|||||||
uuid: 8.3.2
|
uuid: 8.3.2
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/next-query-params@4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1):
|
||||||
|
resolution: {integrity: sha512-hGNCYRH8YyA5ItiBGSKrtMl21b2MAqfPkdI1mvwloNVqSU142IaGzqHN+OTovyeLIpQfonY01y7BAHb/UH4POg==}
|
||||||
|
peerDependencies:
|
||||||
|
next: ^10.0.0 || ^11.0.0 || ^12.0.0 || ^13.0.0
|
||||||
|
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||||
|
use-query-params: ^2.0.0
|
||||||
|
dependencies:
|
||||||
|
next: 13.4.2(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
react: 18.2.0
|
||||||
|
tslib: 2.6.0
|
||||||
|
use-query-params: 2.2.1(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
dev: false
|
||||||
|
|
||||||
/next-tick@1.1.0:
|
/next-tick@1.1.0:
|
||||||
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
|
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
|
||||||
dev: false
|
dev: false
|
||||||
@@ -7147,6 +7166,10 @@ packages:
|
|||||||
randombytes: 2.1.0
|
randombytes: 2.1.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/serialize-query-params@2.0.2:
|
||||||
|
resolution: {integrity: sha512-1chMo1dST4pFA9RDXAtF0Rbjaut4is7bzFbI1Z26IuMub68pNCILku85aYmeFhvnY//BXUPUhoRMjYcsT93J/Q==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/serve-static@1.15.0:
|
/serve-static@1.15.0:
|
||||||
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
|
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
|
||||||
engines: {node: '>= 0.8.0'}
|
engines: {node: '>= 0.8.0'}
|
||||||
@@ -7824,6 +7847,24 @@ packages:
|
|||||||
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/use-query-params@2.2.1(react-dom@18.2.0)(react@18.2.0):
|
||||||
|
resolution: {integrity: sha512-i6alcyLB8w9i3ZK3caNftdb+UnbfBRNPDnc89CNQWkGRmDrm/gfydHvMBfVsQJRq3NoHOM2dt/ceBWG2397v1Q==}
|
||||||
|
peerDependencies:
|
||||||
|
'@reach/router': ^1.2.1
|
||||||
|
react: '>=16.8.0'
|
||||||
|
react-dom: '>=16.8.0'
|
||||||
|
react-router-dom: '>=5'
|
||||||
|
peerDependenciesMeta:
|
||||||
|
'@reach/router':
|
||||||
|
optional: true
|
||||||
|
react-router-dom:
|
||||||
|
optional: true
|
||||||
|
dependencies:
|
||||||
|
react: 18.2.0
|
||||||
|
react-dom: 18.2.0(react@18.2.0)
|
||||||
|
serialize-query-params: 2.0.2
|
||||||
|
dev: false
|
||||||
|
|
||||||
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
|
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
|
||||||
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
|
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
|
||||||
engines: {node: '>=10'}
|
engines: {node: '>=10'}
|
||||||
|
|||||||
@@ -7,9 +7,13 @@ const defaultId = "11111111-1111-1111-1111-111111111111";
|
|||||||
await prisma.organization.deleteMany({
|
await prisma.organization.deleteMany({
|
||||||
where: { id: defaultId },
|
where: { id: defaultId },
|
||||||
});
|
});
|
||||||
await prisma.organization.create({
|
|
||||||
data: { id: defaultId },
|
// If there's an existing org, just seed into it
|
||||||
});
|
const org =
|
||||||
|
(await prisma.organization.findFirst({})) ??
|
||||||
|
(await prisma.organization.create({
|
||||||
|
data: { id: defaultId },
|
||||||
|
}));
|
||||||
|
|
||||||
await prisma.experiment.deleteMany({
|
await prisma.experiment.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
@@ -21,7 +25,7 @@ await prisma.experiment.create({
|
|||||||
data: {
|
data: {
|
||||||
id: defaultId,
|
id: defaultId,
|
||||||
label: "Country Capitals Example",
|
label: "Country Capitals Example",
|
||||||
organizationId: defaultId,
|
organizationId: org.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -103,30 +107,41 @@ await prisma.testScenario.deleteMany({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const countries = [
|
||||||
|
"Afghanistan",
|
||||||
|
"Albania",
|
||||||
|
"Algeria",
|
||||||
|
"Andorra",
|
||||||
|
"Angola",
|
||||||
|
"Antigua and Barbuda",
|
||||||
|
"Argentina",
|
||||||
|
"Armenia",
|
||||||
|
"Australia",
|
||||||
|
"Austria",
|
||||||
|
"Austrian Empire",
|
||||||
|
"Azerbaijan",
|
||||||
|
"Baden",
|
||||||
|
"Bahamas, The",
|
||||||
|
"Bahrain",
|
||||||
|
"Bangladesh",
|
||||||
|
"Barbados",
|
||||||
|
"Bavaria",
|
||||||
|
"Belarus",
|
||||||
|
"Belgium",
|
||||||
|
"Belize",
|
||||||
|
"Benin (Dahomey)",
|
||||||
|
"Bolivia",
|
||||||
|
"Bosnia and Herzegovina",
|
||||||
|
"Botswana",
|
||||||
|
];
|
||||||
await prisma.testScenario.createMany({
|
await prisma.testScenario.createMany({
|
||||||
data: [
|
data: countries.map((country, i) => ({
|
||||||
{
|
experimentId: defaultId,
|
||||||
experimentId: defaultId,
|
sortIndex: i,
|
||||||
sortIndex: 0,
|
variableValues: {
|
||||||
variableValues: {
|
country: country,
|
||||||
country: "Spain",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
})),
|
||||||
experimentId: defaultId,
|
|
||||||
sortIndex: 1,
|
|
||||||
variableValues: {
|
|
||||||
country: "USA",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
experimentId: defaultId,
|
|
||||||
sortIndex: 2,
|
|
||||||
variableValues: {
|
|
||||||
country: "Chile",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const variants = await prisma.promptVariant.findMany({
|
const variants = await prisma.promptVariant.findMany({
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
||||||
import ResizeTextarea from "react-textarea-autosize";
|
import ResizeTextarea from "react-textarea-autosize";
|
||||||
import React from "react";
|
import React, { useLayoutEffect, useState } from "react";
|
||||||
|
|
||||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||||
HTMLTextAreaElement,
|
HTMLTextAreaElement,
|
||||||
TextareaProps & { minRows?: number }
|
TextareaProps & { minRows?: number }
|
||||||
> = (props, ref) => {
|
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
|
||||||
|
const [isRerendered, setIsRerendered] = useState(false);
|
||||||
|
useLayoutEffect(() => setIsRerendered(true), []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Textarea
|
<Textarea
|
||||||
minH="unset"
|
minH="unset"
|
||||||
overflow="hidden"
|
minRows={minRows}
|
||||||
|
overflowY={isRerendered ? overflowY : "hidden"}
|
||||||
w="100%"
|
w="100%"
|
||||||
resize="none"
|
resize="none"
|
||||||
ref={ref}
|
ref={ref}
|
||||||
minRows={1}
|
|
||||||
transition="height none"
|
transition="height none"
|
||||||
as={ResizeTextarea}
|
as={ResizeTextarea}
|
||||||
{...props}
|
{...props}
|
||||||
|
|||||||
135
src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
135
src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Spinner,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { ModelStatsCard } from "./ModelStatsCard";
|
||||||
|
import { ModelSearch } from "./ModelSearch";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { isObject, isString } from "lodash-es";
|
||||||
|
import { type Model, type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { keyForModel } from "~/utils/utils";
|
||||||
|
|
||||||
|
export const ChangeModelModal = ({
|
||||||
|
variant,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
onClose: () => void;
|
||||||
|
}) => {
|
||||||
|
const originalModelProviderName = variant.modelProvider as SupportedProvider;
|
||||||
|
const originalModelProvider = frontendModelProviders[originalModelProviderName];
|
||||||
|
const originalModel = originalModelProvider.models[variant.model] as Model;
|
||||||
|
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
|
||||||
|
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const experiment = useExperiment();
|
||||||
|
|
||||||
|
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
|
||||||
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
|
|
||||||
|
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment) return;
|
||||||
|
|
||||||
|
await getModifiedPromptMutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
newModel: selectedModel,
|
||||||
|
});
|
||||||
|
setConvertedModel(selectedModel);
|
||||||
|
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
|
||||||
|
|
||||||
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
|
|
||||||
|
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (
|
||||||
|
!variant.experimentId ||
|
||||||
|
!modifiedPromptFn ||
|
||||||
|
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
|
||||||
|
)
|
||||||
|
return;
|
||||||
|
await replaceVariantMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
constructFn: modifiedPromptFn,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
onClose();
|
||||||
|
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||||
|
|
||||||
|
const originalModelLabel = keyForModel(originalModel);
|
||||||
|
const selectedModelLabel = keyForModel(selectedModel);
|
||||||
|
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
isOpen
|
||||||
|
onClose={onClose}
|
||||||
|
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||||
|
>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w={1200}>
|
||||||
|
<ModalHeader>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={RiExchangeFundsFill} />
|
||||||
|
<Text>Change Model</Text>
|
||||||
|
</HStack>
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody maxW="unset">
|
||||||
|
<VStack spacing={8}>
|
||||||
|
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||||
|
{originalModelLabel !== selectedModelLabel && (
|
||||||
|
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||||
|
)}
|
||||||
|
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
|
{isString(modifiedPromptFn) && (
|
||||||
|
<CompareFunctions
|
||||||
|
originalFunction={variant.constructFn}
|
||||||
|
newFunction={modifiedPromptFn}
|
||||||
|
leftTitle={originalModelLabel}
|
||||||
|
rightTitle={convertedModelLabel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</ModalBody>
|
||||||
|
|
||||||
|
<ModalFooter>
|
||||||
|
<HStack>
|
||||||
|
<Button
|
||||||
|
colorScheme="gray"
|
||||||
|
onClick={getModifiedPromptFn}
|
||||||
|
minW={24}
|
||||||
|
isDisabled={originalModel === selectedModel || modificationInProgress}
|
||||||
|
>
|
||||||
|
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
colorScheme="blue"
|
||||||
|
onClick={replaceVariant}
|
||||||
|
minW={24}
|
||||||
|
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
|
||||||
|
>
|
||||||
|
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
50
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
50
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import { VStack, Text } from "@chakra-ui/react";
|
||||||
|
import { type LegacyRef, useCallback } from "react";
|
||||||
|
import Select, { type SingleValue } from "react-select";
|
||||||
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { type Model } from "~/modelProviders/types";
|
||||||
|
import { keyForModel } from "~/utils/utils";
|
||||||
|
|
||||||
|
const modelOptions: { label: string; value: Model }[] = [];
|
||||||
|
|
||||||
|
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
|
||||||
|
for (const [_, modelValue] of Object.entries(providerValue.models)) {
|
||||||
|
modelOptions.push({
|
||||||
|
label: keyForModel(modelValue),
|
||||||
|
value: modelValue,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ModelSearch = ({
|
||||||
|
selectedModel,
|
||||||
|
setSelectedModel,
|
||||||
|
}: {
|
||||||
|
selectedModel: Model;
|
||||||
|
setSelectedModel: (model: Model) => void;
|
||||||
|
}) => {
|
||||||
|
const handleSelection = useCallback(
|
||||||
|
(option: SingleValue<{ label: string; value: Model }>) => {
|
||||||
|
if (!option) return;
|
||||||
|
setSelectedModel(option.value);
|
||||||
|
},
|
||||||
|
[setSelectedModel],
|
||||||
|
);
|
||||||
|
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
|
||||||
|
|
||||||
|
const [containerRef, containerDimensions] = useElementDimensions();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||||
|
<Text>Browse Models</Text>
|
||||||
|
<Select
|
||||||
|
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||||
|
value={selectedOption}
|
||||||
|
options={modelOptions}
|
||||||
|
onChange={handleSelection}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -7,11 +7,9 @@ import {
|
|||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
Link,
|
Link,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { modelStats } from "~/modelProviders/modelStats";
|
import { type Model } from "~/modelProviders/types";
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
|
|
||||||
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
|
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
|
||||||
const stats = modelStats[model];
|
|
||||||
return (
|
return (
|
||||||
<VStack w="full" align="start">
|
<VStack w="full" align="start">
|
||||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||||
@@ -22,14 +20,14 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
|
|||||||
<HStack w="full" align="flex-start">
|
<HStack w="full" align="flex-start">
|
||||||
<Text flex={1} fontSize="lg">
|
<Text flex={1} fontSize="lg">
|
||||||
<Text as="span" color="gray.600">
|
<Text as="span" color="gray.600">
|
||||||
{stats.provider} /{" "}
|
{model.provider} /{" "}
|
||||||
</Text>
|
</Text>
|
||||||
<Text as="span" fontWeight="bold" color="gray.900">
|
<Text as="span" fontWeight="bold" color="gray.900">
|
||||||
{model}
|
{model.name}
|
||||||
</Text>
|
</Text>
|
||||||
</Text>
|
</Text>
|
||||||
<Link
|
<Link
|
||||||
href={stats.learnMoreUrl}
|
href={model.learnMoreUrl}
|
||||||
isExternal
|
isExternal
|
||||||
color="blue.500"
|
color="blue.500"
|
||||||
fontWeight="bold"
|
fontWeight="bold"
|
||||||
@@ -46,26 +44,41 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
|
|||||||
fontSize="sm"
|
fontSize="sm"
|
||||||
columns={{ base: 2, md: 4 }}
|
columns={{ base: 2, md: 4 }}
|
||||||
>
|
>
|
||||||
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
|
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
|
||||||
<SelectedModelLabeledInfo
|
{model.promptTokenPrice && (
|
||||||
label="Input"
|
<SelectedModelLabeledInfo
|
||||||
info={
|
label="Input"
|
||||||
<Text>
|
info={
|
||||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
<Text>
|
||||||
<Text color="gray.500"> / 1K tokens</Text>
|
${(model.promptTokenPrice * 1000).toFixed(3)}
|
||||||
</Text>
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
}
|
</Text>
|
||||||
/>
|
}
|
||||||
<SelectedModelLabeledInfo
|
/>
|
||||||
label="Output"
|
)}
|
||||||
info={
|
{model.completionTokenPrice && (
|
||||||
<Text>
|
<SelectedModelLabeledInfo
|
||||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
label="Output"
|
||||||
<Text color="gray.500"> / 1K tokens</Text>
|
info={
|
||||||
</Text>
|
<Text>
|
||||||
}
|
${(model.completionTokenPrice * 1000).toFixed(3)}
|
||||||
/>
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{model.pricePerSecond && (
|
||||||
|
<SelectedModelLabeledInfo
|
||||||
|
label="Price"
|
||||||
|
info={
|
||||||
|
<Text>
|
||||||
|
${model.pricePerSecond.toFixed(3)}
|
||||||
|
<Text color="gray.500"> / second</Text>
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</VStack>
|
</VStack>
|
||||||
</VStack>
|
</VStack>
|
||||||
50
src/components/OutputsTable/AddVariantButton.tsx
Normal file
50
src/components/OutputsTable/AddVariantButton.tsx
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
||||||
|
import { BsPlus } from "react-icons/bs";
|
||||||
|
import { Text } from "@chakra-ui/react";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { cellPadding } from "../constants";
|
||||||
|
import { ActionButton } from "./ScenariosHeader";
|
||||||
|
|
||||||
|
export default function AddVariantButton() {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const mutation = api.promptVariants.create.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data) return;
|
||||||
|
await mutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [mutation]);
|
||||||
|
|
||||||
|
const { canModify } = useExperimentAccess();
|
||||||
|
if (!canModify) return <Box w={cellPadding.x} />;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="100%" justifyContent="flex-end">
|
||||||
|
<ActionButton
|
||||||
|
onClick={onClick}
|
||||||
|
py={5}
|
||||||
|
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
||||||
|
>
|
||||||
|
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||||
|
</ActionButton>
|
||||||
|
{/* <Button
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
fontWeight="normal"
|
||||||
|
bgColor="transparent"
|
||||||
|
_hover={{ bgColor: "gray.100" }}
|
||||||
|
px={cellPadding.x}
|
||||||
|
onClick={onClick}
|
||||||
|
height="unset"
|
||||||
|
minH={headerMinHeight}
|
||||||
|
>
|
||||||
|
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||||
|
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||||
|
</Button> */}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -18,11 +18,9 @@ export const FloatingLabelInput = ({
|
|||||||
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
||||||
fontSize={isFocused || !!value ? "12px" : "16px"}
|
fontSize={isFocused || !!value ? "12px" : "16px"}
|
||||||
transition="all 0.15s"
|
transition="all 0.15s"
|
||||||
zIndex="100"
|
zIndex="5"
|
||||||
bg="white"
|
bg="white"
|
||||||
px={1}
|
px={1}
|
||||||
mt={0}
|
|
||||||
mb={2}
|
|
||||||
lineHeight="1"
|
lineHeight="1"
|
||||||
pointerEvents="none"
|
pointerEvents="none"
|
||||||
color={isFocused ? "blue.500" : "gray.500"}
|
color={isFocused ? "blue.500" : "gray.500"}
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
|
||||||
import { BsPlus } from "react-icons/bs";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
|
|
||||||
// Extracted Button styling into reusable component
|
|
||||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
|
||||||
<Button
|
|
||||||
fontWeight="normal"
|
|
||||||
bgColor="transparent"
|
|
||||||
_hover={{ bgColor: "gray.100" }}
|
|
||||||
px={2}
|
|
||||||
onClick={onClick}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
|
|
||||||
export default function NewScenarioButton() {
|
|
||||||
const { canModify } = useExperimentAccess();
|
|
||||||
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.scenarios.create.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
|
|
||||||
const [onClick] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
});
|
|
||||||
await utils.scenarios.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
const [onAutogenerate, autogenerating] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
autogenerate: true,
|
|
||||||
});
|
|
||||||
await utils.scenarios.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
if (!canModify) return null;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<HStack spacing={2}>
|
|
||||||
<StyledButton onClick={onClick}>
|
|
||||||
<Icon as={BsPlus} boxSize={6} />
|
|
||||||
Add Scenario
|
|
||||||
</StyledButton>
|
|
||||||
<StyledButton onClick={onAutogenerate}>
|
|
||||||
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
|
|
||||||
Autogenerate Scenario
|
|
||||||
</StyledButton>
|
|
||||||
</HStack>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
|
|
||||||
import { BsPlus } from "react-icons/bs";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
|
||||||
|
|
||||||
export default function NewVariantButton() {
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const mutation = api.promptVariants.create.useMutation();
|
|
||||||
const utils = api.useContext();
|
|
||||||
|
|
||||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment.data) return;
|
|
||||||
await mutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
}, [mutation]);
|
|
||||||
|
|
||||||
const { canModify } = useExperimentAccess();
|
|
||||||
if (!canModify) return <Box w={cellPadding.x} />;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
w="100%"
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
fontWeight="normal"
|
|
||||||
bgColor="transparent"
|
|
||||||
_hover={{ bgColor: "gray.100" }}
|
|
||||||
px={cellPadding.x}
|
|
||||||
onClick={onClick}
|
|
||||||
height="unset"
|
|
||||||
minH={headerMinHeight}
|
|
||||||
>
|
|
||||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
|
||||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
|
|||||||
import { OutputStats } from "./OutputStats";
|
import { OutputStats } from "./OutputStats";
|
||||||
import { ErrorHandler } from "./ErrorHandler";
|
import { ErrorHandler } from "./ErrorHandler";
|
||||||
import { CellOptions } from "./CellOptions";
|
import { CellOptions } from "./CellOptions";
|
||||||
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -40,7 +40,7 @@ export default function OutputCell({
|
|||||||
);
|
);
|
||||||
|
|
||||||
const provider =
|
const provider =
|
||||||
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
|
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
|
||||||
|
|
||||||
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
||||||
|
|
||||||
@@ -88,11 +88,9 @@ export default function OutputCell({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const normalizedOutput = modelOutput
|
const normalizedOutput = modelOutput
|
||||||
? // @ts-expect-error TODO FIX ASAP
|
? provider.normalizeOutput(modelOutput.output)
|
||||||
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
|
||||||
: streamedMessage
|
: streamedMessage
|
||||||
? // @ts-expect-error TODO FIX ASAP
|
? provider.normalizeOutput(streamedMessage)
|
||||||
provider.normalizeOutput(streamedMessage)
|
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
if (modelOutput && normalizedOutput?.type === "json") {
|
if (modelOutput && normalizedOutput?.type === "json") {
|
||||||
|
|||||||
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
74
src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
||||||
|
import {
|
||||||
|
BsChevronDoubleLeft,
|
||||||
|
BsChevronDoubleRight,
|
||||||
|
BsChevronLeft,
|
||||||
|
BsChevronRight,
|
||||||
|
} from "react-icons/bs";
|
||||||
|
import { usePage, useScenarios } from "~/utils/hooks";
|
||||||
|
|
||||||
|
const ScenarioPaginator = () => {
|
||||||
|
const [page, setPage] = usePage();
|
||||||
|
const { data } = useScenarios();
|
||||||
|
|
||||||
|
if (!data) return null;
|
||||||
|
|
||||||
|
const { scenarios, startIndex, lastPage, count } = data;
|
||||||
|
|
||||||
|
const nextPage = () => {
|
||||||
|
if (page < lastPage) {
|
||||||
|
setPage(page + 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const prevPage = () => {
|
||||||
|
if (page > 1) {
|
||||||
|
setPage(page - 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const goToLastPage = () => setPage(lastPage, "replace");
|
||||||
|
const goToFirstPage = () => setPage(1, "replace");
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack pt={4}>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToFirstPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Go to first page"
|
||||||
|
icon={<BsChevronDoubleLeft />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={prevPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Previous page"
|
||||||
|
icon={<BsChevronLeft />}
|
||||||
|
/>
|
||||||
|
<Box>
|
||||||
|
{startIndex}-{startIndex + scenarios.length - 1} / {count}
|
||||||
|
</Box>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={nextPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Next page"
|
||||||
|
icon={<BsChevronRight />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToLastPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Go to last page"
|
||||||
|
icon={<BsChevronDoubleRight />}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ScenarioPaginator;
|
||||||
@@ -4,11 +4,13 @@ import { cellPadding } from "../constants";
|
|||||||
import OutputCell from "./OutputCell/OutputCell";
|
import OutputCell from "./OutputCell/OutputCell";
|
||||||
import ScenarioEditor from "./ScenarioEditor";
|
import ScenarioEditor from "./ScenarioEditor";
|
||||||
import type { PromptVariant, Scenario } from "./types";
|
import type { PromptVariant, Scenario } from "./types";
|
||||||
|
import { borders } from "./styles";
|
||||||
|
|
||||||
const ScenarioRow = (props: {
|
const ScenarioRow = (props: {
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
variants: PromptVariant[];
|
variants: PromptVariant[];
|
||||||
canHide: boolean;
|
canHide: boolean;
|
||||||
|
rowStart: number;
|
||||||
}) => {
|
}) => {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
@@ -21,15 +23,21 @@ const ScenarioRow = (props: {
|
|||||||
onMouseLeave={() => setIsHovered(false)}
|
onMouseLeave={() => setIsHovered(false)}
|
||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
borderLeftWidth={1}
|
borderLeftWidth={1}
|
||||||
|
{...borders}
|
||||||
|
rowStart={props.rowStart}
|
||||||
|
colStart={1}
|
||||||
>
|
>
|
||||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
{props.variants.map((variant) => (
|
{props.variants.map((variant, i) => (
|
||||||
<GridItem
|
<GridItem
|
||||||
key={variant.id}
|
key={variant.id}
|
||||||
onMouseEnter={() => setIsHovered(true)}
|
onMouseEnter={() => setIsHovered(true)}
|
||||||
onMouseLeave={() => setIsHovered(false)}
|
onMouseLeave={() => setIsHovered(false)}
|
||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
|
rowStart={props.rowStart}
|
||||||
|
colStart={i + 2}
|
||||||
|
{...borders}
|
||||||
>
|
>
|
||||||
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
||||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||||
|
|||||||
@@ -1,52 +1,82 @@
|
|||||||
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
import {
|
||||||
|
Button,
|
||||||
|
type ButtonProps,
|
||||||
|
HStack,
|
||||||
|
Text,
|
||||||
|
Icon,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuList,
|
||||||
|
MenuItem,
|
||||||
|
IconButton,
|
||||||
|
Spinner,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
|
import {
|
||||||
import { stickyHeaderStyle } from "./styles";
|
useExperiment,
|
||||||
import { BsPencil } from "react-icons/bs";
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useScenarios,
|
||||||
|
} from "~/utils/hooks";
|
||||||
|
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
|
||||||
export const ScenariosHeader = ({
|
export const ActionButton = (props: ButtonProps) => (
|
||||||
headerRows,
|
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
||||||
numScenarios,
|
);
|
||||||
}: {
|
|
||||||
headerRows: number;
|
export const ScenariosHeader = () => {
|
||||||
numScenarios: number;
|
|
||||||
}) => {
|
|
||||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
|
const scenarios = useScenarios();
|
||||||
|
|
||||||
const [ref, dimensions] = useElementDimensions();
|
const experiment = useExperiment();
|
||||||
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
const createScenarioMutation = api.scenarios.create.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const [onAddScenario, loading] = useHandledAsyncCallback(
|
||||||
|
async (autogenerate: boolean) => {
|
||||||
|
if (!experiment.data) return;
|
||||||
|
await createScenarioMutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
autogenerate,
|
||||||
|
});
|
||||||
|
await utils.scenarios.list.invalidate();
|
||||||
|
},
|
||||||
|
[createScenarioMutation],
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<GridItem
|
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
<Text fontSize={16} fontWeight="bold">
|
||||||
ref={ref as any}
|
Scenarios ({scenarios.data?.count})
|
||||||
display="flex"
|
</Text>
|
||||||
alignItems="flex-end"
|
{canModify && (
|
||||||
rowSpan={headerRows}
|
<Menu>
|
||||||
px={cellPadding.x}
|
<MenuButton mt={1}>
|
||||||
py={cellPadding.y}
|
<IconButton
|
||||||
// Only display the part of the grid item that has content
|
variant="ghost"
|
||||||
sx={{ ...stickyHeaderStyle, top: topValue }}
|
aria-label="Edit Scenarios"
|
||||||
>
|
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||||
<HStack w="100%">
|
/>
|
||||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
</MenuButton>
|
||||||
Scenarios ({numScenarios})
|
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
|
||||||
</Heading>
|
<MenuItem
|
||||||
{canModify && (
|
icon={<Icon as={BsPlus} boxSize={6} mx={-1} />}
|
||||||
<Button
|
onClick={() => onAddScenario(false)}
|
||||||
size="xs"
|
>
|
||||||
variant="ghost"
|
Add Scenario
|
||||||
color="gray.500"
|
</MenuItem>
|
||||||
aria-label="Edit"
|
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
||||||
leftIcon={<BsPencil />}
|
Autogenerate Scenario
|
||||||
onClick={openDrawer}
|
</MenuItem>
|
||||||
>
|
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
|
||||||
Edit Vars
|
Edit Vars
|
||||||
</Button>
|
</MenuItem>
|
||||||
)}
|
</MenuList>
|
||||||
</HStack>
|
</Menu>
|
||||||
</GridItem>
|
)}
|
||||||
|
</HStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,17 +1,47 @@
|
|||||||
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
|
import {
|
||||||
|
Box,
|
||||||
|
Button,
|
||||||
|
HStack,
|
||||||
|
Spinner,
|
||||||
|
Tooltip,
|
||||||
|
useToast,
|
||||||
|
Text,
|
||||||
|
IconButton,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useRef, useEffect, useState, useCallback } from "react";
|
||||||
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||||
|
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||||
|
|
||||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||||
|
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
const [isChanged, setIsChanged] = useState(false);
|
const [isChanged, setIsChanged] = useState(false);
|
||||||
|
|
||||||
|
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||||
|
|
||||||
|
const toggleFullscreen = useCallback(() => {
|
||||||
|
setIsFullscreen((prev) => !prev);
|
||||||
|
editorRef.current?.focus();
|
||||||
|
}, [setIsFullscreen]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleEsc = (event: KeyboardEvent) => {
|
||||||
|
if (event.key === "Escape" && isFullscreen) {
|
||||||
|
toggleFullscreen();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener("keydown", handleEsc);
|
||||||
|
return () => window.removeEventListener("keydown", handleEsc);
|
||||||
|
}, [isFullscreen, toggleFullscreen]);
|
||||||
|
|
||||||
const lastSavedFn = props.variant.constructFn;
|
const lastSavedFn = props.variant.constructFn;
|
||||||
|
|
||||||
const modifierKey = useModifierKeyLabel();
|
const modifierKey = useModifierKeyLabel();
|
||||||
@@ -99,11 +129,23 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
readOnly: !canModify,
|
readOnly: !canModify,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Workaround because otherwise the commands only work on whatever
|
||||||
|
// editor was loaded on the page last.
|
||||||
|
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
||||||
editorRef.current.onDidFocusEditorText(() => {
|
editorRef.current.onDidFocusEditorText(() => {
|
||||||
// Workaround because otherwise the command only works on whatever
|
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, onSave);
|
||||||
// editor was loaded on the page last.
|
|
||||||
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
editorRef.current?.addCommand(
|
||||||
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave);
|
monaco.KeyMod.CtrlCmd | monaco.KeyMod.Shift | monaco.KeyCode.KeyF,
|
||||||
|
toggleFullscreen,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Exit fullscreen with escape
|
||||||
|
editorRef.current?.addCommand(monaco.KeyCode.Escape, () => {
|
||||||
|
if (isFullscreen) {
|
||||||
|
toggleFullscreen();
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
editorRef.current.onDidChangeModelContent(checkForChanges);
|
editorRef.current.onDidChangeModelContent(checkForChanges);
|
||||||
@@ -132,8 +174,40 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
}, [canModify]);
|
}, [canModify]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="100%" pos="relative">
|
<Box
|
||||||
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
|
w="100%"
|
||||||
|
ref={containerRef}
|
||||||
|
sx={
|
||||||
|
isFullscreen
|
||||||
|
? {
|
||||||
|
position: "fixed",
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
right: 0,
|
||||||
|
bottom: 0,
|
||||||
|
}
|
||||||
|
: { h: "400px", w: "100%" }
|
||||||
|
}
|
||||||
|
bgColor={editorBackground}
|
||||||
|
zIndex={isFullscreen ? 1000 : "unset"}
|
||||||
|
pos="relative"
|
||||||
|
_hover={{ ".fullscreen-toggle": { opacity: 1 } }}
|
||||||
|
>
|
||||||
|
<Box id={editorId} w="100%" h="100%" />
|
||||||
|
<Tooltip label={`${modifierKey} + ⇧ + F`}>
|
||||||
|
<IconButton
|
||||||
|
className="fullscreen-toggle"
|
||||||
|
aria-label="Minimize"
|
||||||
|
icon={isFullscreen ? <FiMinimize /> : <FiMaximize />}
|
||||||
|
position="absolute"
|
||||||
|
top={2}
|
||||||
|
right={2}
|
||||||
|
onClick={toggleFullscreen}
|
||||||
|
opacity={0}
|
||||||
|
transition="opacity 0.2s"
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
|
||||||
{isChanged && (
|
{isChanged && (
|
||||||
<HStack pos="absolute" bottom={2} right={2}>
|
<HStack pos="absolute" bottom={2} right={2}>
|
||||||
<Button
|
<Button
|
||||||
@@ -146,7 +220,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
>
|
>
|
||||||
Reset
|
Reset
|
||||||
</Button>
|
</Button>
|
||||||
<Tooltip label={`${modifierKey} + Enter`}>
|
<Tooltip label={`${modifierKey} + S`}>
|
||||||
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||||
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
import { Grid, GridItem } from "@chakra-ui/react";
|
import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import NewScenarioButton from "./NewScenarioButton";
|
import AddVariantButton from "./AddVariantButton";
|
||||||
import NewVariantButton from "./NewVariantButton";
|
|
||||||
import ScenarioRow from "./ScenarioRow";
|
import ScenarioRow from "./ScenarioRow";
|
||||||
import VariantEditor from "./VariantEditor";
|
import VariantEditor from "./VariantEditor";
|
||||||
import VariantHeader from "../VariantHeader/VariantHeader";
|
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||||
import VariantStats from "./VariantStats";
|
import VariantStats from "./VariantStats";
|
||||||
import { ScenariosHeader } from "./ScenariosHeader";
|
import { ScenariosHeader } from "./ScenariosHeader";
|
||||||
import { stickyHeaderStyle } from "./styles";
|
import { borders } from "./styles";
|
||||||
|
import { useScenarios } from "~/utils/hooks";
|
||||||
|
import ScenarioPaginator from "./ScenarioPaginator";
|
||||||
|
|
||||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||||
const variants = api.promptVariants.list.useQuery(
|
const variants = api.promptVariants.list.useQuery(
|
||||||
@@ -15,68 +16,91 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
{ enabled: !!experimentId },
|
{ enabled: !!experimentId },
|
||||||
);
|
);
|
||||||
|
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
const scenarios = useScenarios();
|
||||||
{ experimentId: experimentId as string },
|
|
||||||
{ enabled: !!experimentId },
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!variants.data || !scenarios.data) return null;
|
if (!variants.data || !scenarios.data) return null;
|
||||||
|
|
||||||
const allCols = variants.data.length + 1;
|
const allCols = variants.data.length + 2;
|
||||||
const headerRows = 3;
|
const variantHeaderRows = 3;
|
||||||
|
const scenarioHeaderRows = 1;
|
||||||
|
const scenarioFooterRows = 1;
|
||||||
|
const visibleScenariosCount = scenarios.data.scenarios.length;
|
||||||
|
const allRows =
|
||||||
|
variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + scenarioFooterRows;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid
|
<Grid
|
||||||
p={4}
|
pt={4}
|
||||||
pb={24}
|
pb={24}
|
||||||
|
pl={4}
|
||||||
display="grid"
|
display="grid"
|
||||||
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
||||||
sx={{
|
sx={{
|
||||||
"> *": {
|
"> *": {
|
||||||
borderColor: "gray.300",
|
borderColor: "gray.300",
|
||||||
borderBottomWidth: 1,
|
|
||||||
borderRightWidth: 1,
|
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
fontSize="sm"
|
fontSize="sm"
|
||||||
>
|
>
|
||||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
<GridItem rowSpan={variantHeaderRows}>
|
||||||
|
<AddVariantButton />
|
||||||
{variants.data.map((variant) => (
|
|
||||||
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
|
|
||||||
))}
|
|
||||||
<GridItem
|
|
||||||
rowSpan={scenarios.data.length + headerRows}
|
|
||||||
padding={0}
|
|
||||||
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
|
|
||||||
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
|
|
||||||
h={8}
|
|
||||||
sx={stickyHeaderStyle}
|
|
||||||
>
|
|
||||||
<NewVariantButton />
|
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant, i) => {
|
||||||
<GridItem key={variant.uiId}>
|
const sharedProps: GridItemProps = {
|
||||||
<VariantEditor variant={variant} />
|
...borders,
|
||||||
</GridItem>
|
colStart: i + 2,
|
||||||
))}
|
borderLeftWidth: i === 0 ? 1 : 0,
|
||||||
{variants.data.map((variant) => (
|
marginLeft: i === 0 ? "-1px" : 0,
|
||||||
<GridItem key={variant.uiId}>
|
};
|
||||||
<VariantStats variant={variant} />
|
return (
|
||||||
</GridItem>
|
<>
|
||||||
))}
|
<VariantHeader
|
||||||
{scenarios.data.map((scenario) => (
|
key={variant.uiId}
|
||||||
|
variant={variant}
|
||||||
|
canHide={variants.data.length > 1}
|
||||||
|
rowStart={1}
|
||||||
|
{...sharedProps}
|
||||||
|
/>
|
||||||
|
<GridItem rowStart={2} {...sharedProps}>
|
||||||
|
<VariantEditor variant={variant} />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem rowStart={3} {...sharedProps}>
|
||||||
|
<VariantStats variant={variant} />
|
||||||
|
</GridItem>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
<GridItem
|
||||||
|
colSpan={allCols - 1}
|
||||||
|
rowStart={variantHeaderRows + 1}
|
||||||
|
colStart={1}
|
||||||
|
{...borders}
|
||||||
|
borderRightWidth={0}
|
||||||
|
>
|
||||||
|
<ScenariosHeader />
|
||||||
|
</GridItem>
|
||||||
|
|
||||||
|
{scenarios.data.scenarios.map((scenario, i) => (
|
||||||
<ScenarioRow
|
<ScenarioRow
|
||||||
|
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
||||||
key={scenario.uiId}
|
key={scenario.uiId}
|
||||||
scenario={scenario}
|
scenario={scenario}
|
||||||
variants={variants.data}
|
variants={variants.data}
|
||||||
canHide={scenarios.data.length > 1}
|
canHide={visibleScenariosCount > 1}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
<GridItem
|
||||||
<NewScenarioButton />
|
rowStart={variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + 2}
|
||||||
|
colStart={1}
|
||||||
|
colSpan={allCols}
|
||||||
|
>
|
||||||
|
<ScenarioPaginator />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
|
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
||||||
|
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
||||||
</Grid>
|
</Grid>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
import { type SystemStyleObject } from "@chakra-ui/react";
|
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
|
||||||
|
|
||||||
export const stickyHeaderStyle: SystemStyleObject = {
|
export const stickyHeaderStyle: SystemStyleObject = {
|
||||||
position: "sticky",
|
position: "sticky",
|
||||||
top: "0",
|
top: "0",
|
||||||
backgroundColor: "#fff",
|
backgroundColor: "#fff",
|
||||||
zIndex: 1,
|
zIndex: 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const borders: GridItemProps = {
|
||||||
|
borderRightWidth: 1,
|
||||||
|
borderBottomWidth: 1,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ import { type RouterOutputs } from "~/utils/api";
|
|||||||
|
|
||||||
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
||||||
|
|
||||||
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];
|
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>["scenarios"][0];
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react";
|
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||||
import Prism from "prismjs";
|
import Prism from "prismjs";
|
||||||
@@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => {
|
|||||||
const CompareFunctions = ({
|
const CompareFunctions = ({
|
||||||
originalFunction,
|
originalFunction,
|
||||||
newFunction = "",
|
newFunction = "",
|
||||||
|
leftTitle = "Original",
|
||||||
|
rightTitle = "Modified",
|
||||||
|
...props
|
||||||
}: {
|
}: {
|
||||||
originalFunction: string;
|
originalFunction: string;
|
||||||
newFunction?: string;
|
newFunction?: string;
|
||||||
}) => {
|
leftTitle?: string;
|
||||||
|
rightTitle?: string;
|
||||||
|
} & StackProps) => {
|
||||||
const showSplitView = useBreakpointValue(
|
const showSplitView = useBreakpointValue(
|
||||||
{
|
{
|
||||||
base: false,
|
base: false,
|
||||||
@@ -34,22 +39,20 @@ const CompareFunctions = ({
|
|||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack w="full" spacing={5}>
|
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
|
||||||
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
|
<DiffViewer
|
||||||
<DiffViewer
|
oldValue={originalFunction}
|
||||||
oldValue={originalFunction}
|
newValue={newFunction || originalFunction}
|
||||||
newValue={newFunction || originalFunction}
|
splitView={showSplitView}
|
||||||
splitView={showSplitView}
|
hideLineNumbers={!showSplitView}
|
||||||
hideLineNumbers={!showSplitView}
|
leftTitle={leftTitle}
|
||||||
leftTitle="Original"
|
rightTitle={rightTitle}
|
||||||
rightTitle={newFunction ? "Modified" : "Unmodified"}
|
disableWordDiff={true}
|
||||||
disableWordDiff={true}
|
compareMethod={DiffMethod.CHARS}
|
||||||
compareMethod={DiffMethod.CHARS}
|
renderContent={highlightSyntax}
|
||||||
renderContent={highlightSyntax}
|
showDiffOnly={false}
|
||||||
showDiffOnly={false}
|
/>
|
||||||
/>
|
</VStack>
|
||||||
</VStack>
|
|
||||||
</HStack>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({
|
|||||||
minW="unset"
|
minW="unset"
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={() => onSubmit()}
|
onClick={() => onSubmit()}
|
||||||
disabled={!instructions}
|
|
||||||
variant={instructions ? "solid" : "ghost"}
|
variant={instructions ? "solid" : "ghost"}
|
||||||
mr={4}
|
mr={4}
|
||||||
borderRadius="8"
|
borderRadius="8"
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||||
import { type IconType } from "react-icons";
|
import { type IconType } from "react-icons";
|
||||||
import { refineOptions, type RefineOptionLabel } from "./refineOptions";
|
|
||||||
|
|
||||||
export const RefineOption = ({
|
export const RefineOption = ({
|
||||||
label,
|
label,
|
||||||
activeLabel,
|
|
||||||
icon,
|
icon,
|
||||||
|
desciption,
|
||||||
|
activeLabel,
|
||||||
onClick,
|
onClick,
|
||||||
loading,
|
loading,
|
||||||
}: {
|
}: {
|
||||||
label: RefineOptionLabel;
|
label: string;
|
||||||
activeLabel: RefineOptionLabel | undefined;
|
|
||||||
icon: IconType;
|
icon: IconType;
|
||||||
onClick: (label: RefineOptionLabel) => void;
|
desciption: string;
|
||||||
|
activeLabel: string | undefined;
|
||||||
|
onClick: (label: string) => void;
|
||||||
loading: boolean;
|
loading: boolean;
|
||||||
}) => {
|
}) => {
|
||||||
const isActive = activeLabel === label;
|
const isActive = activeLabel === label;
|
||||||
const desciption = refineOptions[label].description;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<GridItem w="80" h="44">
|
<GridItem w="80" h="44">
|
||||||
|
|||||||
@@ -15,17 +15,16 @@ import {
|
|||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsStars } from "react-icons/bs";
|
import { BsStars } from "react-icons/bs";
|
||||||
import { VscJson } from "react-icons/vsc";
|
|
||||||
import { TfiThought } from "react-icons/tfi";
|
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import CompareFunctions from "./CompareFunctions";
|
import CompareFunctions from "./CompareFunctions";
|
||||||
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||||
import { type RefineOptionLabel, refineOptions } from "./refineOptions";
|
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
|
||||||
import { RefineOption } from "./RefineOption";
|
import { RefineOption } from "./RefineOption";
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const RefinePromptModal = ({
|
export const RefinePromptModal = ({
|
||||||
variant,
|
variant,
|
||||||
@@ -36,25 +35,29 @@ export const RefinePromptModal = ({
|
|||||||
}) => {
|
}) => {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
|
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
|
||||||
api.promptVariants.getRefinedPromptFn.useMutation();
|
|
||||||
|
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||||
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
const [instructions, setInstructions] = useState<string>("");
|
const [instructions, setInstructions] = useState<string>("");
|
||||||
|
|
||||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
|
||||||
RefineOptionLabel | undefined
|
undefined,
|
||||||
>(undefined);
|
);
|
||||||
|
|
||||||
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(
|
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||||
async (label?: RefineOptionLabel) => {
|
async (label?: string) => {
|
||||||
if (!variant.experimentId) return;
|
if (!variant.experimentId) return;
|
||||||
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
const updatedInstructions = label
|
||||||
|
? (providerRefineOptions[label] as RefineOptionInfo).instructions
|
||||||
|
: instructions;
|
||||||
setActiveRefineOptionLabel(label);
|
setActiveRefineOptionLabel(label);
|
||||||
await getRefinedPromptMutateAsync({
|
await getModifiedPromptMutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
instructions: updatedInstructions,
|
instructions: updatedInstructions,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
||||||
);
|
);
|
||||||
|
|
||||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
@@ -75,7 +78,11 @@ export const RefinePromptModal = ({
|
|||||||
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
|
<Modal
|
||||||
|
isOpen
|
||||||
|
onClose={onClose}
|
||||||
|
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||||
|
>
|
||||||
<ModalOverlay />
|
<ModalOverlay />
|
||||||
<ModalContent w={1200}>
|
<ModalContent w={1200}>
|
||||||
<ModalHeader>
|
<ModalHeader>
|
||||||
@@ -88,35 +95,37 @@ export const RefinePromptModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<VStack spacing={4}>
|
<VStack spacing={4}>
|
||||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
{Object.keys(providerRefineOptions).length && (
|
||||||
<RefineOption
|
<>
|
||||||
label="Convert to function call"
|
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||||
activeLabel={activeRefineOptionLabel}
|
{Object.keys(providerRefineOptions).map((label) => (
|
||||||
icon={VscJson}
|
<RefineOption
|
||||||
onClick={getRefinedPromptFn}
|
key={label}
|
||||||
loading={refiningInProgress}
|
label={label}
|
||||||
/>
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
<RefineOption
|
icon={providerRefineOptions[label]!.icon}
|
||||||
label="Add chain of thought"
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
activeLabel={activeRefineOptionLabel}
|
desciption={providerRefineOptions[label]!.description}
|
||||||
icon={TfiThought}
|
activeLabel={activeRefineOptionLabel}
|
||||||
onClick={getRefinedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
loading={refiningInProgress}
|
loading={modificationInProgress}
|
||||||
/>
|
/>
|
||||||
</SimpleGrid>
|
))}
|
||||||
<HStack>
|
</SimpleGrid>
|
||||||
<Text color="gray.500">or</Text>
|
<Text color="gray.500">or</Text>
|
||||||
</HStack>
|
</>
|
||||||
|
)}
|
||||||
<CustomInstructionsInput
|
<CustomInstructionsInput
|
||||||
instructions={instructions}
|
instructions={instructions}
|
||||||
setInstructions={setInstructions}
|
setInstructions={setInstructions}
|
||||||
loading={refiningInProgress}
|
loading={modificationInProgress}
|
||||||
onSubmit={getRefinedPromptFn}
|
onSubmit={getModifiedPromptFn}
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
<CompareFunctions
|
<CompareFunctions
|
||||||
originalFunction={variant.constructFn}
|
originalFunction={variant.constructFn}
|
||||||
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||||
|
maxH="40vh"
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
</ModalBody>
|
</ModalBody>
|
||||||
@@ -124,12 +133,10 @@ export const RefinePromptModal = ({
|
|||||||
<ModalFooter>
|
<ModalFooter>
|
||||||
<HStack spacing={4}>
|
<HStack spacing={4}>
|
||||||
<Button
|
<Button
|
||||||
|
colorScheme="blue"
|
||||||
onClick={replaceVariant}
|
onClick={replaceVariant}
|
||||||
minW={24}
|
minW={24}
|
||||||
disabled={replacementInProgress || !refinedPromptFn}
|
isDisabled={replacementInProgress || !refinedPromptFn}
|
||||||
_disabled={{
|
|
||||||
bgColor: "blue.500",
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -1,18 +1,22 @@
|
|||||||
// Super hacky, but we'll redo the organization when we have more models
|
// Super hacky, but we'll redo the organization when we have more models
|
||||||
|
|
||||||
export type RefineOptionLabel = "Add chain of thought" | "Convert to function call";
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import { VscJson } from "react-icons/vsc";
|
||||||
|
import { TfiThought } from "react-icons/tfi";
|
||||||
|
import { type IconType } from "react-icons";
|
||||||
|
|
||||||
export const refineOptions: Record<
|
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
|
||||||
RefineOptionLabel,
|
|
||||||
{ description: string; instructions: string }
|
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
|
||||||
> = {
|
"openai/ChatCompletion": {
|
||||||
"Add chain of thought": {
|
"Add chain of thought": {
|
||||||
description: "Asking the model to plan its answer can increase accuracy.",
|
icon: VscJson,
|
||||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
description: "Asking the model to plan its answer can increase accuracy.",
|
||||||
|
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||||
|
|
||||||
This is what a prompt looks like before adding chain of thought:
|
This is what a prompt looks like before adding chain of thought:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -25,11 +29,11 @@ export const refineOptions: Record<
|
|||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
});
|
||||||
|
|
||||||
This is what one looks like after adding chain of thought:
|
This is what one looks like after adding chain of thought:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -42,13 +46,13 @@ export const refineOptions: Record<
|
|||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
});
|
||||||
|
|
||||||
Here's another example:
|
Here's another example:
|
||||||
|
|
||||||
Before:
|
Before:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -78,11 +82,11 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "score_post",
|
name: "score_post",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
After:
|
After:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -115,17 +119,18 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "score_post",
|
name: "score_post",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
Add chain of thought to the original prompt.`,
|
Add chain of thought to the original prompt.`,
|
||||||
},
|
},
|
||||||
"Convert to function call": {
|
"Convert to function call": {
|
||||||
description: "Use function calls to get output from the model in a more structured way.",
|
icon: TfiThought,
|
||||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
description: "Use function calls to get output from the model in a more structured way.",
|
||||||
|
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||||
|
|
||||||
This is what a prompt looks like before adding a function:
|
This is what a prompt looks like before adding a function:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -138,11 +143,11 @@ export const refineOptions: Record<
|
|||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
});
|
||||||
|
|
||||||
This is what one looks like after adding a function:
|
This is what one looks like after adding a function:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -172,13 +177,13 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "extract_sentiment",
|
name: "extract_sentiment",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
Here's another example of adding a function:
|
Here's another example of adding a function:
|
||||||
|
|
||||||
Before:
|
Before:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -196,11 +201,11 @@ export const refineOptions: Record<
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
};
|
});
|
||||||
|
|
||||||
After:
|
After:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -230,8 +235,53 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "score_post",
|
name: "score_post",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
|
Another example
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
stream: true,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
definePrompt("openai/ChatCompletion", {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "write_in_language",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
text: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
function_call: {
|
||||||
|
name: "write_in_language",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
"replicate/llama2": {},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
import {
|
|
||||||
Button,
|
|
||||||
Modal,
|
|
||||||
ModalBody,
|
|
||||||
ModalCloseButton,
|
|
||||||
ModalContent,
|
|
||||||
ModalFooter,
|
|
||||||
ModalHeader,
|
|
||||||
ModalOverlay,
|
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
Spinner,
|
|
||||||
HStack,
|
|
||||||
Icon,
|
|
||||||
} from "@chakra-ui/react";
|
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
|
||||||
import { useState } from "react";
|
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import { ModelStatsCard } from "./ModelStatsCard";
|
|
||||||
import { SelectModelSearch } from "./SelectModelSearch";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
|
|
||||||
export const SelectModelModal = ({
|
|
||||||
originalModel,
|
|
||||||
variantId,
|
|
||||||
onClose,
|
|
||||||
}: {
|
|
||||||
originalModel: SupportedModel;
|
|
||||||
variantId: string;
|
|
||||||
onClose: () => void;
|
|
||||||
}) => {
|
|
||||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
|
||||||
const utils = api.useContext();
|
|
||||||
|
|
||||||
const experiment = useExperiment();
|
|
||||||
|
|
||||||
const createMutation = api.promptVariants.create.useMutation();
|
|
||||||
|
|
||||||
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
|
|
||||||
if (!experiment?.data?.id) return;
|
|
||||||
await createMutation.mutateAsync({
|
|
||||||
experimentId: experiment?.data?.id,
|
|
||||||
variantId,
|
|
||||||
newModel: selectedModel,
|
|
||||||
});
|
|
||||||
await utils.promptVariants.list.invalidate();
|
|
||||||
onClose();
|
|
||||||
}, [createMutation, experiment?.data?.id, variantId, onClose]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
|
||||||
<ModalOverlay />
|
|
||||||
<ModalContent w={1200}>
|
|
||||||
<ModalHeader>
|
|
||||||
<HStack>
|
|
||||||
<Icon as={RiExchangeFundsFill} />
|
|
||||||
<Text>Change Model</Text>
|
|
||||||
</HStack>
|
|
||||||
</ModalHeader>
|
|
||||||
<ModalCloseButton />
|
|
||||||
<ModalBody maxW="unset">
|
|
||||||
<VStack spacing={8}>
|
|
||||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
|
||||||
{originalModel !== selectedModel && (
|
|
||||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
|
||||||
)}
|
|
||||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
|
||||||
</VStack>
|
|
||||||
</ModalBody>
|
|
||||||
|
|
||||||
<ModalFooter>
|
|
||||||
<Button
|
|
||||||
colorScheme="blue"
|
|
||||||
onClick={createNewVariant}
|
|
||||||
minW={24}
|
|
||||||
disabled={originalModel === selectedModel}
|
|
||||||
>
|
|
||||||
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
|
|
||||||
</Button>
|
|
||||||
</ModalFooter>
|
|
||||||
</ModalContent>
|
|
||||||
</Modal>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
import { VStack, Text } from "@chakra-ui/react";
|
|
||||||
import { type LegacyRef, useCallback } from "react";
|
|
||||||
import Select, { type SingleValue } from "react-select";
|
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import { useElementDimensions } from "~/utils/hooks";
|
|
||||||
|
|
||||||
const modelOptions: { value: SupportedModel; label: string }[] = [
|
|
||||||
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
|
|
||||||
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
|
|
||||||
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
|
|
||||||
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
|
|
||||||
{ value: "gpt-4", label: "gpt-4" },
|
|
||||||
{ value: "gpt-4-0613", label: "gpt-4-0613" },
|
|
||||||
{ value: "gpt-4-32k", label: "gpt-4-32k" },
|
|
||||||
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
|
|
||||||
];
|
|
||||||
|
|
||||||
export const SelectModelSearch = ({
|
|
||||||
selectedModel,
|
|
||||||
setSelectedModel,
|
|
||||||
}: {
|
|
||||||
selectedModel: SupportedModel;
|
|
||||||
setSelectedModel: (model: SupportedModel) => void;
|
|
||||||
}) => {
|
|
||||||
const handleSelection = useCallback(
|
|
||||||
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
|
|
||||||
if (!option) return;
|
|
||||||
setSelectedModel(option.value);
|
|
||||||
},
|
|
||||||
[setSelectedModel],
|
|
||||||
);
|
|
||||||
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
|
|
||||||
|
|
||||||
const [containerRef, containerDimensions] = useElementDimensions();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
|
||||||
<Text>Browse Models</Text>
|
|
||||||
<Select
|
|
||||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
|
||||||
value={selectedOption}
|
|
||||||
options={modelOptions}
|
|
||||||
onChange={handleSelection}
|
|
||||||
/>
|
|
||||||
</VStack>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -3,28 +3,34 @@ import { type PromptVariant } from "../OutputsTable/types";
|
|||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { RiDraggable } from "react-icons/ri";
|
import { RiDraggable } from "react-icons/ri";
|
||||||
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
|
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||||
|
|
||||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
export default function VariantHeader(
|
||||||
|
allProps: {
|
||||||
|
variant: PromptVariant;
|
||||||
|
canHide: boolean;
|
||||||
|
} & GridItemProps,
|
||||||
|
) {
|
||||||
|
const { variant, canHide, ...gridItemProps } = allProps;
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||||
const [label, setLabel] = useState(props.variant.label);
|
const [label, setLabel] = useState(variant.label);
|
||||||
|
|
||||||
const updateMutation = api.promptVariants.update.useMutation();
|
const updateMutation = api.promptVariants.update.useMutation();
|
||||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||||
if (label && label !== props.variant.label) {
|
if (label && label !== variant.label) {
|
||||||
await updateMutation.mutateAsync({
|
await updateMutation.mutateAsync({
|
||||||
id: props.variant.id,
|
id: variant.id,
|
||||||
updates: { label: label },
|
updates: { label: label },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
}, [updateMutation, variant.id, variant.label, label]);
|
||||||
|
|
||||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||||
const [onReorder] = useHandledAsyncCallback(
|
const [onReorder] = useHandledAsyncCallback(
|
||||||
@@ -32,7 +38,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
setIsDragTarget(false);
|
setIsDragTarget(false);
|
||||||
const draggedId = e.dataTransfer.getData("text/plain");
|
const draggedId = e.dataTransfer.getData("text/plain");
|
||||||
const droppedId = props.variant.id;
|
const droppedId = variant.id;
|
||||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||||
await reorderMutation.mutateAsync({
|
await reorderMutation.mutateAsync({
|
||||||
draggedId,
|
draggedId,
|
||||||
@@ -40,16 +46,16 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
},
|
},
|
||||||
[reorderMutation, props.variant.id],
|
[reorderMutation, variant.id],
|
||||||
);
|
);
|
||||||
|
|
||||||
const [menuOpen, setMenuOpen] = useState(false);
|
const [menuOpen, setMenuOpen] = useState(false);
|
||||||
|
|
||||||
if (!canModify) {
|
if (!canModify) {
|
||||||
return (
|
return (
|
||||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
|
||||||
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||||
{props.variant.label}
|
{variant.label}
|
||||||
</Text>
|
</Text>
|
||||||
</GridItem>
|
</GridItem>
|
||||||
);
|
);
|
||||||
@@ -64,6 +70,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||||
}}
|
}}
|
||||||
borderTopWidth={1}
|
borderTopWidth={1}
|
||||||
|
{...gridItemProps}
|
||||||
>
|
>
|
||||||
<HStack
|
<HStack
|
||||||
spacing={4}
|
spacing={4}
|
||||||
@@ -71,7 +78,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
minH={headerMinHeight}
|
minH={headerMinHeight}
|
||||||
draggable={!isInputHovered}
|
draggable={!isInputHovered}
|
||||||
onDragStart={(e) => {
|
onDragStart={(e) => {
|
||||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
e.dataTransfer.setData("text/plain", variant.id);
|
||||||
e.currentTarget.style.opacity = "0.4";
|
e.currentTarget.style.opacity = "0.4";
|
||||||
}}
|
}}
|
||||||
onDragEnd={(e) => {
|
onDragEnd={(e) => {
|
||||||
@@ -112,8 +119,8 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
onMouseLeave={() => setIsInputHovered(false)}
|
onMouseLeave={() => setIsInputHovered(false)}
|
||||||
/>
|
/>
|
||||||
<VariantHeaderMenuButton
|
<VariantHeaderMenuButton
|
||||||
variant={props.variant}
|
variant={variant}
|
||||||
canHide={props.canHide}
|
canHide={canHide}
|
||||||
menuOpen={menuOpen}
|
menuOpen={menuOpen}
|
||||||
setMenuOpen={setMenuOpen}
|
setMenuOpen={setMenuOpen}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ import { FaRegClone } from "react-icons/fa";
|
|||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
|
|
||||||
export default function VariantHeaderMenuButton({
|
export default function VariantHeaderMenuButton({
|
||||||
variant,
|
variant,
|
||||||
@@ -51,7 +50,7 @@ export default function VariantHeaderMenuButton({
|
|||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [hideMutation, variant.id]);
|
}, [hideMutation, variant.id]);
|
||||||
|
|
||||||
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
|
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
|
||||||
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -73,7 +72,7 @@ export default function VariantHeaderMenuButton({
|
|||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||||
onClick={() => setSelectModelModalOpen(true)}
|
onClick={() => setChangeModelModalOpen(true)}
|
||||||
>
|
>
|
||||||
Change Model
|
Change Model
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
@@ -98,12 +97,8 @@ export default function VariantHeaderMenuButton({
|
|||||||
)}
|
)}
|
||||||
</MenuList>
|
</MenuList>
|
||||||
</Menu>
|
</Menu>
|
||||||
{selectModelModalOpen && (
|
{changeModelModalOpen && (
|
||||||
<SelectModelModal
|
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
|
||||||
originalModel={variant.model as SupportedModel}
|
|
||||||
variantId={variant.id}
|
|
||||||
onClose={() => setSelectModelModalOpen(false)}
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
{refinePromptModalOpen && (
|
{refinePromptModalOpen && (
|
||||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ export const env = createEnv({
|
|||||||
server: {
|
server: {
|
||||||
DATABASE_URL: z.string().url(),
|
DATABASE_URL: z.string().url(),
|
||||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||||
OPENAI_API_KEY: z.string().min(1),
|
|
||||||
RESTRICT_PRISMA_LOGS: z
|
RESTRICT_PRISMA_LOGS: z
|
||||||
.string()
|
.string()
|
||||||
.optional()
|
.optional()
|
||||||
@@ -17,7 +16,8 @@ export const env = createEnv({
|
|||||||
.transform((val) => val.toLowerCase() === "true"),
|
.transform((val) => val.toLowerCase() === "true"),
|
||||||
GITHUB_CLIENT_ID: z.string().min(1),
|
GITHUB_CLIENT_ID: z.string().min(1),
|
||||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||||
REPLICATE_API_TOKEN: z.string().min(1),
|
OPENAI_API_KEY: z.string().min(1),
|
||||||
|
REPLICATE_API_TOKEN: z.string().default("placeholder"),
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||||
|
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
||||||
|
|
||||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||||
|
|
||||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||||
// just include them in the default `modelProviders` object because it has some
|
// just include them in the default `modelProviders` object because it has some
|
||||||
// transient dependencies that can only be imported on the server.
|
// transient dependencies that can only be imported on the server.
|
||||||
const modelProvidersFrontend = {
|
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
||||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||||
"replicate/llama2": replicateLlama2Frontend,
|
"replicate/llama2": replicateLlama2Frontend,
|
||||||
} as const;
|
};
|
||||||
|
|
||||||
export default modelProvidersFrontend;
|
export default frontendModelProviders;
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||||
import replicateLlama2 from "./replicate-llama2";
|
import replicateLlama2 from "./replicate-llama2";
|
||||||
|
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||||
|
|
||||||
const modelProviders = {
|
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||||
"openai/ChatCompletion": openaiChatCompletion,
|
"openai/ChatCompletion": openaiChatCompletion,
|
||||||
"replicate/llama2": replicateLlama2,
|
"replicate/llama2": replicateLlama2,
|
||||||
} as const;
|
};
|
||||||
|
|
||||||
export default modelProviders;
|
export default modelProviders;
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
import { type SupportedModel } from "../server/types";
|
|
||||||
|
|
||||||
interface ModelStats {
|
|
||||||
contextLength: number;
|
|
||||||
promptTokenPrice: number;
|
|
||||||
completionTokenPrice: number;
|
|
||||||
speed: "fast" | "medium" | "slow";
|
|
||||||
provider: "OpenAI";
|
|
||||||
learnMoreUrl: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const modelStats: Record<SupportedModel, ModelStats> = {
|
|
||||||
"gpt-4": {
|
|
||||||
contextLength: 8192,
|
|
||||||
promptTokenPrice: 0.00003,
|
|
||||||
completionTokenPrice: 0.00006,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-4-0613": {
|
|
||||||
contextLength: 8192,
|
|
||||||
promptTokenPrice: 0.00003,
|
|
||||||
completionTokenPrice: 0.00006,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-4-32k": {
|
|
||||||
contextLength: 32768,
|
|
||||||
promptTokenPrice: 0.00006,
|
|
||||||
completionTokenPrice: 0.00012,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-4-32k-0613": {
|
|
||||||
contextLength: 32768,
|
|
||||||
promptTokenPrice: 0.00006,
|
|
||||||
completionTokenPrice: 0.00012,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo": {
|
|
||||||
contextLength: 4096,
|
|
||||||
promptTokenPrice: 0.0000015,
|
|
||||||
completionTokenPrice: 0.000002,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-0613": {
|
|
||||||
contextLength: 4096,
|
|
||||||
promptTokenPrice: 0.0000015,
|
|
||||||
completionTokenPrice: 0.000002,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-16k": {
|
|
||||||
contextLength: 16384,
|
|
||||||
promptTokenPrice: 0.000003,
|
|
||||||
completionTokenPrice: 0.000004,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-16k-0613": {
|
|
||||||
contextLength: 16384,
|
|
||||||
promptTokenPrice: 0.000003,
|
|
||||||
completionTokenPrice: 0.000004,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
};
|
|
||||||
@@ -56,6 +56,14 @@ modelProperty.type = "string";
|
|||||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||||
delete modelProperty["oneOf"];
|
delete modelProperty["oneOf"];
|
||||||
|
|
||||||
|
// The default of "inf" confuses the Typescript generator, so can just remove it
|
||||||
|
assert(
|
||||||
|
"max_tokens" in completionRequestSchema.properties &&
|
||||||
|
isObject(completionRequestSchema.properties.max_tokens) &&
|
||||||
|
"default" in completionRequestSchema.properties.max_tokens,
|
||||||
|
);
|
||||||
|
delete completionRequestSchema.properties.max_tokens["default"];
|
||||||
|
|
||||||
// Get the directory of the current script
|
// Get the directory of the current script
|
||||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,6 @@
|
|||||||
},
|
},
|
||||||
"max_tokens": {
|
"max_tokens": {
|
||||||
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
||||||
"default": "inf",
|
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
"presence_penalty": {
|
"presence_penalty": {
|
||||||
|
|||||||
@@ -1,8 +1,50 @@
|
|||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
import { type OpenaiChatModelProvider } from ".";
|
import { type SupportedModel } from ".";
|
||||||
import { type ModelProviderFrontend } from "../types";
|
import { type FrontendModelProvider } from "../types";
|
||||||
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
|
||||||
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||||
|
name: "OpenAI ChatCompletion",
|
||||||
|
|
||||||
|
models: {
|
||||||
|
"gpt-4-0613": {
|
||||||
|
name: "GPT-4",
|
||||||
|
contextWindow: 8192,
|
||||||
|
promptTokenPrice: 0.00003,
|
||||||
|
completionTokenPrice: 0.00006,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-4-32k-0613": {
|
||||||
|
name: "GPT-4 32k",
|
||||||
|
contextWindow: 32768,
|
||||||
|
promptTokenPrice: 0.00006,
|
||||||
|
completionTokenPrice: 0.00012,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-0613": {
|
||||||
|
name: "GPT-3.5 Turbo",
|
||||||
|
contextWindow: 4096,
|
||||||
|
promptTokenPrice: 0.0000015,
|
||||||
|
completionTokenPrice: 0.000002,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-16k-0613": {
|
||||||
|
name: "GPT-3.5 Turbo 16k",
|
||||||
|
contextWindow: 16384,
|
||||||
|
promptTokenPrice: 0.000003,
|
||||||
|
completionTokenPrice: 0.000004,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
|
||||||
normalizeOutput: (output) => {
|
normalizeOutput: (output) => {
|
||||||
const message = output.choices[0]?.message;
|
const message = output.choices[0]?.message;
|
||||||
if (!message)
|
if (!message)
|
||||||
@@ -39,4 +81,4 @@ const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export default modelProviderFrontend;
|
export default frontendModelProvider;
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import { countOpenAIChatTokens } from "~/utils/countTokens";
|
|||||||
import { type CompletionResponse } from "../types";
|
import { type CompletionResponse } from "../types";
|
||||||
import { omit } from "lodash-es";
|
import { omit } from "lodash-es";
|
||||||
import { openai } from "~/server/utils/openai";
|
import { openai } from "~/server/utils/openai";
|
||||||
import { type OpenAIChatModel } from "~/server/types";
|
|
||||||
import { truthyFilter } from "~/utils/utils";
|
import { truthyFilter } from "~/utils/utils";
|
||||||
import { APIError } from "openai";
|
import { APIError } from "openai";
|
||||||
import { modelStats } from "../modelStats";
|
import frontendModelProvider from "./frontend";
|
||||||
|
import modelProvider, { type SupportedModel } from ".";
|
||||||
|
|
||||||
const mergeStreamedChunks = (
|
const mergeStreamedChunks = (
|
||||||
base: ChatCompletion | null,
|
base: ChatCompletion | null,
|
||||||
@@ -60,6 +60,7 @@ export async function getCompletion(
|
|||||||
let finalCompletion: ChatCompletion | null = null;
|
let finalCompletion: ChatCompletion | null = null;
|
||||||
let promptTokens: number | undefined = undefined;
|
let promptTokens: number | undefined = undefined;
|
||||||
let completionTokens: number | undefined = undefined;
|
let completionTokens: number | undefined = undefined;
|
||||||
|
const modelName = modelProvider.getModel(input) as SupportedModel;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (onStream) {
|
if (onStream) {
|
||||||
@@ -81,12 +82,9 @@ export async function getCompletion(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
promptTokens = countOpenAIChatTokens(
|
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
||||||
input.model as keyof typeof OpenAIChatModel,
|
|
||||||
input.messages,
|
|
||||||
);
|
|
||||||
completionTokens = countOpenAIChatTokens(
|
completionTokens = countOpenAIChatTokens(
|
||||||
input.model as keyof typeof OpenAIChatModel,
|
modelName,
|
||||||
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -106,10 +104,10 @@ export async function getCompletion(
|
|||||||
}
|
}
|
||||||
const timeToComplete = Date.now() - start;
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
|
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
||||||
let cost = undefined;
|
let cost = undefined;
|
||||||
if (stats && promptTokens && completionTokens) {
|
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
||||||
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
|
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { type ModelProvider } from "../types";
|
|||||||
import inputSchema from "./codegen/input.schema.json";
|
import inputSchema from "./codegen/input.schema.json";
|
||||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { getCompletion } from "./getCompletion";
|
import { getCompletion } from "./getCompletion";
|
||||||
|
import frontendModelProvider from "./frontend";
|
||||||
|
|
||||||
const supportedModels = [
|
const supportedModels = [
|
||||||
"gpt-4-0613",
|
"gpt-4-0613",
|
||||||
@@ -11,7 +12,7 @@ const supportedModels = [
|
|||||||
"gpt-3.5-turbo-16k-0613",
|
"gpt-3.5-turbo-16k-0613",
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
type SupportedModel = (typeof supportedModels)[number];
|
export type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
export type OpenaiChatModelProvider = ModelProvider<
|
export type OpenaiChatModelProvider = ModelProvider<
|
||||||
SupportedModel,
|
SupportedModel,
|
||||||
@@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
const modelProvider: OpenaiChatModelProvider = {
|
const modelProvider: OpenaiChatModelProvider = {
|
||||||
name: "OpenAI ChatCompletion",
|
|
||||||
models: {
|
|
||||||
"gpt-4-0613": {
|
|
||||||
name: "GPT-4",
|
|
||||||
learnMore: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-4-32k-0613": {
|
|
||||||
name: "GPT-4 32k",
|
|
||||||
learnMore: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-0613": {
|
|
||||||
name: "GPT-3.5 Turbo",
|
|
||||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-16k-0613": {
|
|
||||||
name: "GPT-3.5 Turbo 16k",
|
|
||||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
getModel: (input) => {
|
getModel: (input) => {
|
||||||
if (supportedModels.includes(input.model as SupportedModel))
|
if (supportedModels.includes(input.model as SupportedModel))
|
||||||
return input.model as SupportedModel;
|
return input.model as SupportedModel;
|
||||||
@@ -57,6 +39,7 @@ const modelProvider: OpenaiChatModelProvider = {
|
|||||||
inputSchema: inputSchema as JSONSchema4,
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
shouldStream: (input) => input.stream ?? false,
|
shouldStream: (input) => input.stream ?? false,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|
||||||
export default modelProvider;
|
export default modelProvider;
|
||||||
|
|||||||
@@ -1,7 +1,36 @@
|
|||||||
import { type ReplicateLlama2Provider } from ".";
|
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||||
import { type ModelProviderFrontend } from "../types";
|
import { type FrontendModelProvider } from "../types";
|
||||||
|
|
||||||
|
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||||
|
name: "Replicate Llama2",
|
||||||
|
|
||||||
|
models: {
|
||||||
|
"7b-chat": {
|
||||||
|
name: "LLama 2 7B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0023,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
|
||||||
|
},
|
||||||
|
"13b-chat": {
|
||||||
|
name: "LLama 2 13B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0023,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
|
||||||
|
},
|
||||||
|
"70b-chat": {
|
||||||
|
name: "LLama 2 70B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0032,
|
||||||
|
speed: "slow",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
|
||||||
normalizeOutput: (output) => {
|
normalizeOutput: (output) => {
|
||||||
return {
|
return {
|
||||||
type: "text",
|
type: "text",
|
||||||
@@ -10,4 +39,4 @@ const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export default modelProviderFrontend;
|
export default frontendModelProvider;
|
||||||
|
|||||||
@@ -27,8 +27,6 @@ export async function getCompletion(
|
|||||||
input: rest,
|
input: rest,
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log("stream?", onStream);
|
|
||||||
|
|
||||||
const interval = onStream
|
const interval = onStream
|
||||||
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||||
setInterval(async () => {
|
setInterval(async () => {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import { type ModelProvider } from "../types";
|
import { type ModelProvider } from "../types";
|
||||||
|
import frontendModelProvider from "./frontend";
|
||||||
import { getCompletion } from "./getCompletion";
|
import { getCompletion } from "./getCompletion";
|
||||||
|
|
||||||
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||||
|
|
||||||
type SupportedModel = (typeof supportedModels)[number];
|
export type SupportedModel = (typeof supportedModels)[number];
|
||||||
|
|
||||||
export type ReplicateLlama2Input = {
|
export type ReplicateLlama2Input = {
|
||||||
model: SupportedModel;
|
model: SupportedModel;
|
||||||
@@ -25,12 +26,6 @@ export type ReplicateLlama2Provider = ModelProvider<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
const modelProvider: ReplicateLlama2Provider = {
|
const modelProvider: ReplicateLlama2Provider = {
|
||||||
name: "OpenAI ChatCompletion",
|
|
||||||
models: {
|
|
||||||
"7b-chat": {},
|
|
||||||
"13b-chat": {},
|
|
||||||
"70b-chat": {},
|
|
||||||
},
|
|
||||||
getModel: (input) => {
|
getModel: (input) => {
|
||||||
if (supportedModels.includes(input.model)) return input.model;
|
if (supportedModels.includes(input.model)) return input.model;
|
||||||
|
|
||||||
@@ -69,6 +64,7 @@ const modelProvider: ReplicateLlama2Provider = {
|
|||||||
},
|
},
|
||||||
shouldStream: (input) => input.stream ?? false,
|
shouldStream: (input) => input.stream ?? false,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|
||||||
export default modelProvider;
|
export default modelProvider;
|
||||||
|
|||||||
@@ -1,9 +1,33 @@
|
|||||||
import { type JSONSchema4 } from "json-schema";
|
import { type JSONSchema4 } from "json-schema";
|
||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
type ModelProviderModel = {
|
const ZodSupportedProvider = z.union([
|
||||||
name?: string;
|
z.literal("openai/ChatCompletion"),
|
||||||
learnMore?: string;
|
z.literal("replicate/llama2"),
|
||||||
|
]);
|
||||||
|
|
||||||
|
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||||
|
|
||||||
|
export const ZodModel = z.object({
|
||||||
|
name: z.string(),
|
||||||
|
contextWindow: z.number(),
|
||||||
|
promptTokenPrice: z.number().optional(),
|
||||||
|
completionTokenPrice: z.number().optional(),
|
||||||
|
pricePerSecond: z.number().optional(),
|
||||||
|
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
|
||||||
|
provider: ZodSupportedProvider,
|
||||||
|
description: z.string().optional(),
|
||||||
|
learnMoreUrl: z.string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type Model = z.infer<typeof ZodModel>;
|
||||||
|
|
||||||
|
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||||
|
name: string;
|
||||||
|
models: Record<SupportedModels, Model>;
|
||||||
|
|
||||||
|
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CompletionResponse<T> =
|
export type CompletionResponse<T> =
|
||||||
@@ -19,8 +43,6 @@ export type CompletionResponse<T> =
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||||
name: string;
|
|
||||||
models: Record<SupportedModels, ModelProviderModel>;
|
|
||||||
getModel: (input: InputSchema) => SupportedModels | null;
|
getModel: (input: InputSchema) => SupportedModels | null;
|
||||||
shouldStream: (input: InputSchema) => boolean;
|
shouldStream: (input: InputSchema) => boolean;
|
||||||
inputSchema: JSONSchema4;
|
inputSchema: JSONSchema4;
|
||||||
@@ -31,7 +53,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
|
|||||||
|
|
||||||
// This is just a convenience for type inference, don't use it at runtime
|
// This is just a convenience for type inference, don't use it at runtime
|
||||||
_outputSchema?: OutputSchema | null;
|
_outputSchema?: OutputSchema | null;
|
||||||
};
|
} & FrontendModelProvider<SupportedModels, OutputSchema>;
|
||||||
|
|
||||||
export type NormalizedOutput =
|
export type NormalizedOutput =
|
||||||
| {
|
| {
|
||||||
@@ -42,7 +64,3 @@ export type NormalizedOutput =
|
|||||||
type: "json";
|
type: "json";
|
||||||
value: JsonValue;
|
value: JsonValue;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
|
|
||||||
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
|
|
||||||
};
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import "~/utils/analytics";
|
|||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
||||||
import { SyncAppStore } from "~/state/sync";
|
import { SyncAppStore } from "~/state/sync";
|
||||||
|
import NextAdapterApp from "next-query-params/app";
|
||||||
|
import { QueryParamProvider } from "use-query-params";
|
||||||
|
|
||||||
const MyApp: AppType<{ session: Session | null }> = ({
|
const MyApp: AppType<{ session: Session | null }> = ({
|
||||||
Component,
|
Component,
|
||||||
@@ -24,7 +26,9 @@ const MyApp: AppType<{ session: Session | null }> = ({
|
|||||||
<SyncAppStore />
|
<SyncAppStore />
|
||||||
<Favicon />
|
<Favicon />
|
||||||
<ChakraThemeProvider>
|
<ChakraThemeProvider>
|
||||||
<Component {...pageProps} />
|
<QueryParamProvider adapter={NextAdapterApp}>
|
||||||
|
<Component {...pageProps} />
|
||||||
|
</QueryParamProvider>
|
||||||
</ChakraThemeProvider>
|
</ChakraThemeProvider>
|
||||||
</SessionProvider>
|
</SessionProvider>
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -20,22 +20,25 @@ export default function ExperimentsPage() {
|
|||||||
const experiments = api.experiments.list.useQuery();
|
const experiments = api.experiments.list.useQuery();
|
||||||
|
|
||||||
const user = useSession().data;
|
const user = useSession().data;
|
||||||
|
const authLoading = useSession().status === "loading";
|
||||||
|
|
||||||
if (user === null) {
|
if (user === null || authLoading) {
|
||||||
return (
|
return (
|
||||||
<AppShell title="Experiments">
|
<AppShell title="Experiments">
|
||||||
<Center h="100%">
|
<Center h="100%">
|
||||||
<Text>
|
{!authLoading && (
|
||||||
<Link
|
<Text>
|
||||||
onClick={() => {
|
<Link
|
||||||
signIn("github").catch(console.error);
|
onClick={() => {
|
||||||
}}
|
signIn("github").catch(console.error);
|
||||||
textDecor="underline"
|
}}
|
||||||
>
|
textDecor="underline"
|
||||||
Sign in
|
>
|
||||||
</Link>{" "}
|
Sign in
|
||||||
to view or create new experiments!
|
</Link>{" "}
|
||||||
</Text>
|
to view or create new experiments!
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
</Center>
|
</Center>
|
||||||
</AppShell>
|
</AppShell>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const [variant, _, scenario] = await prisma.$transaction([
|
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||||
prisma.promptVariant.create({
|
prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
@@ -121,7 +121,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "system",
|
role: "system",
|
||||||
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
});`,
|
});`,
|
||||||
@@ -133,20 +133,38 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
prisma.templateVariable.create({
|
prisma.templateVariable.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
label: "text",
|
label: "language",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
prisma.testScenario.create({
|
prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
variableValues: {
|
variableValues: {
|
||||||
text: "This is a test scenario.",
|
language: "English",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
variableValues: {
|
||||||
|
language: "Spanish",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
variableValues: {
|
||||||
|
language: "German",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
await generateNewCell(variant.id, scenario.id);
|
await generateNewCell(variant.id, scenario1.id);
|
||||||
|
await generateNewCell(variant.id, scenario2.id);
|
||||||
|
await generateNewCell(variant.id, scenario3.id);
|
||||||
|
|
||||||
return exp;
|
return exp;
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import { z } from "zod";
|
|||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import userError from "~/server/utils/error";
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||||
@@ -10,6 +9,7 @@ import { type PromptVariant } from "@prisma/client";
|
|||||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
import parseConstructFn from "~/server/utils/parseConstructFn";
|
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||||
|
import { ZodModel } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
@@ -144,7 +144,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
variantId: z.string().optional(),
|
variantId: z.string().optional(),
|
||||||
newModel: z.string().optional(),
|
newModel: ZodModel.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -186,10 +186,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
? `${originalVariant?.label} Copy`
|
? `${originalVariant?.label} Copy`
|
||||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
const newConstructFn = await deriveNewConstructFn(
|
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
|
||||||
originalVariant,
|
|
||||||
input.newModel as SupportedModel,
|
|
||||||
);
|
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -284,11 +281,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return updatedPromptVariant;
|
return updatedPromptVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
getRefinedPromptFn: protectedProcedure
|
getModifiedPromptFn: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
instructions: z.string(),
|
instructions: z.string().optional(),
|
||||||
|
newModel: ZodModel.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -307,7 +305,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
const promptConstructionFn = await deriveNewConstructFn(
|
const promptConstructionFn = await deriveNewConstructFn(
|
||||||
existing,
|
existing,
|
||||||
constructedPrompt.model as SupportedModel,
|
input.newModel,
|
||||||
input.instructions,
|
input.instructions,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -7,21 +7,39 @@ import { runAllEvals } from "~/server/utils/evaluations";
|
|||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
|
const PAGE_SIZE = 10;
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
.input(z.object({ experimentId: z.string() }))
|
.input(z.object({ experimentId: z.string(), page: z.number() }))
|
||||||
.query(async ({ input, ctx }) => {
|
.query(async ({ input, ctx }) => {
|
||||||
await requireCanViewExperiment(input.experimentId, ctx);
|
await requireCanViewExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.testScenario.findMany({
|
const { experimentId, page } = input;
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId,
|
||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
orderBy: {
|
orderBy: { sortIndex: "asc" },
|
||||||
sortIndex: "asc",
|
skip: (page - 1) * PAGE_SIZE,
|
||||||
|
take: PAGE_SIZE,
|
||||||
|
});
|
||||||
|
|
||||||
|
const count = await prisma.testScenario.count({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
visible: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
scenarios,
|
||||||
|
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||||
|
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||||
|
count,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: protectedProcedure
|
create: protectedProcedure
|
||||||
@@ -34,22 +52,21 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||||
|
|
||||||
const maxSortIndex =
|
await prisma.testScenario.updateMany({
|
||||||
(
|
where: {
|
||||||
await prisma.testScenario.aggregate({
|
experimentId: input.experimentId,
|
||||||
where: {
|
},
|
||||||
experimentId: input.experimentId,
|
data: {
|
||||||
},
|
sortIndex: {
|
||||||
_max: {
|
increment: 1,
|
||||||
sortIndex: true,
|
},
|
||||||
},
|
},
|
||||||
})
|
});
|
||||||
)._max.sortIndex ?? 0;
|
|
||||||
|
|
||||||
const createNewScenarioAction = prisma.testScenario.create({
|
const createNewScenarioAction = prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
sortIndex: maxSortIndex + 1,
|
sortIndex: 0,
|
||||||
variableValues: input.autogenerate
|
variableValues: input.autogenerate
|
||||||
? await autogenerateScenarioValues(input.experimentId)
|
? await autogenerateScenarioValues(input.experimentId)
|
||||||
: {},
|
: {},
|
||||||
|
|||||||
@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const provider = modelProviders[prompt.modelProvider];
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||||
|
|
||||||
if (streamingChannel) {
|
if (streamingChannel) {
|
||||||
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
: null;
|
: null;
|
||||||
|
|
||||||
for (let i = 0; true; i++) {
|
for (let i = 0; true; i++) {
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
|
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
if (response.type === "success") {
|
if (response.type === "success") {
|
||||||
const inputHash = hashPrompt(prompt);
|
const inputHash = hashPrompt(prompt);
|
||||||
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
data: {
|
data: {
|
||||||
scenarioVariantCellId,
|
scenarioVariantCellId,
|
||||||
inputHash,
|
inputHash,
|
||||||
output: response.value as unknown as Prisma.InputJsonObject,
|
output: response.value as Prisma.InputJsonObject,
|
||||||
timeToComplete: response.timeToComplete,
|
timeToComplete: response.timeToComplete,
|
||||||
promptTokens: response.promptTokens,
|
promptTokens: response.promptTokens,
|
||||||
completionTokens: response.completionTokens,
|
completionTokens: response.completionTokens,
|
||||||
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
errorMessage: response.message,
|
errorMessage: response.message,
|
||||||
statusCode: response.statusCode,
|
statusCode: response.statusCode,
|
||||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||||
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
|
retrievalStatus: "ERROR",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
export enum OpenAIChatModel {
|
|
||||||
"gpt-4" = "gpt-4",
|
|
||||||
"gpt-4-0613" = "gpt-4-0613",
|
|
||||||
"gpt-4-32k" = "gpt-4-32k",
|
|
||||||
"gpt-4-32k-0613" = "gpt-4-32k-0613",
|
|
||||||
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
|
||||||
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
|
||||||
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
|
|
||||||
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
|
||||||
}
|
|
||||||
|
|
||||||
export type SupportedModel = keyof typeof OpenAIChatModel;
|
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { type SupportedModel } from "../types";
|
|
||||||
import ivm from "isolated-vm";
|
import ivm from "isolated-vm";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
import { openai } from "./openai";
|
import { openai } from "./openai";
|
||||||
import { getApiShapeForModel } from "./getTypesForModel";
|
|
||||||
import { isObject } from "lodash-es";
|
import { isObject } from "lodash-es";
|
||||||
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
||||||
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||||
|
import { type SupportedProvider, type Model } from "~/modelProviders/types";
|
||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
|
||||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||||
|
|
||||||
export async function deriveNewConstructFn(
|
export async function deriveNewConstructFn(
|
||||||
originalVariant: PromptVariant | null,
|
originalVariant: PromptVariant | null,
|
||||||
newModel?: SupportedModel,
|
newModel?: Model,
|
||||||
instructions?: string,
|
instructions?: string,
|
||||||
) {
|
) {
|
||||||
if (originalVariant && !newModel && !instructions) {
|
if (originalVariant && !newModel && !instructions) {
|
||||||
@@ -36,10 +36,11 @@ export async function deriveNewConstructFn(
|
|||||||
const NUM_RETRIES = 5;
|
const NUM_RETRIES = 5;
|
||||||
const requestUpdatedPromptFunction = async (
|
const requestUpdatedPromptFunction = async (
|
||||||
originalVariant: PromptVariant,
|
originalVariant: PromptVariant,
|
||||||
newModel?: SupportedModel,
|
newModel?: Model,
|
||||||
instructions?: string,
|
instructions?: string,
|
||||||
) => {
|
) => {
|
||||||
const originalModel = originalVariant.model as SupportedModel;
|
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
|
||||||
|
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
|
||||||
let newContructionFn = "";
|
let newContructionFn = "";
|
||||||
for (let i = 0; i < NUM_RETRIES; i++) {
|
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||||
try {
|
try {
|
||||||
@@ -47,17 +48,33 @@ const requestUpdatedPromptFunction = async (
|
|||||||
{
|
{
|
||||||
role: "system",
|
role: "system",
|
||||||
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
||||||
getApiShapeForModel(originalModel),
|
originalModelProvider.inputSchema,
|
||||||
null,
|
null,
|
||||||
2,
|
2,
|
||||||
)}\n\nDo not add any assistant messages.`,
|
)}\n\nDo not add any assistant messages.`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
|
||||||
|
},
|
||||||
];
|
];
|
||||||
if (newModel) {
|
if (newModel) {
|
||||||
messages.push({
|
messages.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
|
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
|
||||||
});
|
});
|
||||||
|
if (newModel.provider !== originalModel.provider) {
|
||||||
|
messages.push({
|
||||||
|
role: "user",
|
||||||
|
content: `The old provider was ${originalModel.provider}. The new provider is ${
|
||||||
|
newModel.provider
|
||||||
|
}. Here is the schema for the new model:\n---\n${JSON.stringify(
|
||||||
|
modelProviders[newModel.provider].inputSchema,
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
)}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (instructions) {
|
if (instructions) {
|
||||||
messages.push({
|
messages.push({
|
||||||
@@ -65,10 +82,6 @@ const requestUpdatedPromptFunction = async (
|
|||||||
content: instructions,
|
content: instructions,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
messages.push({
|
|
||||||
role: "system",
|
|
||||||
content: "The prompt variable has already been declared, so do not declare it again.",
|
|
||||||
});
|
|
||||||
const completion = await openai.chat.completions.create({
|
const completion = await openai.chat.completions.create({
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
messages,
|
messages,
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
|||||||
import parseConstructFn from "./parseConstructFn";
|
import parseConstructFn from "./parseConstructFn";
|
||||||
import { type JsonObject } from "type-fest";
|
import { type JsonObject } from "type-fest";
|
||||||
import hashPrompt from "./hashPrompt";
|
import hashPrompt from "./hashPrompt";
|
||||||
|
import { omit } from "lodash-es";
|
||||||
|
|
||||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: variantId,
|
id: variantId,
|
||||||
@@ -18,7 +19,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!variant || !scenario) return null;
|
if (!variant || !scenario) return;
|
||||||
|
|
||||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
where: {
|
where: {
|
||||||
@@ -32,7 +33,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (cell) return cell;
|
if (cell) return;
|
||||||
|
|
||||||
const parsedConstructFn = await parseConstructFn(
|
const parsedConstructFn = await parseConstructFn(
|
||||||
variant.constructFn,
|
variant.constructFn,
|
||||||
@@ -40,7 +41,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
);
|
);
|
||||||
|
|
||||||
if ("error" in parsedConstructFn) {
|
if ("error" in parsedConstructFn) {
|
||||||
return await prisma.scenarioVariantCell.create({
|
await prisma.scenarioVariantCell.create({
|
||||||
data: {
|
data: {
|
||||||
promptVariantId: variantId,
|
promptVariantId: variantId,
|
||||||
testScenarioId: scenarioId,
|
testScenarioId: scenarioId,
|
||||||
@@ -49,6 +50,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
retrievalStatus: "ERROR",
|
retrievalStatus: "ERROR",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const inputHash = hashPrompt(parsedConstructFn);
|
const inputHash = hashPrompt(parsedConstructFn);
|
||||||
@@ -69,29 +71,33 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
|||||||
where: { inputHash },
|
where: { inputHash },
|
||||||
});
|
});
|
||||||
|
|
||||||
let newModelOutput;
|
|
||||||
|
|
||||||
if (matchingModelOutput) {
|
if (matchingModelOutput) {
|
||||||
newModelOutput = await prisma.modelOutput.create({
|
const newModelOutput = await prisma.modelOutput.create({
|
||||||
data: {
|
data: {
|
||||||
|
...omit(matchingModelOutput, ["id"]),
|
||||||
scenarioVariantCellId: cell.id,
|
scenarioVariantCellId: cell.id,
|
||||||
inputHash,
|
|
||||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||||
timeToComplete: matchingModelOutput.timeToComplete,
|
|
||||||
cost: matchingModelOutput.cost,
|
|
||||||
promptTokens: matchingModelOutput.promptTokens,
|
|
||||||
completionTokens: matchingModelOutput.completionTokens,
|
|
||||||
createdAt: matchingModelOutput.createdAt,
|
|
||||||
updatedAt: matchingModelOutput.updatedAt,
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: cell.id },
|
where: { id: cell.id },
|
||||||
data: { retrievalStatus: "COMPLETE" },
|
data: { retrievalStatus: "COMPLETE" },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Copy over all eval results as well
|
||||||
|
await Promise.all(
|
||||||
|
(
|
||||||
|
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
|
||||||
|
).map(async (evaluation) => {
|
||||||
|
await prisma.outputEvaluation.create({
|
||||||
|
data: {
|
||||||
|
...omit(evaluation, ["id"]),
|
||||||
|
modelOutputId: newModelOutput.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
cell = await queueLLMRetrievalTask(cell.id);
|
cell = await queueLLMRetrievalTask(cell.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
return { ...cell, modelOutput: newModelOutput };
|
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
import { type SupportedModel } from "../types";
|
|
||||||
|
|
||||||
export const getApiShapeForModel = (model: SupportedModel) => {
|
|
||||||
// if (model in OpenAIChatModel) return openAIChatApiShape;
|
|
||||||
return "";
|
|
||||||
};
|
|
||||||
@@ -70,7 +70,6 @@ export default async function parseConstructFn(
|
|||||||
// We've validated the JSON schema so this should be safe
|
// We've validated the JSON schema so this should be safe
|
||||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||||
|
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
const model = provider.getModel(input);
|
const model = provider.getModel(input);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return {
|
return {
|
||||||
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
|
|
||||||
model,
|
model,
|
||||||
modelInput: input,
|
modelInput: input,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ export const editorBackground = "#fafafa";
|
|||||||
export type SharedVariantEditorSlice = {
|
export type SharedVariantEditorSlice = {
|
||||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||||
loadMonaco: () => Promise<void>;
|
loadMonaco: () => Promise<void>;
|
||||||
scenarios: RouterOutputs["scenarios"]["list"];
|
scenarios: RouterOutputs["scenarios"]["list"]["scenarios"];
|
||||||
updateScenariosModel: () => void;
|
updateScenariosModel: () => void;
|
||||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]["scenarios"]) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||||
|
|||||||
@@ -1,17 +1,14 @@
|
|||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { useScenarios } from "~/utils/hooks";
|
||||||
import { useAppStore } from "./store";
|
import { useAppStore } from "./store";
|
||||||
|
|
||||||
export function useSyncVariantEditor() {
|
export function useSyncVariantEditor() {
|
||||||
const experiment = useExperiment();
|
const scenarios = useScenarios();
|
||||||
const scenarios = api.scenarios.list.useQuery(
|
|
||||||
{ experimentId: experiment.data?.id ?? "" },
|
|
||||||
{ enabled: !!experiment.data?.id },
|
|
||||||
);
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (scenarios.data) {
|
if (scenarios.data) {
|
||||||
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data);
|
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data.scenarios);
|
||||||
}
|
}
|
||||||
}, [scenarios.data]);
|
}, [scenarios.data]);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { GPTTokens } from "gpt-tokens";
|
import { GPTTokens } from "gpt-tokens";
|
||||||
import { type OpenAIChatModel } from "~/server/types";
|
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
|
||||||
|
|
||||||
interface GPTTokensMessageItem {
|
interface GPTTokensMessageItem {
|
||||||
name?: string;
|
name?: string;
|
||||||
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const countOpenAIChatTokens = (
|
export const countOpenAIChatTokens = (
|
||||||
model: keyof typeof OpenAIChatModel,
|
model: SupportedModel,
|
||||||
messages: ChatCompletion.Choice.Message[],
|
messages: ChatCompletion.Choice.Message[],
|
||||||
) => {
|
) => {
|
||||||
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
|
import { NumberParam, useQueryParam, withDefault } from "use-query-params";
|
||||||
|
|
||||||
export const useExperiment = () => {
|
export const useExperiment = () => {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -93,3 +94,15 @@ export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | un
|
|||||||
|
|
||||||
return [ref, dimensions];
|
return [ref, dimensions];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const usePage = () => useQueryParam("page", withDefault(NumberParam, 1));
|
||||||
|
|
||||||
|
export const useScenarios = () => {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const [page] = usePage();
|
||||||
|
|
||||||
|
return api.scenarios.list.useQuery(
|
||||||
|
{ experimentId: experiment.data?.id ?? "", page },
|
||||||
|
{ enabled: experiment.data?.id != null },
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|||||||
@@ -1 +1,5 @@
|
|||||||
|
import { type Model } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
||||||
|
|
||||||
|
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;
|
||||||
|
|||||||
Reference in New Issue
Block a user