Only pass in model and provider
I got somewhat confused by the extra fields, sorry. Also makes some frontend changes to track that state more directly although in retrospect not sure the frontend changes make things any better.
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
import {
|
import {
|
||||||
Button,
|
Button,
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
Modal,
|
Modal,
|
||||||
ModalBody,
|
ModalBody,
|
||||||
ModalCloseButton,
|
ModalCloseButton,
|
||||||
@@ -7,24 +9,21 @@ import {
|
|||||||
ModalFooter,
|
ModalFooter,
|
||||||
ModalHeader,
|
ModalHeader,
|
||||||
ModalOverlay,
|
ModalOverlay,
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
Spinner,
|
Spinner,
|
||||||
HStack,
|
Text,
|
||||||
Icon,
|
VStack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
|
||||||
import { useState } from "react";
|
|
||||||
import { ModelStatsCard } from "./ModelStatsCard";
|
|
||||||
import { ModelSearch } from "./ModelSearch";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
|
||||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
import { type Model, type SupportedProvider } from "~/modelProviders/types";
|
import { useState } from "react";
|
||||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { keyForModel } from "~/utils/utils";
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { lookupModel, modelLabel } from "~/utils/utils";
|
||||||
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
|
import { ModelSearch } from "./ModelSearch";
|
||||||
|
import { ModelStatsCard } from "./ModelStatsCard";
|
||||||
|
|
||||||
export const ChangeModelModal = ({
|
export const ChangeModelModal = ({
|
||||||
variant,
|
variant,
|
||||||
@@ -33,11 +32,13 @@ export const ChangeModelModal = ({
|
|||||||
variant: PromptVariant;
|
variant: PromptVariant;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
const originalModelProviderName = variant.modelProvider as SupportedProvider;
|
const originalModel = lookupModel(variant.modelProvider, variant.model);
|
||||||
const originalModelProvider = frontendModelProviders[originalModelProviderName];
|
const [selectedModel, setSelectedModel] = useState({
|
||||||
const originalModel = originalModelProvider.models[variant.model] as Model;
|
provider: variant.modelProvider,
|
||||||
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
|
model: variant.model,
|
||||||
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
|
} as ProviderModel);
|
||||||
|
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
||||||
|
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
@@ -72,9 +73,10 @@ export const ChangeModelModal = ({
|
|||||||
onClose();
|
onClose();
|
||||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||||
|
|
||||||
const originalModelLabel = keyForModel(originalModel);
|
const originalLabel = modelLabel(variant.modelProvider, variant.model);
|
||||||
const selectedModelLabel = keyForModel(selectedModel);
|
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
|
||||||
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
|
const convertedLabel =
|
||||||
|
convertedModel && modelLabel(convertedModel.provider, convertedModel.model);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
@@ -94,16 +96,19 @@ export const ChangeModelModal = ({
|
|||||||
<ModalBody maxW="unset">
|
<ModalBody maxW="unset">
|
||||||
<VStack spacing={8}>
|
<VStack spacing={8}>
|
||||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||||
{originalModelLabel !== selectedModelLabel && (
|
{originalLabel !== selectedLabel && (
|
||||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
<ModelStatsCard
|
||||||
|
label="New Model"
|
||||||
|
model={lookupModel(selectedModel.provider, selectedModel.model)}
|
||||||
|
/>
|
||||||
)}
|
)}
|
||||||
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
{isString(modifiedPromptFn) && (
|
{isString(modifiedPromptFn) && (
|
||||||
<CompareFunctions
|
<CompareFunctions
|
||||||
originalFunction={variant.constructFn}
|
originalFunction={variant.constructFn}
|
||||||
newFunction={modifiedPromptFn}
|
newFunction={modifiedPromptFn}
|
||||||
leftTitle={originalModelLabel}
|
leftTitle={originalLabel}
|
||||||
rightTitle={convertedModelLabel}
|
rightTitle={convertedLabel}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
@@ -115,7 +120,7 @@ export const ChangeModelModal = ({
|
|||||||
colorScheme="gray"
|
colorScheme="gray"
|
||||||
onClick={getModifiedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
minW={24}
|
minW={24}
|
||||||
isDisabled={originalModel === selectedModel || modificationInProgress}
|
isDisabled={originalLabel === selectedLabel || modificationInProgress}
|
||||||
>
|
>
|
||||||
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -1,49 +1,35 @@
|
|||||||
import { VStack, Text } from "@chakra-ui/react";
|
import { Text, VStack } from "@chakra-ui/react";
|
||||||
import { type LegacyRef, useCallback } from "react";
|
import { type LegacyRef } from "react";
|
||||||
import Select, { type SingleValue } from "react-select";
|
import Select from "react-select";
|
||||||
import { useElementDimensions } from "~/utils/hooks";
|
import { useElementDimensions } from "~/utils/hooks";
|
||||||
|
|
||||||
|
import { flatMap } from "lodash-es";
|
||||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
import { type Model } from "~/modelProviders/types";
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
import { keyForModel } from "~/utils/utils";
|
import { modelLabel } from "~/utils/utils";
|
||||||
|
|
||||||
const modelOptions: { label: string; value: Model }[] = [];
|
const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) =>
|
||||||
|
Object.entries(provider.models).map(([modelId]) => ({
|
||||||
|
provider: providerId,
|
||||||
|
model: modelId,
|
||||||
|
})),
|
||||||
|
) as ProviderModel[];
|
||||||
|
|
||||||
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
|
export const ModelSearch = (props: {
|
||||||
for (const [_, modelValue] of Object.entries(providerValue.models)) {
|
selectedModel: ProviderModel;
|
||||||
modelOptions.push({
|
setSelectedModel: (model: ProviderModel) => void;
|
||||||
label: keyForModel(modelValue),
|
|
||||||
value: modelValue,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const ModelSearch = ({
|
|
||||||
selectedModel,
|
|
||||||
setSelectedModel,
|
|
||||||
}: {
|
|
||||||
selectedModel: Model;
|
|
||||||
setSelectedModel: (model: Model) => void;
|
|
||||||
}) => {
|
}) => {
|
||||||
const handleSelection = useCallback(
|
|
||||||
(option: SingleValue<{ label: string; value: Model }>) => {
|
|
||||||
if (!option) return;
|
|
||||||
setSelectedModel(option.value);
|
|
||||||
},
|
|
||||||
[setSelectedModel],
|
|
||||||
);
|
|
||||||
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
|
|
||||||
|
|
||||||
const [containerRef, containerDimensions] = useElementDimensions();
|
const [containerRef, containerDimensions] = useElementDimensions();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
|
||||||
<Text>Browse Models</Text>
|
<Text>Browse Models</Text>
|
||||||
<Select
|
<Select<ProviderModel>
|
||||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||||
value={selectedOption}
|
getOptionLabel={(data) => modelLabel(data.provider, data.model)}
|
||||||
|
getOptionValue={(data) => modelLabel(data.provider, data.model)}
|
||||||
options={modelOptions}
|
options={modelOptions}
|
||||||
onChange={handleSelection}
|
onChange={(option) => option && props.setSelectedModel(option)}
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
import {
|
import {
|
||||||
VStack,
|
|
||||||
Text,
|
|
||||||
HStack,
|
|
||||||
type StackProps,
|
|
||||||
GridItem,
|
GridItem,
|
||||||
SimpleGrid,
|
HStack,
|
||||||
Link,
|
Link,
|
||||||
|
SimpleGrid,
|
||||||
|
Text,
|
||||||
|
VStack,
|
||||||
|
type StackProps,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { type Model } from "~/modelProviders/types";
|
import { type lookupModel } from "~/utils/utils";
|
||||||
|
|
||||||
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
|
export const ModelStatsCard = ({
|
||||||
|
label,
|
||||||
|
model,
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
model: ReturnType<typeof lookupModel>;
|
||||||
|
}) => {
|
||||||
|
if (!model) return null;
|
||||||
return (
|
return (
|
||||||
<VStack w="full" align="start">
|
<VStack w="full" align="start">
|
||||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||||
|
|||||||
@@ -3,26 +3,26 @@ import { type IconType } from "react-icons";
|
|||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
const ZodSupportedProvider = z.union([
|
export const ZodSupportedProvider = z.union([
|
||||||
z.literal("openai/ChatCompletion"),
|
z.literal("openai/ChatCompletion"),
|
||||||
z.literal("replicate/llama2"),
|
z.literal("replicate/llama2"),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||||
|
|
||||||
export const ZodModel = z.object({
|
export type Model = {
|
||||||
name: z.string(),
|
name: string;
|
||||||
contextWindow: z.number(),
|
contextWindow: number;
|
||||||
promptTokenPrice: z.number().optional(),
|
promptTokenPrice?: number;
|
||||||
completionTokenPrice: z.number().optional(),
|
completionTokenPrice?: number;
|
||||||
pricePerSecond: z.number().optional(),
|
pricePerSecond?: number;
|
||||||
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
|
speed: "fast" | "medium" | "slow";
|
||||||
provider: ZodSupportedProvider,
|
provider: SupportedProvider;
|
||||||
description: z.string().optional(),
|
description?: string;
|
||||||
learnMoreUrl: z.string().optional(),
|
learnMoreUrl?: string;
|
||||||
});
|
};
|
||||||
|
|
||||||
export type Model = z.infer<typeof ZodModel>;
|
export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };
|
||||||
|
|
||||||
export type RefinementAction = { icon?: IconType; description: string; instructions: string };
|
export type RefinementAction = { icon?: IconType; description: string; instructions: string };
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import { type PromptVariant } from "@prisma/client";
|
|||||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
import parseConstructFn from "~/server/utils/parseConstructFn";
|
import parseConstructFn from "~/server/utils/parseConstructFn";
|
||||||
import { ZodModel } from "~/modelProviders/types";
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { ZodSupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure
|
list: publicProcedure
|
||||||
@@ -144,7 +145,6 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
variantId: z.string().optional(),
|
variantId: z.string().optional(),
|
||||||
newModel: ZodModel.optional(),
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -186,7 +186,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
? `${originalVariant?.label} Copy`
|
? `${originalVariant?.label} Copy`
|
||||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
|
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -286,7 +286,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
instructions: z.string().optional(),
|
instructions: z.string().optional(),
|
||||||
newModel: ZodModel.optional(),
|
newModel: z
|
||||||
|
.object({
|
||||||
|
provider: ZodSupportedProvider,
|
||||||
|
model: z.string(),
|
||||||
|
})
|
||||||
|
.optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -303,11 +308,11 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return userError(constructedPrompt.error);
|
return userError(constructedPrompt.error);
|
||||||
}
|
}
|
||||||
|
|
||||||
const promptConstructionFn = await deriveNewConstructFn(
|
const model = input.newModel
|
||||||
existing,
|
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||||
input.newModel,
|
: undefined;
|
||||||
input.instructions,
|
|
||||||
);
|
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||||
|
|
||||||
// TODO: Validate promptConstructionFn
|
// TODO: Validate promptConstructionFn
|
||||||
// TODO: Record in some sort of history
|
// TODO: Record in some sort of history
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import { type Model } from "~/modelProviders/types";
|
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||||
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
|
||||||
|
|
||||||
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;
|
export const lookupModel = (provider: string, model: string) => {
|
||||||
|
const modelObj = frontendModelProviders[provider as ProviderModel["provider"]]?.models[model];
|
||||||
|
return modelObj ? { ...modelObj, provider } : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelLabel = (provider: string, model: string) =>
|
||||||
|
`${provider}/${lookupModel(provider, model)?.name ?? model}`;
|
||||||
|
|||||||
Reference in New Issue
Block a user