Compare commits
15 Commits
model-prov
...
bugfix-max
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
213a00a8e6 | ||
|
|
af9943eefc | ||
|
|
741128e0f4 | ||
|
|
aff14539d8 | ||
|
|
1af81a50a9 | ||
|
|
7e1fbb3767 | ||
|
|
a5d972005e | ||
|
|
a180b5bef2 | ||
|
|
55c697223e | ||
|
|
9978075867 | ||
|
|
372c2512c9 | ||
|
|
1822fe198e | ||
|
|
f06e1db3db | ||
|
|
9314a86857 | ||
|
|
54dcb4a567 |
@@ -17,6 +17,9 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
|
||||
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
||||
OPENAI_API_KEY=""
|
||||
|
||||
# Replicate API token. Create a token here: https://replicate.com/account/api-tokens
|
||||
REPLICATE_API_TOKEN=""
|
||||
|
||||
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||
|
||||
# Next Auth
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
||||
import ResizeTextarea from "react-textarea-autosize";
|
||||
import React from "react";
|
||||
import React, { useLayoutEffect, useState } from "react";
|
||||
|
||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||
HTMLTextAreaElement,
|
||||
TextareaProps & { minRows?: number }
|
||||
> = (props, ref) => {
|
||||
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
|
||||
const [isRerendered, setIsRerendered] = useState(false);
|
||||
useLayoutEffect(() => setIsRerendered(true), []);
|
||||
|
||||
return (
|
||||
<Textarea
|
||||
minH="unset"
|
||||
overflow="hidden"
|
||||
minRows={minRows}
|
||||
overflowY={isRerendered ? overflowY : "hidden"}
|
||||
w="100%"
|
||||
resize="none"
|
||||
ref={ref}
|
||||
minRows={1}
|
||||
transition="height none"
|
||||
as={ResizeTextarea}
|
||||
{...props}
|
||||
|
||||
@@ -18,11 +18,9 @@ export const FloatingLabelInput = ({
|
||||
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
||||
fontSize={isFocused || !!value ? "12px" : "16px"}
|
||||
transition="all 0.15s"
|
||||
zIndex="100"
|
||||
zIndex="5"
|
||||
bg="white"
|
||||
px={1}
|
||||
mt={0}
|
||||
mb={2}
|
||||
lineHeight="1"
|
||||
pointerEvents="none"
|
||||
color={isFocused ? "blue.500" : "gray.500"}
|
||||
|
||||
@@ -49,7 +49,11 @@ export default function NewScenarioButton() {
|
||||
Add Scenario
|
||||
</StyledButton>
|
||||
<StyledButton onClick={onAutogenerate}>
|
||||
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
|
||||
<Icon
|
||||
as={autogenerating ? Spinner : BsPlus}
|
||||
boxSize={autogenerating ? 4 : 6}
|
||||
mr={autogenerating ? 2 : 0}
|
||||
/>
|
||||
Autogenerate Scenario
|
||||
</StyledButton>
|
||||
</HStack>
|
||||
|
||||
@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -40,7 +40,7 @@ export default function OutputCell({
|
||||
);
|
||||
|
||||
const provider =
|
||||
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
|
||||
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
|
||||
|
||||
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
||||
|
||||
@@ -88,11 +88,9 @@ export default function OutputCell({
|
||||
}
|
||||
|
||||
const normalizedOutput = modelOutput
|
||||
? // @ts-expect-error TODO FIX ASAP
|
||||
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
||||
? provider.normalizeOutput(modelOutput.output)
|
||||
: streamedMessage
|
||||
? // @ts-expect-error TODO FIX ASAP
|
||||
provider.normalizeOutput(streamedMessage)
|
||||
? provider.normalizeOutput(streamedMessage)
|
||||
: null;
|
||||
|
||||
if (modelOutput && normalizedOutput?.type === "json") {
|
||||
|
||||
@@ -4,5 +4,5 @@ export const stickyHeaderStyle: SystemStyleObject = {
|
||||
position: "sticky",
|
||||
top: "0",
|
||||
backgroundColor: "#fff",
|
||||
zIndex: 1,
|
||||
zIndex: 10,
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||
import Prism from "prismjs";
|
||||
@@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => {
|
||||
const CompareFunctions = ({
|
||||
originalFunction,
|
||||
newFunction = "",
|
||||
leftTitle = "Original",
|
||||
rightTitle = "Modified",
|
||||
...props
|
||||
}: {
|
||||
originalFunction: string;
|
||||
newFunction?: string;
|
||||
}) => {
|
||||
leftTitle?: string;
|
||||
rightTitle?: string;
|
||||
} & StackProps) => {
|
||||
const showSplitView = useBreakpointValue(
|
||||
{
|
||||
base: false,
|
||||
@@ -34,22 +39,20 @@ const CompareFunctions = ({
|
||||
);
|
||||
|
||||
return (
|
||||
<HStack w="full" spacing={5}>
|
||||
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
|
||||
<DiffViewer
|
||||
oldValue={originalFunction}
|
||||
newValue={newFunction || originalFunction}
|
||||
splitView={showSplitView}
|
||||
hideLineNumbers={!showSplitView}
|
||||
leftTitle="Original"
|
||||
rightTitle={newFunction ? "Modified" : "Unmodified"}
|
||||
disableWordDiff={true}
|
||||
compareMethod={DiffMethod.CHARS}
|
||||
renderContent={highlightSyntax}
|
||||
showDiffOnly={false}
|
||||
/>
|
||||
</VStack>
|
||||
</HStack>
|
||||
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
|
||||
<DiffViewer
|
||||
oldValue={originalFunction}
|
||||
newValue={newFunction || originalFunction}
|
||||
splitView={showSplitView}
|
||||
hideLineNumbers={!showSplitView}
|
||||
leftTitle={leftTitle}
|
||||
rightTitle={rightTitle}
|
||||
disableWordDiff={true}
|
||||
compareMethod={DiffMethod.CHARS}
|
||||
renderContent={highlightSyntax}
|
||||
showDiffOnly={false}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({
|
||||
minW="unset"
|
||||
size="sm"
|
||||
onClick={() => onSubmit()}
|
||||
disabled={!instructions}
|
||||
variant={instructions ? "solid" : "ghost"}
|
||||
mr={4}
|
||||
borderRadius="8"
|
||||
|
||||
@@ -36,25 +36,25 @@ export const RefinePromptModal = ({
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
|
||||
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getRefinedPromptFn.useMutation();
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
const [instructions, setInstructions] = useState<string>("");
|
||||
|
||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
||||
RefineOptionLabel | undefined
|
||||
>(undefined);
|
||||
|
||||
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||
async (label?: RefineOptionLabel) => {
|
||||
if (!variant.experimentId) return;
|
||||
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
||||
setActiveRefineOptionLabel(label);
|
||||
await getRefinedPromptMutateAsync({
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
instructions: updatedInstructions,
|
||||
});
|
||||
},
|
||||
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
||||
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
||||
);
|
||||
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
@@ -75,7 +75,11 @@ export const RefinePromptModal = ({
|
||||
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||
|
||||
return (
|
||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={onClose}
|
||||
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
@@ -93,15 +97,15 @@ export const RefinePromptModal = ({
|
||||
label="Convert to function call"
|
||||
activeLabel={activeRefineOptionLabel}
|
||||
icon={VscJson}
|
||||
onClick={getRefinedPromptFn}
|
||||
loading={refiningInProgress}
|
||||
onClick={getModifiedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
/>
|
||||
<RefineOption
|
||||
label="Add chain of thought"
|
||||
activeLabel={activeRefineOptionLabel}
|
||||
icon={TfiThought}
|
||||
onClick={getRefinedPromptFn}
|
||||
loading={refiningInProgress}
|
||||
onClick={getModifiedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
/>
|
||||
</SimpleGrid>
|
||||
<HStack>
|
||||
@@ -110,13 +114,14 @@ export const RefinePromptModal = ({
|
||||
<CustomInstructionsInput
|
||||
instructions={instructions}
|
||||
setInstructions={setInstructions}
|
||||
loading={refiningInProgress}
|
||||
onSubmit={getRefinedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
onSubmit={getModifiedPromptFn}
|
||||
/>
|
||||
</VStack>
|
||||
<CompareFunctions
|
||||
originalFunction={variant.constructFn}
|
||||
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||
maxH="40vh"
|
||||
/>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
@@ -124,12 +129,10 @@ export const RefinePromptModal = ({
|
||||
<ModalFooter>
|
||||
<HStack spacing={4}>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
disabled={replacementInProgress || !refinedPromptFn}
|
||||
_disabled={{
|
||||
bgColor: "blue.500",
|
||||
}}
|
||||
isDisabled={replacementInProgress || !refinedPromptFn}
|
||||
>
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
|
||||
@@ -12,7 +12,7 @@ export const refineOptions: Record<
|
||||
|
||||
This is what a prompt looks like before adding chain of thought:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
@@ -25,11 +25,11 @@ export const refineOptions: Record<
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
};
|
||||
});
|
||||
|
||||
This is what one looks like after adding chain of thought:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
@@ -42,13 +42,13 @@ export const refineOptions: Record<
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||
},
|
||||
],
|
||||
};
|
||||
});
|
||||
|
||||
Here's another example:
|
||||
|
||||
Before:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
@@ -78,11 +78,11 @@ export const refineOptions: Record<
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
@@ -115,7 +115,7 @@ export const refineOptions: Record<
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
Add chain of thought to the original prompt.`,
|
||||
},
|
||||
@@ -125,7 +125,7 @@ export const refineOptions: Record<
|
||||
|
||||
This is what a prompt looks like before adding a function:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
@@ -138,11 +138,11 @@ export const refineOptions: Record<
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
};
|
||||
});
|
||||
|
||||
This is what one looks like after adding a function:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
@@ -172,13 +172,13 @@ export const refineOptions: Record<
|
||||
function_call: {
|
||||
name: "extract_sentiment",
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
Here's another example of adding a function:
|
||||
|
||||
Before:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
@@ -196,11 +196,11 @@ export const refineOptions: Record<
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
};
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
prompt = {
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
@@ -230,7 +230,7 @@ export const refineOptions: Record<
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||
},
|
||||
|
||||
@@ -20,36 +20,60 @@ import { ModelStatsCard } from "./ModelStatsCard";
|
||||
import { SelectModelSearch } from "./SelectModelSearch";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
|
||||
export const SelectModelModal = ({
|
||||
originalModel,
|
||||
variantId,
|
||||
variant,
|
||||
onClose,
|
||||
}: {
|
||||
originalModel: SupportedModel;
|
||||
variantId: string;
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const originalModel = variant.model as SupportedModel;
|
||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
||||
const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined);
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
|
||||
const createMutation = api.promptVariants.create.useMutation();
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
|
||||
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment?.data?.id) return;
|
||||
await createMutation.mutateAsync({
|
||||
experimentId: experiment?.data?.id,
|
||||
variantId,
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment) return;
|
||||
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
newModel: selectedModel,
|
||||
});
|
||||
setConvertedModel(selectedModel);
|
||||
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
|
||||
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
|
||||
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (
|
||||
!variant.experimentId ||
|
||||
!modifiedPromptFn ||
|
||||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
|
||||
)
|
||||
return;
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
constructFn: modifiedPromptFn,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [createMutation, experiment?.data?.id, variantId, onClose]);
|
||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||
|
||||
return (
|
||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={onClose}
|
||||
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
@@ -66,18 +90,36 @@ export const SelectModelModal = ({
|
||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||
)}
|
||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
{isString(modifiedPromptFn) && (
|
||||
<CompareFunctions
|
||||
originalFunction={variant.constructFn}
|
||||
newFunction={modifiedPromptFn}
|
||||
leftTitle={originalModel}
|
||||
rightTitle={convertedModel}
|
||||
/>
|
||||
)}
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={createNewVariant}
|
||||
minW={24}
|
||||
disabled={originalModel === selectedModel}
|
||||
>
|
||||
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
|
||||
</Button>
|
||||
<HStack>
|
||||
<Button
|
||||
colorScheme="gray"
|
||||
onClick={getModifiedPromptFn}
|
||||
minW={24}
|
||||
isDisabled={originalModel === selectedModel || modificationInProgress}
|
||||
>
|
||||
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
|
||||
>
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
|
||||
@@ -18,7 +18,6 @@ import { useState } from "react";
|
||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
|
||||
export default function VariantHeaderMenuButton({
|
||||
variant,
|
||||
@@ -99,11 +98,7 @@ export default function VariantHeaderMenuButton({
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{selectModelModalOpen && (
|
||||
<SelectModelModal
|
||||
originalModel={variant.model as SupportedModel}
|
||||
variantId={variant.id}
|
||||
onClose={() => setSelectModelModalOpen(false)}
|
||||
/>
|
||||
<SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} />
|
||||
)}
|
||||
{refinePromptModalOpen && (
|
||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
||||
|
||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||
|
||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||
// just include them in the default `modelProviders` object because it has some
|
||||
// transient dependencies that can only be imported on the server.
|
||||
const modelProvidersFrontend = {
|
||||
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||
"replicate/llama2": replicateLlama2Frontend,
|
||||
} as const;
|
||||
};
|
||||
|
||||
export default modelProvidersFrontend;
|
||||
export default frontendModelProviders;
|
||||
@@ -1,9 +1,10 @@
|
||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||
import replicateLlama2 from "./replicate-llama2";
|
||||
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||
|
||||
const modelProviders = {
|
||||
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletion,
|
||||
"replicate/llama2": replicateLlama2,
|
||||
} as const;
|
||||
};
|
||||
|
||||
export default modelProviders;
|
||||
|
||||
@@ -56,6 +56,14 @@ modelProperty.type = "string";
|
||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||
delete modelProperty["oneOf"];
|
||||
|
||||
// The default of "inf" confuses the Typescript generator, so can just remove it
|
||||
assert(
|
||||
"max_tokens" in completionRequestSchema.properties &&
|
||||
isObject(completionRequestSchema.properties.max_tokens) &&
|
||||
"default" in completionRequestSchema.properties.max_tokens,
|
||||
);
|
||||
delete completionRequestSchema.properties.max_tokens["default"];
|
||||
|
||||
// Get the directory of the current script
|
||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||
|
||||
|
||||
@@ -150,7 +150,6 @@
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
||||
"default": "inf",
|
||||
"type": "integer"
|
||||
},
|
||||
"presence_penalty": {
|
||||
|
||||
@@ -1,8 +1,30 @@
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { type OpenaiChatModelProvider } from ".";
|
||||
import { type ModelProviderFrontend } from "../types";
|
||||
import { type SupportedModel } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
|
||||
models: {
|
||||
"gpt-4-0613": {
|
||||
name: "GPT-4",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
name: "GPT-4 32k",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
name: "GPT-3.5 Turbo",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
name: "GPT-3.5 Turbo 16k",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
},
|
||||
|
||||
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
||||
normalizeOutput: (output) => {
|
||||
const message = output.choices[0]?.message;
|
||||
if (!message)
|
||||
@@ -39,4 +61,4 @@ const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
|
||||
},
|
||||
};
|
||||
|
||||
export default modelProviderFrontend;
|
||||
export default frontendModelProvider;
|
||||
|
||||
@@ -3,6 +3,7 @@ import { type ModelProvider } from "../types";
|
||||
import inputSchema from "./codegen/input.schema.json";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
import frontendModelProvider from "./frontend";
|
||||
|
||||
const supportedModels = [
|
||||
"gpt-4-0613",
|
||||
@@ -11,7 +12,7 @@ const supportedModels = [
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
] as const;
|
||||
|
||||
type SupportedModel = (typeof supportedModels)[number];
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type OpenaiChatModelProvider = ModelProvider<
|
||||
SupportedModel,
|
||||
@@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider<
|
||||
>;
|
||||
|
||||
const modelProvider: OpenaiChatModelProvider = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
models: {
|
||||
"gpt-4-0613": {
|
||||
name: "GPT-4",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
name: "GPT-4 32k",
|
||||
learnMore: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
name: "GPT-3.5 Turbo",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
name: "GPT-3.5 Turbo 16k",
|
||||
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
},
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model as SupportedModel))
|
||||
return input.model as SupportedModel;
|
||||
@@ -57,6 +39,7 @@ const modelProvider: OpenaiChatModelProvider = {
|
||||
inputSchema: inputSchema as JSONSchema4,
|
||||
shouldStream: (input) => input.stream ?? false,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import { type ReplicateLlama2Provider } from ".";
|
||||
import { type ModelProviderFrontend } from "../types";
|
||||
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||
name: "Replicate Llama2",
|
||||
|
||||
models: {
|
||||
"7b-chat": {},
|
||||
"13b-chat": {},
|
||||
"70b-chat": {},
|
||||
},
|
||||
|
||||
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
||||
normalizeOutput: (output) => {
|
||||
return {
|
||||
type: "text",
|
||||
@@ -10,4 +18,4 @@ const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
||||
},
|
||||
};
|
||||
|
||||
export default modelProviderFrontend;
|
||||
export default frontendModelProvider;
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { type ModelProvider } from "../types";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
|
||||
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||
|
||||
type SupportedModel = (typeof supportedModels)[number];
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type ReplicateLlama2Input = {
|
||||
model: SupportedModel;
|
||||
@@ -25,12 +26,6 @@ export type ReplicateLlama2Provider = ModelProvider<
|
||||
>;
|
||||
|
||||
const modelProvider: ReplicateLlama2Provider = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
models: {
|
||||
"7b-chat": {},
|
||||
"13b-chat": {},
|
||||
"70b-chat": {},
|
||||
},
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model)) return input.model;
|
||||
|
||||
@@ -69,6 +64,7 @@ const modelProvider: ReplicateLlama2Provider = {
|
||||
},
|
||||
shouldStream: (input) => input.stream ?? false,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import { type JSONSchema4 } from "json-schema";
|
||||
import { type JsonValue } from "type-fest";
|
||||
|
||||
type ModelProviderModel = {
|
||||
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
|
||||
|
||||
type ModelInfo = {
|
||||
name?: string;
|
||||
learnMore?: string;
|
||||
};
|
||||
|
||||
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||
name: string;
|
||||
models: Record<SupportedModels, ModelInfo>;
|
||||
|
||||
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||
};
|
||||
|
||||
export type CompletionResponse<T> =
|
||||
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
|
||||
| {
|
||||
@@ -19,8 +28,6 @@ export type CompletionResponse<T> =
|
||||
};
|
||||
|
||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||
name: string;
|
||||
models: Record<SupportedModels, ModelProviderModel>;
|
||||
getModel: (input: InputSchema) => SupportedModels | null;
|
||||
shouldStream: (input: InputSchema) => boolean;
|
||||
inputSchema: JSONSchema4;
|
||||
@@ -31,7 +38,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
|
||||
|
||||
// This is just a convenience for type inference, don't use it at runtime
|
||||
_outputSchema?: OutputSchema | null;
|
||||
};
|
||||
} & FrontendModelProvider<SupportedModels, OutputSchema>;
|
||||
|
||||
export type NormalizedOutput =
|
||||
| {
|
||||
@@ -42,7 +49,3 @@ export type NormalizedOutput =
|
||||
type: "json";
|
||||
value: JsonValue;
|
||||
};
|
||||
|
||||
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
|
||||
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
|
||||
};
|
||||
|
||||
@@ -20,22 +20,25 @@ export default function ExperimentsPage() {
|
||||
const experiments = api.experiments.list.useQuery();
|
||||
|
||||
const user = useSession().data;
|
||||
const authLoading = useSession().status === "loading";
|
||||
|
||||
if (user === null) {
|
||||
if (user === null || authLoading) {
|
||||
return (
|
||||
<AppShell title="Experiments">
|
||||
<Center h="100%">
|
||||
<Text>
|
||||
<Link
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
textDecor="underline"
|
||||
>
|
||||
Sign in
|
||||
</Link>{" "}
|
||||
to view or create new experiments!
|
||||
</Text>
|
||||
{!authLoading && (
|
||||
<Text>
|
||||
<Link
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
textDecor="underline"
|
||||
>
|
||||
Sign in
|
||||
</Link>{" "}
|
||||
to view or create new experiments!
|
||||
</Text>
|
||||
)}
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
|
||||
@@ -98,7 +98,7 @@ export const experimentsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
const [variant, _, scenario] = await prisma.$transaction([
|
||||
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||
prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
@@ -121,7 +121,7 @@ export const experimentsRouter = createTRPCRouter({
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
||||
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});`,
|
||||
@@ -133,20 +133,38 @@ export const experimentsRouter = createTRPCRouter({
|
||||
prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "text",
|
||||
label: "language",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
text: "This is a test scenario.",
|
||||
language: "English",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "Spanish",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "German",
|
||||
},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
await generateNewCell(variant.id, scenario.id);
|
||||
await generateNewCell(variant.id, scenario1.id);
|
||||
await generateNewCell(variant.id, scenario2.id);
|
||||
await generateNewCell(variant.id, scenario3.id);
|
||||
|
||||
return exp;
|
||||
}),
|
||||
|
||||
@@ -284,11 +284,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
getRefinedPromptFn: protectedProcedure
|
||||
getModifiedPromptFn: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
instructions: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
newModel: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
@@ -307,7 +308,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(
|
||||
existing,
|
||||
constructedPrompt.model as SupportedModel,
|
||||
input.newModel as SupportedModel | undefined,
|
||||
input.instructions,
|
||||
);
|
||||
|
||||
|
||||
@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
|
||||
const provider = modelProviders[prompt.modelProvider];
|
||||
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||
|
||||
if (streamingChannel) {
|
||||
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
: null;
|
||||
|
||||
for (let i = 0; true; i++) {
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
const inputHash = hashPrompt(prompt);
|
||||
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
data: {
|
||||
scenarioVariantCellId,
|
||||
inputHash,
|
||||
output: response.value as unknown as Prisma.InputJsonObject,
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
timeToComplete: response.timeToComplete,
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
errorMessage: response.message,
|
||||
statusCode: response.statusCode,
|
||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -52,11 +52,15 @@ const requestUpdatedPromptFunction = async (
|
||||
2,
|
||||
)}\n\nDo not add any assistant messages.`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `This is the current prompt constructor function:\n---\n${originalVariant.constructFn}`,
|
||||
},
|
||||
];
|
||||
if (newModel) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
|
||||
content: `Return the prompt constructor function for ${newModel} given the existing prompt constructor function for ${originalModel}`,
|
||||
});
|
||||
}
|
||||
if (instructions) {
|
||||
@@ -65,10 +69,6 @@ const requestUpdatedPromptFunction = async (
|
||||
content: instructions,
|
||||
});
|
||||
}
|
||||
messages.push({
|
||||
role: "system",
|
||||
content: "The prompt variable has already been declared, so do not declare it again.",
|
||||
});
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages,
|
||||
|
||||
@@ -70,7 +70,6 @@ export default async function parseConstructFn(
|
||||
// We've validated the JSON schema so this should be safe
|
||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
const model = provider.getModel(input);
|
||||
if (!model) {
|
||||
return {
|
||||
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
|
||||
|
||||
return {
|
||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
|
||||
model,
|
||||
modelInput: input,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user