Only pass in model and provider

I got somewhat confused by the extra fields, sorry.

Also makes some frontend changes to track that state more directly although in retrospect not sure the frontend changes make things any better.
This commit is contained in:
Kyle Corbitt
2023-07-24 17:20:38 -07:00
parent e0b457c6c5
commit 9952dd93d8
6 changed files with 101 additions and 91 deletions

View File

@@ -1,5 +1,7 @@
import { import {
Button, Button,
HStack,
Icon,
Modal, Modal,
ModalBody, ModalBody,
ModalCloseButton, ModalCloseButton,
@@ -7,24 +9,21 @@ import {
ModalFooter, ModalFooter,
ModalHeader, ModalHeader,
ModalOverlay, ModalOverlay,
VStack,
Text,
Spinner, Spinner,
HStack, Text,
Icon, VStack,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { RiExchangeFundsFill } from "react-icons/ri";
import { useState } from "react";
import { ModelStatsCard } from "./ModelStatsCard";
import { ModelSearch } from "./ModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { type PromptVariant } from "@prisma/client"; import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es"; import { isObject, isString } from "lodash-es";
import { type Model, type SupportedProvider } from "~/modelProviders/types"; import { useState } from "react";
import frontendModelProviders from "~/modelProviders/frontendModelProviders"; import { RiExchangeFundsFill } from "react-icons/ri";
import { keyForModel } from "~/utils/utils"; import { type ProviderModel } from "~/modelProviders/types";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { lookupModel, modelLabel } from "~/utils/utils";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { ModelSearch } from "./ModelSearch";
import { ModelStatsCard } from "./ModelStatsCard";
export const ChangeModelModal = ({ export const ChangeModelModal = ({
variant, variant,
@@ -33,11 +32,13 @@ export const ChangeModelModal = ({
variant: PromptVariant; variant: PromptVariant;
onClose: () => void; onClose: () => void;
}) => { }) => {
const originalModelProviderName = variant.modelProvider as SupportedProvider; const originalModel = lookupModel(variant.modelProvider, variant.model);
const originalModelProvider = frontendModelProviders[originalModelProviderName]; const [selectedModel, setSelectedModel] = useState({
const originalModel = originalModelProvider.models[variant.model] as Model; provider: variant.modelProvider,
const [selectedModel, setSelectedModel] = useState<Model>(originalModel); model: variant.model,
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined); } as ProviderModel);
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
const utils = api.useContext(); const utils = api.useContext();
const experiment = useExperiment(); const experiment = useExperiment();
@@ -72,9 +73,10 @@ export const ChangeModelModal = ({
onClose(); onClose();
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]); }, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
const originalModelLabel = keyForModel(originalModel); const originalLabel = modelLabel(variant.modelProvider, variant.model);
const selectedModelLabel = keyForModel(selectedModel); const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined; const convertedLabel =
convertedModel && modelLabel(convertedModel.provider, convertedModel.model);
return ( return (
<Modal <Modal
@@ -94,16 +96,19 @@ export const ChangeModelModal = ({
<ModalBody maxW="unset"> <ModalBody maxW="unset">
<VStack spacing={8}> <VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} /> <ModelStatsCard label="Original Model" model={originalModel} />
{originalModelLabel !== selectedModelLabel && ( {originalLabel !== selectedLabel && (
<ModelStatsCard label="New Model" model={selectedModel} /> <ModelStatsCard
label="New Model"
model={lookupModel(selectedModel.provider, selectedModel.model)}
/>
)} )}
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} /> <ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && ( {isString(modifiedPromptFn) && (
<CompareFunctions <CompareFunctions
originalFunction={variant.constructFn} originalFunction={variant.constructFn}
newFunction={modifiedPromptFn} newFunction={modifiedPromptFn}
leftTitle={originalModelLabel} leftTitle={originalLabel}
rightTitle={convertedModelLabel} rightTitle={convertedLabel}
/> />
)} )}
</VStack> </VStack>
@@ -115,7 +120,7 @@ export const ChangeModelModal = ({
colorScheme="gray" colorScheme="gray"
onClick={getModifiedPromptFn} onClick={getModifiedPromptFn}
minW={24} minW={24}
isDisabled={originalModel === selectedModel || modificationInProgress} isDisabled={originalLabel === selectedLabel || modificationInProgress}
> >
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>} {modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
</Button> </Button>

View File

@@ -1,49 +1,35 @@
import { VStack, Text } from "@chakra-ui/react"; import { Text, VStack } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react"; import { type LegacyRef } from "react";
import Select, { type SingleValue } from "react-select"; import Select from "react-select";
import { useElementDimensions } from "~/utils/hooks"; import { useElementDimensions } from "~/utils/hooks";
import { flatMap } from "lodash-es";
import frontendModelProviders from "~/modelProviders/frontendModelProviders"; import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { type Model } from "~/modelProviders/types"; import { type ProviderModel } from "~/modelProviders/types";
import { keyForModel } from "~/utils/utils"; import { modelLabel } from "~/utils/utils";
const modelOptions: { label: string; value: Model }[] = []; const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) =>
Object.entries(provider.models).map(([modelId]) => ({
provider: providerId,
model: modelId,
})),
) as ProviderModel[];
for (const [_, providerValue] of Object.entries(frontendModelProviders)) { export const ModelSearch = (props: {
for (const [_, modelValue] of Object.entries(providerValue.models)) { selectedModel: ProviderModel;
modelOptions.push({ setSelectedModel: (model: ProviderModel) => void;
label: keyForModel(modelValue),
value: modelValue,
});
}
}
export const ModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: Model;
setSelectedModel: (model: Model) => void;
}) => { }) => {
const handleSelection = useCallback(
(option: SingleValue<{ label: string; value: Model }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
const [containerRef, containerDimensions] = useElementDimensions(); const [containerRef, containerDimensions] = useElementDimensions();
return ( return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full"> <VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text> <Text>Browse Models</Text>
<Select <Select<ProviderModel>
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }} styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption} getOptionLabel={(data) => modelLabel(data.provider, data.model)}
getOptionValue={(data) => modelLabel(data.provider, data.model)}
options={modelOptions} options={modelOptions}
onChange={handleSelection} onChange={(option) => option && props.setSelectedModel(option)}
/> />
</VStack> </VStack>
); );

View File

@@ -1,15 +1,22 @@
import { import {
VStack,
Text,
HStack,
type StackProps,
GridItem, GridItem,
SimpleGrid, HStack,
Link, Link,
SimpleGrid,
Text,
VStack,
type StackProps,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { type Model } from "~/modelProviders/types"; import { type lookupModel } from "~/utils/utils";
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => { export const ModelStatsCard = ({
label,
model,
}: {
label: string;
model: ReturnType<typeof lookupModel>;
}) => {
if (!model) return null;
return ( return (
<VStack w="full" align="start"> <VStack w="full" align="start">
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase"> <Text fontWeight="bold" fontSize="sm" textTransform="uppercase">

View File

@@ -3,26 +3,26 @@ import { type IconType } from "react-icons";
import { type JsonValue } from "type-fest"; import { type JsonValue } from "type-fest";
import { z } from "zod"; import { z } from "zod";
const ZodSupportedProvider = z.union([ export const ZodSupportedProvider = z.union([
z.literal("openai/ChatCompletion"), z.literal("openai/ChatCompletion"),
z.literal("replicate/llama2"), z.literal("replicate/llama2"),
]); ]);
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>; export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
export const ZodModel = z.object({ export type Model = {
name: z.string(), name: string;
contextWindow: z.number(), contextWindow: number;
promptTokenPrice: z.number().optional(), promptTokenPrice?: number;
completionTokenPrice: z.number().optional(), completionTokenPrice?: number;
pricePerSecond: z.number().optional(), pricePerSecond?: number;
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]), speed: "fast" | "medium" | "slow";
provider: ZodSupportedProvider, provider: SupportedProvider;
description: z.string().optional(), description?: string;
learnMoreUrl: z.string().optional(), learnMoreUrl?: string;
}); };
export type Model = z.infer<typeof ZodModel>; export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };
export type RefinementAction = { icon?: IconType; description: string; instructions: string }; export type RefinementAction = { icon?: IconType; description: string; instructions: string };

View File

@@ -9,7 +9,8 @@ import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn"; import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
import parseConstructFn from "~/server/utils/parseConstructFn"; import parseConstructFn from "~/server/utils/parseConstructFn";
import { ZodModel } from "~/modelProviders/types"; import modelProviders from "~/modelProviders/modelProviders";
import { ZodSupportedProvider } from "~/modelProviders/types";
export const promptVariantsRouter = createTRPCRouter({ export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure list: publicProcedure
@@ -144,7 +145,6 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
variantId: z.string().optional(), variantId: z.string().optional(),
newModel: ZodModel.optional(),
}), }),
) )
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
@@ -186,7 +186,7 @@ export const promptVariantsRouter = createTRPCRouter({
? `${originalVariant?.label} Copy` ? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`; : `Prompt Variant ${largestSortIndex + 2}`;
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel); const newConstructFn = await deriveNewConstructFn(originalVariant);
const createNewVariantAction = prisma.promptVariant.create({ const createNewVariantAction = prisma.promptVariant.create({
data: { data: {
@@ -286,7 +286,12 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({ z.object({
id: z.string(), id: z.string(),
instructions: z.string().optional(), instructions: z.string().optional(),
newModel: ZodModel.optional(), newModel: z
.object({
provider: ZodSupportedProvider,
model: z.string(),
})
.optional(),
}), }),
) )
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
@@ -303,11 +308,11 @@ export const promptVariantsRouter = createTRPCRouter({
return userError(constructedPrompt.error); return userError(constructedPrompt.error);
} }
const promptConstructionFn = await deriveNewConstructFn( const model = input.newModel
existing, ? modelProviders[input.newModel.provider].models[input.newModel.model]
input.newModel, : undefined;
input.instructions,
); const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
// TODO: Validate promptConstructionFn // TODO: Validate promptConstructionFn
// TODO: Record in some sort of history // TODO: Record in some sort of history

View File

@@ -1,5 +1,12 @@
import { type Model } from "~/modelProviders/types"; import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { type ProviderModel } from "~/modelProviders/types";
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x); export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`; export const lookupModel = (provider: string, model: string) => {
const modelObj = frontendModelProviders[provider as ProviderModel["provider"]]?.models[model];
return modelObj ? { ...modelObj, provider } : null;
};
export const modelLabel = (provider: string, model: string) =>
`${provider}/${lookupModel(provider, model)?.name ?? model}`;