Compare commits

..

15 Commits

Author SHA1 Message Date
Kyle Corbitt
213a00a8e6 Fix typescript hints for max_tokens 2023-07-21 12:04:58 -07:00
Kyle Corbitt
af9943eefc Merge pull request #77 from OpenPipe/provider-types
Slightly better typings for ModelProviders
2023-07-21 11:51:25 -07:00
Kyle Corbitt
741128e0f4 Better division of labor between frontend and backend model providers
A bit better thinking on which types go where.
2023-07-21 11:49:35 -07:00
David Corbitt
aff14539d8 Add comment to .env.example 2023-07-21 11:29:21 -07:00
David Corbitt
1af81a50a9 Add REPLICATE_API_TOKEN to .env.example 2023-07-21 11:28:14 -07:00
Kyle Corbitt
7e1fbb3767 Slightly better typings for ModelProviders
Still not great because the `any`s loosen some call sites up more than I'd like, but better than the broken types before.
2023-07-21 06:50:05 -07:00
David Corbitt
a5d972005e Add user's current prompt to prompt derivation 2023-07-21 00:43:39 -07:00
arcticfly
a180b5bef2 Show prompt diff when changing models (#76)
* Make CompareFunctions more configurable

* Change RefinePromptModal styles

* Accept newModel in getModifiedPromptFn

* Show prompt comparison in SelectModelModal

* Pass variant to SelectModelModal

* Update instructions

* Properly use isDisabled
2023-07-20 23:26:49 -07:00
Kyle Corbitt
55c697223e Merge pull request #74 from OpenPipe/model-providers
replicate/llama2 provider
2023-07-20 23:21:42 -07:00
arcticfly
9978075867 Fix auth flicker (#75)
* Remove experiments flicker for unauthenticated users

* Decrease size of NewScenarioButton spinner
2023-07-20 20:46:31 -07:00
Kyle Corbitt
372c2512c9 Merge pull request #73 from OpenPipe/model-providers
More work on modelProviders
2023-07-20 18:56:14 -07:00
arcticfly
1822fe198e Initially render AutoResizeTextArea without overflow (#72)
* Rerender resized text area with scroll

* Remove default hidden overflow
2023-07-20 15:00:09 -07:00
Kyle Corbitt
f06e1db3db Merge pull request #71 from OpenPipe/model-providers
Prep for more model providers
2023-07-20 14:55:31 -07:00
arcticfly
9314a86857 Use translation in initial scenarios (#70) 2023-07-20 14:28:48 -07:00
David Corbitt
54dcb4a567 Prevent text input labels from overlaying scenarios header 2023-07-20 14:28:36 -07:00
27 changed files with 259 additions and 174 deletions

View File

@@ -17,6 +17,9 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key # https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
OPENAI_API_KEY="" OPENAI_API_KEY=""
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
REPLICATE_API_TOKEN=""
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318" NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
# Next Auth # Next Auth

View File

@@ -1,19 +1,22 @@
import { Textarea, type TextareaProps } from "@chakra-ui/react"; import { Textarea, type TextareaProps } from "@chakra-ui/react";
import ResizeTextarea from "react-textarea-autosize"; import ResizeTextarea from "react-textarea-autosize";
import React from "react"; import React, { useLayoutEffect, useState } from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction< export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement, HTMLTextAreaElement,
TextareaProps & { minRows?: number } TextareaProps & { minRows?: number }
> = (props, ref) => { > = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
const [isRerendered, setIsRerendered] = useState(false);
useLayoutEffect(() => setIsRerendered(true), []);
return ( return (
<Textarea <Textarea
minH="unset" minH="unset"
overflow="hidden" minRows={minRows}
overflowY={isRerendered ? overflowY : "hidden"}
w="100%" w="100%"
resize="none" resize="none"
ref={ref} ref={ref}
minRows={1}
transition="height none" transition="height none"
as={ResizeTextarea} as={ResizeTextarea}
{...props} {...props}

View File

@@ -18,11 +18,9 @@ export const FloatingLabelInput = ({
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"} transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
fontSize={isFocused || !!value ? "12px" : "16px"} fontSize={isFocused || !!value ? "12px" : "16px"}
transition="all 0.15s" transition="all 0.15s"
zIndex="100" zIndex="5"
bg="white" bg="white"
px={1} px={1}
mt={0}
mb={2}
lineHeight="1" lineHeight="1"
pointerEvents="none" pointerEvents="none"
color={isFocused ? "blue.500" : "gray.500"} color={isFocused ? "blue.500" : "gray.500"}

View File

@@ -49,7 +49,11 @@ export default function NewScenarioButton() {
Add Scenario Add Scenario
</StyledButton> </StyledButton>
<StyledButton onClick={onAutogenerate}> <StyledButton onClick={onAutogenerate}>
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} /> <Icon
as={autogenerating ? Spinner : BsPlus}
boxSize={autogenerating ? 4 : 6}
mr={autogenerating ? 2 : 0}
/>
Autogenerate Scenario Autogenerate Scenario
</StyledButton> </StyledButton>
</HStack> </HStack>

View File

@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats"; import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler"; import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions"; import { CellOptions } from "./CellOptions";
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend"; import frontendModelProviders from "~/modelProviders/frontendModelProviders";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -40,7 +40,7 @@ export default function OutputCell({
); );
const provider = const provider =
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend]; frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0]; type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
@@ -88,11 +88,9 @@ export default function OutputCell({
} }
const normalizedOutput = modelOutput const normalizedOutput = modelOutput
? // @ts-expect-error TODO FIX ASAP ? provider.normalizeOutput(modelOutput.output)
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
: streamedMessage : streamedMessage
? // @ts-expect-error TODO FIX ASAP ? provider.normalizeOutput(streamedMessage)
provider.normalizeOutput(streamedMessage)
: null; : null;
if (modelOutput && normalizedOutput?.type === "json") { if (modelOutput && normalizedOutput?.type === "json") {

View File

@@ -4,5 +4,5 @@ export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky", position: "sticky",
top: "0", top: "0",
backgroundColor: "#fff", backgroundColor: "#fff",
zIndex: 1, zIndex: 10,
}; };

View File

@@ -1,4 +1,4 @@
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react"; import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
import React from "react"; import React from "react";
import DiffViewer, { DiffMethod } from "react-diff-viewer"; import DiffViewer, { DiffMethod } from "react-diff-viewer";
import Prism from "prismjs"; import Prism from "prismjs";
@@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => {
const CompareFunctions = ({ const CompareFunctions = ({
originalFunction, originalFunction,
newFunction = "", newFunction = "",
leftTitle = "Original",
rightTitle = "Modified",
...props
}: { }: {
originalFunction: string; originalFunction: string;
newFunction?: string; newFunction?: string;
}) => { leftTitle?: string;
rightTitle?: string;
} & StackProps) => {
const showSplitView = useBreakpointValue( const showSplitView = useBreakpointValue(
{ {
base: false, base: false,
@@ -34,22 +39,20 @@ const CompareFunctions = ({
); );
return ( return (
<HStack w="full" spacing={5}> <VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
<DiffViewer <DiffViewer
oldValue={originalFunction} oldValue={originalFunction}
newValue={newFunction || originalFunction} newValue={newFunction || originalFunction}
splitView={showSplitView} splitView={showSplitView}
hideLineNumbers={!showSplitView} hideLineNumbers={!showSplitView}
leftTitle="Original" leftTitle={leftTitle}
rightTitle={newFunction ? "Modified" : "Unmodified"} rightTitle={rightTitle}
disableWordDiff={true} disableWordDiff={true}
compareMethod={DiffMethod.CHARS} compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax} renderContent={highlightSyntax}
showDiffOnly={false} showDiffOnly={false}
/> />
</VStack> </VStack>
</HStack>
); );
}; };

View File

@@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({
minW="unset" minW="unset"
size="sm" size="sm"
onClick={() => onSubmit()} onClick={() => onSubmit()}
disabled={!instructions}
variant={instructions ? "solid" : "ghost"} variant={instructions ? "solid" : "ghost"}
mr={4} mr={4}
borderRadius="8" borderRadius="8"

View File

@@ -36,25 +36,25 @@ export const RefinePromptModal = ({
}) => { }) => {
const utils = api.useContext(); const utils = api.useContext();
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } = const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getRefinedPromptFn.useMutation(); api.promptVariants.getModifiedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>(""); const [instructions, setInstructions] = useState<string>("");
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState< const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
RefineOptionLabel | undefined RefineOptionLabel | undefined
>(undefined); >(undefined);
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback( const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
async (label?: RefineOptionLabel) => { async (label?: RefineOptionLabel) => {
if (!variant.experimentId) return; if (!variant.experimentId) return;
const updatedInstructions = label ? refineOptions[label].instructions : instructions; const updatedInstructions = label ? refineOptions[label].instructions : instructions;
setActiveRefineOptionLabel(label); setActiveRefineOptionLabel(label);
await getRefinedPromptMutateAsync({ await getModifiedPromptMutateAsync({
id: variant.id, id: variant.id,
instructions: updatedInstructions, instructions: updatedInstructions,
}); });
}, },
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel], [getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
); );
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation(); const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
@@ -75,7 +75,11 @@ export const RefinePromptModal = ({
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]); }, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
return ( return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}> <Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay /> <ModalOverlay />
<ModalContent w={1200}> <ModalContent w={1200}>
<ModalHeader> <ModalHeader>
@@ -93,15 +97,15 @@ export const RefinePromptModal = ({
label="Convert to function call" label="Convert to function call"
activeLabel={activeRefineOptionLabel} activeLabel={activeRefineOptionLabel}
icon={VscJson} icon={VscJson}
onClick={getRefinedPromptFn} onClick={getModifiedPromptFn}
loading={refiningInProgress} loading={modificationInProgress}
/> />
<RefineOption <RefineOption
label="Add chain of thought" label="Add chain of thought"
activeLabel={activeRefineOptionLabel} activeLabel={activeRefineOptionLabel}
icon={TfiThought} icon={TfiThought}
onClick={getRefinedPromptFn} onClick={getModifiedPromptFn}
loading={refiningInProgress} loading={modificationInProgress}
/> />
</SimpleGrid> </SimpleGrid>
<HStack> <HStack>
@@ -110,13 +114,14 @@ export const RefinePromptModal = ({
<CustomInstructionsInput <CustomInstructionsInput
instructions={instructions} instructions={instructions}
setInstructions={setInstructions} setInstructions={setInstructions}
loading={refiningInProgress} loading={modificationInProgress}
onSubmit={getRefinedPromptFn} onSubmit={getModifiedPromptFn}
/> />
</VStack> </VStack>
<CompareFunctions <CompareFunctions
originalFunction={variant.constructFn} originalFunction={variant.constructFn}
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined} newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
maxH="40vh"
/> />
</VStack> </VStack>
</ModalBody> </ModalBody>
@@ -124,12 +129,10 @@ export const RefinePromptModal = ({
<ModalFooter> <ModalFooter>
<HStack spacing={4}> <HStack spacing={4}>
<Button <Button
colorScheme="blue"
onClick={replaceVariant} onClick={replaceVariant}
minW={24} minW={24}
disabled={replacementInProgress || !refinedPromptFn} isDisabled={replacementInProgress || !refinedPromptFn}
_disabled={{
bgColor: "blue.500",
}}
> >
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>} {replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button> </Button>

View File

@@ -12,7 +12,7 @@ export const refineOptions: Record<
This is what a prompt looks like before adding chain of thought: This is what a prompt looks like before adding chain of thought:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true, stream: true,
messages: [ messages: [
@@ -25,11 +25,11 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`, content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
}, },
], ],
}; });
This is what one looks like after adding chain of thought: This is what one looks like after adding chain of thought:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true, stream: true,
messages: [ messages: [
@@ -42,13 +42,13 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`, content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
}, },
], ],
}; });
Here's another example: Here's another example:
Before: Before:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
messages: [ messages: [
{ {
@@ -78,11 +78,11 @@ export const refineOptions: Record<
function_call: { function_call: {
name: "score_post", name: "score_post",
}, },
}; });
After: After:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
messages: [ messages: [
{ {
@@ -115,7 +115,7 @@ export const refineOptions: Record<
function_call: { function_call: {
name: "score_post", name: "score_post",
}, },
}; });
Add chain of thought to the original prompt.`, Add chain of thought to the original prompt.`,
}, },
@@ -125,7 +125,7 @@ export const refineOptions: Record<
This is what a prompt looks like before adding a function: This is what a prompt looks like before adding a function:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true, stream: true,
messages: [ messages: [
@@ -138,11 +138,11 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`, content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
}, },
], ],
}; });
This is what one looks like after adding a function: This is what one looks like after adding a function:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true, stream: true,
messages: [ messages: [
@@ -172,13 +172,13 @@ export const refineOptions: Record<
function_call: { function_call: {
name: "extract_sentiment", name: "extract_sentiment",
}, },
}; });
Here's another example of adding a function: Here's another example of adding a function:
Before: Before:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
messages: [ messages: [
{ {
@@ -196,11 +196,11 @@ export const refineOptions: Record<
}, },
], ],
temperature: 0, temperature: 0,
}; });
After: After:
prompt = { definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
messages: [ messages: [
{ {
@@ -230,7 +230,7 @@ export const refineOptions: Record<
function_call: { function_call: {
name: "score_post", name: "score_post",
}, },
}; });
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`, Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
}, },

View File

@@ -20,36 +20,60 @@ import { ModelStatsCard } from "./ModelStatsCard";
import { SelectModelSearch } from "./SelectModelSearch"; import { SelectModelSearch } from "./SelectModelSearch";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es";
export const SelectModelModal = ({ export const SelectModelModal = ({
originalModel, variant,
variantId,
onClose, onClose,
}: { }: {
originalModel: SupportedModel; variant: PromptVariant;
variantId: string;
onClose: () => void; onClose: () => void;
}) => { }) => {
const originalModel = variant.model as SupportedModel;
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel); const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined);
const utils = api.useContext(); const utils = api.useContext();
const experiment = useExperiment(); const experiment = useExperiment();
const createMutation = api.promptVariants.create.useMutation(); const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => { const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment?.data?.id) return; if (!experiment) return;
await createMutation.mutateAsync({
experimentId: experiment?.data?.id, await getModifiedPromptMutateAsync({
variantId, id: variant.id,
newModel: selectedModel, newModel: selectedModel,
}); });
setConvertedModel(selectedModel);
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (
!variant.experimentId ||
!modifiedPromptFn ||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
)
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: modifiedPromptFn,
});
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
onClose(); onClose();
}, [createMutation, experiment?.data?.id, variantId, onClose]); }, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
return ( return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}> <Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay /> <ModalOverlay />
<ModalContent w={1200}> <ModalContent w={1200}>
<ModalHeader> <ModalHeader>
@@ -66,18 +90,36 @@ export const SelectModelModal = ({
<ModelStatsCard label="New Model" model={selectedModel} /> <ModelStatsCard label="New Model" model={selectedModel} />
)} )}
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} /> <SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && (
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={modifiedPromptFn}
leftTitle={originalModel}
rightTitle={convertedModel}
/>
)}
</VStack> </VStack>
</ModalBody> </ModalBody>
<ModalFooter> <ModalFooter>
<HStack>
<Button
colorScheme="gray"
onClick={getModifiedPromptFn}
minW={24}
isDisabled={originalModel === selectedModel || modificationInProgress}
>
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
</Button>
<Button <Button
colorScheme="blue" colorScheme="blue"
onClick={createNewVariant} onClick={replaceVariant}
minW={24} minW={24}
disabled={originalModel === selectedModel} isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
> >
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>} {replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button> </Button>
</HStack>
</ModalFooter> </ModalFooter>
</ModalContent> </ModalContent>
</Modal> </Modal>

View File

@@ -18,7 +18,6 @@ import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal"; import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri"; import { RiExchangeFundsFill } from "react-icons/ri";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal"; import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
export default function VariantHeaderMenuButton({ export default function VariantHeaderMenuButton({
variant, variant,
@@ -99,11 +98,7 @@ export default function VariantHeaderMenuButton({
</MenuList> </MenuList>
</Menu> </Menu>
{selectModelModalOpen && ( {selectModelModalOpen && (
<SelectModelModal <SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} />
originalModel={variant.model as SupportedModel}
variantId={variant.id}
onClose={() => setSelectModelModalOpen(false)}
/>
)} )}
{refinePromptModalOpen && ( {refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} /> <RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />

View File

@@ -1,14 +1,15 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend"; import replicateLlama2Frontend from "./replicate-llama2/frontend";
import { type SupportedProvider, type FrontendModelProvider } from "./types";
// TODO: make sure we get a typescript error if you forget to add a provider here // TODO: make sure we get a typescript error if you forget to add a provider here
// Keep attributes here that need to be accessible from the frontend. We can't // Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some // just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server. // transient dependencies that can only be imported on the server.
const modelProvidersFrontend = { const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend, "openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend, "replicate/llama2": replicateLlama2Frontend,
} as const; };
export default modelProvidersFrontend; export default frontendModelProviders;

View File

@@ -1,9 +1,10 @@
import openaiChatCompletion from "./openai-ChatCompletion"; import openaiChatCompletion from "./openai-ChatCompletion";
import replicateLlama2 from "./replicate-llama2"; import replicateLlama2 from "./replicate-llama2";
import { type SupportedProvider, type ModelProvider } from "./types";
const modelProviders = { const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
"openai/ChatCompletion": openaiChatCompletion, "openai/ChatCompletion": openaiChatCompletion,
"replicate/llama2": replicateLlama2, "replicate/llama2": replicateLlama2,
} as const; };
export default modelProviders; export default modelProviders;

View File

@@ -56,6 +56,14 @@ modelProperty.type = "string";
modelProperty.enum = modelProperty.oneOf[1].enum; modelProperty.enum = modelProperty.oneOf[1].enum;
delete modelProperty["oneOf"]; delete modelProperty["oneOf"];
// The default of "inf" confuses the Typescript generator, so can just remove it
assert(
"max_tokens" in completionRequestSchema.properties &&
isObject(completionRequestSchema.properties.max_tokens) &&
"default" in completionRequestSchema.properties.max_tokens,
);
delete completionRequestSchema.properties.max_tokens["default"];
// Get the directory of the current script // Get the directory of the current script
const currentDirectory = path.dirname(import.meta.url).replace("file://", ""); const currentDirectory = path.dirname(import.meta.url).replace("file://", "");

View File

@@ -150,7 +150,6 @@
}, },
"max_tokens": { "max_tokens": {
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n", "description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
"default": "inf",
"type": "integer" "type": "integer"
}, },
"presence_penalty": { "presence_penalty": {

View File

@@ -1,8 +1,30 @@
import { type JsonValue } from "type-fest"; import { type JsonValue } from "type-fest";
import { type OpenaiChatModelProvider } from "."; import { type SupportedModel } from ".";
import { type ModelProviderFrontend } from "../types"; import { type FrontendModelProvider } from "../types";
import { type ChatCompletion } from "openai/resources/chat";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
normalizeOutput: (output) => { normalizeOutput: (output) => {
const message = output.choices[0]?.message; const message = output.choices[0]?.message;
if (!message) if (!message)
@@ -39,4 +61,4 @@ const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
}, },
}; };
export default modelProviderFrontend; export default frontendModelProvider;

View File

@@ -3,6 +3,7 @@ import { type ModelProvider } from "../types";
import inputSchema from "./codegen/input.schema.json"; import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion"; import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend";
const supportedModels = [ const supportedModels = [
"gpt-4-0613", "gpt-4-0613",
@@ -11,7 +12,7 @@ const supportedModels = [
"gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k-0613",
] as const; ] as const;
type SupportedModel = (typeof supportedModels)[number]; export type SupportedModel = (typeof supportedModels)[number];
export type OpenaiChatModelProvider = ModelProvider< export type OpenaiChatModelProvider = ModelProvider<
SupportedModel, SupportedModel,
@@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider<
>; >;
const modelProvider: OpenaiChatModelProvider = { const modelProvider: OpenaiChatModelProvider = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
getModel: (input) => { getModel: (input) => {
if (supportedModels.includes(input.model as SupportedModel)) if (supportedModels.includes(input.model as SupportedModel))
return input.model as SupportedModel; return input.model as SupportedModel;
@@ -57,6 +39,7 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
shouldStream: (input) => input.stream ?? false, shouldStream: (input) => input.stream ?? false,
getCompletion, getCompletion,
...frontendModelProvider,
}; };
export default modelProvider; export default modelProvider;

View File

@@ -1,7 +1,15 @@
import { type ReplicateLlama2Provider } from "."; import { type SupportedModel, type ReplicateLlama2Output } from ".";
import { type ModelProviderFrontend } from "../types"; import { type FrontendModelProvider } from "../types";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
name: "Replicate Llama2",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
normalizeOutput: (output) => { normalizeOutput: (output) => {
return { return {
type: "text", type: "text",
@@ -10,4 +18,4 @@ const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
}, },
}; };
export default modelProviderFrontend; export default frontendModelProvider;

View File

@@ -1,9 +1,10 @@
import { type ModelProvider } from "../types"; import { type ModelProvider } from "../types";
import frontendModelProvider from "./frontend";
import { getCompletion } from "./getCompletion"; import { getCompletion } from "./getCompletion";
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const; const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
type SupportedModel = (typeof supportedModels)[number]; export type SupportedModel = (typeof supportedModels)[number];
export type ReplicateLlama2Input = { export type ReplicateLlama2Input = {
model: SupportedModel; model: SupportedModel;
@@ -25,12 +26,6 @@ export type ReplicateLlama2Provider = ModelProvider<
>; >;
const modelProvider: ReplicateLlama2Provider = { const modelProvider: ReplicateLlama2Provider = {
name: "OpenAI ChatCompletion",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
getModel: (input) => { getModel: (input) => {
if (supportedModels.includes(input.model)) return input.model; if (supportedModels.includes(input.model)) return input.model;
@@ -69,6 +64,7 @@ const modelProvider: ReplicateLlama2Provider = {
}, },
shouldStream: (input) => input.stream ?? false, shouldStream: (input) => input.stream ?? false,
getCompletion, getCompletion,
...frontendModelProvider,
}; };
export default modelProvider; export default modelProvider;

View File

@@ -1,11 +1,20 @@
import { type JSONSchema4 } from "json-schema"; import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest"; import { type JsonValue } from "type-fest";
type ModelProviderModel = { export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
type ModelInfo = {
name?: string; name?: string;
learnMore?: string; learnMore?: string;
}; };
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelInfo>;
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
};
export type CompletionResponse<T> = export type CompletionResponse<T> =
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number } | { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
| { | {
@@ -19,8 +28,6 @@ export type CompletionResponse<T> =
}; };
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = { export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelProviderModel>;
getModel: (input: InputSchema) => SupportedModels | null; getModel: (input: InputSchema) => SupportedModels | null;
shouldStream: (input: InputSchema) => boolean; shouldStream: (input: InputSchema) => boolean;
inputSchema: JSONSchema4; inputSchema: JSONSchema4;
@@ -31,7 +38,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
// This is just a convenience for type inference, don't use it at runtime // This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null; _outputSchema?: OutputSchema | null;
}; } & FrontendModelProvider<SupportedModels, OutputSchema>;
export type NormalizedOutput = export type NormalizedOutput =
| { | {
@@ -42,7 +49,3 @@ export type NormalizedOutput =
type: "json"; type: "json";
value: JsonValue; value: JsonValue;
}; };
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
};

View File

@@ -20,11 +20,13 @@ export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery(); const experiments = api.experiments.list.useQuery();
const user = useSession().data; const user = useSession().data;
const authLoading = useSession().status === "loading";
if (user === null) { if (user === null || authLoading) {
return ( return (
<AppShell title="Experiments"> <AppShell title="Experiments">
<Center h="100%"> <Center h="100%">
{!authLoading && (
<Text> <Text>
<Link <Link
onClick={() => { onClick={() => {
@@ -36,6 +38,7 @@ export default function ExperimentsPage() {
</Link>{" "} </Link>{" "}
to view or create new experiments! to view or create new experiments!
</Text> </Text>
)}
</Center> </Center>
</AppShell> </AppShell>
); );

View File

@@ -98,7 +98,7 @@ export const experimentsRouter = createTRPCRouter({
}, },
}); });
const [variant, _, scenario] = await prisma.$transaction([ const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
prisma.promptVariant.create({ prisma.promptVariant.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
@@ -121,7 +121,7 @@ export const experimentsRouter = createTRPCRouter({
messages: [ messages: [
{ {
role: "system", role: "system",
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`, content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
}, },
], ],
});`, });`,
@@ -133,20 +133,38 @@ export const experimentsRouter = createTRPCRouter({
prisma.templateVariable.create({ prisma.templateVariable.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
label: "text", label: "language",
}, },
}), }),
prisma.testScenario.create({ prisma.testScenario.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
variableValues: { variableValues: {
text: "This is a test scenario.", language: "English",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "Spanish",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "German",
}, },
}, },
}), }),
]); ]);
await generateNewCell(variant.id, scenario.id); await generateNewCell(variant.id, scenario1.id);
await generateNewCell(variant.id, scenario2.id);
await generateNewCell(variant.id, scenario3.id);
return exp; return exp;
}), }),

View File

@@ -284,11 +284,12 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant; return updatedPromptVariant;
}), }),
getRefinedPromptFn: protectedProcedure getModifiedPromptFn: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
instructions: z.string(), instructions: z.string().optional(),
newModel: z.string().optional(),
}), }),
) )
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
@@ -307,7 +308,7 @@ export const promptVariantsRouter = createTRPCRouter({
const promptConstructionFn = await deriveNewConstructFn( const promptConstructionFn = await deriveNewConstructFn(
existing, existing,
constructedPrompt.model as SupportedModel, input.newModel as SupportedModel | undefined,
input.instructions, input.instructions,
); );

View File

@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const provider = modelProviders[prompt.modelProvider]; const provider = modelProviders[prompt.modelProvider];
// @ts-expect-error TODO FIX ASAP
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null; const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
if (streamingChannel) { if (streamingChannel) {
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
: null; : null;
for (let i = 0; true; i++) { for (let i = 0; true; i++) {
// @ts-expect-error TODO FIX ASAP
const response = await provider.getCompletion(prompt.modelInput, onStream); const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") { if (response.type === "success") {
const inputHash = hashPrompt(prompt); const inputHash = hashPrompt(prompt);
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
data: { data: {
scenarioVariantCellId, scenarioVariantCellId,
inputHash, inputHash,
output: response.value as unknown as Prisma.InputJsonObject, output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete, timeToComplete: response.timeToComplete,
promptTokens: response.promptTokens, promptTokens: response.promptTokens,
completionTokens: response.completionTokens, completionTokens: response.completionTokens,
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
errorMessage: response.message, errorMessage: response.message,
statusCode: response.statusCode, statusCode: response.statusCode,
retryTime: shouldRetry ? new Date(Date.now() + delay) : null, retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: shouldRetry ? "PENDING" : "ERROR", retrievalStatus: "ERROR",
}, },
}); });

View File

@@ -52,11 +52,15 @@ const requestUpdatedPromptFunction = async (
2, 2,
)}\n\nDo not add any assistant messages.`, )}\n\nDo not add any assistant messages.`,
}, },
{
role: "user",
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
},
]; ];
if (newModel) { if (newModel) {
messages.push({ messages.push({
role: "user", role: "user",
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`, content: `Return the prompt constructor function for ${newModel} given the existing prompt constructor function for ${originalModel}`,
}); });
} }
if (instructions) { if (instructions) {
@@ -65,10 +69,6 @@ const requestUpdatedPromptFunction = async (
content: instructions, content: instructions,
}); });
} }
messages.push({
role: "system",
content: "The prompt variable has already been declared, so do not declare it again.",
});
const completion = await openai.chat.completions.create({ const completion = await openai.chat.completions.create({
model: "gpt-4", model: "gpt-4",
messages, messages,

View File

@@ -70,7 +70,6 @@ export default async function parseConstructFn(
// We've validated the JSON schema so this should be safe // We've validated the JSON schema so this should be safe
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0]; const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
// @ts-expect-error TODO FIX ASAP
const model = provider.getModel(input); const model = provider.getModel(input);
if (!model) { if (!model) {
return { return {
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
return { return {
modelProvider: prompt.modelProvider as keyof typeof modelProviders, modelProvider: prompt.modelProvider as keyof typeof modelProviders,
// @ts-expect-error TODO FIX ASAP
model, model,
modelInput: input, modelInput: input,
}; };