Compare commits

..

31 Commits

Author SHA1 Message Date
David Corbitt
05a932ea74 Give negative margin to account for border 2023-07-23 16:46:34 -07:00
David Corbitt
85d42a014b Fix padding on AddVariant button 2023-07-22 16:27:59 -07:00
arcticfly
7d1ded3b18 Improve menu styling (#85) 2023-07-22 16:22:00 -07:00
Kyle Corbitt
b00f6dd04b Merge pull request #84 from OpenPipe/paginated-scenarios
Paginate scenarios
2023-07-22 16:12:02 -07:00
Kyle Corbitt
2e395e4d39 Paginate scenarios
Show 10 scenarios at a time and let the user paginate through them to keep the interface responsive with potentially 1000s of scenarios.
2023-07-22 16:10:16 -07:00
Kyle Corbitt
4b06d05908 Merge pull request #82 from OpenPipe/space-out-scenarios
Separate scenarios from prompts in outputs table
2023-07-22 14:44:51 -07:00
Kyle Corbitt
aabf355b81 Merge pull request #81 from OpenPipe/fullscreen-editor
Fullscreen editor
2023-07-22 14:44:42 -07:00
Kyle Corbitt
61e5f0775d separate scenarios from prompts in outputs table 2023-07-22 07:38:19 -07:00
Kyle Corbitt
cc1d1178da Fullscreen editor 2023-07-21 22:19:38 -07:00
David Corbitt
7466db63df Make REPLICATE_API_TOKEN optional 2023-07-21 20:23:38 -07:00
David Corbitt
79a0b03bf8 Add another function call example 2023-07-21 20:16:36 -07:00
arcticfly
6fb7a82d72 Add support for switching to Llama models (#80)
* Add support for switching to Llama models

* Fix prettier
2023-07-21 20:10:59 -07:00
Kyle Corbitt
4ea30a3ba3 Merge pull request #79 from OpenPipe/copy-evals
Copy over evals when new cell created
2023-07-21 18:43:44 -07:00
Kyle Corbitt
52d1d5c7ee Copy over evals when new cell created
Fixes a bug where new cells generated as clones of existing cells didn't get the eval results cloned as well.
2023-07-21 18:40:40 -07:00
Kyle Corbitt
46036a44d2 small README update 2023-07-21 14:32:07 -07:00
Kyle Corbitt
3753fe5c16 Merge pull request #78 from OpenPipe/bugfix-max-tokens
Fix typescript hints for max_tokens
2023-07-21 12:10:00 -07:00
Kyle Corbitt
213a00a8e6 Fix typescript hints for max_tokens 2023-07-21 12:04:58 -07:00
Kyle Corbitt
af9943eefc Merge pull request #77 from OpenPipe/provider-types
Slightly better typings for ModelProviders
2023-07-21 11:51:25 -07:00
Kyle Corbitt
741128e0f4 Better division of labor between frontend and backend model providers
A bit better thinking on which types go where.
2023-07-21 11:49:35 -07:00
David Corbitt
aff14539d8 Add comment to .env.example 2023-07-21 11:29:21 -07:00
David Corbitt
1af81a50a9 Add REPLICATE_API_TOKEN to .env.example 2023-07-21 11:28:14 -07:00
Kyle Corbitt
7e1fbb3767 Slightly better typings for ModelProviders
Still not great because the `any`s loosen some call sites up more than I'd like, but better than the broken types before.
2023-07-21 06:50:05 -07:00
David Corbitt
a5d972005e Add user's current prompt to prompt derivation 2023-07-21 00:43:39 -07:00
arcticfly
a180b5bef2 Show prompt diff when changing models (#76)
* Make CompareFunctions more configurable

* Change RefinePromptModal styles

* Accept newModel in getModifiedPromptFn

* Show prompt comparison in SelectModelModal

* Pass variant to SelectModelModal

* Update instructions

* Properly use isDisabled
2023-07-20 23:26:49 -07:00
Kyle Corbitt
55c697223e Merge pull request #74 from OpenPipe/model-providers
replicate/llama2 provider
2023-07-20 23:21:42 -07:00
arcticfly
9978075867 Fix auth flicker (#75)
* Remove experiments flicker for unauthenticated users

* Decrease size of NewScenarioButton spinner
2023-07-20 20:46:31 -07:00
Kyle Corbitt
372c2512c9 Merge pull request #73 from OpenPipe/model-providers
More work on modelProviders
2023-07-20 18:56:14 -07:00
arcticfly
1822fe198e Initially render AutoResizeTextArea without overflow (#72)
* Rerender resized text area with scroll

* Remove default hidden overflow
2023-07-20 15:00:09 -07:00
Kyle Corbitt
f06e1db3db Merge pull request #71 from OpenPipe/model-providers
Prep for more model providers
2023-07-20 14:55:31 -07:00
arcticfly
9314a86857 Use translation in initial scenarios (#70) 2023-07-20 14:28:48 -07:00
David Corbitt
54dcb4a567 Prevent text input labels from overlaying scenarios header 2023-07-20 14:28:36 -07:00
60 changed files with 1174 additions and 776 deletions

View File

@@ -17,6 +17,9 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
OPENAI_API_KEY=""
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
REPLICATE_API_TOKEN=""
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
# Next Auth

View File

@@ -1,6 +1,3 @@
{
"eslint.format.enable": true,
"editor.codeActionsOnSave": {
"source.fixAll.eslint": true
}
"eslint.format.enable": true
}

View File

@@ -43,7 +43,8 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
## Supported Models
OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
## Running Locally

View File

@@ -59,6 +59,7 @@
"lodash-es": "^4.17.21",
"next": "^13.4.2",
"next-auth": "^4.22.1",
"next-query-params": "^4.2.3",
"nextjs-routes": "^2.0.1",
"openai": "4.0.0-beta.2",
"pluralize": "^8.0.0",
@@ -79,6 +80,7 @@
"superjson": "1.12.2",
"tsx": "^3.12.7",
"type-fest": "^4.0.0",
"use-query-params": "^2.2.1",
"vite-tsconfig-paths": "^4.2.0",
"zod": "^3.21.4",
"zustand": "^4.3.9"

41
pnpm-lock.yaml generated
View File

@@ -119,6 +119,9 @@ dependencies:
next-auth:
specifier: ^4.22.1
version: 4.22.1(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
next-query-params:
specifier: ^4.2.3
version: 4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1)
nextjs-routes:
specifier: ^2.0.1
version: 2.0.1(next@13.4.2)
@@ -179,6 +182,9 @@ dependencies:
type-fest:
specifier: ^4.0.0
version: 4.0.0
use-query-params:
specifier: ^2.2.1
version: 2.2.1(react-dom@18.2.0)(react@18.2.0)
vite-tsconfig-paths:
specifier: ^4.2.0
version: 4.2.0(typescript@5.0.4)
@@ -6037,6 +6043,19 @@ packages:
uuid: 8.3.2
dev: false
/next-query-params@4.2.3(next@13.4.2)(react@18.2.0)(use-query-params@2.2.1):
resolution: {integrity: sha512-hGNCYRH8YyA5ItiBGSKrtMl21b2MAqfPkdI1mvwloNVqSU142IaGzqHN+OTovyeLIpQfonY01y7BAHb/UH4POg==}
peerDependencies:
next: ^10.0.0 || ^11.0.0 || ^12.0.0 || ^13.0.0
react: ^16.8.0 || ^17.0.0 || ^18.0.0
use-query-params: ^2.0.0
dependencies:
next: 13.4.2(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0)
react: 18.2.0
tslib: 2.6.0
use-query-params: 2.2.1(react-dom@18.2.0)(react@18.2.0)
dev: false
/next-tick@1.1.0:
resolution: {integrity: sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ==}
dev: false
@@ -7147,6 +7166,10 @@ packages:
randombytes: 2.1.0
dev: true
/serialize-query-params@2.0.2:
resolution: {integrity: sha512-1chMo1dST4pFA9RDXAtF0Rbjaut4is7bzFbI1Z26IuMub68pNCILku85aYmeFhvnY//BXUPUhoRMjYcsT93J/Q==}
dev: false
/serve-static@1.15.0:
resolution: {integrity: sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==}
engines: {node: '>= 0.8.0'}
@@ -7824,6 +7847,24 @@ packages:
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
dev: false
/use-query-params@2.2.1(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-i6alcyLB8w9i3ZK3caNftdb+UnbfBRNPDnc89CNQWkGRmDrm/gfydHvMBfVsQJRq3NoHOM2dt/ceBWG2397v1Q==}
peerDependencies:
'@reach/router': ^1.2.1
react: '>=16.8.0'
react-dom: '>=16.8.0'
react-router-dom: '>=5'
peerDependenciesMeta:
'@reach/router':
optional: true
react-router-dom:
optional: true
dependencies:
react: 18.2.0
react-dom: 18.2.0(react@18.2.0)
serialize-query-params: 2.0.2
dev: false
/use-sidecar@1.1.2(@types/react@18.2.6)(react@18.2.0):
resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
engines: {node: '>=10'}

View File

@@ -7,9 +7,13 @@ const defaultId = "11111111-1111-1111-1111-111111111111";
await prisma.organization.deleteMany({
where: { id: defaultId },
});
await prisma.organization.create({
data: { id: defaultId },
});
// If there's an existing org, just seed into it
const org =
(await prisma.organization.findFirst({})) ??
(await prisma.organization.create({
data: { id: defaultId },
}));
await prisma.experiment.deleteMany({
where: {
@@ -21,7 +25,7 @@ await prisma.experiment.create({
data: {
id: defaultId,
label: "Country Capitals Example",
organizationId: defaultId,
organizationId: org.id,
},
});
@@ -103,30 +107,41 @@ await prisma.testScenario.deleteMany({
},
});
const countries = [
"Afghanistan",
"Albania",
"Algeria",
"Andorra",
"Angola",
"Antigua and Barbuda",
"Argentina",
"Armenia",
"Australia",
"Austria",
"Austrian Empire",
"Azerbaijan",
"Baden",
"Bahamas, The",
"Bahrain",
"Bangladesh",
"Barbados",
"Bavaria",
"Belarus",
"Belgium",
"Belize",
"Benin (Dahomey)",
"Bolivia",
"Bosnia and Herzegovina",
"Botswana",
];
await prisma.testScenario.createMany({
data: [
{
experimentId: defaultId,
sortIndex: 0,
variableValues: {
country: "Spain",
},
data: countries.map((country, i) => ({
experimentId: defaultId,
sortIndex: i,
variableValues: {
country: country,
},
{
experimentId: defaultId,
sortIndex: 1,
variableValues: {
country: "USA",
},
},
{
experimentId: defaultId,
sortIndex: 2,
variableValues: {
country: "Chile",
},
},
],
})),
});
const variants = await prisma.promptVariant.findMany({

View File

@@ -1,19 +1,22 @@
import { Textarea, type TextareaProps } from "@chakra-ui/react";
import ResizeTextarea from "react-textarea-autosize";
import React from "react";
import React, { useLayoutEffect, useState } from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement,
TextareaProps & { minRows?: number }
> = (props, ref) => {
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
const [isRerendered, setIsRerendered] = useState(false);
useLayoutEffect(() => setIsRerendered(true), []);
return (
<Textarea
minH="unset"
overflow="hidden"
minRows={minRows}
overflowY={isRerendered ? overflowY : "hidden"}
w="100%"
resize="none"
ref={ref}
minRows={1}
transition="height none"
as={ResizeTextarea}
{...props}

View File

@@ -0,0 +1,135 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
Icon,
} from "@chakra-ui/react";
import { RiExchangeFundsFill } from "react-icons/ri";
import { useState } from "react";
import { ModelStatsCard } from "./ModelStatsCard";
import { ModelSearch } from "./ModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es";
import { type Model, type SupportedProvider } from "~/modelProviders/types";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { keyForModel } from "~/utils/utils";
export const ChangeModelModal = ({
variant,
onClose,
}: {
variant: PromptVariant;
onClose: () => void;
}) => {
const originalModelProviderName = variant.modelProvider as SupportedProvider;
const originalModelProvider = frontendModelProviders[originalModelProviderName];
const originalModel = originalModelProvider.models[variant.model] as Model;
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
const utils = api.useContext();
const experiment = useExperiment();
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment) return;
await getModifiedPromptMutateAsync({
id: variant.id,
newModel: selectedModel,
});
setConvertedModel(selectedModel);
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (
!variant.experimentId ||
!modifiedPromptFn ||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
)
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: modifiedPromptFn,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
const originalModelLabel = keyForModel(originalModel);
const selectedModelLabel = keyForModel(selectedModel);
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
return (
<Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={RiExchangeFundsFill} />
<Text>Change Model</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} />
{originalModelLabel !== selectedModelLabel && (
<ModelStatsCard label="New Model" model={selectedModel} />
)}
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && (
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={modifiedPromptFn}
leftTitle={originalModelLabel}
rightTitle={convertedModelLabel}
/>
)}
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button
colorScheme="gray"
onClick={getModifiedPromptFn}
minW={24}
isDisabled={originalModel === selectedModel || modificationInProgress}
>
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
</Button>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,50 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { useElementDimensions } from "~/utils/hooks";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { type Model } from "~/modelProviders/types";
import { keyForModel } from "~/utils/utils";
const modelOptions: { label: string; value: Model }[] = [];
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
for (const [_, modelValue] of Object.entries(providerValue.models)) {
modelOptions.push({
label: keyForModel(modelValue),
value: modelValue,
});
}
}
export const ModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: Model;
setSelectedModel: (model: Model) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ label: string; value: Model }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -7,11 +7,9 @@ import {
SimpleGrid,
Link,
} from "@chakra-ui/react";
import { modelStats } from "~/modelProviders/modelStats";
import { type SupportedModel } from "~/server/types";
import { type Model } from "~/modelProviders/types";
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
const stats = modelStats[model];
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
return (
<VStack w="full" align="start">
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
@@ -22,14 +20,14 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
<HStack w="full" align="flex-start">
<Text flex={1} fontSize="lg">
<Text as="span" color="gray.600">
{stats.provider} /{" "}
{model.provider} /{" "}
</Text>
<Text as="span" fontWeight="bold" color="gray.900">
{model}
{model.name}
</Text>
</Text>
<Link
href={stats.learnMoreUrl}
href={model.learnMoreUrl}
isExternal
color="blue.500"
fontWeight="bold"
@@ -46,26 +44,41 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
fontSize="sm"
columns={{ base: 2, md: 4 }}
>
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
{model.promptTokenPrice && (
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(model.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
)}
{model.completionTokenPrice && (
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(model.completionTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
)}
{model.pricePerSecond && (
<SelectedModelLabeledInfo
label="Price"
info={
<Text>
${model.pricePerSecond.toFixed(3)}
<Text color="gray.500"> / second</Text>
</Text>
}
/>
)}
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
</SimpleGrid>
</VStack>
</VStack>

View File

@@ -0,0 +1,50 @@
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { Text } from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding } from "../constants";
import { ActionButton } from "./ScenariosHeader";
export default function AddVariantButton() {
const experiment = useExperiment();
const mutation = api.promptVariants.create.useMutation();
const utils = api.useContext();
const [onClick, loading] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
});
await utils.promptVariants.list.invalidate();
}, [mutation]);
const { canModify } = useExperimentAccess();
if (!canModify) return <Box w={cellPadding.x} />;
return (
<Flex w="100%" justifyContent="flex-end">
<ActionButton
onClick={onClick}
py={5}
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
>
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
</ActionButton>
{/* <Button
alignItems="center"
justifyContent="center"
fontWeight="normal"
bgColor="transparent"
_hover={{ bgColor: "gray.100" }}
px={cellPadding.x}
onClick={onClick}
height="unset"
minH={headerMinHeight}
>
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
</Button> */}
</Flex>
);
}

View File

@@ -18,11 +18,9 @@ export const FloatingLabelInput = ({
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
fontSize={isFocused || !!value ? "12px" : "16px"}
transition="all 0.15s"
zIndex="100"
zIndex="5"
bg="white"
px={1}
mt={0}
mb={2}
lineHeight="1"
pointerEvents="none"
color={isFocused ? "blue.500" : "gray.500"}

View File

@@ -1,57 +0,0 @@
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
// Extracted Button styling into reusable component
const StyledButton = ({ children, onClick }: ButtonProps) => (
<Button
fontWeight="normal"
bgColor="transparent"
_hover={{ bgColor: "gray.100" }}
px={2}
onClick={onClick}
>
{children}
</Button>
);
export default function NewScenarioButton() {
const { canModify } = useExperimentAccess();
const experiment = useExperiment();
const mutation = api.scenarios.create.useMutation();
const utils = api.useContext();
const [onClick] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
});
await utils.scenarios.list.invalidate();
}, [mutation]);
const [onAutogenerate, autogenerating] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
autogenerate: true,
});
await utils.scenarios.list.invalidate();
}, [mutation]);
if (!canModify) return null;
return (
<HStack spacing={2}>
<StyledButton onClick={onClick}>
<Icon as={BsPlus} boxSize={6} />
Add Scenario
</StyledButton>
<StyledButton onClick={onAutogenerate}>
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
Autogenerate Scenario
</StyledButton>
</HStack>
);
}

View File

@@ -1,40 +0,0 @@
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding, headerMinHeight } from "../constants";
export default function NewVariantButton() {
const experiment = useExperiment();
const mutation = api.promptVariants.create.useMutation();
const utils = api.useContext();
const [onClick, loading] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
});
await utils.promptVariants.list.invalidate();
}, [mutation]);
const { canModify } = useExperimentAccess();
if (!canModify) return <Box w={cellPadding.x} />;
return (
<Button
w="100%"
alignItems="center"
justifyContent="center"
fontWeight="normal"
bgColor="transparent"
_hover={{ bgColor: "gray.100" }}
px={cellPadding.x}
onClick={onClick}
height="unset"
minH={headerMinHeight}
>
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
</Button>
);
}

View File

@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
export default function OutputCell({
scenario,
@@ -40,7 +40,7 @@ export default function OutputCell({
);
const provider =
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
@@ -88,11 +88,9 @@ export default function OutputCell({
}
const normalizedOutput = modelOutput
? // @ts-expect-error TODO FIX ASAP
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
? provider.normalizeOutput(modelOutput.output)
: streamedMessage
? // @ts-expect-error TODO FIX ASAP
provider.normalizeOutput(streamedMessage)
? provider.normalizeOutput(streamedMessage)
: null;
if (modelOutput && normalizedOutput?.type === "json") {

View File

@@ -0,0 +1,74 @@
import { Box, HStack, IconButton } from "@chakra-ui/react";
import {
BsChevronDoubleLeft,
BsChevronDoubleRight,
BsChevronLeft,
BsChevronRight,
} from "react-icons/bs";
import { usePage, useScenarios } from "~/utils/hooks";
const ScenarioPaginator = () => {
const [page, setPage] = usePage();
const { data } = useScenarios();
if (!data) return null;
const { scenarios, startIndex, lastPage, count } = data;
const nextPage = () => {
if (page < lastPage) {
setPage(page + 1, "replace");
}
};
const prevPage = () => {
if (page > 1) {
setPage(page - 1, "replace");
}
};
const goToLastPage = () => setPage(lastPage, "replace");
const goToFirstPage = () => setPage(1, "replace");
return (
<HStack pt={4}>
<IconButton
variant="ghost"
size="sm"
onClick={goToFirstPage}
isDisabled={page === 1}
aria-label="Go to first page"
icon={<BsChevronDoubleLeft />}
/>
<IconButton
variant="ghost"
size="sm"
onClick={prevPage}
isDisabled={page === 1}
aria-label="Previous page"
icon={<BsChevronLeft />}
/>
<Box>
{startIndex}-{startIndex + scenarios.length - 1} / {count}
</Box>
<IconButton
variant="ghost"
size="sm"
onClick={nextPage}
isDisabled={page === lastPage}
aria-label="Next page"
icon={<BsChevronRight />}
/>
<IconButton
variant="ghost"
size="sm"
onClick={goToLastPage}
isDisabled={page === lastPage}
aria-label="Go to last page"
icon={<BsChevronDoubleRight />}
/>
</HStack>
);
};
export default ScenarioPaginator;

View File

@@ -4,11 +4,13 @@ import { cellPadding } from "../constants";
import OutputCell from "./OutputCell/OutputCell";
import ScenarioEditor from "./ScenarioEditor";
import type { PromptVariant, Scenario } from "./types";
import { borders } from "./styles";
const ScenarioRow = (props: {
scenario: Scenario;
variants: PromptVariant[];
canHide: boolean;
rowStart: number;
}) => {
const [isHovered, setIsHovered] = useState(false);
@@ -21,15 +23,21 @@ const ScenarioRow = (props: {
onMouseLeave={() => setIsHovered(false)}
sx={isHovered ? highlightStyle : undefined}
borderLeftWidth={1}
{...borders}
rowStart={props.rowStart}
colStart={1}
>
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
</GridItem>
{props.variants.map((variant) => (
{props.variants.map((variant, i) => (
<GridItem
key={variant.id}
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
sx={isHovered ? highlightStyle : undefined}
rowStart={props.rowStart}
colStart={i + 2}
{...borders}
>
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />

View File

@@ -1,52 +1,82 @@
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
import {
Button,
type ButtonProps,
HStack,
Text,
Icon,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
Spinner,
} from "@chakra-ui/react";
import { cellPadding } from "../constants";
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
import { stickyHeaderStyle } from "./styles";
import { BsPencil } from "react-icons/bs";
import {
useExperiment,
useExperimentAccess,
useHandledAsyncCallback,
useScenarios,
} from "~/utils/hooks";
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
import { useAppStore } from "~/state/store";
import { api } from "~/utils/api";
export const ScenariosHeader = ({
headerRows,
numScenarios,
}: {
headerRows: number;
numScenarios: number;
}) => {
export const ActionButton = (props: ButtonProps) => (
<Button size="sm" variant="ghost" color="gray.600" {...props} />
);
export const ScenariosHeader = () => {
const openDrawer = useAppStore((s) => s.openDrawer);
const { canModify } = useExperimentAccess();
const scenarios = useScenarios();
const [ref, dimensions] = useElementDimensions();
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
const experiment = useExperiment();
const createScenarioMutation = api.scenarios.create.useMutation();
const utils = api.useContext();
const [onAddScenario, loading] = useHandledAsyncCallback(
async (autogenerate: boolean) => {
if (!experiment.data) return;
await createScenarioMutation.mutateAsync({
experimentId: experiment.data.id,
autogenerate,
});
await utils.scenarios.list.invalidate();
},
[createScenarioMutation],
);
return (
<GridItem
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ref={ref as any}
display="flex"
alignItems="flex-end"
rowSpan={headerRows}
px={cellPadding.x}
py={cellPadding.y}
// Only display the part of the grid item that has content
sx={{ ...stickyHeaderStyle, top: topValue }}
>
<HStack w="100%">
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({numScenarios})
</Heading>
{canModify && (
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
)}
</HStack>
</GridItem>
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
<Text fontSize={16} fontWeight="bold">
Scenarios ({scenarios.data?.count})
</Text>
{canModify && (
<Menu>
<MenuButton mt={1}>
<IconButton
variant="ghost"
aria-label="Edit Scenarios"
icon={<Icon as={loading ? Spinner : BsGear} />}
/>
</MenuButton>
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
<MenuItem
icon={<Icon as={BsPlus} boxSize={6} mx={-1} />}
onClick={() => onAddScenario(false)}
>
Add Scenario
</MenuItem>
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
Autogenerate Scenario
</MenuItem>
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
Edit Vars
</MenuItem>
</MenuList>
</Menu>
)}
</HStack>
);
};

View File

@@ -1,17 +1,47 @@
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import {
Box,
Button,
HStack,
Spinner,
Tooltip,
useToast,
Text,
IconButton,
} from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react";
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import { FiMaximize, FiMinimize } from "react-icons/fi";
import { editorBackground } from "~/state/sharedVariantEditor.slice";
export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const containerRef = useRef<HTMLDivElement | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
const [isChanged, setIsChanged] = useState(false);
const [isFullscreen, setIsFullscreen] = useState(false);
const toggleFullscreen = useCallback(() => {
setIsFullscreen((prev) => !prev);
editorRef.current?.focus();
}, [setIsFullscreen]);
useEffect(() => {
const handleEsc = (event: KeyboardEvent) => {
if (event.key === "Escape" && isFullscreen) {
toggleFullscreen();
}
};
window.addEventListener("keydown", handleEsc);
return () => window.removeEventListener("keydown", handleEsc);
}, [isFullscreen, toggleFullscreen]);
const lastSavedFn = props.variant.constructFn;
const modifierKey = useModifierKeyLabel();
@@ -99,11 +129,23 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
readOnly: !canModify,
});
// Workaround because otherwise the commands only work on whatever
// editor was loaded on the page last.
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
editorRef.current.onDidFocusEditorText(() => {
// Workaround because otherwise the command only works on whatever
// editor was loaded on the page last.
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave);
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, onSave);
editorRef.current?.addCommand(
monaco.KeyMod.CtrlCmd | monaco.KeyMod.Shift | monaco.KeyCode.KeyF,
toggleFullscreen,
);
// Exit fullscreen with escape
editorRef.current?.addCommand(monaco.KeyCode.Escape, () => {
if (isFullscreen) {
toggleFullscreen();
}
});
});
editorRef.current.onDidChangeModelContent(checkForChanges);
@@ -132,8 +174,40 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
}, [canModify]);
return (
<Box w="100%" pos="relative">
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
<Box
w="100%"
ref={containerRef}
sx={
isFullscreen
? {
position: "fixed",
top: 0,
left: 0,
right: 0,
bottom: 0,
}
: { h: "400px", w: "100%" }
}
bgColor={editorBackground}
zIndex={isFullscreen ? 1000 : "unset"}
pos="relative"
_hover={{ ".fullscreen-toggle": { opacity: 1 } }}
>
<Box id={editorId} w="100%" h="100%" />
<Tooltip label={`${modifierKey} + ⇧ + F`}>
<IconButton
className="fullscreen-toggle"
aria-label="Minimize"
icon={isFullscreen ? <FiMinimize /> : <FiMaximize />}
position="absolute"
top={2}
right={2}
onClick={toggleFullscreen}
opacity={0}
transition="opacity 0.2s"
/>
</Tooltip>
{isChanged && (
<HStack pos="absolute" bottom={2} right={2}>
<Button
@@ -146,7 +220,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
>
Reset
</Button>
<Tooltip label={`${modifierKey} + Enter`}>
<Tooltip label={`${modifierKey} + S`}>
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
</Button>

View File

@@ -1,13 +1,14 @@
import { Grid, GridItem } from "@chakra-ui/react";
import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
import { api } from "~/utils/api";
import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton";
import AddVariantButton from "./AddVariantButton";
import ScenarioRow from "./ScenarioRow";
import VariantEditor from "./VariantEditor";
import VariantHeader from "../VariantHeader/VariantHeader";
import VariantStats from "./VariantStats";
import { ScenariosHeader } from "./ScenariosHeader";
import { stickyHeaderStyle } from "./styles";
import { borders } from "./styles";
import { useScenarios } from "~/utils/hooks";
import ScenarioPaginator from "./ScenarioPaginator";
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const variants = api.promptVariants.list.useQuery(
@@ -15,68 +16,91 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
{ enabled: !!experimentId },
);
const scenarios = api.scenarios.list.useQuery(
{ experimentId: experimentId as string },
{ enabled: !!experimentId },
);
const scenarios = useScenarios();
if (!variants.data || !scenarios.data) return null;
const allCols = variants.data.length + 1;
const headerRows = 3;
const allCols = variants.data.length + 2;
const variantHeaderRows = 3;
const scenarioHeaderRows = 1;
const scenarioFooterRows = 1;
const visibleScenariosCount = scenarios.data.scenarios.length;
const allRows =
variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + scenarioFooterRows;
return (
<Grid
p={4}
pt={4}
pb={24}
pl={4}
display="grid"
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
sx={{
"> *": {
borderColor: "gray.300",
borderBottomWidth: 1,
borderRightWidth: 1,
},
}}
fontSize="sm"
>
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
{variants.data.map((variant) => (
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
))}
<GridItem
rowSpan={scenarios.data.length + headerRows}
padding={0}
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
h={8}
sx={stickyHeaderStyle}
>
<NewVariantButton />
<GridItem rowSpan={variantHeaderRows}>
<AddVariantButton />
</GridItem>
{variants.data.map((variant) => (
<GridItem key={variant.uiId}>
<VariantEditor variant={variant} />
</GridItem>
))}
{variants.data.map((variant) => (
<GridItem key={variant.uiId}>
<VariantStats variant={variant} />
</GridItem>
))}
{scenarios.data.map((scenario) => (
{variants.data.map((variant, i) => {
const sharedProps: GridItemProps = {
...borders,
colStart: i + 2,
borderLeftWidth: i === 0 ? 1 : 0,
marginLeft: i === 0 ? "-1px" : 0,
};
return (
<>
<VariantHeader
key={variant.uiId}
variant={variant}
canHide={variants.data.length > 1}
rowStart={1}
{...sharedProps}
/>
<GridItem rowStart={2} {...sharedProps}>
<VariantEditor variant={variant} />
</GridItem>
<GridItem rowStart={3} {...sharedProps}>
<VariantStats variant={variant} />
</GridItem>
</>
);
})}
<GridItem
colSpan={allCols - 1}
rowStart={variantHeaderRows + 1}
colStart={1}
{...borders}
borderRightWidth={0}
>
<ScenariosHeader />
</GridItem>
{scenarios.data.scenarios.map((scenario, i) => (
<ScenarioRow
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
key={scenario.uiId}
scenario={scenario}
variants={variants.data}
canHide={scenarios.data.length > 1}
canHide={visibleScenariosCount > 1}
/>
))}
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
<NewScenarioButton />
<GridItem
rowStart={variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + 2}
colStart={1}
colSpan={allCols}
>
<ScenarioPaginator />
</GridItem>
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
</Grid>
);
}

View File

@@ -1,8 +1,13 @@
import { type SystemStyleObject } from "@chakra-ui/react";
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "0",
backgroundColor: "#fff",
zIndex: 1,
zIndex: 10,
};
export const borders: GridItemProps = {
borderRightWidth: 1,
borderBottomWidth: 1,
};

View File

@@ -2,4 +2,4 @@ import { type RouterOutputs } from "~/utils/api";
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>[0];
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>["scenarios"][0];

View File

@@ -1,4 +1,4 @@
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react";
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
import React from "react";
import DiffViewer, { DiffMethod } from "react-diff-viewer";
import Prism from "prismjs";
@@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => {
const CompareFunctions = ({
originalFunction,
newFunction = "",
leftTitle = "Original",
rightTitle = "Modified",
...props
}: {
originalFunction: string;
newFunction?: string;
}) => {
leftTitle?: string;
rightTitle?: string;
} & StackProps) => {
const showSplitView = useBreakpointValue(
{
base: false,
@@ -34,22 +39,20 @@ const CompareFunctions = ({
);
return (
<HStack w="full" spacing={5}>
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={showSplitView}
hideLineNumbers={!showSplitView}
leftTitle="Original"
rightTitle={newFunction ? "Modified" : "Unmodified"}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
showDiffOnly={false}
/>
</VStack>
</HStack>
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={showSplitView}
hideLineNumbers={!showSplitView}
leftTitle={leftTitle}
rightTitle={rightTitle}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
showDiffOnly={false}
/>
</VStack>
);
};

View File

@@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({
minW="unset"
size="sm"
onClick={() => onSubmit()}
disabled={!instructions}
variant={instructions ? "solid" : "ghost"}
mr={4}
borderRadius="8"

View File

@@ -1,22 +1,22 @@
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
import { type IconType } from "react-icons";
import { refineOptions, type RefineOptionLabel } from "./refineOptions";
export const RefineOption = ({
label,
activeLabel,
icon,
desciption,
activeLabel,
onClick,
loading,
}: {
label: RefineOptionLabel;
activeLabel: RefineOptionLabel | undefined;
label: string;
icon: IconType;
onClick: (label: RefineOptionLabel) => void;
desciption: string;
activeLabel: string | undefined;
onClick: (label: string) => void;
loading: boolean;
}) => {
const isActive = activeLabel === label;
const desciption = refineOptions[label].description;
return (
<GridItem w="80" h="44">

View File

@@ -15,17 +15,16 @@ import {
SimpleGrid,
} from "@chakra-ui/react";
import { BsStars } from "react-icons/bs";
import { VscJson } from "react-icons/vsc";
import { TfiThought } from "react-icons/tfi";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import CompareFunctions from "./CompareFunctions";
import { CustomInstructionsInput } from "./CustomInstructionsInput";
import { type RefineOptionLabel, refineOptions } from "./refineOptions";
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
import { RefineOption } from "./RefineOption";
import { isObject, isString } from "lodash-es";
import { type SupportedProvider } from "~/modelProviders/types";
export const RefinePromptModal = ({
variant,
@@ -36,25 +35,29 @@ export const RefinePromptModal = ({
}) => {
const utils = api.useContext();
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getRefinedPromptFn.useMutation();
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
RefineOptionLabel | undefined
>(undefined);
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
undefined,
);
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(
async (label?: RefineOptionLabel) => {
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
async (label?: string) => {
if (!variant.experimentId) return;
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
const updatedInstructions = label
? (providerRefineOptions[label] as RefineOptionInfo).instructions
: instructions;
setActiveRefineOptionLabel(label);
await getRefinedPromptMutateAsync({
await getModifiedPromptMutateAsync({
id: variant.id,
instructions: updatedInstructions,
});
},
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
@@ -75,7 +78,11 @@ export const RefinePromptModal = ({
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
<Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
@@ -88,35 +95,37 @@ export const RefinePromptModal = ({
<ModalBody maxW="unset">
<VStack spacing={8}>
<VStack spacing={4}>
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
<RefineOption
label="Convert to function call"
activeLabel={activeRefineOptionLabel}
icon={VscJson}
onClick={getRefinedPromptFn}
loading={refiningInProgress}
/>
<RefineOption
label="Add chain of thought"
activeLabel={activeRefineOptionLabel}
icon={TfiThought}
onClick={getRefinedPromptFn}
loading={refiningInProgress}
/>
</SimpleGrid>
<HStack>
<Text color="gray.500">or</Text>
</HStack>
{Object.keys(providerRefineOptions).length && (
<>
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
{Object.keys(providerRefineOptions).map((label) => (
<RefineOption
key={label}
label={label}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
icon={providerRefineOptions[label]!.icon}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
desciption={providerRefineOptions[label]!.description}
activeLabel={activeRefineOptionLabel}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>
))}
</SimpleGrid>
<Text color="gray.500">or</Text>
</>
)}
<CustomInstructionsInput
instructions={instructions}
setInstructions={setInstructions}
loading={refiningInProgress}
onSubmit={getRefinedPromptFn}
loading={modificationInProgress}
onSubmit={getModifiedPromptFn}
/>
</VStack>
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
maxH="40vh"
/>
</VStack>
</ModalBody>
@@ -124,12 +133,10 @@ export const RefinePromptModal = ({
<ModalFooter>
<HStack spacing={4}>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
disabled={replacementInProgress || !refinedPromptFn}
_disabled={{
bgColor: "blue.500",
}}
isDisabled={replacementInProgress || !refinedPromptFn}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>

View File

@@ -1,18 +1,22 @@
// Super hacky, but we'll redo the organization when we have more models
export type RefineOptionLabel = "Add chain of thought" | "Convert to function call";
import { type SupportedProvider } from "~/modelProviders/types";
import { VscJson } from "react-icons/vsc";
import { TfiThought } from "react-icons/tfi";
import { type IconType } from "react-icons";
export const refineOptions: Record<
RefineOptionLabel,
{ description: string; instructions: string }
> = {
"Add chain of thought": {
description: "Asking the model to plan its answer can increase accuracy.",
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
"openai/ChatCompletion": {
"Add chain of thought": {
icon: VscJson,
description: "Asking the model to plan its answer can increase accuracy.",
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
This is what a prompt looks like before adding chain of thought:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -25,11 +29,11 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
},
],
};
});
This is what one looks like after adding chain of thought:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -42,13 +46,13 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
},
],
};
});
Here's another example:
Before:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -78,11 +82,11 @@ export const refineOptions: Record<
function_call: {
name: "score_post",
},
};
});
After:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -115,17 +119,18 @@ export const refineOptions: Record<
function_call: {
name: "score_post",
},
};
});
Add chain of thought to the original prompt.`,
},
"Convert to function call": {
description: "Use function calls to get output from the model in a more structured way.",
instructions: `OpenAI functions are a specialized way for an LLM to return output.
},
"Convert to function call": {
icon: TfiThought,
description: "Use function calls to get output from the model in a more structured way.",
instructions: `OpenAI functions are a specialized way for an LLM to return output.
This is what a prompt looks like before adding a function:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -138,11 +143,11 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
},
],
};
});
This is what one looks like after adding a function:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -172,13 +177,13 @@ export const refineOptions: Record<
function_call: {
name: "extract_sentiment",
},
};
});
Here's another example of adding a function:
Before:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -196,11 +201,11 @@ export const refineOptions: Record<
},
],
temperature: 0,
};
});
After:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -230,8 +235,53 @@ export const refineOptions: Record<
function_call: {
name: "score_post",
},
};
});
Another example
Before:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
stream: true,
messages: [
{
role: "system",
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
},
],
});
After:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
},
],
functions: [
{
name: "write_in_language",
parameters: {
type: "object",
properties: {
text: {
type: "string",
},
},
},
},
],
function_call: {
name: "write_in_language",
},
});
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
},
},
"replicate/llama2": {},
};

View File

@@ -1,85 +0,0 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
Icon,
} from "@chakra-ui/react";
import { RiExchangeFundsFill } from "react-icons/ri";
import { useState } from "react";
import { type SupportedModel } from "~/server/types";
import { ModelStatsCard } from "./ModelStatsCard";
import { SelectModelSearch } from "./SelectModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const SelectModelModal = ({
originalModel,
variantId,
onClose,
}: {
originalModel: SupportedModel;
variantId: string;
onClose: () => void;
}) => {
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
const utils = api.useContext();
const experiment = useExperiment();
const createMutation = api.promptVariants.create.useMutation();
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment?.data?.id) return;
await createMutation.mutateAsync({
experimentId: experiment?.data?.id,
variantId,
newModel: selectedModel,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [createMutation, experiment?.data?.id, variantId, onClose]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={RiExchangeFundsFill} />
<Text>Change Model</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} />
{originalModel !== selectedModel && (
<ModelStatsCard label="New Model" model={selectedModel} />
)}
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
</VStack>
</ModalBody>
<ModalFooter>
<Button
colorScheme="blue"
onClick={createNewVariant}
minW={24}
disabled={originalModel === selectedModel}
>
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -1,47 +0,0 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { type SupportedModel } from "~/server/types";
import { useElementDimensions } from "~/utils/hooks";
const modelOptions: { value: SupportedModel; label: string }[] = [
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
{ value: "gpt-4", label: "gpt-4" },
{ value: "gpt-4-0613", label: "gpt-4-0613" },
{ value: "gpt-4-32k", label: "gpt-4-32k" },
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
];
export const SelectModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: SupportedModel;
setSelectedModel: (model: SupportedModel) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -3,28 +3,34 @@ import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { RiDraggable } from "react-icons/ri";
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "../OutputsTable/styles";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
export default function VariantHeader(
allProps: {
variant: PromptVariant;
canHide: boolean;
} & GridItemProps,
) {
const { variant, canHide, ...gridItemProps } = allProps;
const { canModify } = useExperimentAccess();
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
const [label, setLabel] = useState(props.variant.label);
const [label, setLabel] = useState(variant.label);
const updateMutation = api.promptVariants.update.useMutation();
const [onSaveLabel] = useHandledAsyncCallback(async () => {
if (label && label !== props.variant.label) {
if (label && label !== variant.label) {
await updateMutation.mutateAsync({
id: props.variant.id,
id: variant.id,
updates: { label: label },
});
}
}, [updateMutation, props.variant.id, props.variant.label, label]);
}, [updateMutation, variant.id, variant.label, label]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
@@ -32,7 +38,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
e.preventDefault();
setIsDragTarget(false);
const draggedId = e.dataTransfer.getData("text/plain");
const droppedId = props.variant.id;
const droppedId = variant.id;
if (!draggedId || !droppedId || draggedId === droppedId) return;
await reorderMutation.mutateAsync({
draggedId,
@@ -40,16 +46,16 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
});
await utils.promptVariants.list.invalidate();
},
[reorderMutation, props.variant.id],
[reorderMutation, variant.id],
);
const [menuOpen, setMenuOpen] = useState(false);
if (!canModify) {
return (
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
{props.variant.label}
{variant.label}
</Text>
</GridItem>
);
@@ -64,6 +70,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
}}
borderTopWidth={1}
{...gridItemProps}
>
<HStack
spacing={4}
@@ -71,7 +78,7 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
minH={headerMinHeight}
draggable={!isInputHovered}
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", props.variant.id);
e.dataTransfer.setData("text/plain", variant.id);
e.currentTarget.style.opacity = "0.4";
}}
onDragEnd={(e) => {
@@ -112,8 +119,8 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
onMouseLeave={() => setIsInputHovered(false)}
/>
<VariantHeaderMenuButton
variant={props.variant}
canHide={props.canHide}
variant={variant}
canHide={canHide}
menuOpen={menuOpen}
setMenuOpen={setMenuOpen}
/>

View File

@@ -17,8 +17,7 @@ import { FaRegClone } from "react-icons/fa";
import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
export default function VariantHeaderMenuButton({
variant,
@@ -51,7 +50,7 @@ export default function VariantHeaderMenuButton({
await utils.promptVariants.list.invalidate();
}, [hideMutation, variant.id]);
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
return (
@@ -73,7 +72,7 @@ export default function VariantHeaderMenuButton({
</MenuItem>
<MenuItem
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
onClick={() => setSelectModelModalOpen(true)}
onClick={() => setChangeModelModalOpen(true)}
>
Change Model
</MenuItem>
@@ -98,12 +97,8 @@ export default function VariantHeaderMenuButton({
)}
</MenuList>
</Menu>
{selectModelModalOpen && (
<SelectModelModal
originalModel={variant.model as SupportedModel}
variantId={variant.id}
onClose={() => setSelectModelModalOpen(false)}
/>
{changeModelModalOpen && (
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
)}
{refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />

View File

@@ -9,7 +9,6 @@ export const env = createEnv({
server: {
DATABASE_URL: z.string().url(),
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
OPENAI_API_KEY: z.string().min(1),
RESTRICT_PRISMA_LOGS: z
.string()
.optional()
@@ -17,7 +16,8 @@ export const env = createEnv({
.transform((val) => val.toLowerCase() === "true"),
GITHUB_CLIENT_ID: z.string().min(1),
GITHUB_CLIENT_SECRET: z.string().min(1),
REPLICATE_API_TOKEN: z.string().min(1),
OPENAI_API_KEY: z.string().min(1),
REPLICATE_API_TOKEN: z.string().default("placeholder"),
},
/**

View File

@@ -1,14 +1,15 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend";
import { type SupportedProvider, type FrontendModelProvider } from "./types";
// TODO: make sure we get a typescript error if you forget to add a provider here
// Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server.
const modelProvidersFrontend = {
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend,
} as const;
};
export default modelProvidersFrontend;
export default frontendModelProviders;

View File

@@ -1,9 +1,10 @@
import openaiChatCompletion from "./openai-ChatCompletion";
import replicateLlama2 from "./replicate-llama2";
import { type SupportedProvider, type ModelProvider } from "./types";
const modelProviders = {
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
"openai/ChatCompletion": openaiChatCompletion,
"replicate/llama2": replicateLlama2,
} as const;
};
export default modelProviders;

View File

@@ -1,77 +0,0 @@
import { type SupportedModel } from "../server/types";
interface ModelStats {
contextLength: number;
promptTokenPrice: number;
completionTokenPrice: number;
speed: "fast" | "medium" | "slow";
provider: "OpenAI";
learnMoreUrl: string;
}
export const modelStats: Record<SupportedModel, ModelStats> = {
"gpt-4": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-0613": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-0613": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
};

View File

@@ -56,6 +56,14 @@ modelProperty.type = "string";
modelProperty.enum = modelProperty.oneOf[1].enum;
delete modelProperty["oneOf"];
// The default of "inf" confuses the Typescript generator, so can just remove it
assert(
"max_tokens" in completionRequestSchema.properties &&
isObject(completionRequestSchema.properties.max_tokens) &&
"default" in completionRequestSchema.properties.max_tokens,
);
delete completionRequestSchema.properties.max_tokens["default"];
// Get the directory of the current script
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");

View File

@@ -150,7 +150,6 @@
},
"max_tokens": {
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
"default": "inf",
"type": "integer"
},
"presence_penalty": {

View File

@@ -1,8 +1,50 @@
import { type JsonValue } from "type-fest";
import { type OpenaiChatModelProvider } from ".";
import { type ModelProviderFrontend } from "../types";
import { type SupportedModel } from ".";
import { type FrontendModelProvider } from "../types";
import { type ChatCompletion } from "openai/resources/chat";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
contextWindow: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
contextWindow: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
contextWindow: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
contextWindow: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
normalizeOutput: (output) => {
const message = output.choices[0]?.message;
if (!message)
@@ -39,4 +81,4 @@ const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
},
};
export default modelProviderFrontend;
export default frontendModelProvider;

View File

@@ -8,10 +8,10 @@ import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types";
import { omit } from "lodash-es";
import { openai } from "~/server/utils/openai";
import { type OpenAIChatModel } from "~/server/types";
import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai";
import { modelStats } from "../modelStats";
import frontendModelProvider from "./frontend";
import modelProvider, { type SupportedModel } from ".";
const mergeStreamedChunks = (
base: ChatCompletion | null,
@@ -60,6 +60,7 @@ export async function getCompletion(
let finalCompletion: ChatCompletion | null = null;
let promptTokens: number | undefined = undefined;
let completionTokens: number | undefined = undefined;
const modelName = modelProvider.getModel(input) as SupportedModel;
try {
if (onStream) {
@@ -81,12 +82,9 @@ export async function getCompletion(
};
}
try {
promptTokens = countOpenAIChatTokens(
input.model as keyof typeof OpenAIChatModel,
input.messages,
);
promptTokens = countOpenAIChatTokens(modelName, input.messages);
completionTokens = countOpenAIChatTokens(
input.model as keyof typeof OpenAIChatModel,
modelName,
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
@@ -106,10 +104,10 @@ export async function getCompletion(
}
const timeToComplete = Date.now() - start;
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
let cost = undefined;
if (stats && promptTokens && completionTokens) {
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
}
return {

View File

@@ -3,6 +3,7 @@ import { type ModelProvider } from "../types";
import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend";
const supportedModels = [
"gpt-4-0613",
@@ -11,7 +12,7 @@ const supportedModels = [
"gpt-3.5-turbo-16k-0613",
] as const;
type SupportedModel = (typeof supportedModels)[number];
export type SupportedModel = (typeof supportedModels)[number];
export type OpenaiChatModelProvider = ModelProvider<
SupportedModel,
@@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider<
>;
const modelProvider: OpenaiChatModelProvider = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
getModel: (input) => {
if (supportedModels.includes(input.model as SupportedModel))
return input.model as SupportedModel;
@@ -57,6 +39,7 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4,
shouldStream: (input) => input.stream ?? false,
getCompletion,
...frontendModelProvider,
};
export default modelProvider;

View File

@@ -1,7 +1,36 @@
import { type ReplicateLlama2Provider } from ".";
import { type ModelProviderFrontend } from "../types";
import { type SupportedModel, type ReplicateLlama2Output } from ".";
import { type FrontendModelProvider } from "../types";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
name: "Replicate Llama2",
models: {
"7b-chat": {
name: "LLama 2 7B Chat",
contextWindow: 4096,
pricePerSecond: 0.0023,
speed: "fast",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
},
"13b-chat": {
name: "LLama 2 13B Chat",
contextWindow: 4096,
pricePerSecond: 0.0023,
speed: "medium",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
},
"70b-chat": {
name: "LLama 2 70B Chat",
contextWindow: 4096,
pricePerSecond: 0.0032,
speed: "slow",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
},
},
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
normalizeOutput: (output) => {
return {
type: "text",
@@ -10,4 +39,4 @@ const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
},
};
export default modelProviderFrontend;
export default frontendModelProvider;

View File

@@ -27,8 +27,6 @@ export async function getCompletion(
input: rest,
});
console.log("stream?", onStream);
const interval = onStream
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
setInterval(async () => {

View File

@@ -1,9 +1,10 @@
import { type ModelProvider } from "../types";
import frontendModelProvider from "./frontend";
import { getCompletion } from "./getCompletion";
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
type SupportedModel = (typeof supportedModels)[number];
export type SupportedModel = (typeof supportedModels)[number];
export type ReplicateLlama2Input = {
model: SupportedModel;
@@ -25,12 +26,6 @@ export type ReplicateLlama2Provider = ModelProvider<
>;
const modelProvider: ReplicateLlama2Provider = {
name: "OpenAI ChatCompletion",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
getModel: (input) => {
if (supportedModels.includes(input.model)) return input.model;
@@ -69,6 +64,7 @@ const modelProvider: ReplicateLlama2Provider = {
},
shouldStream: (input) => input.stream ?? false,
getCompletion,
...frontendModelProvider,
};
export default modelProvider;

View File

@@ -1,9 +1,33 @@
import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest";
import { z } from "zod";
type ModelProviderModel = {
name?: string;
learnMore?: string;
const ZodSupportedProvider = z.union([
z.literal("openai/ChatCompletion"),
z.literal("replicate/llama2"),
]);
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
export const ZodModel = z.object({
name: z.string(),
contextWindow: z.number(),
promptTokenPrice: z.number().optional(),
completionTokenPrice: z.number().optional(),
pricePerSecond: z.number().optional(),
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
provider: ZodSupportedProvider,
description: z.string().optional(),
learnMoreUrl: z.string().optional(),
});
export type Model = z.infer<typeof ZodModel>;
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
name: string;
models: Record<SupportedModels, Model>;
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
};
export type CompletionResponse<T> =
@@ -19,8 +43,6 @@ export type CompletionResponse<T> =
};
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelProviderModel>;
getModel: (input: InputSchema) => SupportedModels | null;
shouldStream: (input: InputSchema) => boolean;
inputSchema: JSONSchema4;
@@ -31,7 +53,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
// This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null;
};
} & FrontendModelProvider<SupportedModels, OutputSchema>;
export type NormalizedOutput =
| {
@@ -42,7 +64,3 @@ export type NormalizedOutput =
type: "json";
value: JsonValue;
};
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
};

View File

@@ -7,6 +7,8 @@ import "~/utils/analytics";
import Head from "next/head";
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
import { SyncAppStore } from "~/state/sync";
import NextAdapterApp from "next-query-params/app";
import { QueryParamProvider } from "use-query-params";
const MyApp: AppType<{ session: Session | null }> = ({
Component,
@@ -24,7 +26,9 @@ const MyApp: AppType<{ session: Session | null }> = ({
<SyncAppStore />
<Favicon />
<ChakraThemeProvider>
<Component {...pageProps} />
<QueryParamProvider adapter={NextAdapterApp}>
<Component {...pageProps} />
</QueryParamProvider>
</ChakraThemeProvider>
</SessionProvider>
</>

View File

@@ -20,22 +20,25 @@ export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery();
const user = useSession().data;
const authLoading = useSession().status === "loading";
if (user === null) {
if (user === null || authLoading) {
return (
<AppShell title="Experiments">
<Center h="100%">
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
{!authLoading && (
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
)}
</Center>
</AppShell>
);

View File

@@ -98,7 +98,7 @@ export const experimentsRouter = createTRPCRouter({
},
});
const [variant, _, scenario] = await prisma.$transaction([
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
prisma.promptVariant.create({
data: {
experimentId: exp.id,
@@ -121,7 +121,7 @@ export const experimentsRouter = createTRPCRouter({
messages: [
{
role: "system",
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
},
],
});`,
@@ -133,20 +133,38 @@ export const experimentsRouter = createTRPCRouter({
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "text",
label: "language",
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
text: "This is a test scenario.",
language: "English",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "Spanish",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "German",
},
},
}),
]);
await generateNewCell(variant.id, scenario.id);
await generateNewCell(variant.id, scenario1.id);
await generateNewCell(variant.id, scenario2.id);
await generateNewCell(variant.id, scenario3.id);
return exp;
}),

View File

@@ -2,7 +2,6 @@ import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { type SupportedModel } from "~/server/types";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
@@ -10,6 +9,7 @@ import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
import parseConstructFn from "~/server/utils/parseConstructFn";
import { ZodModel } from "~/modelProviders/types";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure
@@ -144,7 +144,7 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({
experimentId: z.string(),
variantId: z.string().optional(),
newModel: z.string().optional(),
newModel: ZodModel.optional(),
}),
)
.mutation(async ({ input, ctx }) => {
@@ -186,10 +186,7 @@ export const promptVariantsRouter = createTRPCRouter({
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const newConstructFn = await deriveNewConstructFn(
originalVariant,
input.newModel as SupportedModel,
);
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
const createNewVariantAction = prisma.promptVariant.create({
data: {
@@ -284,11 +281,12 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
getRefinedPromptFn: protectedProcedure
getModifiedPromptFn: protectedProcedure
.input(
z.object({
id: z.string(),
instructions: z.string(),
instructions: z.string().optional(),
newModel: ZodModel.optional(),
}),
)
.mutation(async ({ input, ctx }) => {
@@ -307,7 +305,7 @@ export const promptVariantsRouter = createTRPCRouter({
const promptConstructionFn = await deriveNewConstructFn(
existing,
constructedPrompt.model as SupportedModel,
input.newModel,
input.instructions,
);

View File

@@ -7,21 +7,39 @@ import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
const PAGE_SIZE = 10;
export const scenariosRouter = createTRPCRouter({
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.input(z.object({ experimentId: z.string(), page: z.number() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
return await prisma.testScenario.findMany({
const { experimentId, page } = input;
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
orderBy: { sortIndex: "asc" },
skip: (page - 1) * PAGE_SIZE,
take: PAGE_SIZE,
});
const count = await prisma.testScenario.count({
where: {
experimentId,
visible: true,
},
});
return {
scenarios,
startIndex: (page - 1) * PAGE_SIZE + 1,
lastPage: Math.ceil(count / PAGE_SIZE),
count,
};
}),
create: protectedProcedure
@@ -34,22 +52,21 @@ export const scenariosRouter = createTRPCRouter({
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
const maxSortIndex =
(
await prisma.testScenario.aggregate({
where: {
experimentId: input.experimentId,
},
_max: {
sortIndex: true,
},
})
)._max.sortIndex ?? 0;
await prisma.testScenario.updateMany({
where: {
experimentId: input.experimentId,
},
data: {
sortIndex: {
increment: 1,
},
},
});
const createNewScenarioAction = prisma.testScenario.create({
data: {
experimentId: input.experimentId,
sortIndex: maxSortIndex + 1,
sortIndex: 0,
variableValues: input.autogenerate
? await autogenerateScenarioValues(input.experimentId)
: {},

View File

@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const provider = modelProviders[prompt.modelProvider];
// @ts-expect-error TODO FIX ASAP
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
if (streamingChannel) {
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
: null;
for (let i = 0; true; i++) {
// @ts-expect-error TODO FIX ASAP
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const inputHash = hashPrompt(prompt);
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
data: {
scenarioVariantCellId,
inputHash,
output: response.value as unknown as Prisma.InputJsonObject,
output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete,
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
errorMessage: response.message,
statusCode: response.statusCode,
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
retrievalStatus: "ERROR",
},
});

View File

@@ -1,12 +0,0 @@
export enum OpenAIChatModel {
"gpt-4" = "gpt-4",
"gpt-4-0613" = "gpt-4-0613",
"gpt-4-32k" = "gpt-4-32k",
"gpt-4-32k-0613" = "gpt-4-32k-0613",
"gpt-3.5-turbo" = "gpt-3.5-turbo",
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
}
export type SupportedModel = keyof typeof OpenAIChatModel;

View File

@@ -1,18 +1,18 @@
import { type PromptVariant } from "@prisma/client";
import { type SupportedModel } from "../types";
import ivm from "isolated-vm";
import dedent from "dedent";
import { openai } from "./openai";
import { getApiShapeForModel } from "./getTypesForModel";
import { isObject } from "lodash-es";
import { type CompletionCreateParams } from "openai/resources/chat/completions";
import formatPromptConstructor from "~/utils/formatPromptConstructor";
import { type SupportedProvider, type Model } from "~/modelProviders/types";
import modelProviders from "~/modelProviders/modelProviders";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn(
originalVariant: PromptVariant | null,
newModel?: SupportedModel,
newModel?: Model,
instructions?: string,
) {
if (originalVariant && !newModel && !instructions) {
@@ -36,10 +36,11 @@ export async function deriveNewConstructFn(
const NUM_RETRIES = 5;
const requestUpdatedPromptFunction = async (
originalVariant: PromptVariant,
newModel?: SupportedModel,
newModel?: Model,
instructions?: string,
) => {
const originalModel = originalVariant.model as SupportedModel;
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
let newContructionFn = "";
for (let i = 0; i < NUM_RETRIES; i++) {
try {
@@ -47,17 +48,33 @@ const requestUpdatedPromptFunction = async (
{
role: "system",
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
getApiShapeForModel(originalModel),
originalModelProvider.inputSchema,
null,
2,
)}\n\nDo not add any assistant messages.`,
},
{
role: "user",
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
},
];
if (newModel) {
messages.push({
role: "user",
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
});
if (newModel.provider !== originalModel.provider) {
messages.push({
role: "user",
content: `The old provider was ${originalModel.provider}. The new provider is ${
newModel.provider
}. Here is the schema for the new model:\n---\n${JSON.stringify(
modelProviders[newModel.provider].inputSchema,
null,
2,
)}`,
});
}
}
if (instructions) {
messages.push({
@@ -65,10 +82,6 @@ const requestUpdatedPromptFunction = async (
content: instructions,
});
}
messages.push({
role: "system",
content: "The prompt variable has already been declared, so do not declare it again.",
});
const completion = await openai.chat.completions.create({
model: "gpt-4",
messages,

View File

@@ -4,8 +4,9 @@ import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
import { omit } from "lodash-es";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: variantId,
@@ -18,7 +19,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
},
});
if (!variant || !scenario) return null;
if (!variant || !scenario) return;
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
@@ -32,7 +33,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
},
});
if (cell) return cell;
if (cell) return;
const parsedConstructFn = await parseConstructFn(
variant.constructFn,
@@ -40,7 +41,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
);
if ("error" in parsedConstructFn) {
return await prisma.scenarioVariantCell.create({
await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
@@ -49,6 +50,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
retrievalStatus: "ERROR",
},
});
return;
}
const inputHash = hashPrompt(parsedConstructFn);
@@ -69,29 +71,33 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
where: { inputHash },
});
let newModelOutput;
if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({
const newModelOutput = await prisma.modelOutput.create({
data: {
...omit(matchingModelOutput, ["id"]),
scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
cost: matchingModelOutput.cost,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
},
});
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "COMPLETE" },
});
// Copy over all eval results as well
await Promise.all(
(
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
).map(async (evaluation) => {
await prisma.outputEvaluation.create({
data: {
...omit(evaluation, ["id"]),
modelOutputId: newModelOutput.id,
},
});
}),
);
} else {
cell = await queueLLMRetrievalTask(cell.id);
}
return { ...cell, modelOutput: newModelOutput };
};

View File

@@ -1,6 +0,0 @@
import { type SupportedModel } from "../types";
export const getApiShapeForModel = (model: SupportedModel) => {
// if (model in OpenAIChatModel) return openAIChatApiShape;
return "";
};

View File

@@ -70,7 +70,6 @@ export default async function parseConstructFn(
// We've validated the JSON schema so this should be safe
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
// @ts-expect-error TODO FIX ASAP
const model = provider.getModel(input);
if (!model) {
return {
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
return {
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
// @ts-expect-error TODO FIX ASAP
model,
modelInput: input,
};

View File

@@ -8,9 +8,9 @@ export const editorBackground = "#fafafa";
export type SharedVariantEditorSlice = {
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
loadMonaco: () => Promise<void>;
scenarios: RouterOutputs["scenarios"]["list"];
scenarios: RouterOutputs["scenarios"]["list"]["scenarios"];
updateScenariosModel: () => void;
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]["scenarios"]) => void;
};
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({

View File

@@ -1,17 +1,14 @@
import { useEffect } from "react";
import { api } from "~/utils/api";
import { useExperiment } from "~/utils/hooks";
import { useScenarios } from "~/utils/hooks";
import { useAppStore } from "./store";
export function useSyncVariantEditor() {
const experiment = useExperiment();
const scenarios = api.scenarios.list.useQuery(
{ experimentId: experiment.data?.id ?? "" },
{ enabled: !!experiment.data?.id },
);
const scenarios = useScenarios();
useEffect(() => {
if (scenarios.data) {
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data);
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data.scenarios);
}
}, [scenarios.data]);
}

View File

@@ -1,6 +1,6 @@
import { type ChatCompletion } from "openai/resources/chat";
import { GPTTokens } from "gpt-tokens";
import { type OpenAIChatModel } from "~/server/types";
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
interface GPTTokensMessageItem {
name?: string;
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
}
export const countOpenAIChatTokens = (
model: keyof typeof OpenAIChatModel,
model: SupportedModel,
messages: ChatCompletion.Choice.Message[],
) => {
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })

View File

@@ -1,6 +1,7 @@
import { useRouter } from "next/router";
import { type RefObject, useCallback, useEffect, useRef, useState } from "react";
import { api } from "~/utils/api";
import { NumberParam, useQueryParam, withDefault } from "use-query-params";
export const useExperiment = () => {
const router = useRouter();
@@ -93,3 +94,15 @@ export const useElementDimensions = (): [RefObject<HTMLElement>, Dimensions | un
return [ref, dimensions];
};
export const usePage = () => useQueryParam("page", withDefault(NumberParam, 1));
export const useScenarios = () => {
const experiment = useExperiment();
const [page] = usePage();
return api.scenarios.list.useQuery(
{ experimentId: experiment.data?.id ?? "", page },
{ enabled: experiment.data?.id != null },
);
};

View File

@@ -1 +1,5 @@
import { type Model } from "~/modelProviders/types";
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;