Compare commits

...

8 Commits

Author SHA1 Message Date
Kyle Corbitt
cc1d1178da Fullscreen editor 2023-07-21 22:19:38 -07:00
David Corbitt
7466db63df Make REPLICATE_API_TOKEN optional 2023-07-21 20:23:38 -07:00
David Corbitt
79a0b03bf8 Add another function call example 2023-07-21 20:16:36 -07:00
arcticfly
6fb7a82d72 Add support for switching to Llama models (#80)
* Add support for switching to Llama models

* Fix prettier
2023-07-21 20:10:59 -07:00
Kyle Corbitt
4ea30a3ba3 Merge pull request #79 from OpenPipe/copy-evals
Copy over evals when new cell created
2023-07-21 18:43:44 -07:00
Kyle Corbitt
52d1d5c7ee Copy over evals when new cell created
Fixes a bug where new cells generated as clones of existing cells didn't get the eval results cloned as well.
2023-07-21 18:40:40 -07:00
Kyle Corbitt
46036a44d2 small README update 2023-07-21 14:32:07 -07:00
Kyle Corbitt
3753fe5c16 Merge pull request #78 from OpenPipe/bugfix-max-tokens
Fix typescript hints for max_tokens
2023-07-21 12:10:00 -07:00
23 changed files with 444 additions and 312 deletions

View File

@@ -43,7 +43,8 @@ Natively supports [OpenAI function calls](https://openai.com/blog/function-calli
## Supported Models
OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
## Running Locally

View File

@@ -15,25 +15,29 @@ import {
} from "@chakra-ui/react";
import { RiExchangeFundsFill } from "react-icons/ri";
import { useState } from "react";
import { type SupportedModel } from "~/server/types";
import { ModelStatsCard } from "./ModelStatsCard";
import { SelectModelSearch } from "./SelectModelSearch";
import { ModelSearch } from "./ModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es";
import { type Model, type SupportedProvider } from "~/modelProviders/types";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { keyForModel } from "~/utils/utils";
export const SelectModelModal = ({
export const ChangeModelModal = ({
variant,
onClose,
}: {
variant: PromptVariant;
onClose: () => void;
}) => {
const originalModel = variant.model as SupportedModel;
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined);
const originalModelProviderName = variant.modelProvider as SupportedProvider;
const originalModelProvider = frontendModelProviders[originalModelProviderName];
const originalModel = originalModelProvider.models[variant.model] as Model;
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
const utils = api.useContext();
const experiment = useExperiment();
@@ -68,6 +72,10 @@ export const SelectModelModal = ({
onClose();
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
const originalModelLabel = keyForModel(originalModel);
const selectedModelLabel = keyForModel(selectedModel);
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
return (
<Modal
isOpen
@@ -86,16 +94,16 @@ export const SelectModelModal = ({
<ModalBody maxW="unset">
<VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} />
{originalModel !== selectedModel && (
{originalModelLabel !== selectedModelLabel && (
<ModelStatsCard label="New Model" model={selectedModel} />
)}
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && (
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={modifiedPromptFn}
leftTitle={originalModel}
rightTitle={convertedModel}
leftTitle={originalModelLabel}
rightTitle={convertedModelLabel}
/>
)}
</VStack>

View File

@@ -0,0 +1,50 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { useElementDimensions } from "~/utils/hooks";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { type Model } from "~/modelProviders/types";
import { keyForModel } from "~/utils/utils";
const modelOptions: { label: string; value: Model }[] = [];
for (const [_, providerValue] of Object.entries(frontendModelProviders)) {
for (const [_, modelValue] of Object.entries(providerValue.models)) {
modelOptions.push({
label: keyForModel(modelValue),
value: modelValue,
});
}
}
export const ModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: Model;
setSelectedModel: (model: Model) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ label: string; value: Model }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel));
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -7,11 +7,9 @@ import {
SimpleGrid,
Link,
} from "@chakra-ui/react";
import { modelStats } from "~/modelProviders/modelStats";
import { type SupportedModel } from "~/server/types";
import { type Model } from "~/modelProviders/types";
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
const stats = modelStats[model];
export const ModelStatsCard = ({ label, model }: { label: string; model: Model }) => {
return (
<VStack w="full" align="start">
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
@@ -22,14 +20,14 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
<HStack w="full" align="flex-start">
<Text flex={1} fontSize="lg">
<Text as="span" color="gray.600">
{stats.provider} /{" "}
{model.provider} /{" "}
</Text>
<Text as="span" fontWeight="bold" color="gray.900">
{model}
{model.name}
</Text>
</Text>
<Link
href={stats.learnMoreUrl}
href={model.learnMoreUrl}
isExternal
color="blue.500"
fontWeight="bold"
@@ -46,26 +44,41 @@ export const ModelStatsCard = ({ label, model }: { label: string; model: Support
fontSize="sm"
columns={{ base: 2, md: 4 }}
>
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
{model.promptTokenPrice && (
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(model.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
)}
{model.completionTokenPrice && (
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(model.completionTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
)}
{model.pricePerSecond && (
<SelectedModelLabeledInfo
label="Price"
info={
<Text>
${model.pricePerSecond.toFixed(3)}
<Text color="gray.500"> / second</Text>
</Text>
}
/>
)}
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
</SimpleGrid>
</VStack>
</VStack>

View File

@@ -1,17 +1,47 @@
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import {
Box,
Button,
HStack,
Spinner,
Tooltip,
useToast,
Text,
IconButton,
} from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react";
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import { FiMaximize, FiMinimize } from "react-icons/fi";
import { editorBackground } from "~/state/sharedVariantEditor.slice";
export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const containerRef = useRef<HTMLDivElement | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
const [isChanged, setIsChanged] = useState(false);
const [isFullscreen, setIsFullscreen] = useState(false);
const toggleFullscreen = useCallback(() => {
setIsFullscreen((prev) => !prev);
editorRef.current?.focus();
}, [setIsFullscreen]);
useEffect(() => {
const handleEsc = (event: KeyboardEvent) => {
if (event.key === "Escape" && isFullscreen) {
toggleFullscreen();
}
};
window.addEventListener("keydown", handleEsc);
return () => window.removeEventListener("keydown", handleEsc);
}, [isFullscreen, toggleFullscreen]);
const lastSavedFn = props.variant.constructFn;
const modifierKey = useModifierKeyLabel();
@@ -99,11 +129,23 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
readOnly: !canModify,
});
// Workaround because otherwise the commands only work on whatever
// editor was loaded on the page last.
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
editorRef.current.onDidFocusEditorText(() => {
// Workaround because otherwise the command only works on whatever
// editor was loaded on the page last.
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave);
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, onSave);
editorRef.current?.addCommand(
monaco.KeyMod.CtrlCmd | monaco.KeyMod.Shift | monaco.KeyCode.KeyF,
toggleFullscreen,
);
// Exit fullscreen with escape
editorRef.current?.addCommand(monaco.KeyCode.Escape, () => {
if (isFullscreen) {
toggleFullscreen();
}
});
});
editorRef.current.onDidChangeModelContent(checkForChanges);
@@ -132,8 +174,40 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
}, [canModify]);
return (
<Box w="100%" pos="relative">
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>
<Box
w="100%"
ref={containerRef}
sx={
isFullscreen
? {
position: "fixed",
top: 0,
left: 0,
right: 0,
bottom: 0,
}
: { h: "400px", w: "100%" }
}
bgColor={editorBackground}
zIndex={isFullscreen ? 1000 : "unset"}
pos="relative"
_hover={{ ".fullscreen-toggle": { opacity: 1 } }}
>
<Box id={editorId} w="100%" h="100%" />
<Tooltip label={`${modifierKey} + ⇧ + F`}>
<IconButton
className="fullscreen-toggle"
aria-label="Minimize"
icon={isFullscreen ? <FiMinimize /> : <FiMaximize />}
position="absolute"
top={2}
right={2}
onClick={toggleFullscreen}
opacity={0}
transition="opacity 0.2s"
/>
</Tooltip>
{isChanged && (
<HStack pos="absolute" bottom={2} right={2}>
<Button
@@ -146,7 +220,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
>
Reset
</Button>
<Tooltip label={`${modifierKey} + Enter`}>
<Tooltip label={`${modifierKey} + S`}>
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
</Button>

View File

@@ -1,22 +1,22 @@
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
import { type IconType } from "react-icons";
import { refineOptions, type RefineOptionLabel } from "./refineOptions";
export const RefineOption = ({
label,
activeLabel,
icon,
desciption,
activeLabel,
onClick,
loading,
}: {
label: RefineOptionLabel;
activeLabel: RefineOptionLabel | undefined;
label: string;
icon: IconType;
onClick: (label: RefineOptionLabel) => void;
desciption: string;
activeLabel: string | undefined;
onClick: (label: string) => void;
loading: boolean;
}) => {
const isActive = activeLabel === label;
const desciption = refineOptions[label].description;
return (
<GridItem w="80" h="44">

View File

@@ -15,17 +15,16 @@ import {
SimpleGrid,
} from "@chakra-ui/react";
import { BsStars } from "react-icons/bs";
import { VscJson } from "react-icons/vsc";
import { TfiThought } from "react-icons/tfi";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import CompareFunctions from "./CompareFunctions";
import { CustomInstructionsInput } from "./CustomInstructionsInput";
import { type RefineOptionLabel, refineOptions } from "./refineOptions";
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
import { RefineOption } from "./RefineOption";
import { isObject, isString } from "lodash-es";
import { type SupportedProvider } from "~/modelProviders/types";
export const RefinePromptModal = ({
variant,
@@ -36,18 +35,22 @@ export const RefinePromptModal = ({
}) => {
const utils = api.useContext();
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
RefineOptionLabel | undefined
>(undefined);
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
undefined,
);
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
async (label?: RefineOptionLabel) => {
async (label?: string) => {
if (!variant.experimentId) return;
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
const updatedInstructions = label
? (providerRefineOptions[label] as RefineOptionInfo).instructions
: instructions;
setActiveRefineOptionLabel(label);
await getModifiedPromptMutateAsync({
id: variant.id,
@@ -92,25 +95,26 @@ export const RefinePromptModal = ({
<ModalBody maxW="unset">
<VStack spacing={8}>
<VStack spacing={4}>
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
<RefineOption
label="Convert to function call"
activeLabel={activeRefineOptionLabel}
icon={VscJson}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>
<RefineOption
label="Add chain of thought"
activeLabel={activeRefineOptionLabel}
icon={TfiThought}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>
</SimpleGrid>
<HStack>
<Text color="gray.500">or</Text>
</HStack>
{Object.keys(providerRefineOptions).length && (
<>
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
{Object.keys(providerRefineOptions).map((label) => (
<RefineOption
key={label}
label={label}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
icon={providerRefineOptions[label]!.icon}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
desciption={providerRefineOptions[label]!.description}
activeLabel={activeRefineOptionLabel}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>
))}
</SimpleGrid>
<Text color="gray.500">or</Text>
</>
)}
<CustomInstructionsInput
instructions={instructions}
setInstructions={setInstructions}

View File

@@ -1,17 +1,21 @@
// Super hacky, but we'll redo the organization when we have more models
export type RefineOptionLabel = "Add chain of thought" | "Convert to function call";
import { type SupportedProvider } from "~/modelProviders/types";
import { VscJson } from "react-icons/vsc";
import { TfiThought } from "react-icons/tfi";
import { type IconType } from "react-icons";
export type RefineOptionInfo = { icon: IconType; description: string; instructions: string };
export const refineOptions: Record<SupportedProvider, { [key: string]: RefineOptionInfo }> = {
"openai/ChatCompletion": {
"Add chain of thought": {
icon: VscJson,
description: "Asking the model to plan its answer can increase accuracy.",
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
export const refineOptions: Record<
RefineOptionLabel,
{ description: string; instructions: string }
> = {
"Add chain of thought": {
description: "Asking the model to plan its answer can increase accuracy.",
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
This is what a prompt looks like before adding chain of thought:
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
@@ -55,9 +59,9 @@ export const refineOptions: Record<
role: "user",
content: \`Title: \${scenario.title}
Body: \${scenario.body}
Need: \${scenario.need}
Rate likelihood on 1-3 scale.\`,
},
],
@@ -89,9 +93,9 @@ export const refineOptions: Record<
role: "user",
content: \`Title: \${scenario.title}
Body: \${scenario.body}
Need: \${scenario.need}
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
},
],
@@ -118,13 +122,14 @@ export const refineOptions: Record<
});
Add chain of thought to the original prompt.`,
},
"Convert to function call": {
description: "Use function calls to get output from the model in a more structured way.",
instructions: `OpenAI functions are a specialized way for an LLM to return output.
},
"Convert to function call": {
icon: TfiThought,
description: "Use function calls to get output from the model in a more structured way.",
instructions: `OpenAI functions are a specialized way for an LLM to return output.
This is what a prompt looks like before adding a function:
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
@@ -139,9 +144,9 @@ export const refineOptions: Record<
},
],
});
This is what one looks like after adding a function:
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
@@ -187,11 +192,11 @@ export const refineOptions: Record<
title: \${scenario.title}
body: \${scenario.body}
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
Need: \${scenario.need}
Answer one integer between 1 and 3.\`,
},
],
@@ -207,9 +212,9 @@ export const refineOptions: Record<
role: "user",
content: \`Title: \${scenario.title}
Body: \${scenario.body}
Need: \${scenario.need}
Rate likelihood on 1-3 scale.\`,
},
],
@@ -231,7 +236,52 @@ export const refineOptions: Record<
name: "score_post",
},
});
Another example
Before:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
stream: true,
messages: [
{
role: "system",
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
},
],
});
After:
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
},
],
functions: [
{
name: "write_in_language",
parameters: {
type: "object",
properties: {
text: {
type: "string",
},
},
},
},
],
function_call: {
name: "write_in_language",
},
});
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
},
},
"replicate/llama2": {},
};

View File

@@ -1,47 +0,0 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { type SupportedModel } from "~/server/types";
import { useElementDimensions } from "~/utils/hooks";
const modelOptions: { value: SupportedModel; label: string }[] = [
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
{ value: "gpt-4", label: "gpt-4" },
{ value: "gpt-4-0613", label: "gpt-4-0613" },
{ value: "gpt-4-32k", label: "gpt-4-32k" },
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
];
export const SelectModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: SupportedModel;
setSelectedModel: (model: SupportedModel) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -17,7 +17,7 @@ import { FaRegClone } from "react-icons/fa";
import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
export default function VariantHeaderMenuButton({
variant,
@@ -50,7 +50,7 @@ export default function VariantHeaderMenuButton({
await utils.promptVariants.list.invalidate();
}, [hideMutation, variant.id]);
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
return (
@@ -72,7 +72,7 @@ export default function VariantHeaderMenuButton({
</MenuItem>
<MenuItem
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
onClick={() => setSelectModelModalOpen(true)}
onClick={() => setChangeModelModalOpen(true)}
>
Change Model
</MenuItem>
@@ -97,8 +97,8 @@ export default function VariantHeaderMenuButton({
)}
</MenuList>
</Menu>
{selectModelModalOpen && (
<SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} />
{changeModelModalOpen && (
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
)}
{refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />

View File

@@ -9,7 +9,6 @@ export const env = createEnv({
server: {
DATABASE_URL: z.string().url(),
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
OPENAI_API_KEY: z.string().min(1),
RESTRICT_PRISMA_LOGS: z
.string()
.optional()
@@ -17,7 +16,8 @@ export const env = createEnv({
.transform((val) => val.toLowerCase() === "true"),
GITHUB_CLIENT_ID: z.string().min(1),
GITHUB_CLIENT_SECRET: z.string().min(1),
REPLICATE_API_TOKEN: z.string().min(1),
OPENAI_API_KEY: z.string().min(1),
REPLICATE_API_TOKEN: z.string().default("placeholder"),
},
/**

View File

@@ -1,77 +0,0 @@
import { type SupportedModel } from "../server/types";
interface ModelStats {
contextLength: number;
promptTokenPrice: number;
completionTokenPrice: number;
speed: "fast" | "medium" | "slow";
provider: "OpenAI";
learnMoreUrl: string;
}
export const modelStats: Record<SupportedModel, ModelStats> = {
"gpt-4": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-0613": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-0613": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
};

View File

@@ -9,19 +9,39 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletio
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
contextWindow: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
contextWindow: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
contextWindow: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
contextWindow: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "openai/ChatCompletion",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},

View File

@@ -8,10 +8,10 @@ import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types";
import { omit } from "lodash-es";
import { openai } from "~/server/utils/openai";
import { type OpenAIChatModel } from "~/server/types";
import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai";
import { modelStats } from "../modelStats";
import frontendModelProvider from "./frontend";
import modelProvider, { type SupportedModel } from ".";
const mergeStreamedChunks = (
base: ChatCompletion | null,
@@ -60,6 +60,7 @@ export async function getCompletion(
let finalCompletion: ChatCompletion | null = null;
let promptTokens: number | undefined = undefined;
let completionTokens: number | undefined = undefined;
const modelName = modelProvider.getModel(input) as SupportedModel;
try {
if (onStream) {
@@ -81,12 +82,9 @@ export async function getCompletion(
};
}
try {
promptTokens = countOpenAIChatTokens(
input.model as keyof typeof OpenAIChatModel,
input.messages,
);
promptTokens = countOpenAIChatTokens(modelName, input.messages);
completionTokens = countOpenAIChatTokens(
input.model as keyof typeof OpenAIChatModel,
modelName,
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
@@ -106,10 +104,10 @@ export async function getCompletion(
}
const timeToComplete = Date.now() - start;
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
let cost = undefined;
if (stats && promptTokens && completionTokens) {
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
}
return {

View File

@@ -5,9 +5,30 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlam
name: "Replicate Llama2",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
"7b-chat": {
name: "LLama 2 7B Chat",
contextWindow: 4096,
pricePerSecond: 0.0023,
speed: "fast",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
},
"13b-chat": {
name: "LLama 2 13B Chat",
contextWindow: 4096,
pricePerSecond: 0.0023,
speed: "medium",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
},
"70b-chat": {
name: "LLama 2 70B Chat",
contextWindow: 4096,
pricePerSecond: 0.0032,
speed: "slow",
provider: "replicate/llama2",
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
},
},
normalizeOutput: (output) => {

View File

@@ -1,16 +1,31 @@
import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest";
import { z } from "zod";
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
const ZodSupportedProvider = z.union([
z.literal("openai/ChatCompletion"),
z.literal("replicate/llama2"),
]);
type ModelInfo = {
name?: string;
learnMore?: string;
};
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
export const ZodModel = z.object({
name: z.string(),
contextWindow: z.number(),
promptTokenPrice: z.number().optional(),
completionTokenPrice: z.number().optional(),
pricePerSecond: z.number().optional(),
speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]),
provider: ZodSupportedProvider,
description: z.string().optional(),
learnMoreUrl: z.string().optional(),
});
export type Model = z.infer<typeof ZodModel>;
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelInfo>;
models: Record<SupportedModels, Model>;
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
};

View File

@@ -2,7 +2,6 @@ import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { type SupportedModel } from "~/server/types";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
@@ -10,6 +9,7 @@ import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
import parseConstructFn from "~/server/utils/parseConstructFn";
import { ZodModel } from "~/modelProviders/types";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure
@@ -144,7 +144,7 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({
experimentId: z.string(),
variantId: z.string().optional(),
newModel: z.string().optional(),
newModel: ZodModel.optional(),
}),
)
.mutation(async ({ input, ctx }) => {
@@ -186,10 +186,7 @@ export const promptVariantsRouter = createTRPCRouter({
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const newConstructFn = await deriveNewConstructFn(
originalVariant,
input.newModel as SupportedModel,
);
const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel);
const createNewVariantAction = prisma.promptVariant.create({
data: {
@@ -289,7 +286,7 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({
id: z.string(),
instructions: z.string().optional(),
newModel: z.string().optional(),
newModel: ZodModel.optional(),
}),
)
.mutation(async ({ input, ctx }) => {
@@ -308,7 +305,7 @@ export const promptVariantsRouter = createTRPCRouter({
const promptConstructionFn = await deriveNewConstructFn(
existing,
input.newModel as SupportedModel | undefined,
input.newModel,
input.instructions,
);

View File

@@ -1,12 +0,0 @@
export enum OpenAIChatModel {
"gpt-4" = "gpt-4",
"gpt-4-0613" = "gpt-4-0613",
"gpt-4-32k" = "gpt-4-32k",
"gpt-4-32k-0613" = "gpt-4-32k-0613",
"gpt-3.5-turbo" = "gpt-3.5-turbo",
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
}
export type SupportedModel = keyof typeof OpenAIChatModel;

View File

@@ -1,18 +1,18 @@
import { type PromptVariant } from "@prisma/client";
import { type SupportedModel } from "../types";
import ivm from "isolated-vm";
import dedent from "dedent";
import { openai } from "./openai";
import { getApiShapeForModel } from "./getTypesForModel";
import { isObject } from "lodash-es";
import { type CompletionCreateParams } from "openai/resources/chat/completions";
import formatPromptConstructor from "~/utils/formatPromptConstructor";
import { type SupportedProvider, type Model } from "~/modelProviders/types";
import modelProviders from "~/modelProviders/modelProviders";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn(
originalVariant: PromptVariant | null,
newModel?: SupportedModel,
newModel?: Model,
instructions?: string,
) {
if (originalVariant && !newModel && !instructions) {
@@ -36,10 +36,11 @@ export async function deriveNewConstructFn(
const NUM_RETRIES = 5;
const requestUpdatedPromptFunction = async (
originalVariant: PromptVariant,
newModel?: SupportedModel,
newModel?: Model,
instructions?: string,
) => {
const originalModel = originalVariant.model as SupportedModel;
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
let newContructionFn = "";
for (let i = 0; i < NUM_RETRIES; i++) {
try {
@@ -47,7 +48,7 @@ const requestUpdatedPromptFunction = async (
{
role: "system",
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
getApiShapeForModel(originalModel),
originalModelProvider.inputSchema,
null,
2,
)}\n\nDo not add any assistant messages.`,
@@ -60,8 +61,20 @@ const requestUpdatedPromptFunction = async (
if (newModel) {
messages.push({
role: "user",
content: `Return the prompt constructor function for ${newModel} given the existing prompt constructor function for ${originalModel}`,
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
});
if (newModel.provider !== originalModel.provider) {
messages.push({
role: "user",
content: `The old provider was ${originalModel.provider}. The new provider is ${
newModel.provider
}. Here is the schema for the new model:\n---\n${JSON.stringify(
modelProviders[newModel.provider].inputSchema,
null,
2,
)}`,
});
}
}
if (instructions) {
messages.push({

View File

@@ -4,8 +4,9 @@ import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
import { omit } from "lodash-es";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: variantId,
@@ -18,7 +19,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
},
});
if (!variant || !scenario) return null;
if (!variant || !scenario) return;
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
@@ -32,7 +33,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
},
});
if (cell) return cell;
if (cell) return;
const parsedConstructFn = await parseConstructFn(
variant.constructFn,
@@ -40,7 +41,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
);
if ("error" in parsedConstructFn) {
return await prisma.scenarioVariantCell.create({
await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
@@ -49,6 +50,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
retrievalStatus: "ERROR",
},
});
return;
}
const inputHash = hashPrompt(parsedConstructFn);
@@ -69,29 +71,33 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
where: { inputHash },
});
let newModelOutput;
if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({
const newModelOutput = await prisma.modelOutput.create({
data: {
...omit(matchingModelOutput, ["id"]),
scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
cost: matchingModelOutput.cost,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
},
});
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "COMPLETE" },
});
// Copy over all eval results as well
await Promise.all(
(
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
).map(async (evaluation) => {
await prisma.outputEvaluation.create({
data: {
...omit(evaluation, ["id"]),
modelOutputId: newModelOutput.id,
},
});
}),
);
} else {
cell = await queueLLMRetrievalTask(cell.id);
}
return { ...cell, modelOutput: newModelOutput };
};

View File

@@ -1,6 +0,0 @@
import { type SupportedModel } from "../types";
export const getApiShapeForModel = (model: SupportedModel) => {
// if (model in OpenAIChatModel) return openAIChatApiShape;
return "";
};

View File

@@ -1,6 +1,6 @@
import { type ChatCompletion } from "openai/resources/chat";
import { GPTTokens } from "gpt-tokens";
import { type OpenAIChatModel } from "~/server/types";
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
interface GPTTokensMessageItem {
name?: string;
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
}
export const countOpenAIChatTokens = (
model: keyof typeof OpenAIChatModel,
model: SupportedModel,
messages: ChatCompletion.Choice.Message[],
) => {
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })

View File

@@ -1 +1,5 @@
import { type Model } from "~/modelProviders/types";
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);
export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;