Allow user to create a version of their current prompt with a new model (#58)

* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier

* Use env variable to restrict prisma logs

* Fix env.mjs

* Remove unnecessary scroll bar from function call output

* Properly record when 404 error occurs in queryLLM task

* Add SelectedModelInfo in SelectModelModal

* Add react-select

* Calculate new prompt after switching model

* Send newly selected model with creation request

* Get new prompt construction function back from GPT-4

* Fix prettier

* Fix prettier
This commit is contained in:
arcticfly
2023-07-18 18:24:04 -07:00
committed by GitHub
parent fa5b1ab1c5
commit e0e64c4207
18 changed files with 634 additions and 113 deletions

View File

@@ -61,6 +61,7 @@
"react": "18.2.0",
"react-dom": "18.2.0",
"react-icons": "^4.10.1",
"react-select": "^5.7.4",
"react-syntax-highlighter": "^15.5.0",
"react-textarea-autosize": "^8.5.0",
"socket.io": "^4.7.1",

67
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.1'
lockfileVersion: '6.0'
settings:
autoInstallPeers: true
@@ -128,6 +128,9 @@ dependencies:
react-icons:
specifier: ^4.10.1
version: 4.10.1(react@18.2.0)
react-select:
specifier: ^5.7.4
version: 5.7.4(@types/react@18.2.6)(react-dom@18.2.0)(react@18.2.0)
react-syntax-highlighter:
specifier: ^15.5.0
version: 15.5.0(react@18.2.0)
@@ -2248,6 +2251,16 @@ packages:
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
dev: true
/@floating-ui/core@1.3.1:
resolution: {integrity: sha512-Bu+AMaXNjrpjh41znzHqaz3r2Nr8hHuHZT6V2LBKMhyMl0FgKA62PNYbqnfgmzOhoWZj70Zecisbo4H1rotP5g==}
dev: false
/@floating-ui/dom@1.4.5:
resolution: {integrity: sha512-96KnRWkRnuBSSFbj0sFGwwOUd8EkiecINVl0O9wiZlZ64EkpyAOG3Xc2vKKNJmru0Z7RqWNymA+6b8OZqjgyyw==}
dependencies:
'@floating-ui/core': 1.3.1
dev: false
/@graphile/logger@0.2.0:
resolution: {integrity: sha512-jjcWBokl9eb1gVJ85QmoaQ73CQ52xAaOCF29ukRbYNl6lY+ts0ErTaDYOBlejcbUs2OpaiqYLO5uDhyLFzWw4w==}
dev: false
@@ -2832,6 +2845,12 @@ packages:
'@types/react': 18.2.6
dev: true
/@types/react-transition-group@4.4.6:
resolution: {integrity: sha512-VnCdSxfcm08KjsJVQcfBmhEQAPnLB8G08hAxn39azX1qYBQ/5RVQuoHuKIcfKOdncuaUvEpFKFzEvbtIMsfVew==}
dependencies:
'@types/react': 18.2.6
dev: false
/@types/react@18.2.6:
resolution: {integrity: sha512-wRZClXn//zxCFW+ye/D2qY65UsYP1Fpex2YXorHc8awoNamkMZSvBxwxdYVInsHOZZd2Ppq8isnSzJL5Mpf8OA==}
dependencies:
@@ -3896,6 +3915,13 @@ packages:
esutils: 2.0.3
dev: true
/dom-helpers@5.2.1:
resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==}
dependencies:
'@babel/runtime': 7.22.6
csstype: 3.1.2
dev: false
/dotenv@16.3.1:
resolution: {integrity: sha512-IPzF4w4/Rd94bA9imS68tZBaYyBWSCE47V1RGuMrB94iyTOIEwRmVL2x/4An+6mETpLrKJ5hQkB8W4kFAadeIQ==}
engines: {node: '>=12'}
@@ -5451,6 +5477,10 @@ packages:
engines: {node: '>= 0.6'}
dev: false
/memoize-one@6.0.0:
resolution: {integrity: sha512-rkpe71W0N0c0Xz6QD0eJETuWAJGnJ9afsl1srmwPrI+yBCkge5EycXXbYRyvL29zZVUWQCY7InPRCv3GDXuZNw==}
dev: false
/merge-descriptors@1.0.1:
resolution: {integrity: sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==}
dev: false
@@ -6370,6 +6400,27 @@ packages:
use-sidecar: 1.1.2(@types/react@18.2.6)(react@18.2.0)
dev: false
/react-select@5.7.4(@types/react@18.2.6)(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-NhuE56X+p9QDFh4BgeygHFIvJJszO1i1KSkg/JPcIJrbovyRtI+GuOEa4XzFCEpZRAEoEI8u/cAHK+jG/PgUzQ==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0
dependencies:
'@babel/runtime': 7.22.6
'@emotion/cache': 11.11.0
'@emotion/react': 11.11.1(@types/react@18.2.6)(react@18.2.0)
'@floating-ui/dom': 1.4.5
'@types/react-transition-group': 4.4.6
memoize-one: 6.0.0
prop-types: 15.8.1
react: 18.2.0
react-dom: 18.2.0(react@18.2.0)
react-transition-group: 4.4.5(react-dom@18.2.0)(react@18.2.0)
use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.6)(react@18.2.0)
transitivePeerDependencies:
- '@types/react'
dev: false
/react-ssr-prepass@1.5.0(react@18.2.0):
resolution: {integrity: sha512-yFNHrlVEReVYKsLI5lF05tZoHveA5pGzjFbFJY/3pOqqjGOmMmqx83N4hIjN2n6E1AOa+eQEUxs3CgRnPmT0RQ==}
peerDependencies:
@@ -6422,6 +6473,20 @@ packages:
- '@types/react'
dev: false
/react-transition-group@4.4.5(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g==}
peerDependencies:
react: '>=16.6.0'
react-dom: '>=16.6.0'
dependencies:
'@babel/runtime': 7.22.6
dom-helpers: 5.2.1
loose-envify: 1.4.0
prop-types: 15.8.1
react: 18.2.0
react-dom: 18.2.0(react@18.2.0)
dev: false
/react@18.2.0:
resolution: {integrity: sha512-/3IjMdb2L9QbBdWiW5e3P2/npwMBaU9mHCSCUzNln0ZCYbcfTsGbTJrU/kGemdH2IWmB2ioZ+zkxtmq6g09fGQ==}
engines: {node: '>=0.10.0'}

View File

@@ -106,7 +106,7 @@ export default function OutputCell({
h="100%"
fontSize="xs"
flexWrap="wrap"
overflowX="auto"
overflowX="hidden"
justifyContent="space-between"
>
<VStack w="full" flex={1} spacing={0}>

View File

@@ -4,7 +4,7 @@ import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton";
import ScenarioRow from "./ScenarioRow";
import VariantEditor from "./VariantEditor";
import VariantHeader from "./VariantHeader";
import VariantHeader from "../VariantHeader/VariantHeader";
import VariantStats from "./VariantStats";
import { ScenariosHeader } from "./ScenariosHeader";
import { stickyHeaderStyle } from "./styles";

View File

@@ -0,0 +1,83 @@
import {
Heading,
VStack,
Text,
HStack,
type StackProps,
Icon,
Button,
GridItem,
SimpleGrid,
} from "@chakra-ui/react";
import { BsChevronRight } from "react-icons/bs";
import { modelStats } from "~/server/modelStats";
import { type SupportedModel } from "~/server/types";
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
const stats = modelStats[model];
return (
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={8}>
<HStack w="full" justifyContent="space-between">
<Text fontWeight="bold" fontSize="xs">
{label}
</Text>
<Button variant="link" onClick={() => window.open(stats.learnMoreUrl, "_blank")}>
<HStack alignItems="center" spacing={0} color="blue.500" fontWeight="bold">
<Text fontSize="xs">Learn More</Text>
<Icon as={BsChevronRight} boxSize={3} strokeWidth={1} />
</HStack>
</Button>
</HStack>
<HStack w="full" justifyContent="space-between">
<Heading as="h3" size="md">
{model}
</Heading>
<Text fontWeight="bold">{stats.provider}</Text>
</HStack>
<SimpleGrid
w="full"
justifyContent="space-between"
alignItems="flex-start"
fontSize="sm"
columns={{ base: 2, md: 4 }}
>
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
</SimpleGrid>
</VStack>
);
};
const SelectedModelLabeledInfo = ({
label,
info,
...props
}: {
label: string;
info: string | number | React.ReactElement;
} & StackProps) => (
<GridItem>
<VStack alignItems="flex-start" {...props}>
<Text fontWeight="bold">{label}</Text>
<Text>{info}</Text>
</VStack>
</GridItem>
);

View File

@@ -0,0 +1,77 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
} from "@chakra-ui/react";
import { useState } from "react";
import { type SupportedModel } from "~/server/types";
import { ModelStatsCard } from "./ModelStatsCard";
import { SelectModelSearch } from "./SelectModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const SelectModelModal = ({
originalModel,
variantId,
onClose,
}: {
originalModel: SupportedModel;
variantId: string;
onClose: () => void;
}) => {
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
const utils = api.useContext();
const experiment = useExperiment();
const duplicateMutation = api.promptVariants.create.useMutation();
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment?.data?.id) return;
await duplicateMutation.mutateAsync({
experimentId: experiment?.data?.id,
variantId,
newModel: selectedModel,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [duplicateMutation, experiment?.data?.id, variantId, onClose]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>Select a New Model</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={4}>
<ModelStatsCard label="ORIGINAL MODEL" model={originalModel} />
{originalModel !== selectedModel && (
<ModelStatsCard label="SELECTED MODEL" model={selectedModel} />
)}
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
</VStack>
</ModalBody>
<ModalFooter>
<Button
colorScheme="blue"
onClick={createNewVariant}
w={20}
disabled={originalModel === selectedModel}
>
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,47 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { type SupportedModel } from "~/server/types";
import { useElementDimensions } from "~/utils/hooks";
const modelOptions: { value: SupportedModel; label: string }[] = [
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
{ value: "gpt-4", label: "gpt-4" },
{ value: "gpt-4-0613", label: "gpt-4-0613" },
{ value: "gpt-4-32k", label: "gpt-4-32k" },
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
];
export const SelectModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: SupportedModel;
setSelectedModel: (model: SupportedModel) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -1,26 +1,13 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "./types";
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import {
Button,
HStack,
Icon,
Menu,
MenuButton,
MenuItem,
MenuList,
MenuDivider,
Text,
GridItem,
Spinner,
} from "@chakra-ui/react"; // Changed here
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { RiDraggable, RiExchangeFundsFill } from "react-icons/ri";
import { HStack, Icon, GridItem } from "@chakra-ui/react"; // Changed here
import { RiDraggable } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "./styles";
import { stickyHeaderStyle } from "../OutputsTable/styles";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext();
@@ -38,14 +25,6 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
}
}, [updateMutation, props.variant.id, props.variant.label, label]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: props.variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, props.variant.id]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
async (e: DragEvent<HTMLDivElement>) => {
@@ -64,21 +43,13 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
);
const [menuOpen, setMenuOpen] = useState(false);
const duplicateMutation = api.promptVariants.create.useMutation();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({
experimentId: props.variant.experimentId,
variantId: props.variant.id,
});
await utils.promptVariants.list.invalidate();
}, [duplicateMutation, props.variant.experimentId, props.variant.id]);
return (
<GridItem
padding={0}
sx={{
...stickyHeaderStyle,
// Ensure that the menu always appears above the sticky header of other variants
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
}}
borderTopWidth={1}
@@ -129,42 +100,12 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
<Menu
z-index="dropdown"
onOpen={() => setMenuOpen(true)}
onClose={() => setMenuOpen(false)}
>
{duplicationInProgress ? (
<Spinner boxSize={4} mx={3} my={3} />
) : (
<MenuButton>
<Button variant="ghost">
<Icon as={BsGear} />
</Button>
</MenuButton>
)}
<MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
Duplicate
</MenuItem>
<MenuItem icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}>Change Model</MenuItem>
{props.canHide && (
<>
<MenuDivider />
<MenuItem
onClick={onHide}
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
color="red.600"
_hover={{ backgroundColor: "red.50" }}
>
<Text>Hide</Text>
</MenuItem>
</>
)}
</MenuList>
</Menu>
<VariantHeaderMenuButton
variant={props.variant}
canHide={props.canHide}
menuOpen={menuOpen}
setMenuOpen={setMenuOpen}
/>
</HStack>
</GridItem>
);

View File

@@ -0,0 +1,102 @@
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import {
Button,
Icon,
Menu,
MenuButton,
MenuItem,
MenuList,
MenuDivider,
Text,
Spinner,
} from "@chakra-ui/react"; // Changed here
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { RiExchangeFundsFill } from "react-icons/ri";
import { useState } from "react";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
export default function VariantHeaderMenuButton({
variant,
canHide,
menuOpen,
setMenuOpen,
}: {
variant: PromptVariant;
canHide: boolean;
menuOpen: boolean;
setMenuOpen: (open: boolean) => void;
}) {
const utils = api.useContext();
const duplicateMutation = api.promptVariants.create.useMutation();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({
experimentId: variant.experimentId,
variantId: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [duplicateMutation, variant.experimentId, variant.id]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, variant.id]);
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
return (
<>
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
{duplicationInProgress ? (
<Spinner boxSize={4} mx={3} my={3} />
) : (
<MenuButton>
<Button variant="ghost">
<Icon as={BsGear} />
</Button>
</MenuButton>
)}
<MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
Duplicate
</MenuItem>
<MenuItem
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
onClick={() => setSelectModelModalOpen(true)}
>
Change Model
</MenuItem>
{canHide && (
<>
<MenuDivider />
<MenuItem
onClick={onHide}
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
color="red.600"
_hover={{ backgroundColor: "red.50" }}
>
<Text>Hide</Text>
</MenuItem>
</>
)}
</MenuList>
</Menu>
{selectModelModalOpen && (
<SelectModelModal
originalModel={variant.model as SupportedModel}
variantId={variant.id}
onClose={() => setSelectModelModalOpen(false)}
/>
)}
</>
);
}

View File

@@ -10,6 +10,11 @@ export const env = createEnv({
DATABASE_URL: z.string().url(),
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
OPENAI_API_KEY: z.string().min(1),
RESTRICT_PRISMA_LOGS: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
},
/**
@@ -35,6 +40,7 @@ export const env = createEnv({
DATABASE_URL: process.env.DATABASE_URL,
NODE_ENV: process.env.NODE_ENV,
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,

View File

@@ -1,16 +1,16 @@
import dedent from "dedent";
import { isObject } from "lodash-es";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel } from "~/server/types";
import { OpenAIChatModel, type SupportedModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -138,6 +138,7 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({
experimentId: z.string(),
variantId: z.string().optional(),
newModel: z.string().optional(),
}),
)
.mutation(async ({ input }) => {
@@ -177,23 +178,17 @@ export const promptVariantsRouter = createTRPCRouter({
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const newConstructFn = await deriveNewConstructFn(
originalVariant,
input.newModel as SupportedModel,
);
const createNewVariantAction = prisma.promptVariant.create({
data: {
experimentId: input.experimentId,
label: newVariantLabel,
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn:
originalVariant?.constructFn ??
dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`,
constructFn: newConstructFn,
model: originalVariant?.model ?? "gpt-3.5-turbo",
},
});

View File

@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
export const prisma =
globalForPrisma.prisma ??
new PrismaClient({
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"],
log:
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
? ["query", "error", "warn"]
: ["error"],
});
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;

77
src/server/modelStats.ts Normal file
View File

@@ -0,0 +1,77 @@
import { type SupportedModel } from "./types";
interface ModelStats {
contextLength: number;
promptTokenPrice: number;
completionTokenPrice: number;
speed: "fast" | "medium" | "slow";
provider: "OpenAI";
learnMoreUrl: string;
}
export const modelStats: Record<SupportedModel, ModelStats> = {
"gpt-4": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-0613": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-0613": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
};

View File

@@ -67,6 +67,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
include: { modelOutput: true },
});
if (!cell) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Cell not found",
retrievalStatus: "ERROR",
},
});
return;
}
@@ -85,6 +93,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
where: { id: cell.promptVariantId },
});
if (!variant) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR",
},
});
return;
}
@@ -92,6 +108,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
where: { id: cell.testScenarioId },
});
if (!scenario) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Scenario not found",
retrievalStatus: "ERROR",
},
});
return;
}

View File

@@ -0,0 +1,107 @@
import { type PromptVariant } from "@prisma/client";
import { type SupportedModel } from "../types";
import ivm from "isolated-vm";
import dedent from "dedent";
import { openai } from "./openai";
import { getApiShapeForModel } from "./getTypesForModel";
import { isObject } from "lodash-es";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn(
originalVariant: PromptVariant | null,
newModel?: SupportedModel,
) {
if (originalVariant && !newModel) {
return originalVariant.constructFn;
}
if (originalVariant && newModel) {
return await getPromptFunctionForNewModel(originalVariant, newModel);
}
return dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`;
}
const NUM_RETRIES = 5;
const getPromptFunctionForNewModel = async (
originalVariant: PromptVariant,
newModel: SupportedModel,
) => {
const originalModel = originalVariant.model as SupportedModel;
let newContructionFn = "";
for (let i = 0; i < NUM_RETRIES; i++) {
try {
// TODO: Add api shape info to prompt
const completion = await openai.chat.completions.create({
model: "gpt-4",
messages: [
{
role: "system",
content: `Your job is to translate prompt constructor functions from ${originalModel} to ${newModel}. Here are is the api shape for the original model:\n---\n${JSON.stringify(
getApiShapeForModel(originalModel),
null,
2,
)}\n\nThe prompt variable has already been declared.}`,
},
{
role: "user",
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
},
],
functions: [
{
name: "translate_prompt_constructor_function",
parameters: {
type: "object",
properties: {
new_prompt_function: {
type: "string",
description: "The new prompt function, runnable in typescript",
},
},
},
},
],
function_call: {
name: "translate_prompt_constructor_function",
},
});
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
const code = `
global.contructPromptFunctionArgs = ${argString};
`;
const context = await isolate.createContext();
const jail = context.global;
await jail.set("global", jail.derefInto());
const script = await isolate.compileScript(code);
await script.run(context);
const contructPromptFunctionArgs = (await context.global.get(
"contructPromptFunctionArgs",
)) as ivm.Reference;
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
if (args && isObject(args) && "new_prompt_function" in args) {
newContructionFn = args.new_prompt_function as string;
break;
}
} catch (e) {
console.error(e);
}
}
return newContructionFn;
};

View File

@@ -21,6 +21,13 @@ export type CompletionResponse = {
export async function getCompletion(
payload: CompletionCreateParams,
channel?: string,
): Promise<CompletionResponse> {
return getOpenAIChatCompletion(payload, channel);
}
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
channel?: string,
): Promise<CompletionResponse> {
// If functions are enabled, disable streaming so that we get the full response with token counts
if (payload.functions?.length) payload.stream = false;

View File

@@ -0,0 +1,7 @@
import { OpenAIChatModel, type SupportedModel } from "../types";
import openAIChatApiShape from "~/codegen/openai.types.ts.txt";
export const getApiShapeForModel = (model: SupportedModel) => {
if (model in OpenAIChatModel) return openAIChatApiShape;
return "";
};

View File

@@ -1,27 +1,6 @@
import { modelStats } from "~/server/modelStats";
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00003,
"gpt-4-0613": 0.00003,
"gpt-4-32k": 0.00006,
"gpt-4-32k-0613": 0.00006,
"gpt-3.5-turbo": 0.0000015,
"gpt-3.5-turbo-0613": 0.0000015,
"gpt-3.5-turbo-16k": 0.000003,
"gpt-3.5-turbo-16k-0613": 0.000003,
};
const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00006,
"gpt-4-0613": 0.00006,
"gpt-4-32k": 0.00012,
"gpt-4-32k-0613": 0.00012,
"gpt-3.5-turbo": 0.000002,
"gpt-3.5-turbo-0613": 0.000002,
"gpt-3.5-turbo-16k": 0.000004,
"gpt-3.5-turbo-16k-0613": 0.000004,
};
export const calculateTokenCost = (
model: SupportedModel | string | null,
numTokens: number,
@@ -40,7 +19,7 @@ const calculateOpenAIChatTokenCost = (
isCompletion: boolean,
) => {
const tokensToDollars = isCompletion
? openAICompletionTokensToDollars[model]
: openAIPromptTokensToDollars[model];
? modelStats[model].completionTokenPrice
: modelStats[model].promptTokenPrice;
return tokensToDollars * numTokens;
};