Compare commits
7 Commits
space-out-
...
change-mod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01343efb6a | ||
|
|
c7aaaea426 | ||
|
|
332e7afb0c | ||
|
|
fe08e29f47 | ||
|
|
89ce730e52 | ||
|
|
ad87c1b2eb | ||
|
|
58ddc72cbb |
@@ -17,9 +17,6 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
|
||||
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
||||
OPENAI_API_KEY=""
|
||||
|
||||
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
|
||||
REPLICATE_API_TOKEN=""
|
||||
|
||||
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||
|
||||
# Next Auth
|
||||
|
||||
@@ -43,8 +43,7 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
|
||||
|
||||
## Supported Models
|
||||
|
||||
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
|
||||
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
|
||||
OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
|
||||
|
||||
## Running Locally
|
||||
|
||||
|
||||
@@ -73,7 +73,6 @@
|
||||
"react-syntax-highlighter": "^15.5.0",
|
||||
"react-textarea-autosize": "^8.5.0",
|
||||
"recast": "^0.23.3",
|
||||
"replicate": "^0.12.3",
|
||||
"socket.io": "^4.7.1",
|
||||
"socket.io-client": "^4.7.1",
|
||||
"superjson": "1.12.2",
|
||||
|
||||
8
pnpm-lock.yaml
generated
8
pnpm-lock.yaml
generated
@@ -161,9 +161,6 @@ dependencies:
|
||||
recast:
|
||||
specifier: ^0.23.3
|
||||
version: 0.23.3
|
||||
replicate:
|
||||
specifier: ^0.12.3
|
||||
version: 0.12.3
|
||||
socket.io:
|
||||
specifier: ^4.7.1
|
||||
version: 4.7.1
|
||||
@@ -6991,11 +6988,6 @@ packages:
|
||||
functions-have-names: 1.2.3
|
||||
dev: true
|
||||
|
||||
/replicate@0.12.3:
|
||||
resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==}
|
||||
engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'}
|
||||
dev: false
|
||||
|
||||
/require-directory@2.1.1:
|
||||
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { type LegacyRef, useCallback } from "react";
|
||||
import Select, { type SingleValue } from "react-select";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { type Model } from "~/modelProviders/types";
|
||||
import { keyForModel } from "~/utils/utils";
|
||||
|
||||
const modelOptions: { label: string; value: Model }[] = [];
|
||||
|
||||
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
|
||||
for (const [_, modelValue] of Object.entries(providerValue.models)) {
|
||||
modelOptions.push({
|
||||
label: keyForModel(modelValue),
|
||||
value: modelValue,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const ModelSearch = ({
|
||||
selectedModel,
|
||||
setSelectedModel,
|
||||
}: {
|
||||
selectedModel: Model;
|
||||
setSelectedModel: (model: Model) => void;
|
||||
}) => {
|
||||
const handleSelection = useCallback(
|
||||
(option: SingleValue<{ label: string; value: Model }>) => {
|
||||
if (!option) return;
|
||||
setSelectedModel(option.value);
|
||||
},
|
||||
[setSelectedModel],
|
||||
);
|
||||
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
|
||||
|
||||
const [containerRef, containerDimensions] = useElementDimensions();
|
||||
|
||||
return (
|
||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||
<Text>Browse Models</Text>
|
||||
<Select
|
||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||
value={selectedOption}
|
||||
options={modelOptions}
|
||||
onChange={handleSelection}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -1,48 +0,0 @@
|
||||
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { cellPadding } from "../constants";
|
||||
import { ActionButton } from "./ScenariosHeader";
|
||||
|
||||
export default function AddVariantButton() {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.promptVariants.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const { canModify } = useExperimentAccess();
|
||||
if (!canModify) return <Box w={cellPadding.x} />;
|
||||
|
||||
return (
|
||||
<Flex w="100%" justifyContent="flex-end">
|
||||
<ActionButton
|
||||
onClick={onClick}
|
||||
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
||||
>
|
||||
Add Variant
|
||||
</ActionButton>
|
||||
{/* <Button
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={cellPadding.x}
|
||||
onClick={onClick}
|
||||
height="unset"
|
||||
minH={headerMinHeight}
|
||||
>
|
||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</Button> */}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
61
src/components/OutputsTable/NewScenarioButton.tsx
Normal file
61
src/components/OutputsTable/NewScenarioButton.tsx
Normal file
@@ -0,0 +1,61 @@
|
||||
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
// Extracted Button styling into reusable component
|
||||
const StyledButton = ({ children, onClick }: ButtonProps) => (
|
||||
<Button
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={2}
|
||||
onClick={onClick}
|
||||
>
|
||||
{children}
|
||||
</Button>
|
||||
);
|
||||
|
||||
export default function NewScenarioButton() {
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.scenarios.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onClick] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const [onAutogenerate, autogenerating] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
autogenerate: true,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
if (!canModify) return null;
|
||||
|
||||
return (
|
||||
<HStack spacing={2}>
|
||||
<StyledButton onClick={onClick}>
|
||||
<Icon as={BsPlus} boxSize={6} />
|
||||
Add Scenario
|
||||
</StyledButton>
|
||||
<StyledButton onClick={onAutogenerate}>
|
||||
<Icon
|
||||
as={autogenerating ? Spinner : BsPlus}
|
||||
boxSize={autogenerating ? 4 : 6}
|
||||
mr={autogenerating ? 2 : 0}
|
||||
/>
|
||||
Autogenerate Scenario
|
||||
</StyledButton>
|
||||
</HStack>
|
||||
);
|
||||
}
|
||||
40
src/components/OutputsTable/NewVariantButton.tsx
Normal file
40
src/components/OutputsTable/NewVariantButton.tsx
Normal file
@@ -0,0 +1,40 @@
|
||||
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
|
||||
export default function NewVariantButton() {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.promptVariants.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const { canModify } = useExperimentAccess();
|
||||
if (!canModify) return <Box w={cellPadding.x} />;
|
||||
|
||||
return (
|
||||
<Button
|
||||
w="100%"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={cellPadding.x}
|
||||
onClick={onClick}
|
||||
height="unset"
|
||||
minH={headerMinHeight}
|
||||
>
|
||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -40,7 +40,7 @@ export default function OutputCell({
|
||||
);
|
||||
|
||||
const provider =
|
||||
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
|
||||
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
|
||||
|
||||
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
||||
|
||||
@@ -88,7 +88,7 @@ export default function OutputCell({
|
||||
}
|
||||
|
||||
const normalizedOutput = modelOutput
|
||||
? provider.normalizeOutput(modelOutput.output)
|
||||
? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
||||
: streamedMessage
|
||||
? provider.normalizeOutput(streamedMessage)
|
||||
: null;
|
||||
|
||||
@@ -4,13 +4,11 @@ import { cellPadding } from "../constants";
|
||||
import OutputCell from "./OutputCell/OutputCell";
|
||||
import ScenarioEditor from "./ScenarioEditor";
|
||||
import type { PromptVariant, Scenario } from "./types";
|
||||
import { borders } from "./styles";
|
||||
|
||||
const ScenarioRow = (props: {
|
||||
scenario: Scenario;
|
||||
variants: PromptVariant[];
|
||||
canHide: boolean;
|
||||
rowStart: number;
|
||||
}) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
@@ -23,21 +21,15 @@ const ScenarioRow = (props: {
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
borderLeftWidth={1}
|
||||
{...borders}
|
||||
rowStart={props.rowStart}
|
||||
colStart={1}
|
||||
>
|
||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||
</GridItem>
|
||||
{props.variants.map((variant, i) => (
|
||||
{props.variants.map((variant) => (
|
||||
<GridItem
|
||||
key={variant.id}
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
rowStart={props.rowStart}
|
||||
colStart={i + 2}
|
||||
{...borders}
|
||||
>
|
||||
<Box h="100%" w="100%" px={cellPadding.x} py={cellPadding.y}>
|
||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||
|
||||
@@ -1,73 +1,52 @@
|
||||
import {
|
||||
Button,
|
||||
type ButtonProps,
|
||||
HStack,
|
||||
Text,
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react";
|
||||
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
|
||||
import { cellPadding } from "../constants";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
||||
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
import { BsPencil } from "react-icons/bs";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { api } from "~/utils/api";
|
||||
|
||||
export const ActionButton = (props: ButtonProps) => (
|
||||
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
||||
);
|
||||
|
||||
export const ScenariosHeader = (props: { numScenarios: number }) => {
|
||||
export const ScenariosHeader = ({
|
||||
headerRows,
|
||||
numScenarios,
|
||||
}: {
|
||||
headerRows: number;
|
||||
numScenarios: number;
|
||||
}) => {
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const experiment = useExperiment();
|
||||
const createScenarioMutation = api.scenarios.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onAddScenario, loading] = useHandledAsyncCallback(
|
||||
async (autogenerate: boolean) => {
|
||||
if (!experiment.data) return;
|
||||
await createScenarioMutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
autogenerate,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
},
|
||||
[createScenarioMutation],
|
||||
);
|
||||
const [ref, dimensions] = useElementDimensions();
|
||||
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
|
||||
|
||||
return (
|
||||
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
||||
<Text fontSize={16} fontWeight="bold">
|
||||
Scenarios ({props.numScenarios})
|
||||
</Text>
|
||||
{canModify && (
|
||||
<Menu>
|
||||
<MenuButton mt={1}>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
aria-label="Edit Scenarios"
|
||||
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||
/>
|
||||
</MenuButton>
|
||||
<MenuList fontSize="md">
|
||||
<MenuItem icon={<Icon as={BsPlus} boxSize={6} />} onClick={() => onAddScenario(false)}>
|
||||
Add Scenario
|
||||
</MenuItem>
|
||||
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
||||
Autogenerate Scenario
|
||||
</MenuItem>
|
||||
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
|
||||
Edit Vars
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
)}
|
||||
</HStack>
|
||||
<GridItem
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
ref={ref as any}
|
||||
display="flex"
|
||||
alignItems="flex-end"
|
||||
rowSpan={headerRows}
|
||||
px={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
// Only display the part of the grid item that has content
|
||||
sx={{ ...stickyHeaderStyle, top: topValue }}
|
||||
>
|
||||
<HStack w="100%">
|
||||
<Heading size="xs" fontWeight="bold" flex={1}>
|
||||
Scenarios ({numScenarios})
|
||||
</Heading>
|
||||
{canModify && (
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
color="gray.500"
|
||||
aria-label="Edit"
|
||||
leftIcon={<BsPencil />}
|
||||
onClick={openDrawer}
|
||||
>
|
||||
Edit Vars
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
</GridItem>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
|
||||
import { Grid, GridItem } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import AddVariantButton from "./AddVariantButton";
|
||||
import NewScenarioButton from "./NewScenarioButton";
|
||||
import NewVariantButton from "./NewVariantButton";
|
||||
import ScenarioRow from "./ScenarioRow";
|
||||
import VariantEditor from "./VariantEditor";
|
||||
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||
import VariantStats from "./VariantStats";
|
||||
import { ScenariosHeader } from "./ScenariosHeader";
|
||||
import { borders } from "./styles";
|
||||
import { stickyHeaderStyle } from "./styles";
|
||||
|
||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||
const variants = api.promptVariants.list.useQuery(
|
||||
@@ -21,76 +22,61 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
|
||||
if (!variants.data || !scenarios.data) return null;
|
||||
|
||||
const allCols = variants.data.length + 2;
|
||||
const variantHeaderRows = 3;
|
||||
const scenarioHeaderRows = 1;
|
||||
const allRows = variantHeaderRows + scenarioHeaderRows + scenarios.data.length;
|
||||
const allCols = variants.data.length + 1;
|
||||
const headerRows = 3;
|
||||
|
||||
return (
|
||||
<Grid
|
||||
pt={4}
|
||||
p={4}
|
||||
pb={24}
|
||||
pl={4}
|
||||
display="grid"
|
||||
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(300px, 1fr)) auto`}
|
||||
sx={{
|
||||
"> *": {
|
||||
borderColor: "gray.300",
|
||||
borderBottomWidth: 1,
|
||||
borderRightWidth: 1,
|
||||
},
|
||||
}}
|
||||
fontSize="sm"
|
||||
>
|
||||
<GridItem rowSpan={variantHeaderRows}>
|
||||
<AddVariantButton />
|
||||
</GridItem>
|
||||
|
||||
{variants.data.map((variant, i) => {
|
||||
const sharedProps: GridItemProps = {
|
||||
...borders,
|
||||
colStart: i + 2,
|
||||
borderLeftWidth: i === 0 ? 1 : 0,
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<VariantHeader
|
||||
key={variant.uiId}
|
||||
variant={variant}
|
||||
canHide={variants.data.length > 1}
|
||||
rowStart={1}
|
||||
{...sharedProps}
|
||||
/>
|
||||
<GridItem rowStart={2} {...sharedProps}>
|
||||
<VariantEditor variant={variant} />
|
||||
</GridItem>
|
||||
<GridItem rowStart={3} {...sharedProps}>
|
||||
<VariantStats variant={variant} />
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
})}
|
||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
|
||||
))}
|
||||
<GridItem
|
||||
colSpan={allCols - 1}
|
||||
rowStart={variantHeaderRows + 1}
|
||||
colStart={1}
|
||||
{...borders}
|
||||
borderRightWidth={0}
|
||||
rowSpan={scenarios.data.length + headerRows}
|
||||
padding={0}
|
||||
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
|
||||
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
|
||||
h={8}
|
||||
sx={stickyHeaderStyle}
|
||||
>
|
||||
<ScenariosHeader numScenarios={scenarios.data.length} />
|
||||
<NewVariantButton />
|
||||
</GridItem>
|
||||
|
||||
{scenarios.data.map((scenario, i) => (
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId}>
|
||||
<VariantEditor variant={variant} />
|
||||
</GridItem>
|
||||
))}
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId}>
|
||||
<VariantStats variant={variant} />
|
||||
</GridItem>
|
||||
))}
|
||||
{scenarios.data.map((scenario) => (
|
||||
<ScenarioRow
|
||||
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
||||
key={scenario.uiId}
|
||||
scenario={scenario}
|
||||
variants={variants.data}
|
||||
canHide={scenarios.data.length > 1}
|
||||
/>
|
||||
))}
|
||||
|
||||
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
||||
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
||||
<NewScenarioButton />
|
||||
</GridItem>
|
||||
</Grid>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type GridItemProps, type SystemStyleObject } from "@chakra-ui/react";
|
||||
import { type SystemStyleObject } from "@chakra-ui/react";
|
||||
|
||||
export const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
@@ -6,8 +6,3 @@ export const stickyHeaderStyle: SystemStyleObject = {
|
||||
backgroundColor: "#fff",
|
||||
zIndex: 10,
|
||||
};
|
||||
|
||||
export const borders: GridItemProps = {
|
||||
borderRightWidth: 1,
|
||||
borderBottomWidth: 1,
|
||||
};
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||
import { type IconType } from "react-icons";
|
||||
import { refineOptions, type RefineOptionLabel } from "./refineOptions";
|
||||
|
||||
export const RefineOption = ({
|
||||
label,
|
||||
icon,
|
||||
desciption,
|
||||
activeLabel,
|
||||
icon,
|
||||
onClick,
|
||||
loading,
|
||||
}: {
|
||||
label: string;
|
||||
label: RefineOptionLabel;
|
||||
activeLabel: RefineOptionLabel | undefined;
|
||||
icon: IconType;
|
||||
desciption: string;
|
||||
activeLabel: string | undefined;
|
||||
onClick: (label: string) => void;
|
||||
onClick: (label: RefineOptionLabel) => void;
|
||||
loading: boolean;
|
||||
}) => {
|
||||
const isActive = activeLabel === label;
|
||||
const desciption = refineOptions[label].description;
|
||||
|
||||
return (
|
||||
<GridItem w="80" h="44">
|
||||
|
||||
@@ -15,16 +15,17 @@ import {
|
||||
SimpleGrid,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { VscJson } from "react-icons/vsc";
|
||||
import { TfiThought } from "react-icons/tfi";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { useState } from "react";
|
||||
import CompareFunctions from "./CompareFunctions";
|
||||
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
|
||||
import { type RefineOptionLabel, refineOptions } from "./refineOptions";
|
||||
import { RefineOption } from "./RefineOption";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type SupportedProvider } from "~/modelProviders/types";
|
||||
|
||||
export const RefinePromptModal = ({
|
||||
variant,
|
||||
@@ -35,22 +36,18 @@ export const RefinePromptModal = ({
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
|
||||
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
|
||||
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
const [instructions, setInstructions] = useState<string>("");
|
||||
|
||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
|
||||
undefined,
|
||||
);
|
||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
||||
RefineOptionLabel | undefined
|
||||
>(undefined);
|
||||
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||
async (label?: string) => {
|
||||
async (label?: RefineOptionLabel) => {
|
||||
if (!variant.experimentId) return;
|
||||
const updatedInstructions = label
|
||||
? (providerRefineOptions[label] as RefineOptionInfo).instructions
|
||||
: instructions;
|
||||
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
||||
setActiveRefineOptionLabel(label);
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
@@ -95,26 +92,25 @@ export const RefinePromptModal = ({
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<VStack spacing={4}>
|
||||
{Object.keys(providerRefineOptions).length && (
|
||||
<>
|
||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||
{Object.keys(providerRefineOptions).map((label) => (
|
||||
<RefineOption
|
||||
key={label}
|
||||
label={label}
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
icon={providerRefineOptions[label]!.icon}
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
desciption={providerRefineOptions[label]!.description}
|
||||
activeLabel={activeRefineOptionLabel}
|
||||
onClick={getModifiedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
/>
|
||||
))}
|
||||
</SimpleGrid>
|
||||
<Text color="gray.500">or</Text>
|
||||
</>
|
||||
)}
|
||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||
<RefineOption
|
||||
label="Convert to function call"
|
||||
activeLabel={activeRefineOptionLabel}
|
||||
icon={VscJson}
|
||||
onClick={getModifiedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
/>
|
||||
<RefineOption
|
||||
label="Add chain of thought"
|
||||
activeLabel={activeRefineOptionLabel}
|
||||
icon={TfiThought}
|
||||
onClick={getModifiedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
/>
|
||||
</SimpleGrid>
|
||||
<HStack>
|
||||
<Text color="gray.500">or</Text>
|
||||
</HStack>
|
||||
<CustomInstructionsInput
|
||||
instructions={instructions}
|
||||
setInstructions={setInstructions}
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
// Super hacky, but we'll redo the organization when we have more models
|
||||
|
||||
import { type SupportedProvider } from "~/modelProviders/types";
|
||||
import { VscJson } from "react-icons/vsc";
|
||||
import { TfiThought } from "react-icons/tfi";
|
||||
import { type IconType } from "react-icons";
|
||||
|
||||
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
|
||||
|
||||
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
|
||||
"openai/ChatCompletion": {
|
||||
"Add chain of thought": {
|
||||
icon: VscJson,
|
||||
description: "Asking the model to plan its answer can increase accuracy.",
|
||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||
export type RefineOptionLabel = "Add chain of thought" | "Convert to function call";
|
||||
|
||||
export const refineOptions: Record<
|
||||
RefineOptionLabel,
|
||||
{ description: string; instructions: string }
|
||||
> = {
|
||||
"Add chain of thought": {
|
||||
description: "Asking the model to plan its answer can increase accuracy.",
|
||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||
|
||||
This is what a prompt looks like before adding chain of thought:
|
||||
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
@@ -59,9 +55,9 @@ export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOpt
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
@@ -93,9 +89,9 @@ export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOpt
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
|
||||
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||
},
|
||||
],
|
||||
@@ -122,14 +118,13 @@ export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOpt
|
||||
});
|
||||
|
||||
Add chain of thought to the original prompt.`,
|
||||
},
|
||||
"Convert to function call": {
|
||||
icon: TfiThought,
|
||||
description: "Use function calls to get output from the model in a more structured way.",
|
||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||
|
||||
},
|
||||
"Convert to function call": {
|
||||
description: "Use function calls to get output from the model in a more structured way.",
|
||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||
|
||||
This is what a prompt looks like before adding a function:
|
||||
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
@@ -144,9 +139,9 @@ export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOpt
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
|
||||
This is what one looks like after adding a function:
|
||||
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
@@ -192,11 +187,11 @@ export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOpt
|
||||
|
||||
title: \${scenario.title}
|
||||
body: \${scenario.body}
|
||||
|
||||
|
||||
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
|
||||
Answer one integer between 1 and 3.\`,
|
||||
},
|
||||
],
|
||||
@@ -212,9 +207,9 @@ export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOpt
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
@@ -236,52 +231,7 @@ export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOpt
|
||||
name: "score_post",
|
||||
},
|
||||
});
|
||||
|
||||
Another example
|
||||
|
||||
Before:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "write_in_language",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
text: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "write_in_language",
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||
},
|
||||
},
|
||||
"replicate/llama2": {},
|
||||
};
|
||||
|
||||
@@ -7,9 +7,11 @@ import {
|
||||
SimpleGrid,
|
||||
Link,
|
||||
} from "@chakra-ui/react";
|
||||
import { type Model } from "~/modelProviders/types";
|
||||
import { modelStats } from "~/modelProviders/modelStats";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
|
||||
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
|
||||
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
|
||||
const stats = modelStats[model];
|
||||
return (
|
||||
<VStack w="full" align="start">
|
||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||
@@ -20,14 +22,14 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Model }
|
||||
<HStack w="full" align="flex-start">
|
||||
<Text flex={1} fontSize="lg">
|
||||
<Text as="span" color="gray.600">
|
||||
{model.provider} /{" "}
|
||||
{stats.provider} /{" "}
|
||||
</Text>
|
||||
<Text as="span" fontWeight="bold" color="gray.900">
|
||||
{model.name}
|
||||
{model}
|
||||
</Text>
|
||||
</Text>
|
||||
<Link
|
||||
href={model.learnMoreUrl}
|
||||
href={stats.learnMoreUrl}
|
||||
isExternal
|
||||
color="blue.500"
|
||||
fontWeight="bold"
|
||||
@@ -44,41 +46,26 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Model }
|
||||
fontSize="sm"
|
||||
columns={{ base: 2, md: 4 }}
|
||||
>
|
||||
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
|
||||
{model.promptTokenPrice && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Input"
|
||||
info={
|
||||
<Text>
|
||||
${(model.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{model.completionTokenPrice && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Output"
|
||||
info={
|
||||
<Text>
|
||||
${(model.completionTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{model.pricePerSecond && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Price"
|
||||
info={
|
||||
<Text>
|
||||
${model.pricePerSecond.toFixed(3)}
|
||||
<Text color="gray.500"> / second</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
|
||||
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
|
||||
<SelectedModelLabeledInfo
|
||||
label="Input"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo
|
||||
label="Output"
|
||||
info={
|
||||
<Text>
|
||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</VStack>
|
||||
@@ -15,29 +15,25 @@ import {
|
||||
} from "@chakra-ui/react";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { useState } from "react";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { ModelStatsCard } from "./ModelStatsCard";
|
||||
import { ModelSearch } from "./ModelSearch";
|
||||
import { SelectModelSearch } from "./SelectModelSearch";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type Model, type SupportedProvider } from "~/modelProviders/types";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { keyForModel } from "~/utils/utils";
|
||||
|
||||
export const ChangeModelModal = ({
|
||||
export const SelectModelModal = ({
|
||||
variant,
|
||||
onClose,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const originalModelProviderName = variant.modelProvider as SupportedProvider;
|
||||
const originalModelProvider = frontendModelProviders[originalModelProviderName];
|
||||
const originalModel = originalModelProvider.models[variant.model] as Model;
|
||||
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
|
||||
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
|
||||
const originalModel = variant.model as SupportedModel;
|
||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
||||
const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined);
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
@@ -72,10 +68,6 @@ export const ChangeModelModal = ({
|
||||
onClose();
|
||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||
|
||||
const originalModelLabel = keyForModel(originalModel);
|
||||
const selectedModelLabel = keyForModel(selectedModel);
|
||||
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen
|
||||
@@ -94,16 +86,16 @@ export const ChangeModelModal = ({
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||
{originalModelLabel !== selectedModelLabel && (
|
||||
{originalModel !== selectedModel && (
|
||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||
)}
|
||||
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
{isString(modifiedPromptFn) && (
|
||||
<CompareFunctions
|
||||
originalFunction={variant.constructFn}
|
||||
newFunction={modifiedPromptFn}
|
||||
leftTitle={originalModelLabel}
|
||||
rightTitle={convertedModelLabel}
|
||||
leftTitle={originalModel}
|
||||
rightTitle={convertedModel}
|
||||
/>
|
||||
)}
|
||||
</VStack>
|
||||
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
47
src/components/SelectModelModal/SelectModelSearch.tsx
Normal file
@@ -0,0 +1,47 @@
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { type LegacyRef, useCallback } from "react";
|
||||
import Select, { type SingleValue } from "react-select";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
|
||||
const modelOptions: { value: SupportedModel; label: string }[] = [
|
||||
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
|
||||
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
|
||||
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
|
||||
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
|
||||
{ value: "gpt-4", label: "gpt-4" },
|
||||
{ value: "gpt-4-0613", label: "gpt-4-0613" },
|
||||
{ value: "gpt-4-32k", label: "gpt-4-32k" },
|
||||
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
|
||||
];
|
||||
|
||||
export const SelectModelSearch = ({
|
||||
selectedModel,
|
||||
setSelectedModel,
|
||||
}: {
|
||||
selectedModel: SupportedModel;
|
||||
setSelectedModel: (model: SupportedModel) => void;
|
||||
}) => {
|
||||
const handleSelection = useCallback(
|
||||
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
|
||||
if (!option) return;
|
||||
setSelectedModel(option.value);
|
||||
},
|
||||
[setSelectedModel],
|
||||
);
|
||||
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
|
||||
|
||||
const [containerRef, containerDimensions] = useElementDimensions();
|
||||
|
||||
return (
|
||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||
<Text>Browse Models</Text>
|
||||
<Select
|
||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||
value={selectedOption}
|
||||
options={modelOptions}
|
||||
onChange={handleSelection}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -3,34 +3,28 @@ import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { RiDraggable } from "react-icons/ri";
|
||||
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
||||
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
import { stickyHeaderStyle } from "../OutputsTable/styles";
|
||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||
|
||||
export default function VariantHeader(
|
||||
allProps: {
|
||||
variant: PromptVariant;
|
||||
canHide: boolean;
|
||||
} & GridItemProps,
|
||||
) {
|
||||
const { variant, canHide, ...gridItemProps } = allProps;
|
||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||
const [label, setLabel] = useState(variant.label);
|
||||
const [label, setLabel] = useState(props.variant.label);
|
||||
|
||||
const updateMutation = api.promptVariants.update.useMutation();
|
||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||
if (label && label !== variant.label) {
|
||||
if (label && label !== props.variant.label) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
id: props.variant.id,
|
||||
updates: { label: label },
|
||||
});
|
||||
}
|
||||
}, [updateMutation, variant.id, variant.label, label]);
|
||||
}, [updateMutation, props.variant.id, props.variant.label, label]);
|
||||
|
||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||
const [onReorder] = useHandledAsyncCallback(
|
||||
@@ -38,7 +32,7 @@ export default function VariantHeader(
|
||||
e.preventDefault();
|
||||
setIsDragTarget(false);
|
||||
const draggedId = e.dataTransfer.getData("text/plain");
|
||||
const droppedId = variant.id;
|
||||
const droppedId = props.variant.id;
|
||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||
await reorderMutation.mutateAsync({
|
||||
draggedId,
|
||||
@@ -46,16 +40,16 @@ export default function VariantHeader(
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
},
|
||||
[reorderMutation, variant.id],
|
||||
[reorderMutation, props.variant.id],
|
||||
);
|
||||
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
|
||||
if (!canModify) {
|
||||
return (
|
||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1} {...gridItemProps}>
|
||||
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||
{variant.label}
|
||||
{props.variant.label}
|
||||
</Text>
|
||||
</GridItem>
|
||||
);
|
||||
@@ -70,7 +64,6 @@ export default function VariantHeader(
|
||||
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||
}}
|
||||
borderTopWidth={1}
|
||||
{...gridItemProps}
|
||||
>
|
||||
<HStack
|
||||
spacing={4}
|
||||
@@ -78,7 +71,7 @@ export default function VariantHeader(
|
||||
minH={headerMinHeight}
|
||||
draggable={!isInputHovered}
|
||||
onDragStart={(e) => {
|
||||
e.dataTransfer.setData("text/plain", variant.id);
|
||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
||||
e.currentTarget.style.opacity = "0.4";
|
||||
}}
|
||||
onDragEnd={(e) => {
|
||||
@@ -119,8 +112,8 @@ export default function VariantHeader(
|
||||
onMouseLeave={() => setIsInputHovered(false)}
|
||||
/>
|
||||
<VariantHeaderMenuButton
|
||||
variant={variant}
|
||||
canHide={canHide}
|
||||
variant={props.variant}
|
||||
canHide={props.canHide}
|
||||
menuOpen={menuOpen}
|
||||
setMenuOpen={setMenuOpen}
|
||||
/>
|
||||
|
||||
@@ -17,7 +17,7 @@ import { FaRegClone } from "react-icons/fa";
|
||||
import { useState } from "react";
|
||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
|
||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||
|
||||
export default function VariantHeaderMenuButton({
|
||||
variant,
|
||||
@@ -50,7 +50,7 @@ export default function VariantHeaderMenuButton({
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [hideMutation, variant.id]);
|
||||
|
||||
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
|
||||
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
|
||||
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||
|
||||
return (
|
||||
@@ -72,7 +72,7 @@ export default function VariantHeaderMenuButton({
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||
onClick={() => setChangeModelModalOpen(true)}
|
||||
onClick={() => setSelectModelModalOpen(true)}
|
||||
>
|
||||
Change Model
|
||||
</MenuItem>
|
||||
@@ -97,8 +97,8 @@ export default function VariantHeaderMenuButton({
|
||||
)}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{changeModelModalOpen && (
|
||||
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
|
||||
{selectModelModalOpen && (
|
||||
<SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} />
|
||||
)}
|
||||
{refinePromptModalOpen && (
|
||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||
|
||||
@@ -9,6 +9,7 @@ export const env = createEnv({
|
||||
server: {
|
||||
DATABASE_URL: z.string().url(),
|
||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||
OPENAI_API_KEY: z.string().min(1),
|
||||
RESTRICT_PRISMA_LOGS: z
|
||||
.string()
|
||||
.optional()
|
||||
@@ -16,8 +17,6 @@ export const env = createEnv({
|
||||
.transform((val) => val.toLowerCase() === "true"),
|
||||
GITHUB_CLIENT_ID: z.string().min(1),
|
||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||
OPENAI_API_KEY: z.string().min(1),
|
||||
REPLICATE_API_TOKEN: z.string().default("placeholder"),
|
||||
},
|
||||
|
||||
/**
|
||||
@@ -43,7 +42,6 @@ export const env = createEnv({
|
||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||
},
|
||||
/**
|
||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
||||
|
||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||
|
||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||
// just include them in the default `modelProviders` object because it has some
|
||||
// transient dependencies that can only be imported on the server.
|
||||
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||
"replicate/llama2": replicateLlama2Frontend,
|
||||
};
|
||||
|
||||
export default frontendModelProviders;
|
||||
@@ -1,10 +1,7 @@
|
||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||
import replicateLlama2 from "./replicate-llama2";
|
||||
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||
|
||||
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||
const modelProviders = {
|
||||
"openai/ChatCompletion": openaiChatCompletion,
|
||||
"replicate/llama2": replicateLlama2,
|
||||
};
|
||||
} as const;
|
||||
|
||||
export default modelProviders;
|
||||
|
||||
10
src/modelProviders/modelProvidersFrontend.ts
Normal file
10
src/modelProviders/modelProvidersFrontend.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import modelProviderFrontend from "./openai-ChatCompletion/frontend";
|
||||
|
||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||
// just include them in the default `modelProviders` object because it has some
|
||||
// transient dependencies that can only be imported on the server.
|
||||
const modelProvidersFrontend = {
|
||||
"openai/ChatCompletion": modelProviderFrontend,
|
||||
} as const;
|
||||
|
||||
export default modelProvidersFrontend;
|
||||
77
src/modelProviders/modelStats.ts
Normal file
77
src/modelProviders/modelStats.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { type SupportedModel } from "../server/types";
|
||||
|
||||
interface ModelStats {
|
||||
contextLength: number;
|
||||
promptTokenPrice: number;
|
||||
completionTokenPrice: number;
|
||||
speed: "fast" | "medium" | "slow";
|
||||
provider: "OpenAI";
|
||||
learnMoreUrl: string;
|
||||
}
|
||||
|
||||
export const modelStats: Record<SupportedModel, ModelStats> = {
|
||||
"gpt-4": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-0613": {
|
||||
contextLength: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
contextLength: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
contextLength: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
contextLength: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "OpenAI",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
};
|
||||
@@ -56,14 +56,6 @@ modelProperty.type = "string";
|
||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||
delete modelProperty["oneOf"];
|
||||
|
||||
// The default of "inf" confuses the Typescript generator, so can just remove it
|
||||
assert(
|
||||
"max_tokens" in completionRequestSchema.properties &&
|
||||
isObject(completionRequestSchema.properties.max_tokens) &&
|
||||
"default" in completionRequestSchema.properties.max_tokens,
|
||||
);
|
||||
delete completionRequestSchema.properties.max_tokens["default"];
|
||||
|
||||
// Get the directory of the current script
|
||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||
|
||||
|
||||
@@ -150,6 +150,7 @@
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
||||
"default": "inf",
|
||||
"type": "integer"
|
||||
},
|
||||
"presence_penalty": {
|
||||
|
||||
@@ -1,50 +1,8 @@
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { type SupportedModel } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
|
||||
models: {
|
||||
"gpt-4-0613": {
|
||||
name: "GPT-4",
|
||||
contextWindow: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
name: "GPT-4 32k",
|
||||
contextWindow: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
name: "GPT-3.5 Turbo",
|
||||
contextWindow: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
name: "GPT-3.5 Turbo 16k",
|
||||
contextWindow: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
},
|
||||
import { type OpenaiChatModelProvider } from ".";
|
||||
import { type ModelProviderFrontend } from "../types";
|
||||
|
||||
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
||||
normalizeOutput: (output) => {
|
||||
const message = output.choices[0]?.message;
|
||||
if (!message)
|
||||
@@ -81,4 +39,4 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletio
|
||||
},
|
||||
};
|
||||
|
||||
export default frontendModelProvider;
|
||||
export default modelProviderFrontend;
|
||||
|
||||
@@ -8,10 +8,10 @@ import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { type CompletionResponse } from "../types";
|
||||
import { omit } from "lodash-es";
|
||||
import { openai } from "~/server/utils/openai";
|
||||
import { type OpenAIChatModel } from "~/server/types";
|
||||
import { truthyFilter } from "~/utils/utils";
|
||||
import { APIError } from "openai";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import modelProvider, { type SupportedModel } from ".";
|
||||
import { modelStats } from "../modelStats";
|
||||
|
||||
const mergeStreamedChunks = (
|
||||
base: ChatCompletion | null,
|
||||
@@ -60,7 +60,6 @@ export async function getCompletion(
|
||||
let finalCompletion: ChatCompletion | null = null;
|
||||
let promptTokens: number | undefined = undefined;
|
||||
let completionTokens: number | undefined = undefined;
|
||||
const modelName = modelProvider.getModel(input) as SupportedModel;
|
||||
|
||||
try {
|
||||
if (onStream) {
|
||||
@@ -82,9 +81,12 @@ export async function getCompletion(
|
||||
};
|
||||
}
|
||||
try {
|
||||
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
||||
promptTokens = countOpenAIChatTokens(
|
||||
input.model as keyof typeof OpenAIChatModel,
|
||||
input.messages,
|
||||
);
|
||||
completionTokens = countOpenAIChatTokens(
|
||||
modelName,
|
||||
input.model as keyof typeof OpenAIChatModel,
|
||||
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
||||
);
|
||||
} catch (err) {
|
||||
@@ -104,10 +106,10 @@ export async function getCompletion(
|
||||
}
|
||||
const timeToComplete = Date.now() - start;
|
||||
|
||||
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
||||
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
|
||||
let cost = undefined;
|
||||
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
||||
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
||||
if (stats && promptTokens && completionTokens) {
|
||||
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -3,7 +3,6 @@ import { type ModelProvider } from "../types";
|
||||
import inputSchema from "./codegen/input.schema.json";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
import frontendModelProvider from "./frontend";
|
||||
|
||||
const supportedModels = [
|
||||
"gpt-4-0613",
|
||||
@@ -12,7 +11,7 @@ const supportedModels = [
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
] as const;
|
||||
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type OpenaiChatModelProvider = ModelProvider<
|
||||
SupportedModel,
|
||||
@@ -21,6 +20,25 @@ export type OpenaiChatModelProvider = ModelProvider<
|
||||
>;
|
||||
|
||||
const modelProvider: OpenaiChatModelProvider = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
models: {
|
||||
"gpt-4-0613": {
|
||||
name: "GPT-4",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
name: "GPT-4 32k",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
name: "GPT-3.5 Turbo",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
name: "GPT-3.5 Turbo 16k",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
},
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model as SupportedModel))
|
||||
return input.model as SupportedModel;
|
||||
@@ -39,7 +57,6 @@ const modelProvider: OpenaiChatModelProvider = {
|
||||
inputSchema: inputSchema as JSONSchema4,
|
||||
shouldStream: (input) => input.stream ?? false,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||
name: "Replicate Llama2",
|
||||
|
||||
models: {
|
||||
"7b-chat": {
|
||||
name: "LLama 2 7B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0023,
|
||||
speed: "fast",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
|
||||
},
|
||||
"13b-chat": {
|
||||
name: "LLama 2 13B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0023,
|
||||
speed: "medium",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
|
||||
},
|
||||
"70b-chat": {
|
||||
name: "LLama 2 70B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0032,
|
||||
speed: "slow",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
|
||||
},
|
||||
},
|
||||
|
||||
normalizeOutput: (output) => {
|
||||
return {
|
||||
type: "text",
|
||||
value: output.join(""),
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
export default frontendModelProvider;
|
||||
@@ -1,62 +0,0 @@
|
||||
import { env } from "~/env.mjs";
|
||||
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
|
||||
import { type CompletionResponse } from "../types";
|
||||
import Replicate from "replicate";
|
||||
|
||||
const replicate = new Replicate({
|
||||
auth: env.REPLICATE_API_TOKEN || "",
|
||||
});
|
||||
|
||||
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
||||
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
input: ReplicateLlama2Input,
|
||||
onStream: ((partialOutput: string[]) => void) | null,
|
||||
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||
const start = Date.now();
|
||||
|
||||
const { model, stream, ...rest } = input;
|
||||
|
||||
try {
|
||||
const prediction = await replicate.predictions.create({
|
||||
version: modelIds[model],
|
||||
input: rest,
|
||||
});
|
||||
|
||||
console.log("stream?", onStream);
|
||||
|
||||
const interval = onStream
|
||||
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||
setInterval(async () => {
|
||||
const partialPrediction = await replicate.predictions.get(prediction.id);
|
||||
|
||||
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
|
||||
}, 500)
|
||||
: null;
|
||||
|
||||
const resp = await replicate.wait(prediction, {});
|
||||
if (interval) clearInterval(interval);
|
||||
|
||||
const timeToComplete = Date.now() - start;
|
||||
|
||||
if (resp.error) throw new Error(resp.error as string);
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
statusCode: 200,
|
||||
value: resp.output as ReplicateLlama2Output,
|
||||
timeToComplete,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
console.error("ERROR IS", error);
|
||||
return {
|
||||
type: "error",
|
||||
message: (error as Error).message,
|
||||
autoRetry: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
import { type ModelProvider } from "../types";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
|
||||
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type ReplicateLlama2Input = {
|
||||
model: SupportedModel;
|
||||
prompt: string;
|
||||
stream?: boolean;
|
||||
max_length?: number;
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
repetition_penalty?: number;
|
||||
debug?: boolean;
|
||||
};
|
||||
|
||||
export type ReplicateLlama2Output = string[];
|
||||
|
||||
export type ReplicateLlama2Provider = ModelProvider<
|
||||
SupportedModel,
|
||||
ReplicateLlama2Input,
|
||||
ReplicateLlama2Output
|
||||
>;
|
||||
|
||||
const modelProvider: ReplicateLlama2Provider = {
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model)) return input.model;
|
||||
|
||||
return null;
|
||||
},
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
model: {
|
||||
type: "string",
|
||||
enum: supportedModels as unknown as string[],
|
||||
},
|
||||
prompt: {
|
||||
type: "string",
|
||||
},
|
||||
stream: {
|
||||
type: "boolean",
|
||||
},
|
||||
max_length: {
|
||||
type: "number",
|
||||
},
|
||||
temperature: {
|
||||
type: "number",
|
||||
},
|
||||
top_p: {
|
||||
type: "number",
|
||||
},
|
||||
repetition_penalty: {
|
||||
type: "number",
|
||||
},
|
||||
debug: {
|
||||
type: "boolean",
|
||||
},
|
||||
},
|
||||
required: ["model", "prompt"],
|
||||
},
|
||||
shouldStream: (input) => input.stream ?? false,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
@@ -1,33 +1,9 @@
|
||||
import { type JSONSchema4 } from "json-schema";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { z } from "zod";
|
||||
|
||||
const ZodSupportedProvider = z.union([
|
||||
z.literal("openai/ChatCompletion"),
|
||||
z.literal("replicate/llama2"),
|
||||
]);
|
||||
|
||||
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||
|
||||
export const ZodModel = z.object({
|
||||
name: z.string(),
|
||||
contextWindow: z.number(),
|
||||
promptTokenPrice: z.number().optional(),
|
||||
completionTokenPrice: z.number().optional(),
|
||||
pricePerSecond: z.number().optional(),
|
||||
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
|
||||
provider: ZodSupportedProvider,
|
||||
description: z.string().optional(),
|
||||
learnMoreUrl: z.string().optional(),
|
||||
});
|
||||
|
||||
export type Model = z.infer<typeof ZodModel>;
|
||||
|
||||
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||
type ModelProviderModel = {
|
||||
name: string;
|
||||
models: Record<SupportedModels, Model>;
|
||||
|
||||
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||
learnMore: string;
|
||||
};
|
||||
|
||||
export type CompletionResponse<T> =
|
||||
@@ -43,6 +19,8 @@ export type CompletionResponse<T> =
|
||||
};
|
||||
|
||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||
name: string;
|
||||
models: Record<SupportedModels, ModelProviderModel>;
|
||||
getModel: (input: InputSchema) => SupportedModels | null;
|
||||
shouldStream: (input: InputSchema) => boolean;
|
||||
inputSchema: JSONSchema4;
|
||||
@@ -53,7 +31,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
|
||||
|
||||
// This is just a convenience for type inference, don't use it at runtime
|
||||
_outputSchema?: OutputSchema | null;
|
||||
} & FrontendModelProvider<SupportedModels, OutputSchema>;
|
||||
};
|
||||
|
||||
export type NormalizedOutput =
|
||||
| {
|
||||
@@ -64,3 +42,7 @@ export type NormalizedOutput =
|
||||
type: "json";
|
||||
value: JsonValue;
|
||||
};
|
||||
|
||||
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
|
||||
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
|
||||
};
|
||||
|
||||
@@ -109,7 +109,8 @@ export const experimentsRouter = createTRPCRouter({
|
||||
constructFn: dedent`
|
||||
/**
|
||||
* Use Javascript to define an OpenAI chat completion
|
||||
* (https://platform.openai.com/docs/api-reference/chat/create).
|
||||
* (https://platform.openai.com/docs/api-reference/chat/create) and
|
||||
* assign it to the \`prompt\` variable.
|
||||
*
|
||||
* You have access to the current scenario in the \`scenario\`
|
||||
* variable.
|
||||
|
||||
@@ -2,6 +2,7 @@ import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
@@ -9,7 +10,6 @@ import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||
import { ZodModel } from "~/modelProviders/types";
|
||||
|
||||
export const promptVariantsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
@@ -144,7 +144,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
variantId: z.string().optional(),
|
||||
newModel: ZodModel.optional(),
|
||||
newModel: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
@@ -186,7 +186,10 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
|
||||
const newConstructFn = await deriveNewConstructFn(
|
||||
originalVariant,
|
||||
input.newModel as SupportedModel,
|
||||
);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
@@ -286,7 +289,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
z.object({
|
||||
id: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
newModel: ZodModel.optional(),
|
||||
newModel: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
@@ -305,7 +308,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(
|
||||
existing,
|
||||
input.newModel,
|
||||
input.newModel as SupportedModel | undefined,
|
||||
input.instructions,
|
||||
);
|
||||
|
||||
|
||||
@@ -34,21 +34,22 @@ export const scenariosRouter = createTRPCRouter({
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.testScenario.updateMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
data: {
|
||||
sortIndex: {
|
||||
increment: 1,
|
||||
},
|
||||
},
|
||||
});
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.testScenario.aggregate({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max.sortIndex ?? 0;
|
||||
|
||||
const createNewScenarioAction = prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
sortIndex: 0,
|
||||
sortIndex: maxSortIndex + 1,
|
||||
variableValues: input.autogenerate
|
||||
? await autogenerateScenarioValues(input.experimentId)
|
||||
: {},
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
/* eslint-disable */
|
||||
// /* eslint-disable */
|
||||
|
||||
import "dotenv/config";
|
||||
import Replicate from "replicate";
|
||||
// import "dotenv/config";
|
||||
// import Replicate from "replicate";
|
||||
|
||||
const replicate = new Replicate({
|
||||
auth: process.env.REPLICATE_API_TOKEN || "",
|
||||
});
|
||||
// const replicate = new Replicate({
|
||||
// auth: process.env.REPLICATE_API_TOKEN || "",
|
||||
// });
|
||||
|
||||
console.log("going to run");
|
||||
const prediction = await replicate.predictions.create({
|
||||
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||
input: {
|
||||
prompt: "...",
|
||||
},
|
||||
});
|
||||
// console.log("going to run");
|
||||
// const prediction = await replicate.predictions.create({
|
||||
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||
// input: {
|
||||
// prompt: "...",
|
||||
// },
|
||||
// });
|
||||
|
||||
console.log("waiting");
|
||||
setInterval(() => {
|
||||
replicate.predictions.get(prediction.id).then((prediction) => {
|
||||
console.log(prediction);
|
||||
});
|
||||
}, 500);
|
||||
// const output = await replicate.wait(prediction, {});
|
||||
// console.log("waiting");
|
||||
// setInterval(() => {
|
||||
// replicate.predictions.get(prediction.id).then((prediction) => {
|
||||
// console.log(prediction.output);
|
||||
// });
|
||||
// }, 500);
|
||||
// // const output = await replicate.wait(prediction, {});
|
||||
|
||||
// console.log(output);
|
||||
// // console.log(output);
|
||||
|
||||
@@ -123,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
data: {
|
||||
scenarioVariantCellId,
|
||||
inputHash,
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
output: response.value as unknown as Prisma.InputJsonObject,
|
||||
timeToComplete: response.timeToComplete,
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
@@ -151,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
errorMessage: response.message,
|
||||
statusCode: response.statusCode,
|
||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||
retrievalStatus: "ERROR",
|
||||
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
12
src/server/types.ts
Normal file
12
src/server/types.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
export enum OpenAIChatModel {
|
||||
"gpt-4" = "gpt-4",
|
||||
"gpt-4-0613" = "gpt-4-0613",
|
||||
"gpt-4-32k" = "gpt-4-32k",
|
||||
"gpt-4-32k-0613" = "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
||||
}
|
||||
|
||||
export type SupportedModel = keyof typeof OpenAIChatModel;
|
||||
@@ -1,18 +1,18 @@
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { type SupportedModel } from "../types";
|
||||
import ivm from "isolated-vm";
|
||||
import dedent from "dedent";
|
||||
import { openai } from "./openai";
|
||||
import { getApiShapeForModel } from "./getTypesForModel";
|
||||
import { isObject } from "lodash-es";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
||||
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||
import { type SupportedProvider, type Model } from "~/modelProviders/types";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function deriveNewConstructFn(
|
||||
originalVariant: PromptVariant | null,
|
||||
newModel?: Model,
|
||||
newModel?: SupportedModel,
|
||||
instructions?: string,
|
||||
) {
|
||||
if (originalVariant && !newModel && !instructions) {
|
||||
@@ -36,11 +36,10 @@ export async function deriveNewConstructFn(
|
||||
const NUM_RETRIES = 5;
|
||||
const requestUpdatedPromptFunction = async (
|
||||
originalVariant: PromptVariant,
|
||||
newModel?: Model,
|
||||
newModel?: SupportedModel,
|
||||
instructions?: string,
|
||||
) => {
|
||||
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
|
||||
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
|
||||
const originalModel = originalVariant.model as SupportedModel;
|
||||
let newContructionFn = "";
|
||||
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||
try {
|
||||
@@ -48,33 +47,17 @@ const requestUpdatedPromptFunction = async (
|
||||
{
|
||||
role: "system",
|
||||
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
||||
originalModelProvider.inputSchema,
|
||||
getApiShapeForModel(originalModel),
|
||||
null,
|
||||
2,
|
||||
)}\n\nDo not add any assistant messages.`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
|
||||
},
|
||||
];
|
||||
if (newModel) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
|
||||
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
|
||||
});
|
||||
if (newModel.provider !== originalModel.provider) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `The old provider was ${originalModel.provider}. The new provider is ${
|
||||
newModel.provider
|
||||
}. Here is the schema for the new model:\n---\n${JSON.stringify(
|
||||
modelProviders[newModel.provider].inputSchema,
|
||||
null,
|
||||
2,
|
||||
)}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (instructions) {
|
||||
messages.push({
|
||||
@@ -82,6 +65,10 @@ const requestUpdatedPromptFunction = async (
|
||||
content: instructions,
|
||||
});
|
||||
}
|
||||
messages.push({
|
||||
role: "system",
|
||||
content: "The prompt variable has already been declared, so do not declare it again.",
|
||||
});
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages,
|
||||
|
||||
@@ -4,9 +4,8 @@ import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||
import parseConstructFn from "./parseConstructFn";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import hashPrompt from "./hashPrompt";
|
||||
import { omit } from "lodash-es";
|
||||
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: variantId,
|
||||
@@ -19,7 +18,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return;
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
@@ -33,7 +32,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
||||
},
|
||||
});
|
||||
|
||||
if (cell) return;
|
||||
if (cell) return cell;
|
||||
|
||||
const parsedConstructFn = await parseConstructFn(
|
||||
variant.constructFn,
|
||||
@@ -41,7 +40,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
||||
);
|
||||
|
||||
if ("error" in parsedConstructFn) {
|
||||
await prisma.scenarioVariantCell.create({
|
||||
return await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
@@ -50,7 +49,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const inputHash = hashPrompt(parsedConstructFn);
|
||||
@@ -71,33 +69,29 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
||||
where: { inputHash },
|
||||
});
|
||||
|
||||
let newModelOutput;
|
||||
|
||||
if (matchingModelOutput) {
|
||||
const newModelOutput = await prisma.modelOutput.create({
|
||||
newModelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
...omit(matchingModelOutput, ["id"]),
|
||||
scenarioVariantCellId: cell.id,
|
||||
inputHash,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
timeToComplete: matchingModelOutput.timeToComplete,
|
||||
cost: matchingModelOutput.cost,
|
||||
promptTokens: matchingModelOutput.promptTokens,
|
||||
completionTokens: matchingModelOutput.completionTokens,
|
||||
createdAt: matchingModelOutput.createdAt,
|
||||
updatedAt: matchingModelOutput.updatedAt,
|
||||
},
|
||||
});
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "COMPLETE" },
|
||||
});
|
||||
|
||||
// Copy over all eval results as well
|
||||
await Promise.all(
|
||||
(
|
||||
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
|
||||
).map(async (evaluation) => {
|
||||
await prisma.outputEvaluation.create({
|
||||
data: {
|
||||
...omit(evaluation, ["id"]),
|
||||
modelOutputId: newModelOutput.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
cell = await queueLLMRetrievalTask(cell.id);
|
||||
}
|
||||
|
||||
return { ...cell, modelOutput: newModelOutput };
|
||||
};
|
||||
|
||||
6
src/server/utils/getTypesForModel.ts
Normal file
6
src/server/utils/getTypesForModel.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { type SupportedModel } from "../types";
|
||||
|
||||
export const getApiShapeForModel = (model: SupportedModel) => {
|
||||
// if (model in OpenAIChatModel) return openAIChatApiShape;
|
||||
return "";
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { GPTTokens } from "gpt-tokens";
|
||||
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
|
||||
import { type OpenAIChatModel } from "~/server/types";
|
||||
|
||||
interface GPTTokensMessageItem {
|
||||
name?: string;
|
||||
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
|
||||
}
|
||||
|
||||
export const countOpenAIChatTokens = (
|
||||
model: SupportedModel,
|
||||
model: keyof typeof OpenAIChatModel,
|
||||
messages: ChatCompletion.Choice.Message[],
|
||||
) => {
|
||||
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
||||
|
||||
@@ -1,5 +1 @@
|
||||
import { type Model } from "~/modelProviders/types";
|
||||
|
||||
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
||||
|
||||
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;
|
||||
|
||||
Reference in New Issue
Block a user