Allow user to create a version of their current prompt with a new model (#58)
* Add dropdown header for model switching * Allow variant duplication * Fix prettier * Use env variable to restrict prisma logs * Fix env.mjs * Remove unnecessary scroll bar from function call output * Properly record when 404 error occurs in queryLLM task * Add SelectedModelInfo in SelectModelModal * Add react-select * Calculate new prompt after switching model * Send newly selected model with creation request * Get new prompt construction function back from GPT-4 * Fix prettier * Fix prettier
This commit is contained in:
@@ -61,6 +61,7 @@
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-icons": "^4.10.1",
|
||||
"react-select": "^5.7.4",
|
||||
"react-syntax-highlighter": "^15.5.0",
|
||||
"react-textarea-autosize": "^8.5.0",
|
||||
"socket.io": "^4.7.1",
|
||||
|
||||
67
pnpm-lock.yaml
generated
67
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
||||
lockfileVersion: '6.1'
|
||||
lockfileVersion: '6.0'
|
||||
|
||||
settings:
|
||||
autoInstallPeers: true
|
||||
@@ -128,6 +128,9 @@ dependencies:
|
||||
react-icons:
|
||||
specifier: ^4.10.1
|
||||
version: 4.10.1(react@18.2.0)
|
||||
react-select:
|
||||
specifier: ^5.7.4
|
||||
version: 5.7.4(@types/react@18.2.6)(react-dom@18.2.0)(react@18.2.0)
|
||||
react-syntax-highlighter:
|
||||
specifier: ^15.5.0
|
||||
version: 15.5.0(react@18.2.0)
|
||||
@@ -2248,6 +2251,16 @@ packages:
|
||||
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
||||
dev: true
|
||||
|
||||
/@floating-ui/core@1.3.1:
|
||||
resolution: {integrity: sha512-Bu+AMaXNjrpjh41znzHqaz3r2Nr8hHuHZT6V2LBKMhyMl0FgKA62PNYbqnfgmzOhoWZj70Zecisbo4H1rotP5g==}
|
||||
dev: false
|
||||
|
||||
/@floating-ui/dom@1.4.5:
|
||||
resolution: {integrity: sha512-96KnRWkRnuBSSFbj0sFGwwOUd8EkiecINVl0O9wiZlZ64EkpyAOG3Xc2vKKNJmru0Z7RqWNymA+6b8OZqjgyyw==}
|
||||
dependencies:
|
||||
'@floating-ui/core': 1.3.1
|
||||
dev: false
|
||||
|
||||
/@graphile/logger@0.2.0:
|
||||
resolution: {integrity: sha512-jjcWBokl9eb1gVJ85QmoaQ73CQ52xAaOCF29ukRbYNl6lY+ts0ErTaDYOBlejcbUs2OpaiqYLO5uDhyLFzWw4w==}
|
||||
dev: false
|
||||
@@ -2832,6 +2845,12 @@ packages:
|
||||
'@types/react': 18.2.6
|
||||
dev: true
|
||||
|
||||
/@types/react-transition-group@4.4.6:
|
||||
resolution: {integrity: sha512-VnCdSxfcm08KjsJVQcfBmhEQAPnLB8G08hAxn39azX1qYBQ/5RVQuoHuKIcfKOdncuaUvEpFKFzEvbtIMsfVew==}
|
||||
dependencies:
|
||||
'@types/react': 18.2.6
|
||||
dev: false
|
||||
|
||||
/@types/react@18.2.6:
|
||||
resolution: {integrity: sha512-wRZClXn//zxCFW+ye/D2qY65UsYP1Fpex2YXorHc8awoNamkMZSvBxwxdYVInsHOZZd2Ppq8isnSzJL5Mpf8OA==}
|
||||
dependencies:
|
||||
@@ -3896,6 +3915,13 @@ packages:
|
||||
esutils: 2.0.3
|
||||
dev: true
|
||||
|
||||
/dom-helpers@5.2.1:
|
||||
resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==}
|
||||
dependencies:
|
||||
'@babel/runtime': 7.22.6
|
||||
csstype: 3.1.2
|
||||
dev: false
|
||||
|
||||
/dotenv@16.3.1:
|
||||
resolution: {integrity: sha512-IPzF4w4/Rd94bA9imS68tZBaYyBWSCE47V1RGuMrB94iyTOIEwRmVL2x/4An+6mETpLrKJ5hQkB8W4kFAadeIQ==}
|
||||
engines: {node: '>=12'}
|
||||
@@ -5451,6 +5477,10 @@ packages:
|
||||
engines: {node: '>= 0.6'}
|
||||
dev: false
|
||||
|
||||
/memoize-one@6.0.0:
|
||||
resolution: {integrity: sha512-rkpe71W0N0c0Xz6QD0eJETuWAJGnJ9afsl1srmwPrI+yBCkge5EycXXbYRyvL29zZVUWQCY7InPRCv3GDXuZNw==}
|
||||
dev: false
|
||||
|
||||
/merge-descriptors@1.0.1:
|
||||
resolution: {integrity: sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==}
|
||||
dev: false
|
||||
@@ -6370,6 +6400,27 @@ packages:
|
||||
use-sidecar: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
||||
dev: false
|
||||
|
||||
/react-select@5.7.4(@types/react@18.2.6)(react-dom@18.2.0)(react@18.2.0):
|
||||
resolution: {integrity: sha512-NhuE56X+p9QDFh4BgeygHFIvJJszO1i1KSkg/JPcIJrbovyRtI+GuOEa4XzFCEpZRAEoEI8u/cAHK+jG/PgUzQ==}
|
||||
peerDependencies:
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
dependencies:
|
||||
'@babel/runtime': 7.22.6
|
||||
'@emotion/cache': 11.11.0
|
||||
'@emotion/react': 11.11.1(@types/react@18.2.6)(react@18.2.0)
|
||||
'@floating-ui/dom': 1.4.5
|
||||
'@types/react-transition-group': 4.4.6
|
||||
memoize-one: 6.0.0
|
||||
prop-types: 15.8.1
|
||||
react: 18.2.0
|
||||
react-dom: 18.2.0(react@18.2.0)
|
||||
react-transition-group: 4.4.5(react-dom@18.2.0)(react@18.2.0)
|
||||
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
dev: false
|
||||
|
||||
/react-ssr-prepass@1.5.0(react@18.2.0):
|
||||
resolution: {integrity: sha512-yFNHrlVEReVYKsLI5lF05tZoHveA5pGzjFbFJY/3pOqqjGOmMmqx83N4hIjN2n6E1AOa+eQEUxs3CgRnPmT0RQ==}
|
||||
peerDependencies:
|
||||
@@ -6422,6 +6473,20 @@ packages:
|
||||
- '@types/react'
|
||||
dev: false
|
||||
|
||||
/react-transition-group@4.4.5(react-dom@18.2.0)(react@18.2.0):
|
||||
resolution: {integrity: sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g==}
|
||||
peerDependencies:
|
||||
react: '>=16.6.0'
|
||||
react-dom: '>=16.6.0'
|
||||
dependencies:
|
||||
'@babel/runtime': 7.22.6
|
||||
dom-helpers: 5.2.1
|
||||
loose-envify: 1.4.0
|
||||
prop-types: 15.8.1
|
||||
react: 18.2.0
|
||||
react-dom: 18.2.0(react@18.2.0)
|
||||
dev: false
|
||||
|
||||
/react@18.2.0:
|
||||
resolution: {integrity: sha512-/3IjMdb2L9QbBdWiW5e3P2/npwMBaU9mHCSCUzNln0ZCYbcfTsGbTJrU/kGemdH2IWmB2ioZ+zkxtmq6g09fGQ==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
|
||||
@@ -106,7 +106,7 @@ export default function OutputCell({
|
||||
h="100%"
|
||||
fontSize="xs"
|
||||
flexWrap="wrap"
|
||||
overflowX="auto"
|
||||
overflowX="hidden"
|
||||
justifyContent="space-between"
|
||||
>
|
||||
<VStack w="full" flex={1} spacing={0}>
|
||||
|
||||
@@ -4,7 +4,7 @@ import NewScenarioButton from "./NewScenarioButton";
|
||||
import NewVariantButton from "./NewVariantButton";
|
||||
import ScenarioRow from "./ScenarioRow";
|
||||
import VariantEditor from "./VariantEditor";
|
||||
import VariantHeader from "./VariantHeader";
|
||||
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||
import VariantStats from "./VariantStats";
|
||||
import { ScenariosHeader } from "./ScenariosHeader";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
|
||||
83
src/components/SelectModelModal/ModelStatsCard.tsx
Normal file
83
src/components/SelectModelModal/ModelStatsCard.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import {
|
||||
Heading,
|
||||
VStack,
|
||||
Text,
|
||||
HStack,
|
||||
type StackProps,
|
||||
Icon,
|
||||
Button,
|
||||
GridItem,
|
||||
SimpleGrid,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsChevronRight } from "react-icons/bs";
|
||||
import { modelStats } from "~/server/modelStats";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
|
||||
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
|
||||
const stats = modelStats[model];
|
||||
return (
|
||||
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={8}>
|
||||
<HStack w="full" justifyContent="space-between">
|
||||
<Text fontWeight="bold" fontSize="xs">
|
||||
{label}
|
||||
</Text>
|
||||
<Button variant="link" onClick={() => window.open(stats.learnMoreUrl, "_blank")}>
|
||||
<HStack alignItems="center" spacing={0} color="blue.500" fontWeight="bold">
|
||||
<Text fontSize="xs">Learn More</Text>
|
||||
<Icon as={BsChevronRight} boxSize={3} strokeWidth={1} />
|
||||
</HStack>
|
||||
</Button>
|
||||
</HStack>
|
||||
<HStack w="full" justifyContent="space-between">
|
||||
<Heading as="h3" size="md">
|
||||
{model}
|
||||
</Heading>
|
||||
<Text fontWeight="bold">{stats.provider}</Text>
|
||||
</HStack>
|
||||
<SimpleGrid
|
||||
w="full"
|
||||
justifyContent="space-between"
|
||||
alignItems="flex-start"
|
||||
fontSize="sm"
|
||||
columns={{ base: 2, md: 4 }}
|
||||
>
|
||||
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
|
||||
<SelectedModelLabeledInfo
|
||||
label="Input"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo
|
||||
label="Output"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const SelectedModelLabeledInfo = ({
|
||||
label,
|
||||
info,
|
||||
...props
|
||||
}: {
|
||||
label: string;
|
||||
info: string | number | React.ReactElement;
|
||||
} & StackProps) => (
|
||||
<GridItem>
|
||||
<VStack alignItems="flex-start" {...props}>
|
||||
<Text fontWeight="bold">{label}</Text>
|
||||
<Text>{info}</Text>
|
||||
</VStack>
|
||||
</GridItem>
|
||||
);
|
||||
77
src/components/SelectModelModal/SelectModelModal.tsx
Normal file
77
src/components/SelectModelModal/SelectModelModal.tsx
Normal file
@@ -0,0 +1,77 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
VStack,
|
||||
Text,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { ModelStatsCard } from "./ModelStatsCard";
|
||||
import { SelectModelSearch } from "./SelectModelSearch";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const SelectModelModal = ({
|
||||
originalModel,
|
||||
variantId,
|
||||
onClose,
|
||||
}: {
|
||||
originalModel: SupportedModel;
|
||||
variantId: string;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
|
||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||
|
||||
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment?.data?.id) return;
|
||||
await duplicateMutation.mutateAsync({
|
||||
experimentId: experiment?.data?.id,
|
||||
variantId,
|
||||
newModel: selectedModel,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [duplicateMutation, experiment?.data?.id, variantId, onClose]);
|
||||
|
||||
return (
|
||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>Select a New Model</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={4}>
|
||||
<ModelStatsCard label="ORIGINAL MODEL" model={originalModel} />
|
||||
{originalModel !== selectedModel && (
|
||||
<ModelStatsCard label="SELECTED MODEL" model={selectedModel} />
|
||||
)}
|
||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={createNewVariant}
|
||||
w={20}
|
||||
disabled={originalModel === selectedModel}
|
||||
>
|
||||
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
@@ -0,0 +1,47 @@
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { type LegacyRef, useCallback } from "react";
|
||||
import Select, { type SingleValue } from "react-select";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
|
||||
const modelOptions: { value: SupportedModel; label: string }[] = [
|
||||
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
|
||||
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
|
||||
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
|
||||
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
|
||||
{ value: "gpt-4", label: "gpt-4" },
|
||||
{ value: "gpt-4-0613", label: "gpt-4-0613" },
|
||||
{ value: "gpt-4-32k", label: "gpt-4-32k" },
|
||||
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
|
||||
];
|
||||
|
||||
export const SelectModelSearch = ({
|
||||
selectedModel,
|
||||
setSelectedModel,
|
||||
}: {
|
||||
selectedModel: SupportedModel;
|
||||
setSelectedModel: (model: SupportedModel) => void;
|
||||
}) => {
|
||||
const handleSelection = useCallback(
|
||||
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
|
||||
if (!option) return;
|
||||
setSelectedModel(option.value);
|
||||
},
|
||||
[setSelectedModel],
|
||||
);
|
||||
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
|
||||
|
||||
const [containerRef, containerDimensions] = useElementDimensions();
|
||||
|
||||
return (
|
||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||
<Text>Browse Models</Text>
|
||||
<Select
|
||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||
value={selectedOption}
|
||||
options={modelOptions}
|
||||
onChange={handleSelection}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -1,26 +1,13 @@
|
||||
import { useState, type DragEvent } from "react";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import {
|
||||
Button,
|
||||
HStack,
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
MenuDivider,
|
||||
Text,
|
||||
GridItem,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react"; // Changed here
|
||||
import { BsFillTrashFill, BsGear } from "react-icons/bs";
|
||||
import { FaRegClone } from "react-icons/fa";
|
||||
import { RiDraggable, RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { HStack, Icon, GridItem } from "@chakra-ui/react"; // Changed here
|
||||
import { RiDraggable } from "react-icons/ri";
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||
|
||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||
const utils = api.useContext();
|
||||
@@ -38,14 +25,6 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
}
|
||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
||||
|
||||
const hideMutation = api.promptVariants.hide.useMutation();
|
||||
const [onHide] = useHandledAsyncCallback(async () => {
|
||||
await hideMutation.mutateAsync({
|
||||
id: props.variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [hideMutation, props.variant.id]);
|
||||
|
||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||
const [onReorder] = useHandledAsyncCallback(
|
||||
async (e: DragEvent<HTMLDivElement>) => {
|
||||
@@ -64,21 +43,13 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
);
|
||||
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||
|
||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||
await duplicateMutation.mutateAsync({
|
||||
experimentId: props.variant.experimentId,
|
||||
variantId: props.variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [duplicateMutation, props.variant.experimentId, props.variant.id]);
|
||||
|
||||
return (
|
||||
<GridItem
|
||||
padding={0}
|
||||
sx={{
|
||||
...stickyHeaderStyle,
|
||||
// Ensure that the menu always appears above the sticky header of other variants
|
||||
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||
}}
|
||||
borderTopWidth={1}
|
||||
@@ -129,42 +100,12 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
||||
onMouseEnter={() => setIsInputHovered(true)}
|
||||
onMouseLeave={() => setIsInputHovered(false)}
|
||||
/>
|
||||
|
||||
<Menu
|
||||
z-index="dropdown"
|
||||
onOpen={() => setMenuOpen(true)}
|
||||
onClose={() => setMenuOpen(false)}
|
||||
>
|
||||
{duplicationInProgress ? (
|
||||
<Spinner boxSize={4} mx={3} my={3} />
|
||||
) : (
|
||||
<MenuButton>
|
||||
<Button variant="ghost">
|
||||
<Icon as={BsGear} />
|
||||
</Button>
|
||||
</MenuButton>
|
||||
)}
|
||||
|
||||
<MenuList mt={-3} fontSize="md">
|
||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||
Duplicate
|
||||
</MenuItem>
|
||||
<MenuItem icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}>Change Model</MenuItem>
|
||||
{props.canHide && (
|
||||
<>
|
||||
<MenuDivider />
|
||||
<MenuItem
|
||||
onClick={onHide}
|
||||
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
|
||||
color="red.600"
|
||||
_hover={{ backgroundColor: "red.50" }}
|
||||
>
|
||||
<Text>Hide</Text>
|
||||
</MenuItem>
|
||||
</>
|
||||
)}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
<VariantHeaderMenuButton
|
||||
variant={props.variant}
|
||||
canHide={props.canHide}
|
||||
menuOpen={menuOpen}
|
||||
setMenuOpen={setMenuOpen}
|
||||
/>
|
||||
</HStack>
|
||||
</GridItem>
|
||||
);
|
||||
102
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
102
src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
@@ -0,0 +1,102 @@
|
||||
import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import {
|
||||
Button,
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
MenuDivider,
|
||||
Text,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react"; // Changed here
|
||||
import { BsFillTrashFill, BsGear } from "react-icons/bs";
|
||||
import { FaRegClone } from "react-icons/fa";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { useState } from "react";
|
||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
|
||||
export default function VariantHeaderMenuButton({
|
||||
variant,
|
||||
canHide,
|
||||
menuOpen,
|
||||
setMenuOpen,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
canHide: boolean;
|
||||
menuOpen: boolean;
|
||||
setMenuOpen: (open: boolean) => void;
|
||||
}) {
|
||||
const utils = api.useContext();
|
||||
|
||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||
|
||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||
await duplicateMutation.mutateAsync({
|
||||
experimentId: variant.experimentId,
|
||||
variantId: variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||
|
||||
const hideMutation = api.promptVariants.hide.useMutation();
|
||||
const [onHide] = useHandledAsyncCallback(async () => {
|
||||
await hideMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [hideMutation, variant.id]);
|
||||
|
||||
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||
{duplicationInProgress ? (
|
||||
<Spinner boxSize={4} mx={3} my={3} />
|
||||
) : (
|
||||
<MenuButton>
|
||||
<Button variant="ghost">
|
||||
<Icon as={BsGear} />
|
||||
</Button>
|
||||
</MenuButton>
|
||||
)}
|
||||
|
||||
<MenuList mt={-3} fontSize="md">
|
||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||
Duplicate
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||
onClick={() => setSelectModelModalOpen(true)}
|
||||
>
|
||||
Change Model
|
||||
</MenuItem>
|
||||
{canHide && (
|
||||
<>
|
||||
<MenuDivider />
|
||||
<MenuItem
|
||||
onClick={onHide}
|
||||
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
|
||||
color="red.600"
|
||||
_hover={{ backgroundColor: "red.50" }}
|
||||
>
|
||||
<Text>Hide</Text>
|
||||
</MenuItem>
|
||||
</>
|
||||
)}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{selectModelModalOpen && (
|
||||
<SelectModelModal
|
||||
originalModel={variant.model as SupportedModel}
|
||||
variantId={variant.id}
|
||||
onClose={() => setSelectModelModalOpen(false)}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -10,6 +10,11 @@ export const env = createEnv({
|
||||
DATABASE_URL: z.string().url(),
|
||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||
OPENAI_API_KEY: z.string().min(1),
|
||||
RESTRICT_PRISMA_LOGS: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("false")
|
||||
.transform((val) => val.toLowerCase() === "true"),
|
||||
},
|
||||
|
||||
/**
|
||||
@@ -35,6 +40,7 @@ export const env = createEnv({
|
||||
DATABASE_URL: process.env.DATABASE_URL,
|
||||
NODE_ENV: process.env.NODE_ENV,
|
||||
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
||||
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
|
||||
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
||||
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
|
||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import dedent from "dedent";
|
||||
import { isObject } from "lodash-es";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { OpenAIChatModel } from "~/server/types";
|
||||
import { OpenAIChatModel, type SupportedModel } from "~/server/types";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
|
||||
export const promptVariantsRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
@@ -138,6 +138,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
variantId: z.string().optional(),
|
||||
newModel: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
@@ -177,23 +178,17 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(
|
||||
originalVariant,
|
||||
input.newModel as SupportedModel,
|
||||
);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: newVariantLabel,
|
||||
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||
constructFn:
|
||||
originalVariant?.constructFn ??
|
||||
dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Return 'Hello, world!'",
|
||||
}
|
||||
]
|
||||
}`,
|
||||
constructFn: newConstructFn,
|
||||
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
|
||||
export const prisma =
|
||||
globalForPrisma.prisma ??
|
||||
new PrismaClient({
|
||||
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"],
|
||||
log:
|
||||
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
|
||||
? ["query", "error", "warn"]
|
||||
: ["error"],
|
||||
});
|
||||
|
||||
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
||||
|
||||
77
src/server/modelStats.ts
Normal file
77
src/server/modelStats.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { type SupportedModel } from "./types";
|
||||
|
||||
interface ModelStats {
|
||||
contextLength: number;
|
||||
promptTokenPrice: number;
|
||||
completionTokenPrice: number;
|
||||
speed: "fast" | "medium" | "slow";
|
||||
provider: "OpenAI";
|
||||
learnMoreUrl: string;
|
||||
}
|
||||
|
||||
export const modelStats: Record<SupportedModel, ModelStats> = {
|
||||
"gpt-4": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-0613": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
};
|
||||
@@ -67,6 +67,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
include: { modelOutput: true },
|
||||
});
|
||||
if (!cell) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Cell not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -85,6 +93,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
where: { id: cell.promptVariantId },
|
||||
});
|
||||
if (!variant) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Prompt Variant not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -92,6 +108,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
where: { id: cell.testScenarioId },
|
||||
});
|
||||
if (!scenario) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Scenario not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
107
src/server/utils/deriveNewContructFn.ts
Normal file
107
src/server/utils/deriveNewContructFn.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { type SupportedModel } from "../types";
|
||||
import ivm from "isolated-vm";
|
||||
import dedent from "dedent";
|
||||
import { openai } from "./openai";
|
||||
import { getApiShapeForModel } from "./getTypesForModel";
|
||||
import { isObject } from "lodash-es";
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function deriveNewConstructFn(
|
||||
originalVariant: PromptVariant | null,
|
||||
newModel?: SupportedModel,
|
||||
) {
|
||||
if (originalVariant && !newModel) {
|
||||
return originalVariant.constructFn;
|
||||
}
|
||||
if (originalVariant && newModel) {
|
||||
return await getPromptFunctionForNewModel(originalVariant, newModel);
|
||||
}
|
||||
return dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Return 'Hello, world!'",
|
||||
}
|
||||
]
|
||||
}`;
|
||||
}
|
||||
|
||||
const NUM_RETRIES = 5;
|
||||
const getPromptFunctionForNewModel = async (
|
||||
originalVariant: PromptVariant,
|
||||
newModel: SupportedModel,
|
||||
) => {
|
||||
const originalModel = originalVariant.model as SupportedModel;
|
||||
let newContructionFn = "";
|
||||
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||
try {
|
||||
// TODO: Add api shape info to prompt
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: `Your job is to translate prompt constructor functions from ${originalModel} to ${newModel}. Here are is the api shape for the original model:\n---\n${JSON.stringify(
|
||||
getApiShapeForModel(originalModel),
|
||||
null,
|
||||
2,
|
||||
)}\n\nThe prompt variable has already been declared.}`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "translate_prompt_constructor_function",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
new_prompt_function: {
|
||||
type: "string",
|
||||
description: "The new prompt function, runnable in typescript",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "translate_prompt_constructor_function",
|
||||
},
|
||||
});
|
||||
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
|
||||
|
||||
const code = `
|
||||
global.contructPromptFunctionArgs = ${argString};
|
||||
`;
|
||||
|
||||
const context = await isolate.createContext();
|
||||
|
||||
const jail = context.global;
|
||||
await jail.set("global", jail.derefInto());
|
||||
|
||||
const script = await isolate.compileScript(code);
|
||||
|
||||
await script.run(context);
|
||||
const contructPromptFunctionArgs = (await context.global.get(
|
||||
"contructPromptFunctionArgs",
|
||||
)) as ivm.Reference;
|
||||
|
||||
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
|
||||
|
||||
if (args && isObject(args) && "new_prompt_function" in args) {
|
||||
newContructionFn = args.new_prompt_function as string;
|
||||
break;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
return newContructionFn;
|
||||
};
|
||||
@@ -21,6 +21,13 @@ export type CompletionResponse = {
|
||||
export async function getCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
return getOpenAIChatCompletion(payload, channel);
|
||||
}
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||
if (payload.functions?.length) payload.stream = false;
|
||||
|
||||
7
src/server/utils/getTypesForModel.ts
Normal file
7
src/server/utils/getTypesForModel.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { OpenAIChatModel, type SupportedModel } from "../types";
|
||||
import openAIChatApiShape from "~/codegen/openai.types.ts.txt";
|
||||
|
||||
export const getApiShapeForModel = (model: SupportedModel) => {
|
||||
if (model in OpenAIChatModel) return openAIChatApiShape;
|
||||
return "";
|
||||
};
|
||||
@@ -1,27 +1,6 @@
|
||||
import { modelStats } from "~/server/modelStats";
|
||||
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
|
||||
|
||||
const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||
"gpt-4": 0.00003,
|
||||
"gpt-4-0613": 0.00003,
|
||||
"gpt-4-32k": 0.00006,
|
||||
"gpt-4-32k-0613": 0.00006,
|
||||
"gpt-3.5-turbo": 0.0000015,
|
||||
"gpt-3.5-turbo-0613": 0.0000015,
|
||||
"gpt-3.5-turbo-16k": 0.000003,
|
||||
"gpt-3.5-turbo-16k-0613": 0.000003,
|
||||
};
|
||||
|
||||
const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||
"gpt-4": 0.00006,
|
||||
"gpt-4-0613": 0.00006,
|
||||
"gpt-4-32k": 0.00012,
|
||||
"gpt-4-32k-0613": 0.00012,
|
||||
"gpt-3.5-turbo": 0.000002,
|
||||
"gpt-3.5-turbo-0613": 0.000002,
|
||||
"gpt-3.5-turbo-16k": 0.000004,
|
||||
"gpt-3.5-turbo-16k-0613": 0.000004,
|
||||
};
|
||||
|
||||
export const calculateTokenCost = (
|
||||
model: SupportedModel | string | null,
|
||||
numTokens: number,
|
||||
@@ -40,7 +19,7 @@ const calculateOpenAIChatTokenCost = (
|
||||
isCompletion: boolean,
|
||||
) => {
|
||||
const tokensToDollars = isCompletion
|
||||
? openAICompletionTokensToDollars[model]
|
||||
: openAIPromptTokensToDollars[model];
|
||||
? modelStats[model].completionTokenPrice
|
||||
: modelStats[model].promptTokenPrice;
|
||||
return tokensToDollars * numTokens;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user