Add support for switching to Llama models (#80)
* Add support for switching to Llama models * Fix prettier
This commit is contained in:
@@ -15,25 +15,29 @@ import {
|
|||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import { ModelStatsCard } from "./ModelStatsCard";
|
import { ModelStatsCard } from "./ModelStatsCard";
|
||||||
import { SelectModelSearch } from "./SelectModelSearch";
|
import { ModelSearch } from "./ModelSearch";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
|
import { type Model, type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { keyForModel } from "~/utils/utils";
|
||||||
|
|
||||||
export const SelectModelModal = ({
|
export const ChangeModelModal = ({
|
||||||
variant,
|
variant,
|
||||||
onClose,
|
onClose,
|
||||||
}: {
|
}: {
|
||||||
variant: PromptVariant;
|
variant: PromptVariant;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
const originalModel = variant.model as SupportedModel;
|
const originalModelProviderName = variant.modelProvider as SupportedProvider;
|
||||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
const originalModelProvider = frontendModelProviders[originalModelProviderName];
|
||||||
const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined);
|
const originalModel = originalModelProvider.models[variant.model] as Model;
|
||||||
|
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
|
||||||
|
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
@@ -68,6 +72,10 @@ export const SelectModelModal = ({
|
|||||||
onClose();
|
onClose();
|
||||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||||
|
|
||||||
|
const originalModelLabel = keyForModel(originalModel);
|
||||||
|
const selectedModelLabel = keyForModel(selectedModel);
|
||||||
|
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
isOpen
|
isOpen
|
||||||
@@ -86,16 +94,16 @@ export const SelectModelModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||||
{originalModel !== selectedModel && (
|
{originalModelLabel !== selectedModelLabel && (
|
||||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||||
)}
|
)}
|
||||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
{isString(modifiedPromptFn) && (
|
{isString(modifiedPromptFn) && (
|
||||||
<CompareFunctions
|
<CompareFunctions
|
||||||
originalFunction={variant.constructFn}
|
originalFunction={variant.constructFn}
|
||||||
newFunction={modifiedPromptFn}
|
newFunction={modifiedPromptFn}
|
||||||
leftTitle={originalModel}
|
leftTitle={originalModelLabel}
|
||||||
rightTitle={convertedModel}
|
rightTitle={convertedModelLabel}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
50
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
50
src/components/ChangeModelModal/ModelSearch.tsx
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import { VStack, Text } from "@chakra-ui/react";
|
||||||
|
import { type LegacyRef, useCallback } from "react";
|
||||||
|
import Select, { type SingleValue } from "react-select";
|
||||||
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
|
||||||
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { type Model } from "~/modelProviders/types";
|
||||||
|
import { keyForModel } from "~/utils/utils";
|
||||||
|
|
||||||
|
const modelOptions: { label: string; value: Model }[] = [];
|
||||||
|
|
||||||
|
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
|
||||||
|
for (const [_, modelValue] of Object.entries(providerValue.models)) {
|
||||||
|
modelOptions.push({
|
||||||
|
label: keyForModel(modelValue),
|
||||||
|
value: modelValue,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ModelSearch = ({
|
||||||
|
selectedModel,
|
||||||
|
setSelectedModel,
|
||||||
|
}: {
|
||||||
|
selectedModel: Model;
|
||||||
|
setSelectedModel: (model: Model) => void;
|
||||||
|
}) => {
|
||||||
|
const handleSelection = useCallback(
|
||||||
|
(option: SingleValue<{ label: string; value: Model }>) => {
|
||||||
|
if (!option) return;
|
||||||
|
setSelectedModel(option.value);
|
||||||
|
},
|
||||||
|
[setSelectedModel],
|
||||||
|
);
|
||||||
|
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
|
||||||
|
|
||||||
|
const [containerRef, containerDimensions] = useElementDimensions();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||||
|
<Text>Browse Models</Text>
|
||||||
|
<Select
|
||||||
|
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||||
|
value={selectedOption}
|
||||||
|
options={modelOptions}
|
||||||
|
onChange={handleSelection}
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -7,11 +7,9 @@ import {
|
|||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
Link,
|
Link,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { modelStats } from "~/modelProviders/modelStats";
|
import { type Model } from "~/modelProviders/types";
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
|
|
||||||
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
|
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
|
||||||
const stats = modelStats[model];
|
|
||||||
return (
|
return (
|
||||||
<VStack w="full" align="start">
|
<VStack w="full" align="start">
|
||||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||||
@@ -22,14 +20,14 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
|
|||||||
<HStack w="full" align="flex-start">
|
<HStack w="full" align="flex-start">
|
||||||
<Text flex={1} fontSize="lg">
|
<Text flex={1} fontSize="lg">
|
||||||
<Text as="span" color="gray.600">
|
<Text as="span" color="gray.600">
|
||||||
{stats.provider} /{" "}
|
{model.provider} /{" "}
|
||||||
</Text>
|
</Text>
|
||||||
<Text as="span" fontWeight="bold" color="gray.900">
|
<Text as="span" fontWeight="bold" color="gray.900">
|
||||||
{model}
|
{model.name}
|
||||||
</Text>
|
</Text>
|
||||||
</Text>
|
</Text>
|
||||||
<Link
|
<Link
|
||||||
href={stats.learnMoreUrl}
|
href={model.learnMoreUrl}
|
||||||
isExternal
|
isExternal
|
||||||
color="blue.500"
|
color="blue.500"
|
||||||
fontWeight="bold"
|
fontWeight="bold"
|
||||||
@@ -46,26 +44,41 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
|
|||||||
fontSize="sm"
|
fontSize="sm"
|
||||||
columns={{ base: 2, md: 4 }}
|
columns={{ base: 2, md: 4 }}
|
||||||
>
|
>
|
||||||
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
|
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
|
||||||
<SelectedModelLabeledInfo
|
{model.promptTokenPrice && (
|
||||||
label="Input"
|
<SelectedModelLabeledInfo
|
||||||
info={
|
label="Input"
|
||||||
<Text>
|
info={
|
||||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
<Text>
|
||||||
<Text color="gray.500"> / 1K tokens</Text>
|
${(model.promptTokenPrice * 1000).toFixed(3)}
|
||||||
</Text>
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
}
|
</Text>
|
||||||
/>
|
}
|
||||||
<SelectedModelLabeledInfo
|
/>
|
||||||
label="Output"
|
)}
|
||||||
info={
|
{model.completionTokenPrice && (
|
||||||
<Text>
|
<SelectedModelLabeledInfo
|
||||||
${(stats.promptTokenPrice * 1000).toFixed(3)}
|
label="Output"
|
||||||
<Text color="gray.500"> / 1K tokens</Text>
|
info={
|
||||||
</Text>
|
<Text>
|
||||||
}
|
${(model.completionTokenPrice * 1000).toFixed(3)}
|
||||||
/>
|
<Text color="gray.500"> / 1K tokens</Text>
|
||||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{model.pricePerSecond && (
|
||||||
|
<SelectedModelLabeledInfo
|
||||||
|
label="Price"
|
||||||
|
info={
|
||||||
|
<Text>
|
||||||
|
${model.pricePerSecond.toFixed(3)}
|
||||||
|
<Text color="gray.500"> / second</Text>
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</VStack>
|
</VStack>
|
||||||
</VStack>
|
</VStack>
|
||||||
@@ -1,22 +1,22 @@
|
|||||||
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||||
import { type IconType } from "react-icons";
|
import { type IconType } from "react-icons";
|
||||||
import { refineOptions, type RefineOptionLabel } from "./refineOptions";
|
|
||||||
|
|
||||||
export const RefineOption = ({
|
export const RefineOption = ({
|
||||||
label,
|
label,
|
||||||
activeLabel,
|
|
||||||
icon,
|
icon,
|
||||||
|
desciption,
|
||||||
|
activeLabel,
|
||||||
onClick,
|
onClick,
|
||||||
loading,
|
loading,
|
||||||
}: {
|
}: {
|
||||||
label: RefineOptionLabel;
|
label: string;
|
||||||
activeLabel: RefineOptionLabel | undefined;
|
|
||||||
icon: IconType;
|
icon: IconType;
|
||||||
onClick: (label: RefineOptionLabel) => void;
|
desciption: string;
|
||||||
|
activeLabel: string | undefined;
|
||||||
|
onClick: (label: string) => void;
|
||||||
loading: boolean;
|
loading: boolean;
|
||||||
}) => {
|
}) => {
|
||||||
const isActive = activeLabel === label;
|
const isActive = activeLabel === label;
|
||||||
const desciption = refineOptions[label].description;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<GridItem w="80" h="44">
|
<GridItem w="80" h="44">
|
||||||
|
|||||||
@@ -15,17 +15,16 @@ import {
|
|||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsStars } from "react-icons/bs";
|
import { BsStars } from "react-icons/bs";
|
||||||
import { VscJson } from "react-icons/vsc";
|
|
||||||
import { TfiThought } from "react-icons/tfi";
|
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import CompareFunctions from "./CompareFunctions";
|
import CompareFunctions from "./CompareFunctions";
|
||||||
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
||||||
import { type RefineOptionLabel, refineOptions } from "./refineOptions";
|
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
|
||||||
import { RefineOption } from "./RefineOption";
|
import { RefineOption } from "./RefineOption";
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const RefinePromptModal = ({
|
export const RefinePromptModal = ({
|
||||||
variant,
|
variant,
|
||||||
@@ -36,18 +35,22 @@ export const RefinePromptModal = ({
|
|||||||
}) => {
|
}) => {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
|
||||||
|
|
||||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
const [instructions, setInstructions] = useState<string>("");
|
const [instructions, setInstructions] = useState<string>("");
|
||||||
|
|
||||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
|
||||||
RefineOptionLabel | undefined
|
undefined,
|
||||||
>(undefined);
|
);
|
||||||
|
|
||||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||||
async (label?: RefineOptionLabel) => {
|
async (label?: string) => {
|
||||||
if (!variant.experimentId) return;
|
if (!variant.experimentId) return;
|
||||||
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
const updatedInstructions = label
|
||||||
|
? (providerRefineOptions[label] as RefineOptionInfo).instructions
|
||||||
|
: instructions;
|
||||||
setActiveRefineOptionLabel(label);
|
setActiveRefineOptionLabel(label);
|
||||||
await getModifiedPromptMutateAsync({
|
await getModifiedPromptMutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
@@ -92,25 +95,26 @@ export const RefinePromptModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<VStack spacing={4}>
|
<VStack spacing={4}>
|
||||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
{Object.keys(providerRefineOptions).length && (
|
||||||
<RefineOption
|
<>
|
||||||
label="Convert to function call"
|
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||||
activeLabel={activeRefineOptionLabel}
|
{Object.keys(providerRefineOptions).map((label) => (
|
||||||
icon={VscJson}
|
<RefineOption
|
||||||
onClick={getModifiedPromptFn}
|
key={label}
|
||||||
loading={modificationInProgress}
|
label={label}
|
||||||
/>
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
<RefineOption
|
icon={providerRefineOptions[label]!.icon}
|
||||||
label="Add chain of thought"
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
activeLabel={activeRefineOptionLabel}
|
desciption={providerRefineOptions[label]!.description}
|
||||||
icon={TfiThought}
|
activeLabel={activeRefineOptionLabel}
|
||||||
onClick={getModifiedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
loading={modificationInProgress}
|
loading={modificationInProgress}
|
||||||
/>
|
/>
|
||||||
</SimpleGrid>
|
))}
|
||||||
<HStack>
|
</SimpleGrid>
|
||||||
<Text color="gray.500">or</Text>
|
<Text color="gray.500">or</Text>
|
||||||
</HStack>
|
</>
|
||||||
|
)}
|
||||||
<CustomInstructionsInput
|
<CustomInstructionsInput
|
||||||
instructions={instructions}
|
instructions={instructions}
|
||||||
setInstructions={setInstructions}
|
setInstructions={setInstructions}
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
// Super hacky, but we'll redo the organization when we have more models
|
// Super hacky, but we'll redo the organization when we have more models
|
||||||
|
|
||||||
export type RefineOptionLabel = "Add chain of thought" | "Convert to function call";
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
import { VscJson } from "react-icons/vsc";
|
||||||
|
import { TfiThought } from "react-icons/tfi";
|
||||||
|
import { type IconType } from "react-icons";
|
||||||
|
|
||||||
|
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
|
||||||
|
|
||||||
|
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
|
||||||
|
"openai/ChatCompletion": {
|
||||||
|
"Add chain of thought": {
|
||||||
|
icon: VscJson,
|
||||||
|
description: "Asking the model to plan its answer can increase accuracy.",
|
||||||
|
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||||
|
|
||||||
export const refineOptions: Record<
|
|
||||||
RefineOptionLabel,
|
|
||||||
{ description: string; instructions: string }
|
|
||||||
> = {
|
|
||||||
"Add chain of thought": {
|
|
||||||
description: "Asking the model to plan its answer can increase accuracy.",
|
|
||||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
|
||||||
|
|
||||||
This is what a prompt looks like before adding chain of thought:
|
This is what a prompt looks like before adding chain of thought:
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
@@ -55,9 +59,9 @@ export const refineOptions: Record<
|
|||||||
role: "user",
|
role: "user",
|
||||||
content: \`Title: \${scenario.title}
|
content: \`Title: \${scenario.title}
|
||||||
Body: \${scenario.body}
|
Body: \${scenario.body}
|
||||||
|
|
||||||
Need: \${scenario.need}
|
Need: \${scenario.need}
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale.\`,
|
Rate likelihood on 1-3 scale.\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -89,9 +93,9 @@ export const refineOptions: Record<
|
|||||||
role: "user",
|
role: "user",
|
||||||
content: \`Title: \${scenario.title}
|
content: \`Title: \${scenario.title}
|
||||||
Body: \${scenario.body}
|
Body: \${scenario.body}
|
||||||
|
|
||||||
Need: \${scenario.need}
|
Need: \${scenario.need}
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -118,13 +122,14 @@ export const refineOptions: Record<
|
|||||||
});
|
});
|
||||||
|
|
||||||
Add chain of thought to the original prompt.`,
|
Add chain of thought to the original prompt.`,
|
||||||
},
|
},
|
||||||
"Convert to function call": {
|
"Convert to function call": {
|
||||||
description: "Use function calls to get output from the model in a more structured way.",
|
icon: TfiThought,
|
||||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
description: "Use function calls to get output from the model in a more structured way.",
|
||||||
|
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||||
|
|
||||||
This is what a prompt looks like before adding a function:
|
This is what a prompt looks like before adding a function:
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
@@ -139,9 +144,9 @@ export const refineOptions: Record<
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
This is what one looks like after adding a function:
|
This is what one looks like after adding a function:
|
||||||
|
|
||||||
definePrompt("openai/ChatCompletion", {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
@@ -187,11 +192,11 @@ export const refineOptions: Record<
|
|||||||
|
|
||||||
title: \${scenario.title}
|
title: \${scenario.title}
|
||||||
body: \${scenario.body}
|
body: \${scenario.body}
|
||||||
|
|
||||||
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||||
|
|
||||||
Need: \${scenario.need}
|
Need: \${scenario.need}
|
||||||
|
|
||||||
Answer one integer between 1 and 3.\`,
|
Answer one integer between 1 and 3.\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -207,9 +212,9 @@ export const refineOptions: Record<
|
|||||||
role: "user",
|
role: "user",
|
||||||
content: \`Title: \${scenario.title}
|
content: \`Title: \${scenario.title}
|
||||||
Body: \${scenario.body}
|
Body: \${scenario.body}
|
||||||
|
|
||||||
Need: \${scenario.need}
|
Need: \${scenario.need}
|
||||||
|
|
||||||
Rate likelihood on 1-3 scale.\`,
|
Rate likelihood on 1-3 scale.\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -231,7 +236,9 @@ export const refineOptions: Record<
|
|||||||
name: "score_post",
|
name: "score_post",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
"replicate/llama2": {},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
import { VStack, Text } from "@chakra-ui/react";
|
|
||||||
import { type LegacyRef, useCallback } from "react";
|
|
||||||
import Select, { type SingleValue } from "react-select";
|
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import { useElementDimensions } from "~/utils/hooks";
|
|
||||||
|
|
||||||
const modelOptions: { value: SupportedModel; label: string }[] = [
|
|
||||||
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
|
|
||||||
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
|
|
||||||
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
|
|
||||||
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
|
|
||||||
{ value: "gpt-4", label: "gpt-4" },
|
|
||||||
{ value: "gpt-4-0613", label: "gpt-4-0613" },
|
|
||||||
{ value: "gpt-4-32k", label: "gpt-4-32k" },
|
|
||||||
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
|
|
||||||
];
|
|
||||||
|
|
||||||
export const SelectModelSearch = ({
|
|
||||||
selectedModel,
|
|
||||||
setSelectedModel,
|
|
||||||
}: {
|
|
||||||
selectedModel: SupportedModel;
|
|
||||||
setSelectedModel: (model: SupportedModel) => void;
|
|
||||||
}) => {
|
|
||||||
const handleSelection = useCallback(
|
|
||||||
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
|
|
||||||
if (!option) return;
|
|
||||||
setSelectedModel(option.value);
|
|
||||||
},
|
|
||||||
[setSelectedModel],
|
|
||||||
);
|
|
||||||
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
|
|
||||||
|
|
||||||
const [containerRef, containerDimensions] = useElementDimensions();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
|
||||||
<Text>Browse Models</Text>
|
|
||||||
<Select
|
|
||||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
|
||||||
value={selectedOption}
|
|
||||||
options={modelOptions}
|
|
||||||
onChange={handleSelection}
|
|
||||||
/>
|
|
||||||
</VStack>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -17,7 +17,7 @@ import { FaRegClone } from "react-icons/fa";
|
|||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
|
||||||
|
|
||||||
export default function VariantHeaderMenuButton({
|
export default function VariantHeaderMenuButton({
|
||||||
variant,
|
variant,
|
||||||
@@ -50,7 +50,7 @@ export default function VariantHeaderMenuButton({
|
|||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [hideMutation, variant.id]);
|
}, [hideMutation, variant.id]);
|
||||||
|
|
||||||
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
|
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
|
||||||
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -72,7 +72,7 @@ export default function VariantHeaderMenuButton({
|
|||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||||
onClick={() => setSelectModelModalOpen(true)}
|
onClick={() => setChangeModelModalOpen(true)}
|
||||||
>
|
>
|
||||||
Change Model
|
Change Model
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
@@ -97,8 +97,8 @@ export default function VariantHeaderMenuButton({
|
|||||||
)}
|
)}
|
||||||
</MenuList>
|
</MenuList>
|
||||||
</Menu>
|
</Menu>
|
||||||
{selectModelModalOpen && (
|
{changeModelModalOpen && (
|
||||||
<SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} />
|
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
|
||||||
)}
|
)}
|
||||||
{refinePromptModalOpen && (
|
{refinePromptModalOpen && (
|
||||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
import { type SupportedModel } from "../server/types";
|
|
||||||
|
|
||||||
interface ModelStats {
|
|
||||||
contextLength: number;
|
|
||||||
promptTokenPrice: number;
|
|
||||||
completionTokenPrice: number;
|
|
||||||
speed: "fast" | "medium" | "slow";
|
|
||||||
provider: "OpenAI";
|
|
||||||
learnMoreUrl: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const modelStats: Record<SupportedModel, ModelStats> = {
|
|
||||||
"gpt-4": {
|
|
||||||
contextLength: 8192,
|
|
||||||
promptTokenPrice: 0.00003,
|
|
||||||
completionTokenPrice: 0.00006,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-4-0613": {
|
|
||||||
contextLength: 8192,
|
|
||||||
promptTokenPrice: 0.00003,
|
|
||||||
completionTokenPrice: 0.00006,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-4-32k": {
|
|
||||||
contextLength: 32768,
|
|
||||||
promptTokenPrice: 0.00006,
|
|
||||||
completionTokenPrice: 0.00012,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-4-32k-0613": {
|
|
||||||
contextLength: 32768,
|
|
||||||
promptTokenPrice: 0.00006,
|
|
||||||
completionTokenPrice: 0.00012,
|
|
||||||
speed: "medium",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://openai.com/gpt-4",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo": {
|
|
||||||
contextLength: 4096,
|
|
||||||
promptTokenPrice: 0.0000015,
|
|
||||||
completionTokenPrice: 0.000002,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-0613": {
|
|
||||||
contextLength: 4096,
|
|
||||||
promptTokenPrice: 0.0000015,
|
|
||||||
completionTokenPrice: 0.000002,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-16k": {
|
|
||||||
contextLength: 16384,
|
|
||||||
promptTokenPrice: 0.000003,
|
|
||||||
completionTokenPrice: 0.000004,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
"gpt-3.5-turbo-16k-0613": {
|
|
||||||
contextLength: 16384,
|
|
||||||
promptTokenPrice: 0.000003,
|
|
||||||
completionTokenPrice: 0.000004,
|
|
||||||
speed: "fast",
|
|
||||||
provider: "OpenAI",
|
|
||||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
|
||||||
},
|
|
||||||
};
|
|
||||||
@@ -9,19 +9,39 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletio
|
|||||||
models: {
|
models: {
|
||||||
"gpt-4-0613": {
|
"gpt-4-0613": {
|
||||||
name: "GPT-4",
|
name: "GPT-4",
|
||||||
learnMore: "https://openai.com/gpt-4",
|
contextWindow: 8192,
|
||||||
|
promptTokenPrice: 0.00003,
|
||||||
|
completionTokenPrice: 0.00006,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
},
|
},
|
||||||
"gpt-4-32k-0613": {
|
"gpt-4-32k-0613": {
|
||||||
name: "GPT-4 32k",
|
name: "GPT-4 32k",
|
||||||
learnMore: "https://openai.com/gpt-4",
|
contextWindow: 32768,
|
||||||
|
promptTokenPrice: 0.00006,
|
||||||
|
completionTokenPrice: 0.00012,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://openai.com/gpt-4",
|
||||||
},
|
},
|
||||||
"gpt-3.5-turbo-0613": {
|
"gpt-3.5-turbo-0613": {
|
||||||
name: "GPT-3.5 Turbo",
|
name: "GPT-3.5 Turbo",
|
||||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
contextWindow: 4096,
|
||||||
|
promptTokenPrice: 0.0000015,
|
||||||
|
completionTokenPrice: 0.000002,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
},
|
},
|
||||||
"gpt-3.5-turbo-16k-0613": {
|
"gpt-3.5-turbo-16k-0613": {
|
||||||
name: "GPT-3.5 Turbo 16k",
|
name: "GPT-3.5 Turbo 16k",
|
||||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
contextWindow: 16384,
|
||||||
|
promptTokenPrice: 0.000003,
|
||||||
|
completionTokenPrice: 0.000004,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "openai/ChatCompletion",
|
||||||
|
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import { countOpenAIChatTokens } from "~/utils/countTokens";
|
|||||||
import { type CompletionResponse } from "../types";
|
import { type CompletionResponse } from "../types";
|
||||||
import { omit } from "lodash-es";
|
import { omit } from "lodash-es";
|
||||||
import { openai } from "~/server/utils/openai";
|
import { openai } from "~/server/utils/openai";
|
||||||
import { type OpenAIChatModel } from "~/server/types";
|
|
||||||
import { truthyFilter } from "~/utils/utils";
|
import { truthyFilter } from "~/utils/utils";
|
||||||
import { APIError } from "openai";
|
import { APIError } from "openai";
|
||||||
import { modelStats } from "../modelStats";
|
import frontendModelProvider from "./frontend";
|
||||||
|
import modelProvider, { type SupportedModel } from ".";
|
||||||
|
|
||||||
const mergeStreamedChunks = (
|
const mergeStreamedChunks = (
|
||||||
base: ChatCompletion | null,
|
base: ChatCompletion | null,
|
||||||
@@ -60,6 +60,7 @@ export async function getCompletion(
|
|||||||
let finalCompletion: ChatCompletion | null = null;
|
let finalCompletion: ChatCompletion | null = null;
|
||||||
let promptTokens: number | undefined = undefined;
|
let promptTokens: number | undefined = undefined;
|
||||||
let completionTokens: number | undefined = undefined;
|
let completionTokens: number | undefined = undefined;
|
||||||
|
const modelName = modelProvider.getModel(input) as SupportedModel;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (onStream) {
|
if (onStream) {
|
||||||
@@ -81,12 +82,9 @@ export async function getCompletion(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
promptTokens = countOpenAIChatTokens(
|
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
||||||
input.model as keyof typeof OpenAIChatModel,
|
|
||||||
input.messages,
|
|
||||||
);
|
|
||||||
completionTokens = countOpenAIChatTokens(
|
completionTokens = countOpenAIChatTokens(
|
||||||
input.model as keyof typeof OpenAIChatModel,
|
modelName,
|
||||||
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -106,10 +104,10 @@ export async function getCompletion(
|
|||||||
}
|
}
|
||||||
const timeToComplete = Date.now() - start;
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
|
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
||||||
let cost = undefined;
|
let cost = undefined;
|
||||||
if (stats && promptTokens && completionTokens) {
|
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
||||||
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
|
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -5,9 +5,30 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlam
|
|||||||
name: "Replicate Llama2",
|
name: "Replicate Llama2",
|
||||||
|
|
||||||
models: {
|
models: {
|
||||||
"7b-chat": {},
|
"7b-chat": {
|
||||||
"13b-chat": {},
|
name: "LLama 2 7B Chat",
|
||||||
"70b-chat": {},
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0023,
|
||||||
|
speed: "fast",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
|
||||||
|
},
|
||||||
|
"13b-chat": {
|
||||||
|
name: "LLama 2 13B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0023,
|
||||||
|
speed: "medium",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
|
||||||
|
},
|
||||||
|
"70b-chat": {
|
||||||
|
name: "LLama 2 70B Chat",
|
||||||
|
contextWindow: 4096,
|
||||||
|
pricePerSecond: 0.0032,
|
||||||
|
speed: "slow",
|
||||||
|
provider: "replicate/llama2",
|
||||||
|
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
normalizeOutput: (output) => {
|
normalizeOutput: (output) => {
|
||||||
|
|||||||
@@ -1,16 +1,31 @@
|
|||||||
import { type JSONSchema4 } from "json-schema";
|
import { type JSONSchema4 } from "json-schema";
|
||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
|
const ZodSupportedProvider = z.union([
|
||||||
|
z.literal("openai/ChatCompletion"),
|
||||||
|
z.literal("replicate/llama2"),
|
||||||
|
]);
|
||||||
|
|
||||||
type ModelInfo = {
|
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||||
name?: string;
|
|
||||||
learnMore?: string;
|
export const ZodModel = z.object({
|
||||||
};
|
name: z.string(),
|
||||||
|
contextWindow: z.number(),
|
||||||
|
promptTokenPrice: z.number().optional(),
|
||||||
|
completionTokenPrice: z.number().optional(),
|
||||||
|
pricePerSecond: z.number().optional(),
|
||||||
|
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
|
||||||
|
provider: ZodSupportedProvider,
|
||||||
|
description: z.string().optional(),
|
||||||
|
learnMoreUrl: z.string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type Model = z.infer<typeof ZodModel>;
|
||||||
|
|
||||||
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||||
name: string;
|
name: string;
|
||||||
models: Record<SupportedModels, ModelInfo>;
|
models: Record<SupportedModels, Model>;
|
||||||
|
|
||||||
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import { z } from "zod";
|
|||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
import userError from "~/server/utils/error";
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||||
@@ -10,6 +9,7 @@ import { type PromptVariant } from "@prisma/client";
|
|||||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
import parseConstructFn from "~/server/utils/parseConstructFn";
|
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||||
|
import { ZodModel } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
@@ -144,7 +144,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
variantId: z.string().optional(),
|
variantId: z.string().optional(),
|
||||||
newModel: z.string().optional(),
|
newModel: ZodModel.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -186,10 +186,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
? `${originalVariant?.label} Copy`
|
? `${originalVariant?.label} Copy`
|
||||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
const newConstructFn = await deriveNewConstructFn(
|
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
|
||||||
originalVariant,
|
|
||||||
input.newModel as SupportedModel,
|
|
||||||
);
|
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -289,7 +286,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
instructions: z.string().optional(),
|
instructions: z.string().optional(),
|
||||||
newModel: z.string().optional(),
|
newModel: ZodModel.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -308,7 +305,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
const promptConstructionFn = await deriveNewConstructFn(
|
const promptConstructionFn = await deriveNewConstructFn(
|
||||||
existing,
|
existing,
|
||||||
input.newModel as SupportedModel | undefined,
|
input.newModel,
|
||||||
input.instructions,
|
input.instructions,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
export enum OpenAIChatModel {
|
|
||||||
"gpt-4" = "gpt-4",
|
|
||||||
"gpt-4-0613" = "gpt-4-0613",
|
|
||||||
"gpt-4-32k" = "gpt-4-32k",
|
|
||||||
"gpt-4-32k-0613" = "gpt-4-32k-0613",
|
|
||||||
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
|
||||||
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
|
||||||
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
|
|
||||||
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
|
||||||
}
|
|
||||||
|
|
||||||
export type SupportedModel = keyof typeof OpenAIChatModel;
|
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { type SupportedModel } from "../types";
|
|
||||||
import ivm from "isolated-vm";
|
import ivm from "isolated-vm";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
import { openai } from "./openai";
|
import { openai } from "./openai";
|
||||||
import { getApiShapeForModel } from "./getTypesForModel";
|
|
||||||
import { isObject } from "lodash-es";
|
import { isObject } from "lodash-es";
|
||||||
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
import { type CompletionCreateParams } from "openai/resources/chat/completions";
|
||||||
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||||
|
import { type SupportedProvider, type Model } from "~/modelProviders/types";
|
||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
|
||||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||||
|
|
||||||
export async function deriveNewConstructFn(
|
export async function deriveNewConstructFn(
|
||||||
originalVariant: PromptVariant | null,
|
originalVariant: PromptVariant | null,
|
||||||
newModel?: SupportedModel,
|
newModel?: Model,
|
||||||
instructions?: string,
|
instructions?: string,
|
||||||
) {
|
) {
|
||||||
if (originalVariant && !newModel && !instructions) {
|
if (originalVariant && !newModel && !instructions) {
|
||||||
@@ -36,10 +36,11 @@ export async function deriveNewConstructFn(
|
|||||||
const NUM_RETRIES = 5;
|
const NUM_RETRIES = 5;
|
||||||
const requestUpdatedPromptFunction = async (
|
const requestUpdatedPromptFunction = async (
|
||||||
originalVariant: PromptVariant,
|
originalVariant: PromptVariant,
|
||||||
newModel?: SupportedModel,
|
newModel?: Model,
|
||||||
instructions?: string,
|
instructions?: string,
|
||||||
) => {
|
) => {
|
||||||
const originalModel = originalVariant.model as SupportedModel;
|
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
|
||||||
|
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
|
||||||
let newContructionFn = "";
|
let newContructionFn = "";
|
||||||
for (let i = 0; i < NUM_RETRIES; i++) {
|
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||||
try {
|
try {
|
||||||
@@ -47,7 +48,7 @@ const requestUpdatedPromptFunction = async (
|
|||||||
{
|
{
|
||||||
role: "system",
|
role: "system",
|
||||||
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
||||||
getApiShapeForModel(originalModel),
|
originalModelProvider.inputSchema,
|
||||||
null,
|
null,
|
||||||
2,
|
2,
|
||||||
)}\n\nDo not add any assistant messages.`,
|
)}\n\nDo not add any assistant messages.`,
|
||||||
@@ -60,8 +61,20 @@ const requestUpdatedPromptFunction = async (
|
|||||||
if (newModel) {
|
if (newModel) {
|
||||||
messages.push({
|
messages.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content: `Return the prompt constructor function for ${newModel} given the existing prompt constructor function for ${originalModel}`,
|
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
|
||||||
});
|
});
|
||||||
|
if (newModel.provider !== originalModel.provider) {
|
||||||
|
messages.push({
|
||||||
|
role: "user",
|
||||||
|
content: `The old provider was ${originalModel.provider}. The new provider is ${
|
||||||
|
newModel.provider
|
||||||
|
}. Here is the schema for the new model:\n---\n${JSON.stringify(
|
||||||
|
modelProviders[newModel.provider].inputSchema,
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
)}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (instructions) {
|
if (instructions) {
|
||||||
messages.push({
|
messages.push({
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
import { type SupportedModel } from "../types";
|
|
||||||
|
|
||||||
export const getApiShapeForModel = (model: SupportedModel) => {
|
|
||||||
// if (model in OpenAIChatModel) return openAIChatApiShape;
|
|
||||||
return "";
|
|
||||||
};
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { GPTTokens } from "gpt-tokens";
|
import { GPTTokens } from "gpt-tokens";
|
||||||
import { type OpenAIChatModel } from "~/server/types";
|
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
|
||||||
|
|
||||||
interface GPTTokensMessageItem {
|
interface GPTTokensMessageItem {
|
||||||
name?: string;
|
name?: string;
|
||||||
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const countOpenAIChatTokens = (
|
export const countOpenAIChatTokens = (
|
||||||
model: keyof typeof OpenAIChatModel,
|
model: SupportedModel,
|
||||||
messages: ChatCompletion.Choice.Message[],
|
messages: ChatCompletion.Choice.Message[],
|
||||||
) => {
|
) => {
|
||||||
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
||||||
|
|||||||
@@ -1 +1,5 @@
|
|||||||
|
import { type Model } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
||||||
|
|
||||||
|
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;
|
||||||
|
|||||||
Reference in New Issue
Block a user