move app to app/ subdir
This commit is contained in:
27
app/src/components/AutoResizeTextArea.tsx
Normal file
27
app/src/components/AutoResizeTextArea.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
||||
import ResizeTextarea from "react-textarea-autosize";
|
||||
import React, { useLayoutEffect, useState } from "react";
|
||||
|
||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||
HTMLTextAreaElement,
|
||||
TextareaProps & { minRows?: number }
|
||||
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
|
||||
const [isRerendered, setIsRerendered] = useState(false);
|
||||
useLayoutEffect(() => setIsRerendered(true), []);
|
||||
|
||||
return (
|
||||
<Textarea
|
||||
minH="unset"
|
||||
minRows={minRows}
|
||||
overflowY={isRerendered ? overflowY : "hidden"}
|
||||
w="100%"
|
||||
resize="none"
|
||||
ref={ref}
|
||||
transition="height none"
|
||||
as={ResizeTextarea}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.forwardRef(AutoResizeTextarea);
|
||||
142
app/src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
142
app/src/components/ChangeModelModal/ChangeModelModal.tsx
Normal file
@@ -0,0 +1,142 @@
|
||||
import {
|
||||
Button,
|
||||
HStack,
|
||||
Icon,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Spinner,
|
||||
Text,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { useState } from "react";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { type ProviderModel } from "~/modelProviders/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import { lookupModel, modelLabel } from "~/utils/utils";
|
||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||
import { ModelSearch } from "./ModelSearch";
|
||||
import { ModelStatsCard } from "./ModelStatsCard";
|
||||
|
||||
export const ChangeModelModal = ({
|
||||
variant,
|
||||
onClose,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const originalModel = lookupModel(variant.modelProvider, variant.model);
|
||||
const [selectedModel, setSelectedModel] = useState({
|
||||
provider: variant.modelProvider,
|
||||
model: variant.model,
|
||||
} as ProviderModel);
|
||||
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment) return;
|
||||
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
newModel: selectedModel,
|
||||
});
|
||||
setConvertedModel(selectedModel);
|
||||
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
|
||||
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
|
||||
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (
|
||||
!variant.experimentId ||
|
||||
!modifiedPromptFn ||
|
||||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
|
||||
)
|
||||
return;
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
promptConstructor: modifiedPromptFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||
|
||||
const originalLabel = modelLabel(variant.modelProvider, variant.model);
|
||||
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
|
||||
const convertedLabel =
|
||||
convertedModel && modelLabel(convertedModel.provider, convertedModel.model);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={onClose}
|
||||
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={RiExchangeFundsFill} />
|
||||
<Text>Change Model</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<ModelStatsCard label="Original Model" model={originalModel} />
|
||||
{originalLabel !== selectedLabel && (
|
||||
<ModelStatsCard
|
||||
label="New Model"
|
||||
model={lookupModel(selectedModel.provider, selectedModel.model)}
|
||||
/>
|
||||
)}
|
||||
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||
{isString(modifiedPromptFn) && (
|
||||
<CompareFunctions
|
||||
originalFunction={variant.promptConstructor}
|
||||
newFunction={modifiedPromptFn}
|
||||
leftTitle={originalLabel}
|
||||
rightTitle={convertedLabel}
|
||||
/>
|
||||
)}
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<HStack>
|
||||
<Button
|
||||
colorScheme="gray"
|
||||
onClick={getModifiedPromptFn}
|
||||
minW={24}
|
||||
isDisabled={originalLabel === selectedLabel || modificationInProgress}
|
||||
>
|
||||
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
|
||||
>
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
36
app/src/components/ChangeModelModal/ModelSearch.tsx
Normal file
36
app/src/components/ChangeModelModal/ModelSearch.tsx
Normal file
@@ -0,0 +1,36 @@
|
||||
import { Text, VStack } from "@chakra-ui/react";
|
||||
import { type LegacyRef } from "react";
|
||||
import Select from "react-select";
|
||||
import { useElementDimensions } from "~/utils/hooks";
|
||||
|
||||
import { flatMap } from "lodash-es";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { type ProviderModel } from "~/modelProviders/types";
|
||||
import { modelLabel } from "~/utils/utils";
|
||||
|
||||
const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) =>
|
||||
Object.entries(provider.models).map(([modelId]) => ({
|
||||
provider: providerId,
|
||||
model: modelId,
|
||||
})),
|
||||
) as ProviderModel[];
|
||||
|
||||
export const ModelSearch = (props: {
|
||||
selectedModel: ProviderModel;
|
||||
setSelectedModel: (model: ProviderModel) => void;
|
||||
}) => {
|
||||
const [containerRef, containerDimensions] = useElementDimensions();
|
||||
|
||||
return (
|
||||
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full" fontFamily="inconsolata">
|
||||
<Text fontWeight="bold">Browse Models</Text>
|
||||
<Select<ProviderModel>
|
||||
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
|
||||
getOptionLabel={(data) => modelLabel(data.provider, data.model)}
|
||||
getOptionValue={(data) => modelLabel(data.provider, data.model)}
|
||||
options={modelOptions}
|
||||
onChange={(option) => option && props.setSelectedModel(option)}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
117
app/src/components/ChangeModelModal/ModelStatsCard.tsx
Normal file
117
app/src/components/ChangeModelModal/ModelStatsCard.tsx
Normal file
@@ -0,0 +1,117 @@
|
||||
import {
|
||||
GridItem,
|
||||
HStack,
|
||||
Link,
|
||||
SimpleGrid,
|
||||
Text,
|
||||
VStack,
|
||||
type StackProps,
|
||||
} from "@chakra-ui/react";
|
||||
import { type lookupModel } from "~/utils/utils";
|
||||
|
||||
export const ModelStatsCard = ({
|
||||
label,
|
||||
model,
|
||||
}: {
|
||||
label: string;
|
||||
model: ReturnType<typeof lookupModel>;
|
||||
}) => {
|
||||
if (!model) return null;
|
||||
return (
|
||||
<VStack w="full" align="start">
|
||||
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
|
||||
{label}
|
||||
</Text>
|
||||
|
||||
<VStack
|
||||
w="full"
|
||||
spacing={6}
|
||||
borderWidth={1}
|
||||
borderColor="gray.300"
|
||||
p={4}
|
||||
borderRadius={8}
|
||||
fontFamily="inconsolata"
|
||||
>
|
||||
<HStack w="full" align="flex-start">
|
||||
<VStack flex={1} fontSize="lg" alignItems="flex-start">
|
||||
<Text as="span" fontWeight="bold" color="gray.900">
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text as="span" color="gray.600" fontSize="sm">
|
||||
Provider: {model.provider}
|
||||
</Text>
|
||||
</VStack>
|
||||
<Link
|
||||
href={model.learnMoreUrl}
|
||||
isExternal
|
||||
color="blue.500"
|
||||
fontWeight="bold"
|
||||
fontSize="sm"
|
||||
ml={2}
|
||||
>
|
||||
Learn More
|
||||
</Link>
|
||||
</HStack>
|
||||
<SimpleGrid
|
||||
w="full"
|
||||
justifyContent="space-between"
|
||||
alignItems="flex-start"
|
||||
fontSize="sm"
|
||||
columns={{ base: 2, md: 4 }}
|
||||
>
|
||||
<SelectedModelLabeledInfo label="Context Window" info={model.contextWindow} />
|
||||
{model.promptTokenPrice && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Input"
|
||||
info={
|
||||
<Text>
|
||||
${(model.promptTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{model.completionTokenPrice && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Output"
|
||||
info={
|
||||
<Text>
|
||||
${(model.completionTokenPrice * 1000).toFixed(3)}
|
||||
<Text color="gray.500"> / 1K tokens</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{model.pricePerSecond && (
|
||||
<SelectedModelLabeledInfo
|
||||
label="Price"
|
||||
info={
|
||||
<Text>
|
||||
${model.pricePerSecond.toFixed(3)}
|
||||
<Text color="gray.500"> / second</Text>
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
<SelectedModelLabeledInfo label="Speed" info={<Text>{model.speed}</Text>} />
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const SelectedModelLabeledInfo = ({
|
||||
label,
|
||||
info,
|
||||
...props
|
||||
}: {
|
||||
label: string;
|
||||
info: string | number | React.ReactElement;
|
||||
} & StackProps) => (
|
||||
<GridItem>
|
||||
<VStack alignItems="flex-start" {...props}>
|
||||
<Text fontWeight="bold">{label}</Text>
|
||||
<Text>{info}</Text>
|
||||
</VStack>
|
||||
</GridItem>
|
||||
);
|
||||
86
app/src/components/CustomInstructionsInput.tsx
Normal file
86
app/src/components/CustomInstructionsInput.tsx
Normal file
@@ -0,0 +1,86 @@
|
||||
import {
|
||||
Button,
|
||||
Spinner,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
Icon,
|
||||
HStack,
|
||||
type InputGroupProps,
|
||||
} from "@chakra-ui/react";
|
||||
import { IoMdSend } from "react-icons/io";
|
||||
import AutoResizeTextArea from "./AutoResizeTextArea";
|
||||
|
||||
export const CustomInstructionsInput = ({
|
||||
instructions,
|
||||
setInstructions,
|
||||
loading,
|
||||
onSubmit,
|
||||
placeholder = "Send custom instructions",
|
||||
...props
|
||||
}: {
|
||||
instructions: string;
|
||||
setInstructions: (instructions: string) => void;
|
||||
loading: boolean;
|
||||
onSubmit: () => void;
|
||||
placeholder?: string;
|
||||
} & InputGroupProps) => {
|
||||
return (
|
||||
<InputGroup
|
||||
size="md"
|
||||
w="full"
|
||||
maxW="600"
|
||||
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
|
||||
borderRadius={8}
|
||||
alignItems="center"
|
||||
colorScheme="orange"
|
||||
{...props}
|
||||
>
|
||||
<AutoResizeTextArea
|
||||
value={instructions}
|
||||
onChange={(e) => setInstructions(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
e.currentTarget.blur();
|
||||
onSubmit();
|
||||
}
|
||||
}}
|
||||
placeholder={placeholder}
|
||||
py={4}
|
||||
pl={4}
|
||||
pr={12}
|
||||
colorScheme="orange"
|
||||
borderColor="gray.300"
|
||||
borderWidth={1}
|
||||
_hover={{
|
||||
borderColor: "gray.300",
|
||||
}}
|
||||
_focus={{
|
||||
borderColor: "gray.300",
|
||||
}}
|
||||
isDisabled={loading}
|
||||
/>
|
||||
<HStack></HStack>
|
||||
<InputRightElement width="8" height="full">
|
||||
<Button
|
||||
h="8"
|
||||
w="8"
|
||||
minW="unset"
|
||||
size="sm"
|
||||
onClick={() => onSubmit()}
|
||||
variant={instructions ? "solid" : "ghost"}
|
||||
mr={4}
|
||||
borderRadius="8"
|
||||
bgColor={instructions ? "orange.400" : "transparent"}
|
||||
colorScheme="orange"
|
||||
>
|
||||
{loading ? (
|
||||
<Spinner boxSize={4} />
|
||||
) : (
|
||||
<Icon as={IoMdSend} color={instructions ? "white" : "gray.500"} boxSize={5} />
|
||||
)}
|
||||
</Button>
|
||||
</InputRightElement>
|
||||
</InputGroup>
|
||||
);
|
||||
};
|
||||
69
app/src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
69
app/src/components/ExperimentSettingsDrawer/DeleteButton.tsx
Normal file
@@ -0,0 +1,69 @@
|
||||
import {
|
||||
Button,
|
||||
Icon,
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogContent,
|
||||
AlertDialogOverlay,
|
||||
useDisclosure,
|
||||
Text,
|
||||
} from "@chakra-ui/react";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useRef } from "react";
|
||||
import { BsTrash } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const DeleteButton = () => {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.experiments.delete.useMutation();
|
||||
const utils = api.useContext();
|
||||
const router = useRouter();
|
||||
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
await mutation.mutateAsync({ id: experiment.data.id });
|
||||
await utils.experiments.list.invalidate();
|
||||
await router.push({ pathname: "/experiments" });
|
||||
onClose();
|
||||
}, [mutation, experiment.data?.id, router]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button size="sm" variant="ghost" colorScheme="red" fontWeight="normal" onClick={onOpen}>
|
||||
<Icon as={BsTrash} boxSize={4} />
|
||||
<Text ml={2}>Delete Experiment</Text>
|
||||
</Button>
|
||||
|
||||
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete Experiment
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||
as well. Are you sure?
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,40 @@
|
||||
import {
|
||||
Drawer,
|
||||
DrawerBody,
|
||||
DrawerCloseButton,
|
||||
DrawerContent,
|
||||
DrawerHeader,
|
||||
DrawerOverlay,
|
||||
Heading,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import EditScenarioVars from "../OutputsTable/EditScenarioVars";
|
||||
import EditEvaluations from "../OutputsTable/EditEvaluations";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { DeleteButton } from "./DeleteButton";
|
||||
|
||||
export default function ExperimentSettingsDrawer() {
|
||||
const isOpen = useAppStore((state) => state.drawerOpen);
|
||||
const closeDrawer = useAppStore((state) => state.closeDrawer);
|
||||
|
||||
return (
|
||||
<Drawer isOpen={isOpen} placement="right" onClose={closeDrawer} size="md">
|
||||
<DrawerOverlay />
|
||||
<DrawerContent>
|
||||
<DrawerCloseButton />
|
||||
<DrawerHeader>
|
||||
<Heading size="md">Experiment Settings</Heading>
|
||||
</DrawerHeader>
|
||||
<DrawerBody h="full" pb={4}>
|
||||
<VStack h="full" justifyContent="space-between">
|
||||
<VStack spacing={6}>
|
||||
<EditScenarioVars />
|
||||
<EditEvaluations />
|
||||
</VStack>
|
||||
<DeleteButton />
|
||||
</VStack>
|
||||
</DrawerBody>
|
||||
</DrawerContent>
|
||||
</Drawer>
|
||||
);
|
||||
}
|
||||
16
app/src/components/Favicon.tsx
Normal file
16
app/src/components/Favicon.tsx
Normal file
@@ -0,0 +1,16 @@
|
||||
import Head from "next/head";
|
||||
|
||||
export default function Favicon() {
|
||||
return (
|
||||
<Head>
|
||||
<link rel="apple-touch-icon" sizes="180x180" href="/favicons/apple-touch-icon.png" />
|
||||
<link rel="icon" type="image/png" sizes="32x32" href="/favicons/favicon-32x32.png" />
|
||||
<link rel="icon" type="image/png" sizes="16x16" href="/favicons/favicon-16x16.png" />
|
||||
<link rel="manifest" href="/favicons/site.webmanifest" />
|
||||
<link rel="shortcut icon" href="/favicons/favicon.ico" />
|
||||
<meta name="msapplication-TileColor" content="#da532c" />
|
||||
<meta name="msapplication-config" content="/favicons/browserconfig.xml" />
|
||||
<meta name="theme-color" content="#ffffff" />
|
||||
</Head>
|
||||
);
|
||||
}
|
||||
57
app/src/components/OutputsTable/AddVariantButton.tsx
Normal file
57
app/src/components/OutputsTable/AddVariantButton.tsx
Normal file
@@ -0,0 +1,57 @@
|
||||
import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
||||
import { BsPlus } from "react-icons/bs";
|
||||
import { Text } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
useExperiment,
|
||||
useExperimentAccess,
|
||||
useHandledAsyncCallback,
|
||||
useVisibleScenarioIds,
|
||||
} from "~/utils/hooks";
|
||||
import { cellPadding } from "../constants";
|
||||
import { ActionButton } from "./ScenariosHeader";
|
||||
|
||||
export default function AddVariantButton() {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.promptVariants.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data) return;
|
||||
await mutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [mutation]);
|
||||
|
||||
const { canModify } = useExperimentAccess();
|
||||
if (!canModify) return <Box w={cellPadding.x} />;
|
||||
|
||||
return (
|
||||
<Flex w="100%" justifyContent="flex-end">
|
||||
<ActionButton
|
||||
onClick={onClick}
|
||||
py={5}
|
||||
leftIcon={<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />}
|
||||
>
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</ActionButton>
|
||||
{/* <Button
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
fontWeight="normal"
|
||||
bgColor="transparent"
|
||||
_hover={{ bgColor: "gray.100" }}
|
||||
px={cellPadding.x}
|
||||
onClick={onClick}
|
||||
height="unset"
|
||||
minH={headerMinHeight}
|
||||
>
|
||||
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
|
||||
</Button> */}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
240
app/src/components/OutputsTable/EditEvaluations.tsx
Normal file
240
app/src/components/OutputsTable/EditEvaluations.tsx
Normal file
@@ -0,0 +1,240 @@
|
||||
import {
|
||||
Text,
|
||||
Button,
|
||||
HStack,
|
||||
Heading,
|
||||
Icon,
|
||||
Input,
|
||||
Stack,
|
||||
VStack,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Select,
|
||||
FormHelperText,
|
||||
Code,
|
||||
} from "@chakra-ui/react";
|
||||
import { type Evaluation, EvalType } from "@prisma/client";
|
||||
import { useCallback, useState } from "react";
|
||||
import { BsPencil, BsX } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
|
||||
|
||||
export function EvaluationEditor(props: {
|
||||
evaluation: Evaluation | null;
|
||||
defaultName?: string;
|
||||
onSave: (id: string | undefined, vals: EvalValues) => void;
|
||||
onCancel: () => void;
|
||||
}) {
|
||||
const [values, setValues] = useState<EvalValues>({
|
||||
label: props.evaluation?.label ?? props.defaultName ?? "",
|
||||
value: props.evaluation?.value ?? "",
|
||||
evalType: props.evaluation?.evalType ?? "CONTAINS",
|
||||
});
|
||||
|
||||
return (
|
||||
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
||||
<HStack w="100%">
|
||||
<FormControl flex={1}>
|
||||
<FormLabel fontSize="sm">Eval Name</FormLabel>
|
||||
<Input
|
||||
size="sm"
|
||||
value={values.label}
|
||||
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flex={1}>
|
||||
<FormLabel fontSize="sm">Eval Type</FormLabel>
|
||||
<Select
|
||||
size="sm"
|
||||
value={values.evalType}
|
||||
onChange={(e) =>
|
||||
setValues((values) => ({
|
||||
...values,
|
||||
evalType: e.target.value as EvalType,
|
||||
}))
|
||||
}
|
||||
>
|
||||
{Object.values(EvalType).map((type) => (
|
||||
<option key={type} value={type}>
|
||||
{type}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
</HStack>
|
||||
{["CONTAINS", "DOES_NOT_CONTAIN"].includes(values.evalType) && (
|
||||
<FormControl>
|
||||
<FormLabel fontSize="sm">Match String</FormLabel>
|
||||
<Input
|
||||
size="sm"
|
||||
value={values.value}
|
||||
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||
/>
|
||||
<FormHelperText>
|
||||
This string will be interpreted as a regex and checked against each model output. You
|
||||
can include scenario variables using <Code>{"{{curly_braces}}"}</Code>
|
||||
</FormHelperText>
|
||||
</FormControl>
|
||||
)}
|
||||
{values.evalType === "GPT4_EVAL" && (
|
||||
<FormControl pt={2}>
|
||||
<FormLabel fontSize="sm">GPT4 Instructions</FormLabel>
|
||||
<AutoResizeTextArea
|
||||
size="sm"
|
||||
value={values.value}
|
||||
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||
minRows={3}
|
||||
/>
|
||||
<FormHelperText>
|
||||
Give instructions to GPT-4 for how to evaluate your prompt. It will have access to the
|
||||
full scenario as well as the output it is evaluating. It will <strong>not</strong> have
|
||||
access to the specific prompt variant, so be sure to be clear about the task you want it
|
||||
to perform.
|
||||
</FormHelperText>
|
||||
</FormControl>
|
||||
)}
|
||||
<HStack alignSelf="flex-end">
|
||||
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => props.onSave(props.evaluation?.id, values)}
|
||||
colorScheme="blue"
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</HStack>
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
export default function EditEvaluations() {
|
||||
const experiment = useExperiment();
|
||||
const evaluations =
|
||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||
|
||||
const [editingId, setEditingId] = useState<string | null>(null);
|
||||
|
||||
const utils = api.useContext();
|
||||
const createMutation = api.evaluations.create.useMutation();
|
||||
const updateMutation = api.evaluations.update.useMutation();
|
||||
|
||||
const deleteMutation = api.evaluations.delete.useMutation();
|
||||
const [onDelete] = useHandledAsyncCallback(async (id: string) => {
|
||||
await deleteMutation.mutateAsync({ id });
|
||||
await utils.evaluations.list.invalidate();
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
}, []);
|
||||
|
||||
const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => {
|
||||
setEditingId(null);
|
||||
if (!experiment.data?.id) return;
|
||||
|
||||
if (id) {
|
||||
await updateMutation.mutateAsync({
|
||||
id,
|
||||
updates: vals,
|
||||
});
|
||||
} else {
|
||||
await createMutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
...vals,
|
||||
});
|
||||
}
|
||||
await utils.evaluations.list.invalidate();
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
await utils.scenarioVariantCells.get.invalidate();
|
||||
}, []);
|
||||
|
||||
const onCancel = useCallback(() => {
|
||||
setEditingId(null);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Stack>
|
||||
<Heading size="sm">Evaluations</Heading>
|
||||
<Stack spacing={4}>
|
||||
<Text fontSize="sm">
|
||||
Evaluations allow you to compare prompt performance in an automated way.
|
||||
</Text>
|
||||
<Stack spacing={2}>
|
||||
{evaluations.map((evaluation) =>
|
||||
editingId == evaluation.id ? (
|
||||
<EvaluationEditor
|
||||
evaluation={evaluation}
|
||||
onSave={onSave}
|
||||
onCancel={onCancel}
|
||||
key={evaluation.id}
|
||||
/>
|
||||
) : (
|
||||
<HStack
|
||||
fontSize="sm"
|
||||
borderTopWidth={1}
|
||||
borderColor="gray.200"
|
||||
py={4}
|
||||
align="center"
|
||||
key={evaluation.id}
|
||||
>
|
||||
<Text fontWeight="bold">{evaluation.label}</Text>
|
||||
<Text flex={1}>
|
||||
{evaluation.evalType}: "{evaluation.value}"
|
||||
</Text>
|
||||
<Button
|
||||
variant="unstyled"
|
||||
color="gray.400"
|
||||
height="unset"
|
||||
width="unset"
|
||||
minW="unset"
|
||||
onClick={() => setEditingId(evaluation.id)}
|
||||
_hover={{
|
||||
color: "gray.800",
|
||||
cursor: "pointer",
|
||||
}}
|
||||
>
|
||||
<Icon as={BsPencil} boxSize={4} />
|
||||
</Button>
|
||||
<Button
|
||||
variant="unstyled"
|
||||
color="gray.400"
|
||||
height="unset"
|
||||
width="unset"
|
||||
minW="unset"
|
||||
onClick={() => onDelete(evaluation.id)}
|
||||
_hover={{
|
||||
color: "gray.800",
|
||||
cursor: "pointer",
|
||||
}}
|
||||
>
|
||||
<Icon as={BsX} boxSize={6} />
|
||||
</Button>
|
||||
</HStack>
|
||||
),
|
||||
)}
|
||||
{editingId == null && (
|
||||
<Button
|
||||
onClick={() => setEditingId("new")}
|
||||
alignSelf="flex-start"
|
||||
size="sm"
|
||||
mt={4}
|
||||
colorScheme="blue"
|
||||
>
|
||||
Add Evaluation
|
||||
</Button>
|
||||
)}
|
||||
{editingId == "new" && (
|
||||
<EvaluationEditor
|
||||
evaluation={null}
|
||||
defaultName={`Eval${evaluations.length + 1}`}
|
||||
onSave={onSave}
|
||||
onCancel={onCancel}
|
||||
/>
|
||||
)}
|
||||
</Stack>
|
||||
</Stack>
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
103
app/src/components/OutputsTable/EditScenarioVars.tsx
Normal file
103
app/src/components/OutputsTable/EditScenarioVars.tsx
Normal file
@@ -0,0 +1,103 @@
|
||||
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { BsCheck, BsX } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export default function EditScenarioVars() {
|
||||
const experiment = useExperiment();
|
||||
const vars =
|
||||
api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||
|
||||
const [newVariable, setNewVariable] = useState<string>("");
|
||||
const newVarIsValid = newVariable.length > 0 && !vars.map((v) => v.label).includes(newVariable);
|
||||
|
||||
const utils = api.useContext();
|
||||
const addVarMutation = api.templateVars.create.useMutation();
|
||||
const [onAddVar] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
if (!newVarIsValid) return;
|
||||
await addVarMutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
label: newVariable,
|
||||
});
|
||||
await utils.templateVars.list.invalidate();
|
||||
setNewVariable("");
|
||||
}, [addVarMutation, experiment.data?.id, newVarIsValid, newVariable]);
|
||||
|
||||
const deleteMutation = api.templateVars.delete.useMutation();
|
||||
const [onDeleteVar] = useHandledAsyncCallback(async (id: string) => {
|
||||
await deleteMutation.mutateAsync({ id });
|
||||
await utils.templateVars.list.invalidate();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Stack>
|
||||
<Heading size="sm">Scenario Variables</Heading>
|
||||
<Stack spacing={2}>
|
||||
<Text fontSize="sm">
|
||||
Scenario variables can be used in your prompt variants as well as evaluations.
|
||||
</Text>
|
||||
<HStack spacing={0}>
|
||||
<Input
|
||||
placeholder="Add Scenario Variable"
|
||||
size="sm"
|
||||
borderTopRadius={0}
|
||||
borderRightRadius={0}
|
||||
value={newVariable}
|
||||
onChange={(e) => setNewVariable(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
onAddVar();
|
||||
}
|
||||
// If the user types a space, replace it with an underscore
|
||||
if (e.key === " ") {
|
||||
e.preventDefault();
|
||||
setNewVariable((v) => v + "_");
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
size="xs"
|
||||
height="100%"
|
||||
borderLeftRadius={0}
|
||||
isDisabled={!newVarIsValid}
|
||||
onClick={onAddVar}
|
||||
>
|
||||
<Icon as={BsCheck} boxSize={8} />
|
||||
</Button>
|
||||
</HStack>
|
||||
|
||||
<HStack spacing={2} py={4} wrap="wrap">
|
||||
{vars.map((variable) => (
|
||||
<HStack
|
||||
key={variable.id}
|
||||
spacing={0}
|
||||
bgColor="blue.100"
|
||||
color="blue.600"
|
||||
pl={2}
|
||||
pr={0}
|
||||
fontWeight="bold"
|
||||
>
|
||||
<Text fontSize="sm" flex={1}>
|
||||
{variable.label}
|
||||
</Text>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
colorScheme="blue"
|
||||
p="unset"
|
||||
minW="unset"
|
||||
px="unset"
|
||||
onClick={() => onDeleteVar(variable.id)}
|
||||
>
|
||||
<Icon as={BsX} boxSize={6} color="blue.800" />
|
||||
</Button>
|
||||
</HStack>
|
||||
))}
|
||||
</HStack>
|
||||
</Stack>
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
46
app/src/components/OutputsTable/FloatingLabelInput.tsx
Normal file
46
app/src/components/OutputsTable/FloatingLabelInput.tsx
Normal file
@@ -0,0 +1,46 @@
|
||||
import { FormLabel, FormControl, type TextareaProps } from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
export const FloatingLabelInput = ({
|
||||
label,
|
||||
value,
|
||||
...props
|
||||
}: { label: string; value: string } & TextareaProps) => {
|
||||
const [isFocused, setIsFocused] = useState(false);
|
||||
|
||||
return (
|
||||
<FormControl position="relative">
|
||||
<FormLabel
|
||||
position="absolute"
|
||||
left="10px"
|
||||
top={isFocused || !!value ? 0 : 3}
|
||||
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
||||
fontSize={isFocused || !!value ? "12px" : "16px"}
|
||||
transition="all 0.15s"
|
||||
zIndex="5"
|
||||
bg="white"
|
||||
px={1}
|
||||
lineHeight="1"
|
||||
pointerEvents="none"
|
||||
color={isFocused ? "blue.500" : "gray.500"}
|
||||
>
|
||||
{label}
|
||||
</FormLabel>
|
||||
<AutoResizeTextArea
|
||||
px={3}
|
||||
pt={3}
|
||||
pb={2}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
onBlur={() => setIsFocused(false)}
|
||||
borderRadius="md"
|
||||
borderColor={isFocused ? "blue.500" : "gray.400"}
|
||||
autoComplete="off"
|
||||
value={value}
|
||||
overflowY="auto"
|
||||
overflowX="hidden"
|
||||
{...props}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
197
app/src/components/OutputsTable/OutputCell/OutputCell.tsx
Normal file
197
app/src/components/OutputsTable/OutputCell/OutputCell.tsx
Normal file
@@ -0,0 +1,197 @@
|
||||
import { api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "../types";
|
||||
import { type StackProps, Text, VStack } from "@chakra-ui/react";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
import stringify from "json-stringify-pretty-compact";
|
||||
import { type ReactElement, useState, useEffect, Fragment, useCallback } from "react";
|
||||
import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { RetryCountdown } from "./RetryCountdown";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { ResponseLog } from "./ResponseLog";
|
||||
import { CellOptions } from "./TopActions";
|
||||
|
||||
const WAITING_MESSAGE_INTERVAL = 20000;
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
variant,
|
||||
}: {
|
||||
scenario: Scenario;
|
||||
variant: PromptVariant;
|
||||
}): ReactElement | null {
|
||||
const utils = api.useContext();
|
||||
const experiment = useExperiment();
|
||||
const vars = api.templateVars.list.useQuery({
|
||||
experimentId: experiment.data?.id ?? "",
|
||||
}).data;
|
||||
|
||||
const scenarioVariables = scenario.variableValues as Record<string, string>;
|
||||
const templateHasVariables =
|
||||
vars?.length === 0 || vars?.some((v) => scenarioVariables[v.label] !== undefined);
|
||||
|
||||
let disabledReason: string | null = null;
|
||||
|
||||
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
|
||||
|
||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||
{ scenarioId: scenario.id, variantId: variant.id },
|
||||
{ refetchInterval },
|
||||
);
|
||||
|
||||
const provider =
|
||||
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
|
||||
|
||||
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
|
||||
|
||||
const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation();
|
||||
const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => {
|
||||
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||
await utils.scenarioVariantCells.get.invalidate({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
});
|
||||
await utils.promptVariants.stats.invalidate({
|
||||
variantId: variant.id,
|
||||
});
|
||||
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||
|
||||
const fetchingOutput = queryLoading || hardRefetching;
|
||||
|
||||
const awaitingOutput =
|
||||
!cell ||
|
||||
!cell.evalsComplete ||
|
||||
cell.retrievalStatus === "PENDING" ||
|
||||
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||
hardRefetching;
|
||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||
|
||||
// TODO: disconnect from socket if we're not streaming anymore
|
||||
const streamedMessage = useSocket<OutputSchema>(cell?.id);
|
||||
|
||||
const mostRecentResponse = cell?.modelResponses[cell.modelResponses.length - 1];
|
||||
|
||||
const CellWrapper = useCallback(
|
||||
({ children, ...props }: StackProps) => (
|
||||
<VStack w="full" alignItems="flex-start" {...props} px={2} py={2} h="100%">
|
||||
{cell && (
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} cell={cell} />
|
||||
)}
|
||||
<VStack w="full" alignItems="flex-start" maxH={500} overflowY="auto" flex={1}>
|
||||
{children}
|
||||
</VStack>
|
||||
{mostRecentResponse && (
|
||||
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
|
||||
)}
|
||||
</VStack>
|
||||
),
|
||||
[hardRefetching, hardRefetch, mostRecentResponse, scenario, cell],
|
||||
);
|
||||
|
||||
if (!vars) return null;
|
||||
|
||||
if (!cell && !fetchingOutput)
|
||||
return (
|
||||
<CellWrapper>
|
||||
<Text color="gray.500">Error retrieving output</Text>
|
||||
</CellWrapper>
|
||||
);
|
||||
|
||||
if (cell && cell.errorMessage) {
|
||||
return (
|
||||
<CellWrapper>
|
||||
<Text color="red.500">{cell.errorMessage}</Text>
|
||||
</CellWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||
|
||||
const showLogs = !streamedMessage && !mostRecentResponse?.output;
|
||||
|
||||
if (showLogs)
|
||||
return (
|
||||
<CellWrapper alignItems="flex-start" fontFamily="inconsolata, monospace" spacing={0}>
|
||||
{cell?.jobQueuedAt && <ResponseLog time={cell.jobQueuedAt} title="Job queued" />}
|
||||
{cell?.jobStartedAt && <ResponseLog time={cell.jobStartedAt} title="Job started" />}
|
||||
{cell?.modelResponses?.map((response) => {
|
||||
let numWaitingMessages = 0;
|
||||
const relativeWaitingTime = response.receivedAt
|
||||
? response.receivedAt.getTime()
|
||||
: Date.now();
|
||||
if (response.requestedAt) {
|
||||
numWaitingMessages = Math.floor(
|
||||
(relativeWaitingTime - response.requestedAt.getTime()) / WAITING_MESSAGE_INTERVAL,
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Fragment key={response.id}>
|
||||
{response.requestedAt && (
|
||||
<ResponseLog time={response.requestedAt} title="Request sent to API" />
|
||||
)}
|
||||
{response.requestedAt &&
|
||||
Array.from({ length: numWaitingMessages }, (_, i) => (
|
||||
<ResponseLog
|
||||
key={`waiting-${i}`}
|
||||
time={
|
||||
new Date(
|
||||
(response.requestedAt?.getTime?.() ?? 0) +
|
||||
(i + 1) * WAITING_MESSAGE_INTERVAL,
|
||||
)
|
||||
}
|
||||
title="Waiting for response..."
|
||||
/>
|
||||
))}
|
||||
{response.receivedAt && (
|
||||
<ResponseLog
|
||||
time={response.receivedAt}
|
||||
title="Response received from API"
|
||||
message={`statusCode: ${response.statusCode ?? ""}\n ${
|
||||
response.errorMessage ?? ""
|
||||
}`}
|
||||
/>
|
||||
)}
|
||||
</Fragment>
|
||||
);
|
||||
}) ?? null}
|
||||
{mostRecentResponse?.retryTime && (
|
||||
<RetryCountdown retryTime={mostRecentResponse.retryTime} />
|
||||
)}
|
||||
</CellWrapper>
|
||||
);
|
||||
|
||||
const normalizedOutput = mostRecentResponse?.output
|
||||
? provider.normalizeOutput(mostRecentResponse?.output)
|
||||
: streamedMessage
|
||||
? provider.normalizeOutput(streamedMessage)
|
||||
: null;
|
||||
|
||||
if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
|
||||
return (
|
||||
<CellWrapper>
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||
language="json"
|
||||
style={docco}
|
||||
lineProps={{
|
||||
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
||||
}}
|
||||
wrapLines
|
||||
>
|
||||
{stringify(normalizedOutput.value, { maxLength: 40 })}
|
||||
</SyntaxHighlighter>
|
||||
</CellWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
const contentToDisplay = (normalizedOutput?.type === "text" && normalizedOutput.value) || "";
|
||||
|
||||
return (
|
||||
<CellWrapper>
|
||||
<Text>{contentToDisplay}</Text>
|
||||
</CellWrapper>
|
||||
);
|
||||
}
|
||||
76
app/src/components/OutputsTable/OutputCell/OutputStats.tsx
Normal file
76
app/src/components/OutputsTable/OutputCell/OutputStats.tsx
Normal file
@@ -0,0 +1,76 @@
|
||||
import { type Scenario } from "../types";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
|
||||
const SHOW_TIME = true;
|
||||
|
||||
export const OutputStats = ({
|
||||
modelResponse,
|
||||
}: {
|
||||
modelResponse: NonNullable<
|
||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelResponses"][0]
|
||||
>;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
const timeToComplete =
|
||||
modelResponse.receivedAt && modelResponse.requestedAt
|
||||
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
|
||||
: 0;
|
||||
|
||||
const promptTokens = modelResponse.promptTokens;
|
||||
const completionTokens = modelResponse.completionTokens;
|
||||
|
||||
return (
|
||||
<HStack
|
||||
w="full"
|
||||
align="center"
|
||||
color="gray.500"
|
||||
fontSize="2xs"
|
||||
mt={{ base: 0, md: 1 }}
|
||||
alignItems="flex-end"
|
||||
>
|
||||
<HStack flex={1} flexWrap="wrap">
|
||||
{modelResponse.outputEvaluations.map((evaluation) => {
|
||||
const passed = evaluation.result > 0.5;
|
||||
return (
|
||||
<Tooltip
|
||||
isDisabled={!evaluation.details}
|
||||
label={evaluation.details}
|
||||
key={evaluation.id}
|
||||
shouldWrapChildren
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Text>{evaluation.evaluation.label}</Text>
|
||||
<Icon
|
||||
as={passed ? BsCheck : BsX}
|
||||
color={passed ? "green.500" : "red.500"}
|
||||
boxSize={6}
|
||||
/>
|
||||
</HStack>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{modelResponse.cost && (
|
||||
<CostTooltip
|
||||
promptTokens={promptTokens}
|
||||
completionTokens={completionTokens}
|
||||
cost={modelResponse.cost}
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{modelResponse.cost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
</CostTooltip>
|
||||
)}
|
||||
{SHOW_TIME && (
|
||||
<HStack spacing={0.5}>
|
||||
<Icon as={BsClock} />
|
||||
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
||||
</HStack>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
36
app/src/components/OutputsTable/OutputCell/PromptModal.tsx
Normal file
36
app/src/components/OutputsTable/OutputCell/PromptModal.tsx
Normal file
@@ -0,0 +1,36 @@
|
||||
import {
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
type UseDisclosureReturn,
|
||||
} from "@chakra-ui/react";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { JSONTree } from "react-json-tree";
|
||||
|
||||
export default function ExpandedModal(props: {
|
||||
cell: NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>;
|
||||
disclosure: UseDisclosureReturn;
|
||||
}) {
|
||||
return (
|
||||
<Modal isOpen={props.disclosure.isOpen} onClose={props.disclosure.onClose} size="2xl">
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Prompt</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<JSONTree
|
||||
data={props.cell.prompt}
|
||||
invertTheme={true}
|
||||
theme="chalk"
|
||||
shouldExpandNodeInitially={() => true}
|
||||
getItemString={() => ""}
|
||||
hideRoot
|
||||
/>
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
22
app/src/components/OutputsTable/OutputCell/ResponseLog.tsx
Normal file
22
app/src/components/OutputsTable/OutputCell/ResponseLog.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import { HStack, VStack, Text } from "@chakra-ui/react";
|
||||
import dayjs from "dayjs";
|
||||
|
||||
export const ResponseLog = ({
|
||||
time,
|
||||
title,
|
||||
message,
|
||||
}: {
|
||||
time: Date;
|
||||
title: string;
|
||||
message?: string;
|
||||
}) => {
|
||||
return (
|
||||
<VStack spacing={0} alignItems="flex-start">
|
||||
<HStack>
|
||||
<Text>{dayjs(time).format("HH:mm:ss")}</Text>
|
||||
<Text>{title}</Text>
|
||||
</HStack>
|
||||
{message && <Text pl={4}>{message}</Text>}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,39 @@
|
||||
import { Text } from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import pluralize from "pluralize";
|
||||
|
||||
export const RetryCountdown = ({ retryTime }: { retryTime: Date }) => {
|
||||
const [msToWait, setMsToWait] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
const initialWaitTime = retryTime.getTime() - Date.now();
|
||||
const msModuloOneSecond = initialWaitTime % 1000;
|
||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||
setMsToWait(remainingTime);
|
||||
|
||||
let interval: NodeJS.Timeout;
|
||||
const timeout = setTimeout(() => {
|
||||
interval = setInterval(() => {
|
||||
remainingTime -= 1000;
|
||||
setMsToWait(remainingTime);
|
||||
|
||||
if (remainingTime <= 0) {
|
||||
clearInterval(interval);
|
||||
}
|
||||
}, 1000);
|
||||
}, msModuloOneSecond);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
}, [retryTime]);
|
||||
|
||||
if (msToWait <= 0) return null;
|
||||
|
||||
return (
|
||||
<Text color="red.600" fontSize="sm">
|
||||
Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
53
app/src/components/OutputsTable/OutputCell/TopActions.tsx
Normal file
53
app/src/components/OutputsTable/OutputCell/TopActions.tsx
Normal file
@@ -0,0 +1,53 @@
|
||||
import { HStack, Icon, IconButton, Spinner, Tooltip, useDisclosure } from "@chakra-ui/react";
|
||||
import { BsArrowClockwise, BsInfoCircle } from "react-icons/bs";
|
||||
import { useExperimentAccess } from "~/utils/hooks";
|
||||
import ExpandedModal from "./PromptModal";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
|
||||
export const CellOptions = ({
|
||||
cell,
|
||||
refetchingOutput,
|
||||
refetchOutput,
|
||||
}: {
|
||||
cell: RouterOutputs["scenarioVariantCells"]["get"];
|
||||
refetchingOutput: boolean;
|
||||
refetchOutput: () => void;
|
||||
}) => {
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const modalDisclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
<HStack justifyContent="flex-end" w="full">
|
||||
{cell && (
|
||||
<>
|
||||
<Tooltip label="See Prompt">
|
||||
<IconButton
|
||||
aria-label="See Prompt"
|
||||
icon={<Icon as={BsInfoCircle} boxSize={4} />}
|
||||
onClick={modalDisclosure.onOpen}
|
||||
size="xs"
|
||||
colorScheme="gray"
|
||||
color="gray.500"
|
||||
variant="ghost"
|
||||
/>
|
||||
</Tooltip>
|
||||
<ExpandedModal cell={cell} disclosure={modalDisclosure} />
|
||||
</>
|
||||
)}
|
||||
{canModify && (
|
||||
<Tooltip label="Refetch output">
|
||||
<IconButton
|
||||
size="xs"
|
||||
color="gray.500"
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
icon={<Icon as={refetchingOutput ? Spinner : BsArrowClockwise} boxSize={4} />}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
207
app/src/components/OutputsTable/ScenarioEditor.tsx
Normal file
207
app/src/components/OutputsTable/ScenarioEditor.tsx
Normal file
@@ -0,0 +1,207 @@
|
||||
import { isEqual } from "lodash-es";
|
||||
import { useEffect, useState, type DragEvent } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { type Scenario } from "./types";
|
||||
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
HStack,
|
||||
Icon,
|
||||
IconButton,
|
||||
Spinner,
|
||||
Text,
|
||||
Tooltip,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsArrowsAngleExpand, BsX } from "react-icons/bs";
|
||||
import { cellPadding } from "../constants";
|
||||
import { FloatingLabelInput } from "./FloatingLabelInput";
|
||||
import { ScenarioEditorModal } from "./ScenarioEditorModal";
|
||||
|
||||
export default function ScenarioEditor({
|
||||
scenario,
|
||||
...props
|
||||
}: {
|
||||
scenario: Scenario;
|
||||
hovered: boolean;
|
||||
canHide: boolean;
|
||||
}) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
|
||||
const savedValues = scenario.variableValues as Record<string, string>;
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
const [variableInputHovered, setVariableInputHovered] = useState(false);
|
||||
|
||||
const [values, setValues] = useState<Record<string, string>>(savedValues);
|
||||
|
||||
useEffect(() => {
|
||||
if (savedValues) setValues(savedValues);
|
||||
}, [savedValues]);
|
||||
|
||||
const experiment = useExperiment();
|
||||
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" });
|
||||
|
||||
const variableLabels = vars.data?.map((v) => v.label) ?? [];
|
||||
|
||||
const hasChanged = !isEqual(savedValues, values);
|
||||
|
||||
const mutation = api.scenarios.replaceWithValues.useMutation();
|
||||
|
||||
const [onSave] = useHandledAsyncCallback(async () => {
|
||||
await mutation.mutateAsync({
|
||||
id: scenario.id,
|
||||
values,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation, values]);
|
||||
|
||||
const hideMutation = api.scenarios.hide.useMutation();
|
||||
const [onHide, hidingInProgress] = useHandledAsyncCallback(async () => {
|
||||
await hideMutation.mutateAsync({
|
||||
id: scenario.id,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
}, [hideMutation, scenario.id]);
|
||||
|
||||
const reorderMutation = api.scenarios.reorder.useMutation();
|
||||
const [onReorder] = useHandledAsyncCallback(
|
||||
async (e: DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(false);
|
||||
const draggedId = e.dataTransfer.getData("text/plain");
|
||||
const droppedId = scenario.id;
|
||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||
await reorderMutation.mutateAsync({
|
||||
draggedId,
|
||||
droppedId,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
},
|
||||
[reorderMutation, scenario.id],
|
||||
);
|
||||
|
||||
const [scenarioEditorModalOpen, setScenarioEditorModalOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<HStack
|
||||
alignItems="flex-start"
|
||||
px={cellPadding.x}
|
||||
py={cellPadding.y}
|
||||
spacing={0}
|
||||
height="100%"
|
||||
draggable={!variableInputHovered}
|
||||
onDragStart={(e) => {
|
||||
e.dataTransfer.setData("text/plain", scenario.id);
|
||||
e.currentTarget.style.opacity = "0.4";
|
||||
}}
|
||||
onDragEnd={(e) => {
|
||||
e.currentTarget.style.opacity = "1";
|
||||
}}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(true);
|
||||
}}
|
||||
onDragLeave={() => {
|
||||
setIsDragTarget(false);
|
||||
}}
|
||||
onDrop={onReorder}
|
||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||
>
|
||||
{variableLabels.length === 0 ? (
|
||||
<Box color="gray.500">
|
||||
{vars.data ? "No scenario variables configured" : "Loading..."}
|
||||
</Box>
|
||||
) : (
|
||||
<VStack spacing={4} flex={1} py={2}>
|
||||
<HStack justifyContent="space-between" w="100%" align="center" spacing={0}>
|
||||
<Text flex={1}>Scenario</Text>
|
||||
<Tooltip label="Expand" hasArrow>
|
||||
<IconButton
|
||||
aria-label="Expand"
|
||||
icon={<Icon as={BsArrowsAngleExpand} boxSize={3} />}
|
||||
onClick={() => setScenarioEditorModalOpen(true)}
|
||||
size="xs"
|
||||
colorScheme="gray"
|
||||
color="gray.500"
|
||||
variant="ghost"
|
||||
/>
|
||||
</Tooltip>
|
||||
{canModify && props.canHide && (
|
||||
<Tooltip label="Delete" hasArrow>
|
||||
<IconButton
|
||||
aria-label="Delete"
|
||||
icon={
|
||||
<Icon
|
||||
as={hidingInProgress ? Spinner : BsX}
|
||||
boxSize={hidingInProgress ? 4 : 6}
|
||||
/>
|
||||
}
|
||||
onClick={onHide}
|
||||
size="xs"
|
||||
display="flex"
|
||||
colorScheme="gray"
|
||||
color="gray.500"
|
||||
variant="ghost"
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
</HStack>
|
||||
{variableLabels.map((key) => {
|
||||
const value = values[key] ?? "";
|
||||
return (
|
||||
<FloatingLabelInput
|
||||
key={key}
|
||||
label={key}
|
||||
isDisabled={!canModify}
|
||||
style={{ width: "100%" }}
|
||||
maxHeight={32}
|
||||
value={value}
|
||||
onChange={(e) => {
|
||||
setValues((prev) => ({ ...prev, [key]: e.target.value }));
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
|
||||
e.preventDefault();
|
||||
e.currentTarget.blur();
|
||||
onSave();
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => setVariableInputHovered(true)}
|
||||
onMouseLeave={() => setVariableInputHovered(false)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{hasChanged && (
|
||||
<HStack justify="right">
|
||||
<Button
|
||||
size="sm"
|
||||
onMouseDown={() => {
|
||||
setValues(savedValues);
|
||||
}}
|
||||
colorScheme="gray"
|
||||
>
|
||||
Reset
|
||||
</Button>
|
||||
<Button size="sm" onMouseDown={onSave} colorScheme="blue">
|
||||
Save
|
||||
</Button>
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
)}
|
||||
</HStack>
|
||||
{scenarioEditorModalOpen && (
|
||||
<ScenarioEditorModal
|
||||
scenarioId={scenario.id}
|
||||
initialValues={savedValues}
|
||||
onClose={() => setScenarioEditorModalOpen(false)}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
123
app/src/components/OutputsTable/ScenarioEditorModal.tsx
Normal file
123
app/src/components/OutputsTable/ScenarioEditorModal.tsx
Normal file
@@ -0,0 +1,123 @@
|
||||
import {
|
||||
Button,
|
||||
HStack,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Spinner,
|
||||
Text,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { isEqual } from "lodash-es";
|
||||
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
useScenario,
|
||||
useHandledAsyncCallback,
|
||||
useExperiment,
|
||||
useExperimentAccess,
|
||||
} from "~/utils/hooks";
|
||||
import { FloatingLabelInput } from "./FloatingLabelInput";
|
||||
|
||||
export const ScenarioEditorModal = ({
|
||||
scenarioId,
|
||||
initialValues,
|
||||
onClose,
|
||||
}: {
|
||||
scenarioId: string;
|
||||
initialValues: Record<string, string>;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
const experiment = useExperiment();
|
||||
const { canModify } = useExperimentAccess();
|
||||
const scenario = useScenario(scenarioId);
|
||||
|
||||
const savedValues = scenario.data?.variableValues as Record<string, string>;
|
||||
|
||||
const [values, setValues] = useState<Record<string, string>>(initialValues);
|
||||
|
||||
useEffect(() => {
|
||||
if (savedValues) setValues(savedValues);
|
||||
}, [savedValues]);
|
||||
|
||||
const hasChanged = !isEqual(savedValues, values);
|
||||
|
||||
const mutation = api.scenarios.replaceWithValues.useMutation();
|
||||
|
||||
const [onSave, saving] = useHandledAsyncCallback(async () => {
|
||||
await mutation.mutateAsync({
|
||||
id: scenarioId,
|
||||
values,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
}, [mutation, values]);
|
||||
|
||||
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" });
|
||||
const variableLabels = vars.data?.map((v) => v.label) ?? [];
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={onClose}
|
||||
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader />
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
{values &&
|
||||
variableLabels.map((key) => {
|
||||
const value = values[key] ?? "";
|
||||
return (
|
||||
<FloatingLabelInput
|
||||
key={key}
|
||||
label={key}
|
||||
isDisabled={!canModify}
|
||||
_disabled={{ opacity: 1 }}
|
||||
style={{ width: "100%" }}
|
||||
value={value}
|
||||
onChange={(e) => {
|
||||
setValues((prev) => ({ ...prev, [key]: e.target.value }));
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
|
||||
e.preventDefault();
|
||||
e.currentTarget.blur();
|
||||
onSave();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
{canModify && (
|
||||
<HStack>
|
||||
<Button
|
||||
colorScheme="gray"
|
||||
onClick={() => setValues(savedValues)}
|
||||
minW={24}
|
||||
isDisabled={!hasChanged}
|
||||
>
|
||||
<Text>Reset</Text>
|
||||
</Button>
|
||||
<Button colorScheme="blue" onClick={onSave} minW={24} isDisabled={!hasChanged}>
|
||||
{saving ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
)}
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
21
app/src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
21
app/src/components/OutputsTable/ScenarioPaginator.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import { useScenarios } from "~/utils/hooks";
|
||||
import Paginator from "../Paginator";
|
||||
|
||||
const ScenarioPaginator = () => {
|
||||
const { data } = useScenarios();
|
||||
|
||||
if (!data) return null;
|
||||
|
||||
const { scenarios, startIndex, lastPage, count } = data;
|
||||
|
||||
return (
|
||||
<Paginator
|
||||
numItemsLoaded={scenarios.length}
|
||||
startIndex={startIndex}
|
||||
lastPage={lastPage}
|
||||
count={count}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default ScenarioPaginator;
|
||||
48
app/src/components/OutputsTable/ScenarioRow.tsx
Normal file
48
app/src/components/OutputsTable/ScenarioRow.tsx
Normal file
@@ -0,0 +1,48 @@
|
||||
import { GridItem } from "@chakra-ui/react";
|
||||
import React, { useState } from "react";
|
||||
import OutputCell from "./OutputCell/OutputCell";
|
||||
import ScenarioEditor from "./ScenarioEditor";
|
||||
import type { PromptVariant, Scenario } from "./types";
|
||||
import { borders } from "./styles";
|
||||
|
||||
const ScenarioRow = (props: {
|
||||
scenario: Scenario;
|
||||
variants: PromptVariant[];
|
||||
canHide: boolean;
|
||||
rowStart: number;
|
||||
}) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const highlightStyle = { backgroundColor: "gray.50" };
|
||||
|
||||
return (
|
||||
<>
|
||||
<GridItem
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
borderLeftWidth={1}
|
||||
{...borders}
|
||||
rowStart={props.rowStart}
|
||||
colStart={1}
|
||||
>
|
||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||
</GridItem>
|
||||
{props.variants.map((variant, i) => (
|
||||
<GridItem
|
||||
key={variant.id}
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
rowStart={props.rowStart}
|
||||
colStart={i + 2}
|
||||
{...borders}
|
||||
>
|
||||
<OutputCell key={variant.id} scenario={props.scenario} variant={variant} />
|
||||
</GridItem>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ScenarioRow;
|
||||
82
app/src/components/OutputsTable/ScenariosHeader.tsx
Normal file
82
app/src/components/OutputsTable/ScenariosHeader.tsx
Normal file
@@ -0,0 +1,82 @@
|
||||
import {
|
||||
Button,
|
||||
type ButtonProps,
|
||||
HStack,
|
||||
Text,
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react";
|
||||
import { cellPadding } from "../constants";
|
||||
import {
|
||||
useExperiment,
|
||||
useExperimentAccess,
|
||||
useHandledAsyncCallback,
|
||||
useScenarios,
|
||||
} from "~/utils/hooks";
|
||||
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { api } from "~/utils/api";
|
||||
|
||||
export const ActionButton = (props: ButtonProps) => (
|
||||
<Button size="sm" variant="ghost" color="gray.600" {...props} />
|
||||
);
|
||||
|
||||
export const ScenariosHeader = () => {
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
const { canModify } = useExperimentAccess();
|
||||
const scenarios = useScenarios();
|
||||
|
||||
const experiment = useExperiment();
|
||||
const createScenarioMutation = api.scenarios.create.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onAddScenario, loading] = useHandledAsyncCallback(
|
||||
async (autogenerate: boolean) => {
|
||||
if (!experiment.data) return;
|
||||
await createScenarioMutation.mutateAsync({
|
||||
experimentId: experiment.data.id,
|
||||
autogenerate,
|
||||
});
|
||||
await utils.scenarios.list.invalidate();
|
||||
},
|
||||
[createScenarioMutation],
|
||||
);
|
||||
|
||||
return (
|
||||
<HStack w="100%" pb={cellPadding.y} pt={0} align="center" spacing={0}>
|
||||
<Text fontSize={16} fontWeight="bold">
|
||||
Scenarios ({scenarios.data?.count})
|
||||
</Text>
|
||||
{canModify && (
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
mt={1}
|
||||
variant="ghost"
|
||||
aria-label="Edit Scenarios"
|
||||
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||
/>
|
||||
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
|
||||
<MenuItem
|
||||
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
|
||||
onClick={() => onAddScenario(false)}
|
||||
>
|
||||
Add Scenario
|
||||
</MenuItem>
|
||||
<MenuItem icon={<BsStars />} onClick={() => onAddScenario(true)}>
|
||||
Autogenerate Scenario
|
||||
</MenuItem>
|
||||
<MenuItem icon={<BsPencil />} onClick={openDrawer}>
|
||||
Edit Vars
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
239
app/src/components/OutputsTable/VariantEditor.tsx
Normal file
239
app/src/components/OutputsTable/VariantEditor.tsx
Normal file
@@ -0,0 +1,239 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
HStack,
|
||||
IconButton,
|
||||
Spinner,
|
||||
Text,
|
||||
Tooltip,
|
||||
useToast,
|
||||
} from "@chakra-ui/react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
useExperimentAccess,
|
||||
useHandledAsyncCallback,
|
||||
useModifierKeyLabel,
|
||||
useVisibleScenarioIds,
|
||||
} from "~/utils/hooks";
|
||||
import { type PromptVariant } from "./types";
|
||||
|
||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||
const [isChanged, setIsChanged] = useState(false);
|
||||
|
||||
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||
|
||||
const toggleFullscreen = useCallback(() => {
|
||||
setIsFullscreen((prev) => !prev);
|
||||
editorRef.current?.focus();
|
||||
}, [setIsFullscreen]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleEsc = (event: KeyboardEvent) => {
|
||||
if (event.key === "Escape" && isFullscreen) {
|
||||
toggleFullscreen();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", handleEsc);
|
||||
return () => window.removeEventListener("keydown", handleEsc);
|
||||
}, [isFullscreen, toggleFullscreen]);
|
||||
|
||||
const lastSavedFn = props.variant.promptConstructor;
|
||||
|
||||
const modifierKey = useModifierKeyLabel();
|
||||
|
||||
const checkForChanges = useCallback(() => {
|
||||
if (!editorRef.current) return;
|
||||
const currentFn = editorRef.current.getValue();
|
||||
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||
}, [lastSavedFn]);
|
||||
|
||||
const matchUpdatedSavedFn = useCallback(() => {
|
||||
if (!editorRef.current) return;
|
||||
editorRef.current.setValue(lastSavedFn);
|
||||
setIsChanged(false);
|
||||
}, [lastSavedFn]);
|
||||
|
||||
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
|
||||
|
||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||
const utils = api.useContext();
|
||||
const toast = useToast();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!editorRef.current) return;
|
||||
|
||||
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
||||
|
||||
const currentFn = editorRef.current.getValue();
|
||||
|
||||
if (!currentFn) return;
|
||||
|
||||
// Check if the editor has any typescript errors
|
||||
const model = editorRef.current.getModel();
|
||||
if (!model) return;
|
||||
|
||||
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
|
||||
const promptRegex = /definePrompt\(/;
|
||||
if (!promptRegex.test(currentFn)) {
|
||||
toast({
|
||||
title: "Missing prompt",
|
||||
description: "Please define the prompt (eg. `definePrompt(...`",
|
||||
status: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const resp = await replaceVariant.mutateAsync({
|
||||
id: props.variant.id,
|
||||
promptConstructor: currentFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
if (resp.status === "error") {
|
||||
return toast({
|
||||
title: "Error saving variant",
|
||||
description: resp.message,
|
||||
status: "error",
|
||||
});
|
||||
}
|
||||
|
||||
setIsChanged(false);
|
||||
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [checkForChanges]);
|
||||
|
||||
useEffect(() => {
|
||||
if (monaco) {
|
||||
const container = document.getElementById(editorId) as HTMLElement;
|
||||
|
||||
editorRef.current = monaco.editor.create(container, {
|
||||
value: lastSavedFn,
|
||||
language: "typescript",
|
||||
theme: "customTheme",
|
||||
lineNumbers: "off",
|
||||
minimap: { enabled: false },
|
||||
wrappingIndent: "indent",
|
||||
wrappingStrategy: "advanced",
|
||||
wordWrap: "on",
|
||||
folding: false,
|
||||
scrollbar: {
|
||||
alwaysConsumeMouseWheel: false,
|
||||
verticalScrollbarSize: 0,
|
||||
},
|
||||
wordWrapBreakAfterCharacters: "",
|
||||
wordWrapBreakBeforeCharacters: "",
|
||||
quickSuggestions: true,
|
||||
readOnly: !canModify,
|
||||
});
|
||||
|
||||
// Workaround because otherwise the commands only work on whatever
|
||||
// editor was loaded on the page last.
|
||||
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
||||
editorRef.current.onDidFocusEditorText(() => {
|
||||
editorRef.current?.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.KeyS, onSave);
|
||||
|
||||
editorRef.current?.addCommand(
|
||||
monaco.KeyMod.CtrlCmd | monaco.KeyMod.Shift | monaco.KeyCode.KeyF,
|
||||
toggleFullscreen,
|
||||
);
|
||||
|
||||
// Exit fullscreen with escape
|
||||
editorRef.current?.addCommand(monaco.KeyCode.Escape, () => {
|
||||
if (isFullscreen) {
|
||||
toggleFullscreen();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
editorRef.current.onDidChangeModelContent(checkForChanges);
|
||||
|
||||
const resizeObserver = new ResizeObserver(() => {
|
||||
editorRef.current?.layout();
|
||||
});
|
||||
resizeObserver.observe(container);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
editorRef.current?.dispose();
|
||||
};
|
||||
}
|
||||
|
||||
// We intentionally skip the onSave and props.savedConfig dependencies here because
|
||||
// we don't want to re-render the editor from scratch
|
||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||
}, [monaco, editorId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!editorRef.current) return;
|
||||
editorRef.current.updateOptions({
|
||||
readOnly: !canModify,
|
||||
});
|
||||
}, [canModify]);
|
||||
|
||||
return (
|
||||
<Box
|
||||
w="100%"
|
||||
ref={containerRef}
|
||||
sx={
|
||||
isFullscreen
|
||||
? {
|
||||
position: "fixed",
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
}
|
||||
: { h: "400px", w: "100%" }
|
||||
}
|
||||
bgColor={editorBackground}
|
||||
zIndex={isFullscreen ? 1000 : "unset"}
|
||||
pos="relative"
|
||||
_hover={{ ".fullscreen-toggle": { opacity: 1 } }}
|
||||
>
|
||||
<Box id={editorId} w="100%" h="100%" />
|
||||
<Tooltip label={`${modifierKey} + ⇧ + F`}>
|
||||
<IconButton
|
||||
className="fullscreen-toggle"
|
||||
aria-label="Minimize"
|
||||
icon={isFullscreen ? <FiMinimize /> : <FiMaximize />}
|
||||
position="absolute"
|
||||
top={2}
|
||||
right={2}
|
||||
onClick={toggleFullscreen}
|
||||
opacity={0}
|
||||
transition="opacity 0.2s"
|
||||
/>
|
||||
</Tooltip>
|
||||
|
||||
{isChanged && (
|
||||
<HStack pos="absolute" bottom={2} right={2}>
|
||||
<Button
|
||||
colorScheme="gray"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
editorRef.current?.setValue(lastSavedFn);
|
||||
checkForChanges();
|
||||
}}
|
||||
>
|
||||
Reset
|
||||
</Button>
|
||||
<Tooltip label={`${modifierKey} + S`}>
|
||||
<Button size="sm" onClick={onSave} colorScheme="blue" w={16} disabled={saveInProgress}>
|
||||
{saveInProgress ? <Spinner boxSize={4} /> : <Text>Save</Text>}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
83
app/src/components/OutputsTable/VariantStats.tsx
Normal file
83
app/src/components/OutputsTable/VariantStats.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { cellPadding } from "../constants";
|
||||
import { api } from "~/utils/api";
|
||||
import chroma from "chroma-js";
|
||||
import { BsCurrencyDollar } from "react-icons/bs";
|
||||
import { CostTooltip } from "../tooltip/CostTooltip";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||
const { data } = api.promptVariants.stats.useQuery(
|
||||
{
|
||||
variantId: props.variant.id,
|
||||
},
|
||||
{
|
||||
initialData: {
|
||||
evalResults: [],
|
||||
overallCost: 0,
|
||||
promptTokens: 0,
|
||||
completionTokens: 0,
|
||||
scenarioCount: 0,
|
||||
outputCount: 0,
|
||||
awaitingEvals: false,
|
||||
},
|
||||
refetchInterval,
|
||||
},
|
||||
);
|
||||
|
||||
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||
useEffect(() => setRefetchInterval(data.awaitingEvals ? 5000 : 0), [data.awaitingEvals]);
|
||||
|
||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||
"green.500",
|
||||
"gray.500",
|
||||
"red.500",
|
||||
]);
|
||||
|
||||
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
|
||||
|
||||
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
||||
|
||||
return (
|
||||
<HStack
|
||||
justifyContent="space-between"
|
||||
alignItems="flex-end"
|
||||
mx="2"
|
||||
fontSize="xs"
|
||||
py={cellPadding.y}
|
||||
>
|
||||
<HStack px={cellPadding.x} flexWrap="wrap">
|
||||
{showNumFinished && (
|
||||
<Text>
|
||||
{data.outputCount} / {data.scenarioCount}
|
||||
</Text>
|
||||
)}
|
||||
{data.evalResults.map((result) => {
|
||||
const passedFrac = result.passCount / result.totalCount;
|
||||
return (
|
||||
<HStack key={result.id}>
|
||||
<Text>{result.label}</Text>
|
||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||
{(passedFrac * 100).toFixed(1)}%
|
||||
</Text>
|
||||
</HStack>
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{data.overallCost && (
|
||||
<CostTooltip
|
||||
promptTokens={data.promptTokens}
|
||||
completionTokens={data.completionTokens}
|
||||
cost={data.overallCost}
|
||||
>
|
||||
<HStack spacing={0} align="center" color="gray.500">
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
</CostTooltip>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
}
|
||||
107
app/src/components/OutputsTable/index.tsx
Normal file
107
app/src/components/OutputsTable/index.tsx
Normal file
@@ -0,0 +1,107 @@
|
||||
import { Grid, GridItem, type GridItemProps } from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
import AddVariantButton from "./AddVariantButton";
|
||||
import ScenarioRow from "./ScenarioRow";
|
||||
import VariantEditor from "./VariantEditor";
|
||||
import VariantHeader from "../VariantHeader/VariantHeader";
|
||||
import VariantStats from "./VariantStats";
|
||||
import { ScenariosHeader } from "./ScenariosHeader";
|
||||
import { borders } from "./styles";
|
||||
import { useScenarios } from "~/utils/hooks";
|
||||
import ScenarioPaginator from "./ScenarioPaginator";
|
||||
import { Fragment } from "react";
|
||||
|
||||
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
|
||||
const variants = api.promptVariants.list.useQuery(
|
||||
{ experimentId: experimentId as string },
|
||||
{ enabled: !!experimentId },
|
||||
);
|
||||
|
||||
const scenarios = useScenarios();
|
||||
|
||||
if (!variants.data || !scenarios.data) return null;
|
||||
|
||||
const allCols = variants.data.length + 2;
|
||||
const variantHeaderRows = 3;
|
||||
const scenarioHeaderRows = 1;
|
||||
const scenarioFooterRows = 1;
|
||||
const visibleScenariosCount = scenarios.data.scenarios.length;
|
||||
const allRows =
|
||||
variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + scenarioFooterRows;
|
||||
|
||||
return (
|
||||
<Grid
|
||||
pt={4}
|
||||
pb={24}
|
||||
pl={8}
|
||||
display="grid"
|
||||
gridTemplateColumns={`250px repeat(${variants.data.length}, minmax(320px, 1fr)) auto`}
|
||||
sx={{
|
||||
"> *": {
|
||||
borderColor: "gray.300",
|
||||
},
|
||||
}}
|
||||
fontSize="sm"
|
||||
>
|
||||
<GridItem rowSpan={variantHeaderRows}>
|
||||
<AddVariantButton />
|
||||
</GridItem>
|
||||
|
||||
{variants.data.map((variant, i) => {
|
||||
const sharedProps: GridItemProps = {
|
||||
...borders,
|
||||
colStart: i + 2,
|
||||
borderLeftWidth: i === 0 ? 1 : 0,
|
||||
marginLeft: i === 0 ? "-1px" : 0,
|
||||
backgroundColor: "gray.100",
|
||||
};
|
||||
return (
|
||||
<Fragment key={variant.uiId}>
|
||||
<VariantHeader
|
||||
variant={variant}
|
||||
canHide={variants.data.length > 1}
|
||||
rowStart={1}
|
||||
{...sharedProps}
|
||||
/>
|
||||
<GridItem rowStart={2} {...sharedProps}>
|
||||
<VariantEditor variant={variant} />
|
||||
</GridItem>
|
||||
<GridItem rowStart={3} {...sharedProps}>
|
||||
<VariantStats variant={variant} />
|
||||
</GridItem>
|
||||
</Fragment>
|
||||
);
|
||||
})}
|
||||
|
||||
<GridItem
|
||||
colSpan={allCols - 1}
|
||||
rowStart={variantHeaderRows + 1}
|
||||
colStart={1}
|
||||
{...borders}
|
||||
borderRightWidth={0}
|
||||
>
|
||||
<ScenariosHeader />
|
||||
</GridItem>
|
||||
|
||||
{scenarios.data.scenarios.map((scenario, i) => (
|
||||
<ScenarioRow
|
||||
rowStart={i + variantHeaderRows + scenarioHeaderRows + 2}
|
||||
key={scenario.uiId}
|
||||
scenario={scenario}
|
||||
variants={variants.data}
|
||||
canHide={visibleScenariosCount > 1}
|
||||
/>
|
||||
))}
|
||||
<GridItem
|
||||
rowStart={variantHeaderRows + scenarioHeaderRows + visibleScenariosCount + 2}
|
||||
colStart={1}
|
||||
colSpan={allCols}
|
||||
>
|
||||
<ScenarioPaginator />
|
||||
</GridItem>
|
||||
|
||||
{/* Add some extra padding on the right, because when the table is too wide to fit in the viewport `pr` on the Grid isn't respected. */}
|
||||
<GridItem rowStart={1} colStart={allCols} rowSpan={allRows} w={4} borderBottomWidth={0} />
|
||||
</Grid>
|
||||
);
|
||||
}
|
||||
6
app/src/components/OutputsTable/styles.ts
Normal file
6
app/src/components/OutputsTable/styles.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { type GridItemProps } from "@chakra-ui/react";
|
||||
|
||||
export const borders: GridItemProps = {
|
||||
borderRightWidth: 1,
|
||||
borderBottomWidth: 1,
|
||||
};
|
||||
5
app/src/components/OutputsTable/types.ts
Normal file
5
app/src/components/OutputsTable/types.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
|
||||
export type PromptVariant = NonNullable<RouterOutputs["promptVariants"]["list"]>[0];
|
||||
|
||||
export type Scenario = NonNullable<RouterOutputs["scenarios"]["list"]>["scenarios"][0];
|
||||
79
app/src/components/Paginator.tsx
Normal file
79
app/src/components/Paginator.tsx
Normal file
@@ -0,0 +1,79 @@
|
||||
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
||||
import {
|
||||
BsChevronDoubleLeft,
|
||||
BsChevronDoubleRight,
|
||||
BsChevronLeft,
|
||||
BsChevronRight,
|
||||
} from "react-icons/bs";
|
||||
import { usePage } from "~/utils/hooks";
|
||||
|
||||
const Paginator = ({
|
||||
numItemsLoaded,
|
||||
startIndex,
|
||||
lastPage,
|
||||
count,
|
||||
}: {
|
||||
numItemsLoaded: number;
|
||||
startIndex: number;
|
||||
lastPage: number;
|
||||
count: number;
|
||||
}) => {
|
||||
const [page, setPage] = usePage();
|
||||
|
||||
const nextPage = () => {
|
||||
if (page < lastPage) {
|
||||
setPage(page + 1, "replace");
|
||||
}
|
||||
};
|
||||
|
||||
const prevPage = () => {
|
||||
if (page > 1) {
|
||||
setPage(page - 1, "replace");
|
||||
}
|
||||
};
|
||||
|
||||
const goToLastPage = () => setPage(lastPage, "replace");
|
||||
const goToFirstPage = () => setPage(1, "replace");
|
||||
|
||||
return (
|
||||
<HStack pt={4}>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={goToFirstPage}
|
||||
isDisabled={page === 1}
|
||||
aria-label="Go to first page"
|
||||
icon={<BsChevronDoubleLeft />}
|
||||
/>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={prevPage}
|
||||
isDisabled={page === 1}
|
||||
aria-label="Previous page"
|
||||
icon={<BsChevronLeft />}
|
||||
/>
|
||||
<Box>
|
||||
{startIndex}-{startIndex + numItemsLoaded - 1} / {count}
|
||||
</Box>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={nextPage}
|
||||
isDisabled={page === lastPage}
|
||||
aria-label="Next page"
|
||||
icon={<BsChevronRight />}
|
||||
/>
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={goToLastPage}
|
||||
isDisabled={page === lastPage}
|
||||
aria-label="Go to last page"
|
||||
icon={<BsChevronDoubleRight />}
|
||||
/>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default Paginator;
|
||||
59
app/src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
59
app/src/components/RefinePromptModal/CompareFunctions.tsx
Normal file
@@ -0,0 +1,59 @@
|
||||
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||
import Prism from "prismjs";
|
||||
import "prismjs/components/prism-javascript";
|
||||
import "prismjs/themes/prism.css"; // choose a theme you like
|
||||
|
||||
const highlightSyntax = (str: string) => {
|
||||
let highlighted;
|
||||
try {
|
||||
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
|
||||
} catch (e) {
|
||||
console.error("Error highlighting:", e);
|
||||
highlighted = str;
|
||||
}
|
||||
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
|
||||
};
|
||||
|
||||
const CompareFunctions = ({
|
||||
originalFunction,
|
||||
newFunction = "",
|
||||
leftTitle = "Original",
|
||||
rightTitle = "Modified",
|
||||
...props
|
||||
}: {
|
||||
originalFunction: string;
|
||||
newFunction?: string;
|
||||
leftTitle?: string;
|
||||
rightTitle?: string;
|
||||
} & StackProps) => {
|
||||
const showSplitView = useBreakpointValue(
|
||||
{
|
||||
base: false,
|
||||
md: true,
|
||||
},
|
||||
{
|
||||
fallback: "base",
|
||||
},
|
||||
);
|
||||
|
||||
return (
|
||||
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
|
||||
<DiffViewer
|
||||
oldValue={originalFunction}
|
||||
newValue={newFunction || originalFunction}
|
||||
splitView={showSplitView}
|
||||
hideLineNumbers={!showSplitView}
|
||||
leftTitle={leftTitle}
|
||||
rightTitle={rightTitle}
|
||||
disableWordDiff={true}
|
||||
compareMethod={DiffMethod.CHARS}
|
||||
renderContent={highlightSyntax}
|
||||
showDiffOnly={false}
|
||||
/>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default CompareFunctions;
|
||||
65
app/src/components/RefinePromptModal/RefineAction.tsx
Normal file
65
app/src/components/RefinePromptModal/RefineAction.tsx
Normal file
@@ -0,0 +1,65 @@
|
||||
import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react";
|
||||
import { type IconType } from "react-icons";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
|
||||
export const RefineAction = ({
|
||||
label,
|
||||
icon,
|
||||
desciption,
|
||||
activeLabel,
|
||||
onClick,
|
||||
loading,
|
||||
}: {
|
||||
label: string;
|
||||
icon?: IconType;
|
||||
desciption: string;
|
||||
activeLabel: string | undefined;
|
||||
onClick: (label: string) => void;
|
||||
loading: boolean;
|
||||
}) => {
|
||||
const isActive = activeLabel === label;
|
||||
|
||||
return (
|
||||
<GridItem w="80" h="44">
|
||||
<VStack
|
||||
w="full"
|
||||
h="full"
|
||||
onClick={() => {
|
||||
!loading && onClick(label);
|
||||
}}
|
||||
borderColor={isActive ? "blue.500" : "gray.200"}
|
||||
borderWidth={2}
|
||||
borderRadius={16}
|
||||
padding={6}
|
||||
backgroundColor="gray.50"
|
||||
_hover={
|
||||
loading
|
||||
? undefined
|
||||
: {
|
||||
backgroundColor: "gray.100",
|
||||
}
|
||||
}
|
||||
spacing={8}
|
||||
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
|
||||
cursor="pointer"
|
||||
opacity={loading ? 0.5 : 1}
|
||||
>
|
||||
<HStack cursor="pointer" spacing={6} fontSize="sm" fontWeight="medium" color="gray.500">
|
||||
<Icon as={icon || BsStars} boxSize={12} />
|
||||
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||
{label}
|
||||
</Heading>
|
||||
</HStack>
|
||||
<Text
|
||||
fontSize="sm"
|
||||
color="gray.500"
|
||||
flexWrap="wrap"
|
||||
wordBreak="break-word"
|
||||
overflowWrap="break-word"
|
||||
>
|
||||
{desciption}
|
||||
</Text>
|
||||
</VStack>
|
||||
</GridItem>
|
||||
);
|
||||
};
|
||||
151
app/src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
151
app/src/components/RefinePromptModal/RefinePromptModal.tsx
Normal file
@@ -0,0 +1,151 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
VStack,
|
||||
Text,
|
||||
Spinner,
|
||||
HStack,
|
||||
Icon,
|
||||
SimpleGrid,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { useState } from "react";
|
||||
import CompareFunctions from "./CompareFunctions";
|
||||
import { CustomInstructionsInput } from "../CustomInstructionsInput";
|
||||
import { RefineAction } from "./RefineAction";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
|
||||
export const RefinePromptModal = ({
|
||||
variant,
|
||||
onClose,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const refinementActions =
|
||||
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
|
||||
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
const [instructions, setInstructions] = useState<string>("");
|
||||
|
||||
const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>(
|
||||
undefined,
|
||||
);
|
||||
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||
async (label?: string) => {
|
||||
if (!variant.experimentId) return;
|
||||
const updatedInstructions = label
|
||||
? (refinementActions[label] as RefinementAction).instructions
|
||||
: instructions;
|
||||
setActiveRefineActionLabel(label);
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
instructions: updatedInstructions,
|
||||
});
|
||||
},
|
||||
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel],
|
||||
);
|
||||
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
|
||||
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (
|
||||
!variant.experimentId ||
|
||||
!refinedPromptFn ||
|
||||
(isObject(refinedPromptFn) && "status" in refinedPromptFn)
|
||||
)
|
||||
return;
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
promptConstructor: refinedPromptFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={onClose}
|
||||
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={BsStars} />
|
||||
<Text>Refine with GPT-4</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8}>
|
||||
<VStack spacing={4} w="full">
|
||||
{Object.keys(refinementActions).length && (
|
||||
<>
|
||||
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
|
||||
{Object.keys(refinementActions).map((label) => (
|
||||
<RefineAction
|
||||
key={label}
|
||||
label={label}
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
icon={refinementActions[label]!.icon}
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
desciption={refinementActions[label]!.description}
|
||||
activeLabel={activeRefineActionLabel}
|
||||
onClick={getModifiedPromptFn}
|
||||
loading={modificationInProgress}
|
||||
/>
|
||||
))}
|
||||
</SimpleGrid>
|
||||
<Text color="gray.500">or</Text>
|
||||
</>
|
||||
)}
|
||||
<CustomInstructionsInput
|
||||
instructions={instructions}
|
||||
setInstructions={setInstructions}
|
||||
loading={modificationInProgress}
|
||||
onSubmit={() => getModifiedPromptFn()}
|
||||
/>
|
||||
</VStack>
|
||||
<CompareFunctions
|
||||
originalFunction={variant.promptConstructor}
|
||||
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||
maxH="40vh"
|
||||
/>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<HStack spacing={4}>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
isDisabled={replacementInProgress || !refinedPromptFn}
|
||||
>
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
141
app/src/components/VariantHeader/VariantHeader.tsx
Normal file
141
app/src/components/VariantHeader/VariantHeader.tsx
Normal file
@@ -0,0 +1,141 @@
|
||||
import { useState, type DragEvent } from "react";
|
||||
import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { RiDraggable } from "react-icons/ri";
|
||||
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
|
||||
|
||||
export default function VariantHeader(
|
||||
allProps: {
|
||||
variant: PromptVariant;
|
||||
canHide: boolean;
|
||||
} & GridItemProps,
|
||||
) {
|
||||
const { variant, canHide, ...gridItemProps } = allProps;
|
||||
const { canModify } = useExperimentAccess();
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||
const [label, setLabel] = useState(variant.label);
|
||||
|
||||
const updateMutation = api.promptVariants.update.useMutation();
|
||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||
if (label && label !== variant.label) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
updates: { label: label },
|
||||
});
|
||||
}
|
||||
}, [updateMutation, variant.id, variant.label, label]);
|
||||
|
||||
const reorderMutation = api.promptVariants.reorder.useMutation();
|
||||
const [onReorder] = useHandledAsyncCallback(
|
||||
async (e: DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(false);
|
||||
const draggedId = e.dataTransfer.getData("text/plain");
|
||||
const droppedId = variant.id;
|
||||
if (!draggedId || !droppedId || draggedId === droppedId) return;
|
||||
await reorderMutation.mutateAsync({
|
||||
draggedId,
|
||||
droppedId,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
},
|
||||
[reorderMutation, variant.id],
|
||||
);
|
||||
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
|
||||
if (!canModify) {
|
||||
return (
|
||||
<GridItem
|
||||
padding={0}
|
||||
sx={{
|
||||
position: "sticky",
|
||||
top: "0",
|
||||
// Ensure that the menu always appears above the sticky header of other variants
|
||||
zIndex: menuOpen ? "dropdown" : 10,
|
||||
}}
|
||||
borderTopWidth={1}
|
||||
{...gridItemProps}
|
||||
>
|
||||
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
|
||||
{variant.label}
|
||||
</Text>
|
||||
</GridItem>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<GridItem
|
||||
padding={0}
|
||||
sx={{
|
||||
position: "sticky",
|
||||
top: "0",
|
||||
// Ensure that the menu always appears above the sticky header of other variants
|
||||
zIndex: menuOpen ? "dropdown" : 10,
|
||||
}}
|
||||
borderTopWidth={1}
|
||||
{...gridItemProps}
|
||||
>
|
||||
<HStack
|
||||
spacing={2}
|
||||
alignItems="flex-start"
|
||||
minH={headerMinHeight}
|
||||
draggable={!isInputHovered}
|
||||
onDragStart={(e) => {
|
||||
e.dataTransfer.setData("text/plain", variant.id);
|
||||
e.currentTarget.style.opacity = "0.4";
|
||||
}}
|
||||
onDragEnd={(e) => {
|
||||
e.currentTarget.style.opacity = "1";
|
||||
}}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault();
|
||||
setIsDragTarget(true);
|
||||
}}
|
||||
onDragLeave={() => {
|
||||
setIsDragTarget(false);
|
||||
}}
|
||||
onDrop={onReorder}
|
||||
backgroundColor={isDragTarget ? "gray.200" : "gray.100"}
|
||||
h="full"
|
||||
>
|
||||
<Icon
|
||||
as={RiDraggable}
|
||||
boxSize={6}
|
||||
mt={2}
|
||||
color="gray.400"
|
||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||
/>
|
||||
<AutoResizeTextArea
|
||||
size="sm"
|
||||
value={label}
|
||||
onChange={(e) => setLabel(e.target.value)}
|
||||
onBlur={onSaveLabel}
|
||||
placeholder="Variant Name"
|
||||
borderWidth={1}
|
||||
borderColor="transparent"
|
||||
fontWeight="bold"
|
||||
fontSize={16}
|
||||
_hover={{ borderColor: "gray.300" }}
|
||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||
flex={1}
|
||||
px={cellPadding.x}
|
||||
onMouseEnter={() => setIsInputHovered(true)}
|
||||
onMouseLeave={() => setIsInputHovered(false)}
|
||||
/>
|
||||
<VariantHeaderMenuButton
|
||||
variant={variant}
|
||||
canHide={canHide}
|
||||
menuOpen={menuOpen}
|
||||
setMenuOpen={setMenuOpen}
|
||||
/>
|
||||
</HStack>
|
||||
</GridItem>
|
||||
);
|
||||
}
|
||||
107
app/src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
107
app/src/components/VariantHeader/VariantHeaderMenuButton.tsx
Normal file
@@ -0,0 +1,107 @@
|
||||
import { type PromptVariant } from "../OutputsTable/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import {
|
||||
Icon,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
MenuDivider,
|
||||
Text,
|
||||
Spinner,
|
||||
IconButton,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
||||
import { FaRegClone } from "react-icons/fa";
|
||||
import { useState } from "react";
|
||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
|
||||
|
||||
export default function VariantHeaderMenuButton({
|
||||
variant,
|
||||
canHide,
|
||||
menuOpen,
|
||||
setMenuOpen,
|
||||
}: {
|
||||
variant: PromptVariant;
|
||||
canHide: boolean;
|
||||
menuOpen: boolean;
|
||||
setMenuOpen: (open: boolean) => void;
|
||||
}) {
|
||||
const utils = api.useContext();
|
||||
|
||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||
await duplicateMutation.mutateAsync({
|
||||
experimentId: variant.experimentId,
|
||||
variantId: variant.id,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||
|
||||
const hideMutation = api.promptVariants.hide.useMutation();
|
||||
const [onHide] = useHandledAsyncCallback(async () => {
|
||||
await hideMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
}, [hideMutation, variant.id]);
|
||||
|
||||
const [changeModelModalOpen, setChangeModelModalOpen] = useState(false);
|
||||
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
variant="ghost"
|
||||
aria-label="Edit Scenarios"
|
||||
icon={<Icon as={duplicationInProgress ? Spinner : BsGear} />}
|
||||
/>
|
||||
|
||||
<MenuList mt={-3} fontSize="md">
|
||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||
Duplicate
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
|
||||
onClick={() => setChangeModelModalOpen(true)}
|
||||
>
|
||||
Change Model
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={BsStars} boxSize={5} />}
|
||||
onClick={() => setRefinePromptModalOpen(true)}
|
||||
>
|
||||
Refine
|
||||
</MenuItem>
|
||||
{canHide && (
|
||||
<>
|
||||
<MenuDivider />
|
||||
<MenuItem
|
||||
onClick={onHide}
|
||||
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
|
||||
color="red.600"
|
||||
_hover={{ backgroundColor: "red.50" }}
|
||||
>
|
||||
<Text>Hide</Text>
|
||||
</MenuItem>
|
||||
</>
|
||||
)}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{changeModelModalOpen && (
|
||||
<ChangeModelModal variant={variant} onClose={() => setChangeModelModalOpen(false)} />
|
||||
)}
|
||||
{refinePromptModalOpen && (
|
||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
6
app/src/components/constants.ts
Normal file
6
app/src/components/constants.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export const cellPadding = {
|
||||
x: 4,
|
||||
y: 2,
|
||||
};
|
||||
|
||||
export const headerMinHeight = 8;
|
||||
110
app/src/components/datasets/DatasetCard.tsx
Normal file
110
app/src/components/datasets/DatasetCard.tsx
Normal file
@@ -0,0 +1,110 @@
|
||||
import {
|
||||
HStack,
|
||||
Icon,
|
||||
VStack,
|
||||
Text,
|
||||
Divider,
|
||||
Spinner,
|
||||
AspectRatio,
|
||||
SkeletonText,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiDatabase2Line } from "react-icons/ri";
|
||||
import { formatTimePast } from "~/utils/dayjs";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsPlusSquare } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
type DatasetData = {
|
||||
name: string;
|
||||
numEntries: number;
|
||||
id: string;
|
||||
createdAt: Date;
|
||||
updatedAt: Date;
|
||||
};
|
||||
|
||||
export const DatasetCard = ({ dataset }: { dataset: DatasetData }) => {
|
||||
return (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
as={Link}
|
||||
href={{ pathname: "/data/[id]", query: { id: dataset.id } }}
|
||||
bg="gray.50"
|
||||
_hover={{ bg: "gray.100" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
justify="space-between"
|
||||
>
|
||||
<HStack w="full" color="gray.700" justify="center">
|
||||
<Icon as={RiDatabase2Line} boxSize={4} />
|
||||
<Text fontWeight="bold">{dataset.name}</Text>
|
||||
</HStack>
|
||||
<HStack h="full" spacing={4} flex={1} align="center">
|
||||
<CountLabel label="Rows" count={dataset.numEntries} />
|
||||
</HStack>
|
||||
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
|
||||
<Text flex={1}>Created {formatTimePast(dataset.createdAt)}</Text>
|
||||
<Divider h={4} orientation="vertical" />
|
||||
<Text flex={1}>Updated {formatTimePast(dataset.updatedAt)}</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||
return (
|
||||
<VStack alignItems="center" flex={1}>
|
||||
<Text color="gray.500" fontWeight="bold">
|
||||
{label}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="gray.500">
|
||||
{count}
|
||||
</Text>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export const NewDatasetCard = () => {
|
||||
const router = useRouter();
|
||||
const createMutation = api.datasets.create.useMutation();
|
||||
const [createDataset, isLoading] = useHandledAsyncCallback(async () => {
|
||||
const newDataset = await createMutation.mutateAsync({ label: "New Dataset" });
|
||||
await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } });
|
||||
}, [createMutation, router]);
|
||||
|
||||
return (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
align="center"
|
||||
justify="center"
|
||||
_hover={{ cursor: "pointer", bg: "gray.50" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
onClick={createDataset}
|
||||
>
|
||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||
New Dataset
|
||||
</Text>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
export const DatasetCardSkeleton = () => (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
<SkeletonText noOfLines={2} w="60%" />
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
21
app/src/components/datasets/DatasetEntriesPaginator.tsx
Normal file
21
app/src/components/datasets/DatasetEntriesPaginator.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import { useDatasetEntries } from "~/utils/hooks";
|
||||
import Paginator from "../Paginator";
|
||||
|
||||
const DatasetEntriesPaginator = () => {
|
||||
const { data } = useDatasetEntries();
|
||||
|
||||
if (!data) return null;
|
||||
|
||||
const { entries, startIndex, lastPage, count } = data;
|
||||
|
||||
return (
|
||||
<Paginator
|
||||
numItemsLoaded={entries.length}
|
||||
startIndex={startIndex}
|
||||
lastPage={lastPage}
|
||||
count={count}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default DatasetEntriesPaginator;
|
||||
31
app/src/components/datasets/DatasetEntriesTable.tsx
Normal file
31
app/src/components/datasets/DatasetEntriesTable.tsx
Normal file
@@ -0,0 +1,31 @@
|
||||
import { type StackProps, VStack, Table, Th, Tr, Thead, Tbody, Text } from "@chakra-ui/react";
|
||||
import { useDatasetEntries } from "~/utils/hooks";
|
||||
import TableRow from "./TableRow";
|
||||
import DatasetEntriesPaginator from "./DatasetEntriesPaginator";
|
||||
|
||||
const DatasetEntriesTable = (props: StackProps) => {
|
||||
const { data } = useDatasetEntries();
|
||||
|
||||
return (
|
||||
<VStack justifyContent="space-between" {...props}>
|
||||
<Table variant="simple" sx={{ "table-layout": "fixed", width: "full" }}>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Input</Th>
|
||||
<Th>Output</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>{data?.entries.map((entry) => <TableRow key={entry.id} entry={entry} />)}</Tbody>
|
||||
</Table>
|
||||
{(!data || data.entries.length) === 0 ? (
|
||||
<Text alignSelf="flex-start" pl={6} color="gray.500">
|
||||
No entries found
|
||||
</Text>
|
||||
) : (
|
||||
<DatasetEntriesPaginator />
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default DatasetEntriesTable;
|
||||
@@ -0,0 +1,26 @@
|
||||
import { Button, HStack, useDisclosure } from "@chakra-ui/react";
|
||||
import { BiImport } from "react-icons/bi";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
|
||||
import { GenerateDataModal } from "./GenerateDataModal";
|
||||
|
||||
export const DatasetHeaderButtons = () => {
|
||||
const generateModalDisclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
<>
|
||||
<HStack>
|
||||
<Button leftIcon={<BiImport />} colorScheme="blue" variant="ghost">
|
||||
Import Data
|
||||
</Button>
|
||||
<Button leftIcon={<BsStars />} colorScheme="blue" onClick={generateModalDisclosure.onOpen}>
|
||||
Generate Data
|
||||
</Button>
|
||||
</HStack>
|
||||
<GenerateDataModal
|
||||
isOpen={generateModalDisclosure.isOpen}
|
||||
onClose={generateModalDisclosure.onClose}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,128 @@
|
||||
import {
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
ModalFooter,
|
||||
Text,
|
||||
HStack,
|
||||
VStack,
|
||||
Icon,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
NumberInputStepper,
|
||||
NumberIncrementStepper,
|
||||
NumberDecrementStepper,
|
||||
Button,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { useState } from "react";
|
||||
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import AutoResizeTextArea from "~/components/AutoResizeTextArea";
|
||||
|
||||
export const GenerateDataModal = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
}: {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
|
||||
const datasetId = useDataset().data?.id;
|
||||
|
||||
const [numToGenerate, setNumToGenerate] = useState<number>(20);
|
||||
const [inputDescription, setInputDescription] = useState<string>(
|
||||
"Each input should contain an email body. Half of the emails should contain event details, and the other half should not.",
|
||||
);
|
||||
const [outputDescription, setOutputDescription] = useState<string>(
|
||||
`Each output should contain "true" or "false", where "true" indicates that the email contains event details.`,
|
||||
);
|
||||
|
||||
const generateEntriesMutation = api.datasetEntries.autogenerateEntries.useMutation();
|
||||
|
||||
const [generateEntries, generateEntriesInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!inputDescription || !outputDescription || !numToGenerate || !datasetId) return;
|
||||
await generateEntriesMutation.mutateAsync({
|
||||
datasetId,
|
||||
inputDescription,
|
||||
outputDescription,
|
||||
numToGenerate,
|
||||
});
|
||||
await utils.datasetEntries.list.invalidate();
|
||||
onClose();
|
||||
}, [
|
||||
generateEntriesMutation,
|
||||
onClose,
|
||||
inputDescription,
|
||||
outputDescription,
|
||||
numToGenerate,
|
||||
datasetId,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={BsStars} />
|
||||
<Text>Generate Data</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack w="full" spacing={8} padding={8} alignItems="flex-start">
|
||||
<VStack alignItems="flex-start" spacing={2}>
|
||||
<Text fontWeight="bold">Number of Rows:</Text>
|
||||
<NumberInput
|
||||
step={5}
|
||||
defaultValue={15}
|
||||
min={0}
|
||||
max={100}
|
||||
onChange={(valueString) => setNumToGenerate(parseInt(valueString) || 0)}
|
||||
value={numToGenerate}
|
||||
w="24"
|
||||
>
|
||||
<NumberInputField />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
</VStack>
|
||||
<VStack alignItems="flex-start" w="full" spacing={2}>
|
||||
<Text fontWeight="bold">Input Description:</Text>
|
||||
<AutoResizeTextArea
|
||||
value={inputDescription}
|
||||
onChange={(e) => setInputDescription(e.target.value)}
|
||||
placeholder="Each input should contain..."
|
||||
/>
|
||||
</VStack>
|
||||
<VStack alignItems="flex-start" w="full" spacing={2}>
|
||||
<Text fontWeight="bold">Output Description (optional):</Text>
|
||||
<AutoResizeTextArea
|
||||
value={outputDescription}
|
||||
onChange={(e) => setOutputDescription(e.target.value)}
|
||||
placeholder="The output should contain..."
|
||||
/>
|
||||
</VStack>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
isLoading={generateEntriesInProgress}
|
||||
isDisabled={!numToGenerate || !inputDescription || !outputDescription}
|
||||
onClick={generateEntries}
|
||||
>
|
||||
Generate
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
13
app/src/components/datasets/TableRow.tsx
Normal file
13
app/src/components/datasets/TableRow.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
import { Td, Tr } from "@chakra-ui/react";
|
||||
import { type DatasetEntry } from "@prisma/client";
|
||||
|
||||
const TableRow = ({ entry }: { entry: DatasetEntry }) => {
|
||||
return (
|
||||
<Tr key={entry.id}>
|
||||
<Td>{entry.input}</Td>
|
||||
<Td>{entry.output}</Td>
|
||||
</Tr>
|
||||
);
|
||||
};
|
||||
|
||||
export default TableRow;
|
||||
114
app/src/components/experiments/ExperimentCard.tsx
Normal file
114
app/src/components/experiments/ExperimentCard.tsx
Normal file
@@ -0,0 +1,114 @@
|
||||
import {
|
||||
HStack,
|
||||
Icon,
|
||||
VStack,
|
||||
Text,
|
||||
Divider,
|
||||
Spinner,
|
||||
AspectRatio,
|
||||
SkeletonText,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import { formatTimePast } from "~/utils/dayjs";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsPlusSquare } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
type ExperimentData = {
|
||||
testScenarioCount: number;
|
||||
promptVariantCount: number;
|
||||
id: string;
|
||||
label: string;
|
||||
sortIndex: number;
|
||||
createdAt: Date;
|
||||
updatedAt: Date;
|
||||
};
|
||||
|
||||
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
|
||||
return (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
as={Link}
|
||||
href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
|
||||
bg="gray.50"
|
||||
_hover={{ bg: "gray.100" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
justify="space-between"
|
||||
>
|
||||
<HStack w="full" color="gray.700" justify="center">
|
||||
<Icon as={RiFlaskLine} boxSize={4} />
|
||||
<Text fontWeight="bold">{exp.label}</Text>
|
||||
</HStack>
|
||||
<HStack h="full" spacing={4} flex={1} align="center">
|
||||
<CountLabel label="Variants" count={exp.promptVariantCount} />
|
||||
<Divider h={12} orientation="vertical" />
|
||||
<CountLabel label="Scenarios" count={exp.testScenarioCount} />
|
||||
</HStack>
|
||||
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
|
||||
<Text flex={1}>Created {formatTimePast(exp.createdAt)}</Text>
|
||||
<Divider h={4} orientation="vertical" />
|
||||
<Text flex={1}>Updated {formatTimePast(exp.updatedAt)}</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||
return (
|
||||
<VStack alignItems="center" flex={1}>
|
||||
<Text color="gray.500" fontWeight="bold">
|
||||
{label}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="gray.500">
|
||||
{count}
|
||||
</Text>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export const NewExperimentCard = () => {
|
||||
const router = useRouter();
|
||||
const createMutation = api.experiments.create.useMutation();
|
||||
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
|
||||
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
|
||||
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
|
||||
}, [createMutation, router]);
|
||||
|
||||
return (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
align="center"
|
||||
justify="center"
|
||||
_hover={{ cursor: "pointer", bg: "gray.50" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
onClick={createExperiment}
|
||||
>
|
||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||
New Experiment
|
||||
</Text>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
export const ExperimentCardSkeleton = () => (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
<SkeletonText noOfLines={2} w="60%" />
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
@@ -0,0 +1,57 @@
|
||||
import {
|
||||
Button,
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogContent,
|
||||
AlertDialogOverlay,
|
||||
} from "@chakra-ui/react";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useRef } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
export const DeleteDialog = ({ onClose }: { onClose: () => void }) => {
|
||||
const experiment = useExperiment();
|
||||
const deleteMutation = api.experiments.delete.useMutation();
|
||||
const utils = api.useContext();
|
||||
const router = useRouter();
|
||||
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
await deleteMutation.mutateAsync({ id: experiment.data.id });
|
||||
await utils.experiments.list.invalidate();
|
||||
await router.push({ pathname: "/experiments" });
|
||||
onClose();
|
||||
}, [deleteMutation, experiment.data?.id, router]);
|
||||
|
||||
return (
|
||||
<AlertDialog isOpen leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete Experiment
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||
as well. Are you sure?
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,42 @@
|
||||
import { Button, HStack, Icon, Spinner, Text } from "@chakra-ui/react";
|
||||
import { useOnForkButtonPressed } from "./useOnForkButtonPressed";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { BsGearFill } from "react-icons/bs";
|
||||
import { TbGitFork } from "react-icons/tb";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
export const ExperimentHeaderButtons = () => {
|
||||
const experiment = useExperiment();
|
||||
|
||||
const canModify = experiment.data?.access.canModify ?? false;
|
||||
|
||||
const { onForkButtonPressed, isForking } = useOnForkButtonPressed();
|
||||
|
||||
const openDrawer = useAppStore((s) => s.openDrawer);
|
||||
|
||||
if (experiment.isLoading) return null;
|
||||
|
||||
return (
|
||||
<HStack spacing={0} mt={{ base: 2, md: 0 }}>
|
||||
<Button
|
||||
onClick={onForkButtonPressed}
|
||||
mr={4}
|
||||
colorScheme={canModify ? undefined : "orange"}
|
||||
bgColor={canModify ? undefined : "orange.400"}
|
||||
minW={0}
|
||||
variant={{ base: "solid", md: canModify ? "ghost" : "solid" }}
|
||||
>
|
||||
{isForking ? <Spinner boxSize={5} /> : <Icon as={TbGitFork} boxSize={5} />}
|
||||
<Text ml={2}>Fork</Text>
|
||||
</Button>
|
||||
{canModify && (
|
||||
<Button variant={{ base: "solid", md: "ghost" }} onClick={openDrawer}>
|
||||
<HStack>
|
||||
<Icon as={BsGearFill} />
|
||||
<Text>Settings</Text>
|
||||
</HStack>
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,30 @@
|
||||
import { useCallback } from "react";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import { useRouter } from "next/router";
|
||||
|
||||
export const useOnForkButtonPressed = () => {
|
||||
const router = useRouter();
|
||||
|
||||
const user = useSession().data;
|
||||
const experiment = useExperiment();
|
||||
|
||||
const forkMutation = api.experiments.fork.useMutation();
|
||||
|
||||
const [onFork, isForking] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
const forkedExperimentId = await forkMutation.mutateAsync({ id: experiment.data.id });
|
||||
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
|
||||
}, [forkMutation, experiment.data?.id, router]);
|
||||
|
||||
const onForkButtonPressed = useCallback(() => {
|
||||
if (user === null) {
|
||||
signIn("github").catch(console.error);
|
||||
} else {
|
||||
onFork();
|
||||
}
|
||||
}, [onFork, user]);
|
||||
|
||||
return { onForkButtonPressed, isForking };
|
||||
};
|
||||
151
app/src/components/nav/AppShell.tsx
Normal file
151
app/src/components/nav/AppShell.tsx
Normal file
@@ -0,0 +1,151 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import {
|
||||
Heading,
|
||||
VStack,
|
||||
Icon,
|
||||
HStack,
|
||||
Image,
|
||||
Text,
|
||||
Box,
|
||||
type BoxProps,
|
||||
Link as ChakraLink,
|
||||
Flex,
|
||||
} from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link, { type LinkProps } from "next/link";
|
||||
import { BsGithub, BsPersonCircle } from "react-icons/bs";
|
||||
import { useRouter } from "next/router";
|
||||
import { type IconType } from "react-icons";
|
||||
import { RiDatabase2Line, RiFlaskLine } from "react-icons/ri";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import UserMenu from "./UserMenu";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType; href: string };
|
||||
|
||||
const IconLink = ({ icon, label, href, color, ...props }: IconLinkProps) => {
|
||||
const router = useRouter();
|
||||
const isActive = href && router.pathname.startsWith(href);
|
||||
return (
|
||||
<Link href={href} style={{ width: "100%" }}>
|
||||
<HStack
|
||||
w="full"
|
||||
p={4}
|
||||
color={color}
|
||||
as={ChakraLink}
|
||||
bgColor={isActive ? "gray.200" : "transparent"}
|
||||
_hover={{ bgColor: "gray.300", textDecoration: "none" }}
|
||||
justifyContent="start"
|
||||
cursor="pointer"
|
||||
{...props}
|
||||
>
|
||||
<Icon as={icon} boxSize={6} mr={2} />
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{label}
|
||||
</Text>
|
||||
</HStack>
|
||||
</Link>
|
||||
);
|
||||
};
|
||||
|
||||
const Divider = () => <Box h="1px" bgColor="gray.200" />;
|
||||
|
||||
const NavSidebar = () => {
|
||||
const user = useSession().data;
|
||||
|
||||
return (
|
||||
<VStack
|
||||
align="stretch"
|
||||
bgColor="gray.100"
|
||||
py={2}
|
||||
pb={0}
|
||||
height="100%"
|
||||
w={{ base: "56px", md: "200px" }}
|
||||
overflow="hidden"
|
||||
>
|
||||
<HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={4} py={2}>
|
||||
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
|
||||
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||
OpenPipe
|
||||
</Heading>
|
||||
</HStack>
|
||||
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
|
||||
{user != null && (
|
||||
<>
|
||||
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
||||
{env.NEXT_PUBLIC_SHOW_DATA && (
|
||||
<IconLink icon={RiDatabase2Line} label="Data" href="/data" />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{user === null && (
|
||||
<HStack
|
||||
w="full"
|
||||
p={4}
|
||||
as={ChakraLink}
|
||||
_hover={{ bgColor: "gray.300", textDecoration: "none" }}
|
||||
justifyContent="start"
|
||||
cursor="pointer"
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
>
|
||||
<Icon as={BsPersonCircle} boxSize={6} mr={2} />
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
Sign In
|
||||
</Text>
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
{user ? (
|
||||
<UserMenu user={user} borderColor={"gray.200"} borderTopWidth={1} borderBottomWidth={1} />
|
||||
) : (
|
||||
<Divider />
|
||||
)}
|
||||
<VStack spacing={0} align="center">
|
||||
<ChakraLink
|
||||
href="https://github.com/openpipe/openpipe"
|
||||
target="_blank"
|
||||
color="gray.500"
|
||||
_hover={{ color: "gray.800" }}
|
||||
p={2}
|
||||
>
|
||||
<Icon as={BsGithub} boxSize={6} />
|
||||
</ChakraLink>
|
||||
</VStack>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default function AppShell(props: { children: React.ReactNode; title?: string }) {
|
||||
const [vh, setVh] = useState("100vh"); // Default height to prevent flicker on initial render
|
||||
|
||||
useEffect(() => {
|
||||
const setHeight = () => {
|
||||
const vh = window.innerHeight * 0.01;
|
||||
document.documentElement.style.setProperty("--vh", `${vh}px`);
|
||||
setVh(`calc(var(--vh, 1vh) * 100)`);
|
||||
};
|
||||
setHeight(); // Set the height at the start
|
||||
|
||||
window.addEventListener("resize", setHeight);
|
||||
window.addEventListener("orientationchange", setHeight);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener("resize", setHeight);
|
||||
window.removeEventListener("orientationchange", setHeight);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex h={vh} w="100vw">
|
||||
<Head>
|
||||
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
|
||||
</Head>
|
||||
<NavSidebar />
|
||||
<Box h="100%" flex={1} overflowY="auto">
|
||||
{props.children}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
76
app/src/components/nav/UserMenu.tsx
Normal file
76
app/src/components/nav/UserMenu.tsx
Normal file
@@ -0,0 +1,76 @@
|
||||
import {
|
||||
HStack,
|
||||
Icon,
|
||||
Image,
|
||||
VStack,
|
||||
Text,
|
||||
Popover,
|
||||
PopoverTrigger,
|
||||
PopoverContent,
|
||||
Link,
|
||||
useColorMode,
|
||||
type StackProps,
|
||||
} from "@chakra-ui/react";
|
||||
import { type Session } from "next-auth";
|
||||
import { signOut } from "next-auth/react";
|
||||
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
|
||||
|
||||
export default function UserMenu({ user, ...rest }: { user: Session } & StackProps) {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const profileImage = user.user.image ? (
|
||||
<Image src={user.user.image} alt="profile picture" boxSize={8} borderRadius="50%" />
|
||||
) : (
|
||||
<Icon as={BsPersonCircle} boxSize={6} />
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Popover placement="right">
|
||||
<PopoverTrigger>
|
||||
<HStack
|
||||
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
|
||||
px={3}
|
||||
spacing={3}
|
||||
py={2}
|
||||
{...rest}
|
||||
cursor="pointer"
|
||||
_hover={{
|
||||
bgColor: colorMode === "light" ? "gray.200" : "gray.700",
|
||||
}}
|
||||
>
|
||||
{profileImage}
|
||||
<VStack spacing={0} align="start" flex={1} flexShrink={1}>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{user.user.name}
|
||||
</Text>
|
||||
<Text color="gray.500" fontSize="xs">
|
||||
{user.user.email}
|
||||
</Text>
|
||||
</VStack>
|
||||
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
|
||||
</HStack>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
|
||||
<VStack align="stretch" spacing={0}>
|
||||
{/* sign out */}
|
||||
<HStack
|
||||
as={Link}
|
||||
onClick={() => {
|
||||
signOut().catch(console.error);
|
||||
}}
|
||||
px={4}
|
||||
py={2}
|
||||
spacing={4}
|
||||
color="gray.500"
|
||||
fontSize="sm"
|
||||
>
|
||||
<Icon as={BsBoxArrowRight} boxSize={6} />
|
||||
<Text>Sign out</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</>
|
||||
);
|
||||
}
|
||||
55
app/src/components/tooltip/CostTooltip.tsx
Normal file
55
app/src/components/tooltip/CostTooltip.tsx
Normal file
@@ -0,0 +1,55 @@
|
||||
import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from "@chakra-ui/react";
|
||||
import { BsCurrencyDollar } from "react-icons/bs";
|
||||
|
||||
type CostTooltipProps = {
|
||||
promptTokens: number | null;
|
||||
completionTokens: number | null;
|
||||
cost: number;
|
||||
} & TooltipProps;
|
||||
|
||||
export const CostTooltip = ({
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
cost,
|
||||
children,
|
||||
...props
|
||||
}: CostTooltipProps) => {
|
||||
return (
|
||||
<Tooltip
|
||||
borderRadius="8"
|
||||
color="gray.800"
|
||||
bgColor="gray.50"
|
||||
borderWidth={1}
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
label={
|
||||
<VStack fontSize="sm" w="200" spacing={4}>
|
||||
<VStack spacing={0}>
|
||||
<Text fontWeight="bold">Cost</Text>
|
||||
<HStack spacing={0}>
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text>{cost.toFixed(6)}</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
<VStack spacing={1}>
|
||||
<Text fontWeight="bold">Token Usage</Text>
|
||||
<HStack>
|
||||
<VStack w="28" spacing={1}>
|
||||
<Text>Prompt</Text>
|
||||
<Text>{promptTokens ?? 0}</Text>
|
||||
</VStack>
|
||||
<Divider borderColor="gray.200" h={8} orientation="vertical" />
|
||||
<VStack w="28" spacing={1}>
|
||||
<Text whiteSpace="nowrap">Completion</Text>
|
||||
<Text>{completionTokens ?? 0}</Text>
|
||||
</VStack>
|
||||
</HStack>
|
||||
</VStack>
|
||||
</VStack>
|
||||
}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
63
app/src/env.mjs
Normal file
63
app/src/env.mjs
Normal file
@@ -0,0 +1,63 @@
|
||||
import { createEnv } from "@t3-oss/env-nextjs";
|
||||
import { z } from "zod";
|
||||
|
||||
export const env = createEnv({
|
||||
/**
|
||||
* Specify your server-side environment variables schema here. This way you can ensure the app
|
||||
* isn't built with invalid env vars.
|
||||
*/
|
||||
server: {
|
||||
DATABASE_URL: z.string().url(),
|
||||
NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
|
||||
RESTRICT_PRISMA_LOGS: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("false")
|
||||
.transform((val) => val.toLowerCase() === "true"),
|
||||
GITHUB_CLIENT_ID: z.string().min(1),
|
||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||
OPENAI_API_KEY: z.string().min(1),
|
||||
REPLICATE_API_TOKEN: z.string().default("placeholder"),
|
||||
ANTHROPIC_API_KEY: z.string().default("placeholder"),
|
||||
SENTRY_AUTH_TOKEN: z.string().optional(),
|
||||
},
|
||||
|
||||
/**
|
||||
* Specify your client-side environment variables schema here. This way you can ensure the app
|
||||
* isn't built with invalid env vars. To expose them to the client, prefix them with
|
||||
* `NEXT_PUBLIC_`.
|
||||
*/
|
||||
client: {
|
||||
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
|
||||
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
||||
NEXT_PUBLIC_HOST: z.string().url().default("http://localhost:3000"),
|
||||
NEXT_PUBLIC_SENTRY_DSN: z.string().optional(),
|
||||
NEXT_PUBLIC_SHOW_DATA: z.string().optional(),
|
||||
},
|
||||
|
||||
/**
|
||||
* You can't destruct `process.env` as a regular object in the Next.js edge runtimes (e.g.
|
||||
* middlewares) or client-side so we need to destruct manually.
|
||||
*/
|
||||
runtimeEnv: {
|
||||
DATABASE_URL: process.env.DATABASE_URL,
|
||||
NODE_ENV: process.env.NODE_ENV,
|
||||
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
||||
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
|
||||
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||
NEXT_PUBLIC_HOST: process.env.NEXT_PUBLIC_HOST,
|
||||
NEXT_PUBLIC_SHOW_DATA: process.env.NEXT_PUBLIC_SHOW_DATA,
|
||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||
ANTHROPIC_API_KEY: process.env.ANTHROPIC_API_KEY,
|
||||
NEXT_PUBLIC_SENTRY_DSN: process.env.NEXT_PUBLIC_SENTRY_DSN,
|
||||
SENTRY_AUTH_TOKEN: process.env.SENTRY_AUTH_TOKEN,
|
||||
},
|
||||
/**
|
||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||
* This is especially useful for Docker builds.
|
||||
*/
|
||||
skipValidation: !!process.env.SKIP_ENV_VALIDATION,
|
||||
});
|
||||
@@ -0,0 +1,69 @@
|
||||
/* eslint-disable @typescript-eslint/no-var-requires */
|
||||
|
||||
import YAML from "yaml";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
|
||||
import $RefParser from "@apidevtools/json-schema-ref-parser";
|
||||
import { type JSONObject } from "superjson/dist/types";
|
||||
import assert from "assert";
|
||||
import { type JSONSchema4Object } from "json-schema";
|
||||
import { isObject } from "lodash-es";
|
||||
|
||||
// @ts-expect-error for some reason missing from types
|
||||
import parserEstree from "prettier/plugins/estree";
|
||||
import parserBabel from "prettier/plugins/babel";
|
||||
import prettier from "prettier/standalone";
|
||||
|
||||
const OPENAPI_URL =
|
||||
"https://raw.githubusercontent.com/tryAGI/Anthropic/1c0871e861de60a4c3a843cb90e17d63e86c234a/docs/openapi.yaml";
|
||||
|
||||
// Fetch the openapi document
|
||||
const response = await fetch(OPENAPI_URL);
|
||||
const openApiYaml = await response.text();
|
||||
|
||||
// Parse the yaml document
|
||||
let schema = YAML.parse(openApiYaml) as JSONObject;
|
||||
schema = openapiSchemaToJsonSchema(schema);
|
||||
|
||||
const jsonSchema = await $RefParser.dereference(schema);
|
||||
|
||||
assert("components" in jsonSchema);
|
||||
const completionRequestSchema = jsonSchema.components.schemas
|
||||
.CreateCompletionRequest as JSONSchema4Object;
|
||||
|
||||
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
||||
// the fact that the schema says `model` can be either a string or an enum,
|
||||
// and displays a warning in the editor. Let's stick with just an enum for
|
||||
// now and drop the string option.
|
||||
assert(
|
||||
"properties" in completionRequestSchema &&
|
||||
isObject(completionRequestSchema.properties) &&
|
||||
"model" in completionRequestSchema.properties &&
|
||||
isObject(completionRequestSchema.properties.model),
|
||||
);
|
||||
|
||||
const modelProperty = completionRequestSchema.properties.model;
|
||||
assert(
|
||||
"oneOf" in modelProperty &&
|
||||
Array.isArray(modelProperty.oneOf) &&
|
||||
modelProperty.oneOf.length === 2 &&
|
||||
isObject(modelProperty.oneOf[1]) &&
|
||||
"enum" in modelProperty.oneOf[1],
|
||||
"Expected model to have oneOf length of 2",
|
||||
);
|
||||
modelProperty.type = "string";
|
||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||
delete modelProperty["oneOf"];
|
||||
|
||||
// Get the directory of the current script
|
||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||
|
||||
// Write the JSON schema to a file in the current directory
|
||||
fs.writeFileSync(
|
||||
path.join(currentDirectory, "input.schema.json"),
|
||||
await prettier.format(JSON.stringify(completionRequestSchema, null, 2), {
|
||||
parser: "json",
|
||||
plugins: [parserBabel, parserEstree],
|
||||
}),
|
||||
);
|
||||
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model": {
|
||||
"description": "The model that will complete your prompt.",
|
||||
"x-oaiTypeLabel": "string",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"claude-2",
|
||||
"claude-2.0",
|
||||
"claude-instant-1",
|
||||
"claude-instant-1.1"
|
||||
]
|
||||
},
|
||||
"prompt": {
|
||||
"description": "The prompt that you want Claude to complete.\n\nFor proper response generation you will need to format your prompt as follows:\n\"\\n\\nHuman: all instructions for the assistant\\n\\nAssistant:\". The prompt string should begin with the characters \"Human:\" and end with \"Assistant:\".",
|
||||
"default": "<|endoftext|>",
|
||||
"example": "\\n\\nHuman: What is the correct translation of ${scenario.input}? I would like a long analysis followed by a short answer.\\n\\nAssistant:",
|
||||
"type": "string"
|
||||
},
|
||||
"max_tokens_to_sample": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 256,
|
||||
"nullable": true,
|
||||
"description": "The maximum number of tokens to generate before stopping."
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"nullable": true,
|
||||
"description": "Amount of randomness injected into the response.\n\nDefaults to 1."
|
||||
},
|
||||
"top_p": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"nullable": true,
|
||||
"description": "Use nucleus sampling.\n\nYou should either alter temperature or top_p, but not both.\n"
|
||||
},
|
||||
"top_k": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"default": 5,
|
||||
"nullable": true,
|
||||
"description": "Only sample from the top K options for each subsequent token."
|
||||
},
|
||||
"stream": {
|
||||
"description": "Whether to incrementally stream the response using server-sent events.",
|
||||
"type": "boolean",
|
||||
"nullable": true,
|
||||
"default": false
|
||||
},
|
||||
"stop_sequences": {
|
||||
"description": "Sequences that will cause the model to stop generating completion text.\nBy default, our models stop on \"\\n\\nHuman:\".",
|
||||
"default": null,
|
||||
"nullable": true,
|
||||
"type": "array"
|
||||
}
|
||||
},
|
||||
"required": ["model", "prompt", "max_tokens_to_sample"]
|
||||
}
|
||||
42
app/src/modelProviders/anthropic-completion/frontend.ts
Normal file
42
app/src/modelProviders/anthropic-completion/frontend.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import { type Completion } from "@anthropic-ai/sdk/resources";
|
||||
import { type SupportedModel } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
import { refinementActions } from "./refinementActions";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, Completion> = {
|
||||
name: "Replicate Llama2",
|
||||
|
||||
models: {
|
||||
"claude-2.0": {
|
||||
name: "Claude 2.0",
|
||||
contextWindow: 100000,
|
||||
promptTokenPrice: 11.02 / 1000000,
|
||||
completionTokenPrice: 32.68 / 1000000,
|
||||
speed: "medium",
|
||||
provider: "anthropic/completion",
|
||||
learnMoreUrl: "https://www.anthropic.com/product",
|
||||
apiDocsUrl: "https://docs.anthropic.com/claude/reference/complete_post",
|
||||
},
|
||||
"claude-instant-1.1": {
|
||||
name: "Claude Instant 1.1",
|
||||
contextWindow: 100000,
|
||||
promptTokenPrice: 1.63 / 1000000,
|
||||
completionTokenPrice: 5.51 / 1000000,
|
||||
speed: "fast",
|
||||
provider: "anthropic/completion",
|
||||
learnMoreUrl: "https://www.anthropic.com/product",
|
||||
apiDocsUrl: "https://docs.anthropic.com/claude/reference/complete_post",
|
||||
},
|
||||
},
|
||||
|
||||
refinementActions,
|
||||
|
||||
normalizeOutput: (output) => {
|
||||
return {
|
||||
type: "text",
|
||||
value: output.completion,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
export default frontendModelProvider;
|
||||
86
app/src/modelProviders/anthropic-completion/getCompletion.ts
Normal file
86
app/src/modelProviders/anthropic-completion/getCompletion.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { env } from "~/env.mjs";
|
||||
import { type CompletionResponse } from "../types";
|
||||
|
||||
import Anthropic, { APIError } from "@anthropic-ai/sdk";
|
||||
import { type Completion, type CompletionCreateParams } from "@anthropic-ai/sdk/resources";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
apiKey: env.ANTHROPIC_API_KEY,
|
||||
});
|
||||
|
||||
export async function getCompletion(
|
||||
input: CompletionCreateParams,
|
||||
onStream: ((partialOutput: Completion) => void) | null,
|
||||
): Promise<CompletionResponse<Completion>> {
|
||||
const start = Date.now();
|
||||
let finalCompletion: Completion | null = null;
|
||||
|
||||
try {
|
||||
if (onStream) {
|
||||
const resp = await anthropic.completions.create(
|
||||
{ ...input, stream: true },
|
||||
{
|
||||
maxRetries: 0,
|
||||
},
|
||||
);
|
||||
|
||||
for await (const part of resp) {
|
||||
if (finalCompletion === null) {
|
||||
finalCompletion = part;
|
||||
} else {
|
||||
finalCompletion = { ...part, completion: finalCompletion.completion + part.completion };
|
||||
}
|
||||
onStream(finalCompletion);
|
||||
}
|
||||
if (!finalCompletion) {
|
||||
return {
|
||||
type: "error",
|
||||
message: "Streaming failed to return a completion",
|
||||
autoRetry: false,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
const resp = await anthropic.completions.create(
|
||||
{ ...input, stream: false },
|
||||
{
|
||||
maxRetries: 0,
|
||||
},
|
||||
);
|
||||
finalCompletion = resp;
|
||||
}
|
||||
const timeToComplete = Date.now() - start;
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
statusCode: 200,
|
||||
value: finalCompletion,
|
||||
timeToComplete,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
console.log("CAUGHT ERROR", error);
|
||||
if (error instanceof APIError) {
|
||||
const message =
|
||||
isObject(error.error) &&
|
||||
"error" in error.error &&
|
||||
isObject(error.error.error) &&
|
||||
"message" in error.error.error &&
|
||||
isString(error.error.error.message)
|
||||
? error.error.error.message
|
||||
: error.message;
|
||||
|
||||
return {
|
||||
type: "error",
|
||||
message,
|
||||
autoRetry: error.status === 429 || error.status === 503,
|
||||
statusCode: error.status,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "error",
|
||||
message: (error as Error).message,
|
||||
autoRetry: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
34
app/src/modelProviders/anthropic-completion/index.ts
Normal file
34
app/src/modelProviders/anthropic-completion/index.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { type JSONSchema4 } from "json-schema";
|
||||
import { type ModelProvider } from "../types";
|
||||
import inputSchema from "./codegen/input.schema.json";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import type { Completion, CompletionCreateParams } from "@anthropic-ai/sdk/resources";
|
||||
|
||||
const supportedModels = ["claude-2.0", "claude-instant-1.1"] as const;
|
||||
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type AnthropicProvider = ModelProvider<SupportedModel, CompletionCreateParams, Completion>;
|
||||
|
||||
const modelProvider: AnthropicProvider = {
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model as SupportedModel))
|
||||
return input.model as SupportedModel;
|
||||
|
||||
const modelMaps: Record<string, SupportedModel> = {
|
||||
"claude-2": "claude-2.0",
|
||||
"claude-instant-1": "claude-instant-1.1",
|
||||
};
|
||||
|
||||
if (input.model in modelMaps) return modelMaps[input.model] as SupportedModel;
|
||||
|
||||
return null;
|
||||
},
|
||||
inputSchema: inputSchema as JSONSchema4,
|
||||
canStream: true,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
@@ -0,0 +1,3 @@
|
||||
import { type RefinementAction } from "../types";
|
||||
|
||||
export const refinementActions: Record<string, RefinementAction> = {};
|
||||
15
app/src/modelProviders/frontendModelProviders.ts
Normal file
15
app/src/modelProviders/frontendModelProviders.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||
import anthropicFrontend from "./anthropic-completion/frontend";
|
||||
import { type SupportedProvider, type FrontendModelProvider } from "./types";
|
||||
|
||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||
// just include them in the default `modelProviders` object because it has some
|
||||
// transient dependencies that can only be imported on the server.
|
||||
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||
"replicate/llama2": replicateLlama2Frontend,
|
||||
"anthropic/completion": anthropicFrontend,
|
||||
};
|
||||
|
||||
export default frontendModelProviders;
|
||||
36
app/src/modelProviders/generateTypes.ts
Normal file
36
app/src/modelProviders/generateTypes.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { type JSONSchema4Object } from "json-schema";
|
||||
import modelProviders from "./modelProviders";
|
||||
import { compile } from "json-schema-to-typescript";
|
||||
import dedent from "dedent";
|
||||
|
||||
export default async function generateTypes() {
|
||||
const combinedSchema = {
|
||||
type: "object",
|
||||
properties: {} as Record<string, JSONSchema4Object>,
|
||||
};
|
||||
|
||||
Object.entries(modelProviders).forEach(([id, provider]) => {
|
||||
combinedSchema.properties[id] = provider.inputSchema;
|
||||
});
|
||||
|
||||
Object.entries(modelProviders).forEach(([id, provider]) => {
|
||||
combinedSchema.properties[id] = provider.inputSchema;
|
||||
});
|
||||
|
||||
const promptTypes = (
|
||||
await compile(combinedSchema as JSONSchema4Object, "PromptTypes", {
|
||||
additionalProperties: false,
|
||||
bannerComment: dedent`
|
||||
/**
|
||||
* This type map defines the input types for each model provider.
|
||||
*/
|
||||
`,
|
||||
})
|
||||
).replace(/export interface PromptTypes/g, "interface PromptTypes");
|
||||
|
||||
return dedent`
|
||||
${promptTypes}
|
||||
|
||||
declare function definePrompt<T extends keyof PromptTypes>(modelProvider: T, input: PromptTypes[T])
|
||||
`;
|
||||
}
|
||||
12
app/src/modelProviders/modelProviders.ts
Normal file
12
app/src/modelProviders/modelProviders.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||
import replicateLlama2 from "./replicate-llama2";
|
||||
import anthropicCompletion from "./anthropic-completion";
|
||||
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||
|
||||
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletion,
|
||||
"replicate/llama2": replicateLlama2,
|
||||
"anthropic/completion": anthropicCompletion,
|
||||
};
|
||||
|
||||
export default modelProviders;
|
||||
@@ -0,0 +1,77 @@
|
||||
/* eslint-disable @typescript-eslint/no-var-requires */
|
||||
|
||||
import YAML from "yaml";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { openapiSchemaToJsonSchema } from "@openapi-contrib/openapi-schema-to-json-schema";
|
||||
import $RefParser from "@apidevtools/json-schema-ref-parser";
|
||||
import { type JSONObject } from "superjson/dist/types";
|
||||
import assert from "assert";
|
||||
import { type JSONSchema4Object } from "json-schema";
|
||||
import { isObject } from "lodash-es";
|
||||
|
||||
// @ts-expect-error for some reason missing from types
|
||||
import parserEstree from "prettier/plugins/estree";
|
||||
import parserBabel from "prettier/plugins/babel";
|
||||
import prettier from "prettier/standalone";
|
||||
|
||||
const OPENAPI_URL =
|
||||
"https://raw.githubusercontent.com/openai/openai-openapi/0c432eb66fd0c758fd8b9bd69db41c1096e5f4db/openapi.yaml";
|
||||
|
||||
// Fetch the openapi document
|
||||
const response = await fetch(OPENAPI_URL);
|
||||
const openApiYaml = await response.text();
|
||||
|
||||
// Parse the yaml document
|
||||
let schema = YAML.parse(openApiYaml) as JSONObject;
|
||||
schema = openapiSchemaToJsonSchema(schema);
|
||||
|
||||
const jsonSchema = await $RefParser.dereference(schema);
|
||||
|
||||
assert("components" in jsonSchema);
|
||||
const completionRequestSchema = jsonSchema.components.schemas
|
||||
.CreateChatCompletionRequest as JSONSchema4Object;
|
||||
|
||||
// We need to do a bit of surgery here since the Monaco editor doesn't like
|
||||
// the fact that the schema says `model` can be either a string or an enum,
|
||||
// and displays a warning in the editor. Let's stick with just an enum for
|
||||
// now and drop the string option.
|
||||
assert(
|
||||
"properties" in completionRequestSchema &&
|
||||
isObject(completionRequestSchema.properties) &&
|
||||
"model" in completionRequestSchema.properties &&
|
||||
isObject(completionRequestSchema.properties.model),
|
||||
);
|
||||
|
||||
const modelProperty = completionRequestSchema.properties.model;
|
||||
assert(
|
||||
"oneOf" in modelProperty &&
|
||||
Array.isArray(modelProperty.oneOf) &&
|
||||
modelProperty.oneOf.length === 2 &&
|
||||
isObject(modelProperty.oneOf[1]) &&
|
||||
"enum" in modelProperty.oneOf[1],
|
||||
"Expected model to have oneOf length of 2",
|
||||
);
|
||||
modelProperty.type = "string";
|
||||
modelProperty.enum = modelProperty.oneOf[1].enum;
|
||||
delete modelProperty["oneOf"];
|
||||
|
||||
// The default of "inf" confuses the Typescript generator, so can just remove it
|
||||
assert(
|
||||
"max_tokens" in completionRequestSchema.properties &&
|
||||
isObject(completionRequestSchema.properties.max_tokens) &&
|
||||
"default" in completionRequestSchema.properties.max_tokens,
|
||||
);
|
||||
delete completionRequestSchema.properties.max_tokens["default"];
|
||||
|
||||
// Get the directory of the current script
|
||||
const currentDirectory = path.dirname(import.meta.url).replace("file://", "");
|
||||
|
||||
// Write the JSON schema to a file in the current directory
|
||||
fs.writeFileSync(
|
||||
path.join(currentDirectory, "input.schema.json"),
|
||||
await prettier.format(JSON.stringify(completionRequestSchema, null, 2), {
|
||||
parser: "json",
|
||||
plugins: [parserBabel, parserEstree],
|
||||
}),
|
||||
);
|
||||
@@ -0,0 +1,185 @@
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model": {
|
||||
"description": "ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.",
|
||||
"example": "gpt-3.5-turbo",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613"
|
||||
]
|
||||
},
|
||||
"messages": {
|
||||
"description": "A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb).",
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"enum": ["system", "user", "assistant", "function"],
|
||||
"description": "The role of the messages author. One of `system`, `user`, `assistant`, or `function`."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The contents of the message. `content` is required for all messages except assistant messages with function calls."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters."
|
||||
},
|
||||
"function_call": {
|
||||
"type": "object",
|
||||
"description": "The name and arguments of a function that should be called, as generated by the model.",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name of the function to call."
|
||||
},
|
||||
"arguments": {
|
||||
"type": "string",
|
||||
"description": "The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["role"]
|
||||
}
|
||||
},
|
||||
"functions": {
|
||||
"description": "A list of functions the model may generate JSON inputs for.",
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "The description of what the function does."
|
||||
},
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"description": "The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format.",
|
||||
"additionalProperties": true
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
},
|
||||
"function_call": {
|
||||
"description": "Controls how the model responds to function calls. \"none\" means the model does not call a function, and responds to the end-user. \"auto\" means the model can pick between an end-user or calling a function. Specifying a particular function via `{\"name\":\\ \"my_function\"}` forces the model to call that function. \"none\" is the default when no functions are present. \"auto\" is the default if functions are present.",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"enum": ["none", "auto"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name of the function to call."
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 2,
|
||||
"default": 1,
|
||||
"example": 1,
|
||||
"nullable": true,
|
||||
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n"
|
||||
},
|
||||
"top_p": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"default": 1,
|
||||
"example": 1,
|
||||
"nullable": true,
|
||||
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n"
|
||||
},
|
||||
"n": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 128,
|
||||
"default": 1,
|
||||
"example": 1,
|
||||
"nullable": true,
|
||||
"description": "How many chat completion choices to generate for each input message."
|
||||
},
|
||||
"stream": {
|
||||
"description": "If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).\n",
|
||||
"type": "boolean",
|
||||
"nullable": true,
|
||||
"default": false
|
||||
},
|
||||
"stop": {
|
||||
"description": "Up to 4 sequences where the API will stop generating further tokens.\n",
|
||||
"default": null,
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 4,
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n",
|
||||
"type": "integer"
|
||||
},
|
||||
"presence_penalty": {
|
||||
"type": "number",
|
||||
"default": 0,
|
||||
"minimum": -2,
|
||||
"maximum": 2,
|
||||
"nullable": true,
|
||||
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
|
||||
},
|
||||
"frequency_penalty": {
|
||||
"type": "number",
|
||||
"default": 0,
|
||||
"minimum": -2,
|
||||
"maximum": 2,
|
||||
"nullable": true,
|
||||
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n"
|
||||
},
|
||||
"logit_bias": {
|
||||
"type": "object",
|
||||
"x-oaiTypeLabel": "map",
|
||||
"default": null,
|
||||
"nullable": true,
|
||||
"description": "Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n"
|
||||
},
|
||||
"user": {
|
||||
"type": "string",
|
||||
"example": "user-1234",
|
||||
"description": "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n"
|
||||
}
|
||||
},
|
||||
"required": ["model", "messages"]
|
||||
}
|
||||
87
app/src/modelProviders/openai-ChatCompletion/frontend.ts
Normal file
87
app/src/modelProviders/openai-ChatCompletion/frontend.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { type SupportedModel } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { refinementActions } from "./refinementActions";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
|
||||
name: "OpenAI ChatCompletion",
|
||||
|
||||
models: {
|
||||
"gpt-4-0613": {
|
||||
name: "GPT-4",
|
||||
contextWindow: 8192,
|
||||
promptTokenPrice: 0.00003,
|
||||
completionTokenPrice: 0.00006,
|
||||
speed: "medium",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
name: "GPT-4 32k",
|
||||
contextWindow: 32768,
|
||||
promptTokenPrice: 0.00006,
|
||||
completionTokenPrice: 0.00012,
|
||||
speed: "medium",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://openai.com/gpt-4",
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
name: "GPT-3.5 Turbo",
|
||||
contextWindow: 4096,
|
||||
promptTokenPrice: 0.0000015,
|
||||
completionTokenPrice: 0.000002,
|
||||
speed: "fast",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
name: "GPT-3.5 Turbo 16k",
|
||||
contextWindow: 16384,
|
||||
promptTokenPrice: 0.000003,
|
||||
completionTokenPrice: 0.000004,
|
||||
speed: "fast",
|
||||
provider: "openai/ChatCompletion",
|
||||
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
|
||||
},
|
||||
},
|
||||
|
||||
refinementActions,
|
||||
|
||||
normalizeOutput: (output) => {
|
||||
const message = output.choices[0]?.message;
|
||||
if (!message)
|
||||
return {
|
||||
type: "json",
|
||||
value: output as unknown as JsonValue,
|
||||
};
|
||||
|
||||
if (message.content) {
|
||||
return {
|
||||
type: "text",
|
||||
value: message.content,
|
||||
};
|
||||
} else if (message.function_call) {
|
||||
let args = message.function_call.arguments ?? "";
|
||||
try {
|
||||
args = JSON.parse(args);
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
}
|
||||
return {
|
||||
type: "json",
|
||||
value: {
|
||||
...message.function_call,
|
||||
arguments: args,
|
||||
},
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "json",
|
||||
value: message as unknown as JsonValue,
|
||||
};
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
export default frontendModelProvider;
|
||||
152
app/src/modelProviders/openai-ChatCompletion/getCompletion.ts
Normal file
152
app/src/modelProviders/openai-ChatCompletion/getCompletion.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||
import {
|
||||
type ChatCompletionChunk,
|
||||
type ChatCompletion,
|
||||
type CompletionCreateParams,
|
||||
} from "openai/resources/chat";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { type CompletionResponse } from "../types";
|
||||
import { isArray, isString, omit } from "lodash-es";
|
||||
import { openai } from "~/server/utils/openai";
|
||||
import { truthyFilter } from "~/utils/utils";
|
||||
import { APIError } from "openai";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import modelProvider, { type SupportedModel } from ".";
|
||||
|
||||
const mergeStreamedChunks = (
|
||||
base: ChatCompletion | null,
|
||||
chunk: ChatCompletionChunk,
|
||||
): ChatCompletion => {
|
||||
if (base === null) {
|
||||
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
|
||||
}
|
||||
|
||||
const choices = [...base.choices];
|
||||
for (const choice of chunk.choices) {
|
||||
const baseChoice = choices.find((c) => c.index === choice.index);
|
||||
if (baseChoice) {
|
||||
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
|
||||
baseChoice.message = baseChoice.message ?? { role: "assistant" };
|
||||
|
||||
if (choice.delta?.content)
|
||||
baseChoice.message.content =
|
||||
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
|
||||
if (choice.delta?.function_call) {
|
||||
const fnCall = baseChoice.message.function_call ?? {};
|
||||
fnCall.name =
|
||||
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
|
||||
fnCall.arguments =
|
||||
((fnCall.arguments as string) ?? "") +
|
||||
((choice.delta.function_call.arguments as string) ?? "");
|
||||
}
|
||||
} else {
|
||||
// @ts-expect-error the types are correctly telling us that finish_reason
|
||||
// could be null, but don't want to fix it right now.
|
||||
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
|
||||
}
|
||||
}
|
||||
|
||||
const merged: ChatCompletion = {
|
||||
...base,
|
||||
choices,
|
||||
};
|
||||
|
||||
return merged;
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
input: CompletionCreateParams,
|
||||
onStream: ((partialOutput: ChatCompletion) => void) | null,
|
||||
): Promise<CompletionResponse<ChatCompletion>> {
|
||||
const start = Date.now();
|
||||
let finalCompletion: ChatCompletion | null = null;
|
||||
let promptTokens: number | undefined = undefined;
|
||||
let completionTokens: number | undefined = undefined;
|
||||
const modelName = modelProvider.getModel(input) as SupportedModel;
|
||||
|
||||
try {
|
||||
if (onStream) {
|
||||
console.log("got started");
|
||||
const resp = await openai.chat.completions.create(
|
||||
{ ...input, stream: true },
|
||||
{
|
||||
maxRetries: 0,
|
||||
},
|
||||
);
|
||||
for await (const part of resp) {
|
||||
console.log("got part", part);
|
||||
finalCompletion = mergeStreamedChunks(finalCompletion, part);
|
||||
onStream(finalCompletion);
|
||||
}
|
||||
console.log("got final", finalCompletion);
|
||||
if (!finalCompletion) {
|
||||
return {
|
||||
type: "error",
|
||||
message: "Streaming failed to return a completion",
|
||||
autoRetry: false,
|
||||
};
|
||||
}
|
||||
try {
|
||||
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
||||
completionTokens = countOpenAIChatTokens(
|
||||
modelName,
|
||||
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
||||
);
|
||||
} catch (err) {
|
||||
// TODO handle this, library seems like maybe it doesn't work with function calls?
|
||||
console.error(err);
|
||||
}
|
||||
} else {
|
||||
const resp = await openai.chat.completions.create(
|
||||
{ ...input, stream: false },
|
||||
{
|
||||
maxRetries: 0,
|
||||
},
|
||||
);
|
||||
finalCompletion = resp;
|
||||
promptTokens = resp.usage?.prompt_tokens ?? 0;
|
||||
completionTokens = resp.usage?.completion_tokens ?? 0;
|
||||
}
|
||||
const timeToComplete = Date.now() - start;
|
||||
|
||||
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
||||
let cost = undefined;
|
||||
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
||||
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
||||
}
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
statusCode: 200,
|
||||
value: finalCompletion,
|
||||
timeToComplete,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
cost,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof APIError) {
|
||||
// The types from the sdk are wrong
|
||||
const rawMessage = error.message as string | string[];
|
||||
// If the message is not a string, stringify it
|
||||
const message = isString(rawMessage)
|
||||
? rawMessage
|
||||
: isArray(rawMessage)
|
||||
? rawMessage.map((m) => m.toString()).join("\n")
|
||||
: (rawMessage as any).toString();
|
||||
return {
|
||||
type: "error",
|
||||
message,
|
||||
autoRetry: error.status === 429 || error.status === 503,
|
||||
statusCode: error.status,
|
||||
};
|
||||
} else {
|
||||
console.error(error);
|
||||
return {
|
||||
type: "error",
|
||||
message: (error as Error).message,
|
||||
autoRetry: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
45
app/src/modelProviders/openai-ChatCompletion/index.ts
Normal file
45
app/src/modelProviders/openai-ChatCompletion/index.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { type JSONSchema4 } from "json-schema";
|
||||
import { type ModelProvider } from "../types";
|
||||
import inputSchema from "./codegen/input.schema.json";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
import frontendModelProvider from "./frontend";
|
||||
|
||||
const supportedModels = [
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
] as const;
|
||||
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type OpenaiChatModelProvider = ModelProvider<
|
||||
SupportedModel,
|
||||
CompletionCreateParams,
|
||||
ChatCompletion
|
||||
>;
|
||||
|
||||
const modelProvider: OpenaiChatModelProvider = {
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model as SupportedModel))
|
||||
return input.model as SupportedModel;
|
||||
|
||||
const modelMaps: Record<string, SupportedModel> = {
|
||||
"gpt-4": "gpt-4-0613",
|
||||
"gpt-4-32k": "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
|
||||
};
|
||||
|
||||
if (input.model in modelMaps) return modelMaps[input.model] as SupportedModel;
|
||||
|
||||
return null;
|
||||
},
|
||||
inputSchema: inputSchema as JSONSchema4,
|
||||
canStream: true,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
@@ -0,0 +1,279 @@
|
||||
import { TfiThought } from "react-icons/tfi";
|
||||
import { type RefinementAction } from "../types";
|
||||
import { VscJson } from "react-icons/vsc";
|
||||
|
||||
export const refinementActions: Record<string, RefinementAction> = {
|
||||
"Add chain of thought": {
|
||||
icon: VscJson,
|
||||
description: "Asking the model to plan its answer can increase accuracy.",
|
||||
instructions: `Adding chain of thought means asking the model to think about its answer before it gives it to you. This is useful for getting more accurate answers. Do not add an assistant message.
|
||||
|
||||
This is what a prompt looks like before adding chain of thought:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
This is what one looks like after adding chain of thought:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
Here's another example:
|
||||
|
||||
Before:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale. Provide an explanation, but always provide a score afterward.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
explanation: {
|
||||
type: "string",
|
||||
}
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
});
|
||||
|
||||
Add chain of thought to the original prompt.`,
|
||||
},
|
||||
"Convert to function call": {
|
||||
icon: TfiThought,
|
||||
description: "Use function calls to get output from the model in a more structured way.",
|
||||
instructions: `OpenAI functions are a specialized way for an LLM to return output.
|
||||
|
||||
This is what a prompt looks like before adding a function:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Evaluate sentiment.\`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
This is what one looks like after adding a function:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-4",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Evaluate sentiment.",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: scenario.user_message,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "extract_sentiment",
|
||||
parameters: {
|
||||
type: "object", // parameters must always be an object with a properties key
|
||||
properties: { // properties key is required
|
||||
sentiment: {
|
||||
type: "string",
|
||||
description: "one of positive/negative/neutral",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "extract_sentiment",
|
||||
},
|
||||
});
|
||||
|
||||
Here's another example of adding a function:
|
||||
|
||||
Before:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Here is the title and body of a reddit post I am interested in:
|
||||
|
||||
title: \${scenario.title}
|
||||
body: \${scenario.body}
|
||||
|
||||
On a scale from 1 to 3, how likely is it that the person writing this post has the following need? If you are not sure, make your best guess, or answer 1.
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Answer one integer between 1 and 3.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`Title: \${scenario.title}
|
||||
Body: \${scenario.body}
|
||||
|
||||
Need: \${scenario.need}
|
||||
|
||||
Rate likelihood on 1-3 scale.\`,
|
||||
},
|
||||
],
|
||||
temperature: 0,
|
||||
functions: [
|
||||
{
|
||||
name: "score_post",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
score: {
|
||||
type: "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "score_post",
|
||||
},
|
||||
});
|
||||
|
||||
Another example
|
||||
|
||||
Before:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
After:
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in \${scenario.language}\`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "write_in_language",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
text: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "write_in_language",
|
||||
},
|
||||
});
|
||||
|
||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||
},
|
||||
};
|
||||
45
app/src/modelProviders/replicate-llama2/frontend.ts
Normal file
45
app/src/modelProviders/replicate-llama2/frontend.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { type SupportedModel, type ReplicateLlama2Output } from ".";
|
||||
import { type FrontendModelProvider } from "../types";
|
||||
import { refinementActions } from "./refinementActions";
|
||||
|
||||
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
|
||||
name: "Replicate Llama2",
|
||||
|
||||
models: {
|
||||
"7b-chat": {
|
||||
name: "LLama 2 7B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0023,
|
||||
speed: "fast",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/a16z-infra/llama7b-v2-chat",
|
||||
},
|
||||
"13b-chat": {
|
||||
name: "LLama 2 13B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0023,
|
||||
speed: "medium",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/a16z-infra/llama13b-v2-chat",
|
||||
},
|
||||
"70b-chat": {
|
||||
name: "LLama 2 70B Chat",
|
||||
contextWindow: 4096,
|
||||
pricePerSecond: 0.0032,
|
||||
speed: "slow",
|
||||
provider: "replicate/llama2",
|
||||
learnMoreUrl: "https://replicate.com/replicate/llama70b-v2-chat",
|
||||
},
|
||||
},
|
||||
|
||||
refinementActions,
|
||||
|
||||
normalizeOutput: (output) => {
|
||||
return {
|
||||
type: "text",
|
||||
value: output.join(""),
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
export default frontendModelProvider;
|
||||
60
app/src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
60
app/src/modelProviders/replicate-llama2/getCompletion.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import { env } from "~/env.mjs";
|
||||
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
|
||||
import { type CompletionResponse } from "../types";
|
||||
import Replicate from "replicate";
|
||||
|
||||
const replicate = new Replicate({
|
||||
auth: env.REPLICATE_API_TOKEN || "",
|
||||
});
|
||||
|
||||
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||
"7b-chat": "4f0b260b6a13eb53a6b1891f089d57c08f41003ae79458be5011303d81a394dc",
|
||||
"13b-chat": "2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
|
||||
"70b-chat": "2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1",
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
input: ReplicateLlama2Input,
|
||||
onStream: ((partialOutput: string[]) => void) | null,
|
||||
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||
const start = Date.now();
|
||||
|
||||
const { model, ...rest } = input;
|
||||
|
||||
try {
|
||||
const prediction = await replicate.predictions.create({
|
||||
version: modelIds[model],
|
||||
input: rest,
|
||||
});
|
||||
|
||||
const interval = onStream
|
||||
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
||||
setInterval(async () => {
|
||||
const partialPrediction = await replicate.predictions.get(prediction.id);
|
||||
|
||||
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
|
||||
}, 500)
|
||||
: null;
|
||||
|
||||
const resp = await replicate.wait(prediction, {});
|
||||
if (interval) clearInterval(interval);
|
||||
|
||||
const timeToComplete = Date.now() - start;
|
||||
|
||||
if (resp.error) throw new Error(resp.error as string);
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
statusCode: 200,
|
||||
value: resp.output as ReplicateLlama2Output,
|
||||
timeToComplete,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
console.error("ERROR IS", error);
|
||||
return {
|
||||
type: "error",
|
||||
message: (error as Error).message,
|
||||
autoRetry: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
81
app/src/modelProviders/replicate-llama2/index.ts
Normal file
81
app/src/modelProviders/replicate-llama2/index.ts
Normal file
@@ -0,0 +1,81 @@
|
||||
import { type ModelProvider } from "../types";
|
||||
import frontendModelProvider from "./frontend";
|
||||
import { getCompletion } from "./getCompletion";
|
||||
|
||||
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
||||
|
||||
export type SupportedModel = (typeof supportedModels)[number];
|
||||
|
||||
export type ReplicateLlama2Input = {
|
||||
model: SupportedModel;
|
||||
prompt: string;
|
||||
max_length?: number;
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
repetition_penalty?: number;
|
||||
debug?: boolean;
|
||||
};
|
||||
|
||||
export type ReplicateLlama2Output = string[];
|
||||
|
||||
export type ReplicateLlama2Provider = ModelProvider<
|
||||
SupportedModel,
|
||||
ReplicateLlama2Input,
|
||||
ReplicateLlama2Output
|
||||
>;
|
||||
|
||||
const modelProvider: ReplicateLlama2Provider = {
|
||||
getModel: (input) => {
|
||||
if (supportedModels.includes(input.model)) return input.model;
|
||||
|
||||
return null;
|
||||
},
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
model: {
|
||||
type: "string",
|
||||
enum: supportedModels as unknown as string[],
|
||||
},
|
||||
system_prompt: {
|
||||
type: "string",
|
||||
description:
|
||||
"System prompt to send to Llama v2. This is prepended to the prompt and helps guide system behavior.",
|
||||
},
|
||||
prompt: {
|
||||
type: "string",
|
||||
description: "Prompt to send to Llama v2.",
|
||||
},
|
||||
max_new_tokens: {
|
||||
type: "number",
|
||||
description:
|
||||
"Maximum number of tokens to generate. A word is generally 2-3 tokens (minimum: 1)",
|
||||
},
|
||||
temperature: {
|
||||
type: "number",
|
||||
description:
|
||||
"Adjusts randomness of outputs, 0.1 is a good starting value. (minimum: 0.01; maximum: 5)",
|
||||
},
|
||||
top_p: {
|
||||
type: "number",
|
||||
description:
|
||||
"When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens (minimum: 0.01; maximum: 1)",
|
||||
},
|
||||
repetition_penalty: {
|
||||
type: "number",
|
||||
description:
|
||||
"Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it. (minimum: 0.01; maximum: 5)",
|
||||
},
|
||||
debug: {
|
||||
type: "boolean",
|
||||
description: "provide debugging output in logs",
|
||||
},
|
||||
},
|
||||
required: ["model", "prompt"],
|
||||
},
|
||||
canStream: true,
|
||||
getCompletion,
|
||||
...frontendModelProvider,
|
||||
};
|
||||
|
||||
export default modelProvider;
|
||||
@@ -0,0 +1,3 @@
|
||||
import { type RefinementAction } from "../types";
|
||||
|
||||
export const refinementActions: Record<string, RefinementAction> = {};
|
||||
72
app/src/modelProviders/types.ts
Normal file
72
app/src/modelProviders/types.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import { type JSONSchema4 } from "json-schema";
|
||||
import { type IconType } from "react-icons";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { z } from "zod";
|
||||
|
||||
export const ZodSupportedProvider = z.union([
|
||||
z.literal("openai/ChatCompletion"),
|
||||
z.literal("replicate/llama2"),
|
||||
z.literal("anthropic/completion"),
|
||||
]);
|
||||
|
||||
export type SupportedProvider = z.infer<typeof ZodSupportedProvider>;
|
||||
|
||||
export type Model = {
|
||||
name: string;
|
||||
contextWindow: number;
|
||||
promptTokenPrice?: number;
|
||||
completionTokenPrice?: number;
|
||||
pricePerSecond?: number;
|
||||
speed: "fast" | "medium" | "slow";
|
||||
provider: SupportedProvider;
|
||||
description?: string;
|
||||
learnMoreUrl?: string;
|
||||
apiDocsUrl?: string;
|
||||
};
|
||||
|
||||
export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };
|
||||
|
||||
export type RefinementAction = { icon?: IconType; description: string; instructions: string };
|
||||
|
||||
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
|
||||
name: string;
|
||||
models: Record<SupportedModels, Model>;
|
||||
refinementActions?: Record<string, RefinementAction>;
|
||||
|
||||
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
|
||||
};
|
||||
|
||||
export type CompletionResponse<T> =
|
||||
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
|
||||
| {
|
||||
type: "success";
|
||||
value: T;
|
||||
timeToComplete: number;
|
||||
statusCode: number;
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
cost?: number;
|
||||
};
|
||||
|
||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||
getModel: (input: InputSchema) => SupportedModels | null;
|
||||
canStream: boolean;
|
||||
inputSchema: JSONSchema4;
|
||||
getCompletion: (
|
||||
input: InputSchema,
|
||||
onStream: ((partialOutput: OutputSchema) => void) | null,
|
||||
) => Promise<CompletionResponse<OutputSchema>>;
|
||||
|
||||
// This is just a convenience for type inference, don't use it at runtime
|
||||
_outputSchema?: OutputSchema | null;
|
||||
} & FrontendModelProvider<SupportedModels, OutputSchema>;
|
||||
|
||||
export type NormalizedOutput =
|
||||
| {
|
||||
type: "text";
|
||||
value: string;
|
||||
}
|
||||
| {
|
||||
type: "json";
|
||||
value: JsonValue;
|
||||
};
|
||||
50
app/src/pages/_app.tsx
Normal file
50
app/src/pages/_app.tsx
Normal file
@@ -0,0 +1,50 @@
|
||||
import { type Session } from "next-auth";
|
||||
import { SessionProvider } from "next-auth/react";
|
||||
import { type AppType } from "next/app";
|
||||
import { api } from "~/utils/api";
|
||||
import Favicon from "~/components/Favicon";
|
||||
import Head from "next/head";
|
||||
import { ChakraThemeProvider } from "~/theme/ChakraThemeProvider";
|
||||
import { SyncAppStore } from "~/state/sync";
|
||||
import NextAdapterApp from "next-query-params/app";
|
||||
import { QueryParamProvider } from "use-query-params";
|
||||
import { SessionIdentifier } from "~/utils/analytics/clientAnalytics";
|
||||
|
||||
const MyApp: AppType<{ session: Session | null }> = ({
|
||||
Component,
|
||||
pageProps: { session, ...pageProps },
|
||||
}) => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=0"
|
||||
/>
|
||||
<meta name="og:title" content="OpenPipe: Open-Source Lab for LLMs" key="title" />
|
||||
<meta
|
||||
name="og:description"
|
||||
content="OpenPipe is a powerful playground for quickly optimizing performance, cost, and speed across models."
|
||||
key="description"
|
||||
/>
|
||||
<meta name="og:image" content="/og.png" key="og-image" />
|
||||
<meta property="og:image:height" content="630" />
|
||||
<meta property="og:image:width" content="1200" />
|
||||
<meta name="twitter:card" content="summary_large_image" />
|
||||
<meta name="twitter:image" content="/og.png" />
|
||||
</Head>
|
||||
<SessionProvider session={session}>
|
||||
<SyncAppStore />
|
||||
<Favicon />
|
||||
<SessionIdentifier />
|
||||
<ChakraThemeProvider>
|
||||
<QueryParamProvider adapter={NextAdapterApp}>
|
||||
<Component {...pageProps} />
|
||||
</QueryParamProvider>
|
||||
</ChakraThemeProvider>
|
||||
</SessionProvider>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default api.withTRPC(MyApp);
|
||||
23
app/src/pages/account/signin.tsx
Normal file
23
app/src/pages/account/signin.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import { useRouter } from "next/router";
|
||||
import { useEffect } from "react";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
|
||||
export default function SignIn() {
|
||||
const session = useSession().data;
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
if (session) {
|
||||
router.push("/experiments").catch(console.error);
|
||||
} else if (session === null) {
|
||||
signIn("github").catch(console.error);
|
||||
}
|
||||
}, [session, router]);
|
||||
|
||||
return (
|
||||
<AppShell>
|
||||
<div />
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
4
app/src/pages/api/auth/[...nextauth].ts
Normal file
4
app/src/pages/api/auth/[...nextauth].ts
Normal file
@@ -0,0 +1,4 @@
|
||||
import NextAuth from "next-auth";
|
||||
import { authOptions } from "~/server/auth";
|
||||
|
||||
export default NextAuth(authOptions);
|
||||
81
app/src/pages/api/experiments/og-image.tsx
Normal file
81
app/src/pages/api/experiments/og-image.tsx
Normal file
@@ -0,0 +1,81 @@
|
||||
import { ImageResponse } from "@vercel/og";
|
||||
import { type NextApiRequest, type NextApiResponse } from "next";
|
||||
|
||||
export const config = {
|
||||
runtime: "experimental-edge",
|
||||
};
|
||||
|
||||
const inconsolataRegularFontP = fetch(
|
||||
new URL("../../../../public/fonts/Inconsolata_SemiExpanded-Medium.ttf", import.meta.url),
|
||||
).then((res) => res.arrayBuffer());
|
||||
|
||||
const OgImage = async (req: NextApiRequest, _res: NextApiResponse) => {
|
||||
// @ts-expect-error - nextUrl is not defined on NextApiRequest for some reason
|
||||
const searchParams = req.nextUrl?.searchParams as URLSearchParams;
|
||||
const experimentLabel = searchParams.get("experimentLabel");
|
||||
const variantsCount = searchParams.get("variantsCount");
|
||||
const scenariosCount = searchParams.get("scenariosCount");
|
||||
|
||||
const inconsolataRegularFont = await inconsolataRegularFontP;
|
||||
|
||||
return new ImageResponse(
|
||||
(
|
||||
<div
|
||||
style={{
|
||||
width: "100%",
|
||||
height: "100%",
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
fontSize: 48,
|
||||
padding: "48px",
|
||||
background: "white",
|
||||
position: "relative",
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
position: "absolute",
|
||||
top: 0,
|
||||
left: 0,
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
padding: 48,
|
||||
}}
|
||||
>
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src="https://app.openpipe.ai/logo.svg"
|
||||
alt="OpenPipe Logo"
|
||||
height={100}
|
||||
width={120}
|
||||
/>
|
||||
<div style={{ marginLeft: 24, fontSize: 64, fontFamily: "Inconsolata" }}>OpenPipe</div>
|
||||
</div>
|
||||
|
||||
<div style={{ display: "flex", fontSize: 72, marginTop: 108 }}>{experimentLabel}</div>
|
||||
<div style={{ display: "flex", flexDirection: "column", marginTop: 36 }}>
|
||||
<div style={{ display: "flex" }}>
|
||||
<span style={{ width: 320 }}>Variants:</span> {variantsCount}
|
||||
</div>
|
||||
<div style={{ display: "flex", marginTop: 24 }}>
|
||||
<span style={{ width: 320 }}>Scenarios:</span> {scenariosCount}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
{
|
||||
fonts: [
|
||||
{
|
||||
name: "inconsolata",
|
||||
data: inconsolataRegularFont,
|
||||
style: "normal",
|
||||
weight: 400,
|
||||
},
|
||||
],
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
export default OgImage;
|
||||
6
app/src/pages/api/sentry-example-api.js
Normal file
6
app/src/pages/api/sentry-example-api.js
Normal file
@@ -0,0 +1,6 @@
|
||||
// A faulty API route to test Sentry's error monitoring
|
||||
// @ts-expect-error just a test file, don't care about types
|
||||
export default function handler(_req, res) {
|
||||
throw new Error("Sentry Example API Route Error");
|
||||
res.status(200).json({ name: "John Doe" });
|
||||
}
|
||||
16
app/src/pages/api/trpc/[trpc].ts
Normal file
16
app/src/pages/api/trpc/[trpc].ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { createNextApiHandler } from "@trpc/server/adapters/next";
|
||||
import { env } from "~/env.mjs";
|
||||
import { appRouter } from "~/server/api/root.router";
|
||||
import { createTRPCContext } from "~/server/api/trpc";
|
||||
|
||||
// export API handler
|
||||
export default createNextApiHandler({
|
||||
router: appRouter,
|
||||
createContext: createTRPCContext,
|
||||
onError:
|
||||
env.NODE_ENV === "development"
|
||||
? ({ path, error }) => {
|
||||
console.error(`❌ tRPC failed on ${path ?? "<no-path>"}: ${error.message}`);
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
99
app/src/pages/data/[id].tsx
Normal file
99
app/src/pages/data/[id].tsx
Normal file
@@ -0,0 +1,99 @@
|
||||
import {
|
||||
Box,
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
Center,
|
||||
Flex,
|
||||
Icon,
|
||||
Input,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useState, useEffect } from "react";
|
||||
import { RiDatabase2Line } from "react-icons/ri";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable";
|
||||
import { DatasetHeaderButtons } from "~/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons";
|
||||
|
||||
export default function Dataset() {
|
||||
const router = useRouter();
|
||||
const utils = api.useContext();
|
||||
|
||||
const dataset = useDataset();
|
||||
const datasetId = router.query.id as string;
|
||||
|
||||
const [name, setName] = useState(dataset.data?.name || "");
|
||||
useEffect(() => {
|
||||
setName(dataset.data?.name || "");
|
||||
}, [dataset.data?.name]);
|
||||
|
||||
const updateMutation = api.datasets.update.useMutation();
|
||||
const [onSaveName] = useHandledAsyncCallback(async () => {
|
||||
if (name && name !== dataset.data?.name && dataset.data?.id) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: dataset.data.id,
|
||||
updates: { name: name },
|
||||
});
|
||||
await Promise.all([utils.datasets.list.invalidate(), utils.datasets.get.invalidate()]);
|
||||
}
|
||||
}, [updateMutation, dataset.data?.id, dataset.data?.name, name]);
|
||||
|
||||
if (!dataset.isLoading && !dataset.data) {
|
||||
return (
|
||||
<AppShell title="Dataset not found">
|
||||
<Center h="100%">
|
||||
<div>Dataset not found 😕</div>
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<AppShell title={dataset.data?.name}>
|
||||
<VStack h="full">
|
||||
<Flex
|
||||
pl={4}
|
||||
pr={8}
|
||||
py={2}
|
||||
w="full"
|
||||
direction={{ base: "column", sm: "row" }}
|
||||
alignItems={{ base: "flex-start", sm: "center" }}
|
||||
>
|
||||
<Breadcrumb flex={1} mt={1}>
|
||||
<BreadcrumbItem>
|
||||
<Link href="/data">
|
||||
<Flex alignItems="center" _hover={{ textDecoration: "underline" }}>
|
||||
<Icon as={RiDatabase2Line} boxSize={4} mr={2} /> Datasets
|
||||
</Flex>
|
||||
</Link>
|
||||
</BreadcrumbItem>
|
||||
<BreadcrumbItem isCurrentPage>
|
||||
<Input
|
||||
size="sm"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
onBlur={onSaveName}
|
||||
borderWidth={1}
|
||||
borderColor="transparent"
|
||||
fontSize={16}
|
||||
px={0}
|
||||
minW={{ base: 100, lg: 300 }}
|
||||
flex={1}
|
||||
_hover={{ borderColor: "gray.300" }}
|
||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||
/>
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
<DatasetHeaderButtons />
|
||||
</Flex>
|
||||
<Box w="full" overflowX="auto" flex={1} pl={4} pr={8} pt={8} pb={16}>
|
||||
{datasetId && <DatasetEntriesTable />}
|
||||
</Box>
|
||||
</VStack>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
83
app/src/pages/data/index.tsx
Normal file
83
app/src/pages/data/index.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import {
|
||||
SimpleGrid,
|
||||
Icon,
|
||||
VStack,
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
Flex,
|
||||
Center,
|
||||
Text,
|
||||
Link,
|
||||
HStack,
|
||||
} from "@chakra-ui/react";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import { RiDatabase2Line } from "react-icons/ri";
|
||||
import {
|
||||
DatasetCard,
|
||||
DatasetCardSkeleton,
|
||||
NewDatasetCard,
|
||||
} from "~/components/datasets/DatasetCard";
|
||||
|
||||
export default function DatasetsPage() {
|
||||
const datasets = api.datasets.list.useQuery();
|
||||
|
||||
const user = useSession().data;
|
||||
const authLoading = useSession().status === "loading";
|
||||
|
||||
if (user === null || authLoading) {
|
||||
return (
|
||||
<AppShell title="Data">
|
||||
<Center h="100%">
|
||||
{!authLoading && (
|
||||
<Text>
|
||||
<Link
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
textDecor="underline"
|
||||
>
|
||||
Sign in
|
||||
</Link>{" "}
|
||||
to view or create new datasets!
|
||||
</Text>
|
||||
)}
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<AppShell title="Data">
|
||||
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||
<HStack minH={8} align="center" pt={2}>
|
||||
<Breadcrumb flex={1}>
|
||||
<BreadcrumbItem>
|
||||
<Flex alignItems="center">
|
||||
<Icon as={RiDatabase2Line} boxSize={4} mr={2} /> Datasets
|
||||
</Flex>
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
</HStack>
|
||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||
<NewDatasetCard />
|
||||
{datasets.data && !datasets.isLoading ? (
|
||||
datasets?.data?.map((dataset) => (
|
||||
<DatasetCard
|
||||
key={dataset.id}
|
||||
dataset={{ ...dataset, numEntries: dataset._count.datasetEntries }}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<>
|
||||
<DatasetCardSkeleton />
|
||||
<DatasetCardSkeleton />
|
||||
<DatasetCardSkeleton />
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
155
app/src/pages/experiments/[id].tsx
Normal file
155
app/src/pages/experiments/[id].tsx
Normal file
@@ -0,0 +1,155 @@
|
||||
import {
|
||||
Box,
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
Center,
|
||||
Flex,
|
||||
Icon,
|
||||
Input,
|
||||
Text,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useState, useEffect } from "react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import OutputsTable from "~/components/OutputsTable";
|
||||
import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { useSyncVariantEditor } from "~/state/sync";
|
||||
import { ExperimentHeaderButtons } from "~/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons";
|
||||
import Head from "next/head";
|
||||
|
||||
// TODO: import less to fix deployment with server side props
|
||||
// export const getServerSideProps = async (context: GetServerSidePropsContext<{ id: string }>) => {
|
||||
// const experimentId = context.params?.id as string;
|
||||
|
||||
// const helpers = createServerSideHelpers({
|
||||
// router: appRouter,
|
||||
// ctx: createInnerTRPCContext({ session: null }),
|
||||
// transformer: superjson, // optional - adds superjson serialization
|
||||
// });
|
||||
|
||||
// // prefetch query
|
||||
// await helpers.experiments.stats.prefetch({ id: experimentId });
|
||||
|
||||
// return {
|
||||
// props: {
|
||||
// trpcState: helpers.dehydrate(),
|
||||
// },
|
||||
// };
|
||||
// };
|
||||
|
||||
export default function Experiment() {
|
||||
const router = useRouter();
|
||||
const utils = api.useContext();
|
||||
useSyncVariantEditor();
|
||||
|
||||
const experiment = useExperiment();
|
||||
const experimentStats = api.experiments.stats.useQuery(
|
||||
{ id: router.query.id as string },
|
||||
{
|
||||
enabled: !!router.query.id,
|
||||
},
|
||||
);
|
||||
const stats = experimentStats.data;
|
||||
|
||||
useEffect(() => {
|
||||
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
|
||||
});
|
||||
|
||||
const [label, setLabel] = useState(experiment.data?.label || "");
|
||||
useEffect(() => {
|
||||
setLabel(experiment.data?.label || "");
|
||||
}, [experiment.data?.label]);
|
||||
|
||||
const updateMutation = api.experiments.update.useMutation();
|
||||
const [onSaveLabel] = useHandledAsyncCallback(async () => {
|
||||
if (label && label !== experiment.data?.label && experiment.data?.id) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: experiment.data.id,
|
||||
updates: { label: label },
|
||||
});
|
||||
await Promise.all([utils.experiments.list.invalidate(), utils.experiments.get.invalidate()]);
|
||||
}
|
||||
}, [updateMutation, experiment.data?.id, experiment.data?.label, label]);
|
||||
|
||||
if (!experiment.isLoading && !experiment.data) {
|
||||
return (
|
||||
<AppShell title="Experiment not found">
|
||||
<Center h="100%">
|
||||
<div>Experiment not found 😕</div>
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
|
||||
const canModify = experiment.data?.access.canModify ?? false;
|
||||
|
||||
return (
|
||||
<>
|
||||
{stats && (
|
||||
<Head>
|
||||
<meta property="og:title" content={stats.experimentLabel} key="title" />
|
||||
<meta
|
||||
property="og:image"
|
||||
content={`/api/experiments/og-image?experimentLabel=${stats.experimentLabel}&variantsCount=${stats.promptVariantCount}&scenariosCount=${stats.testScenarioCount}`}
|
||||
key="og-image"
|
||||
/>
|
||||
</Head>
|
||||
)}
|
||||
<AppShell title={experiment.data?.label}>
|
||||
<VStack h="full">
|
||||
<Flex
|
||||
px={4}
|
||||
py={2}
|
||||
w="full"
|
||||
direction={{ base: "column", sm: "row" }}
|
||||
alignItems={{ base: "flex-start", sm: "center" }}
|
||||
>
|
||||
<Breadcrumb flex={1}>
|
||||
<BreadcrumbItem>
|
||||
<Link href="/experiments">
|
||||
<Flex alignItems="center" _hover={{ textDecoration: "underline" }}>
|
||||
<Icon as={RiFlaskLine} boxSize={4} mr={2} /> Experiments
|
||||
</Flex>
|
||||
</Link>
|
||||
</BreadcrumbItem>
|
||||
<BreadcrumbItem isCurrentPage>
|
||||
{canModify ? (
|
||||
<Input
|
||||
size="sm"
|
||||
value={label}
|
||||
onChange={(e) => setLabel(e.target.value)}
|
||||
onBlur={onSaveLabel}
|
||||
borderWidth={1}
|
||||
borderColor="transparent"
|
||||
fontSize={16}
|
||||
px={0}
|
||||
minW={{ base: 100, lg: 300 }}
|
||||
flex={1}
|
||||
_hover={{ borderColor: "gray.300" }}
|
||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||
/>
|
||||
) : (
|
||||
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
|
||||
{experiment.data?.label}
|
||||
</Text>
|
||||
)}
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
<ExperimentHeaderButtons />
|
||||
</Flex>
|
||||
<ExperimentSettingsDrawer />
|
||||
<Box w="100%" overflowX="auto" flex={1}>
|
||||
<OutputsTable experimentId={router.query.id as string | undefined} />
|
||||
</Box>
|
||||
</VStack>
|
||||
</AppShell>
|
||||
</>
|
||||
);
|
||||
}
|
||||
78
app/src/pages/experiments/index.tsx
Normal file
78
app/src/pages/experiments/index.tsx
Normal file
@@ -0,0 +1,78 @@
|
||||
import {
|
||||
SimpleGrid,
|
||||
Icon,
|
||||
VStack,
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
Flex,
|
||||
Center,
|
||||
Text,
|
||||
Link,
|
||||
HStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
ExperimentCard,
|
||||
ExperimentCardSkeleton,
|
||||
NewExperimentCard,
|
||||
} from "~/components/experiments/ExperimentCard";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
|
||||
export default function ExperimentsPage() {
|
||||
const experiments = api.experiments.list.useQuery();
|
||||
|
||||
const user = useSession().data;
|
||||
const authLoading = useSession().status === "loading";
|
||||
|
||||
if (user === null || authLoading) {
|
||||
return (
|
||||
<AppShell title="Experiments">
|
||||
<Center h="100%">
|
||||
{!authLoading && (
|
||||
<Text>
|
||||
<Link
|
||||
onClick={() => {
|
||||
signIn("github").catch(console.error);
|
||||
}}
|
||||
textDecor="underline"
|
||||
>
|
||||
Sign in
|
||||
</Link>{" "}
|
||||
to view or create new experiments!
|
||||
</Text>
|
||||
)}
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<AppShell title="Experiments">
|
||||
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||
<HStack minH={8} align="center" pt={2}>
|
||||
<Breadcrumb flex={1}>
|
||||
<BreadcrumbItem>
|
||||
<Flex alignItems="center">
|
||||
<Icon as={RiFlaskLine} boxSize={4} mr={2} /> Experiments
|
||||
</Flex>
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
</HStack>
|
||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||
<NewExperimentCard />
|
||||
{experiments.data && !experiments.isLoading ? (
|
||||
experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)
|
||||
) : (
|
||||
<>
|
||||
<ExperimentCardSkeleton />
|
||||
<ExperimentCardSkeleton />
|
||||
<ExperimentCardSkeleton />
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</VStack>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
15
app/src/pages/index.tsx
Normal file
15
app/src/pages/index.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
import { type GetServerSideProps } from "next";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/require-await
|
||||
export const getServerSideProps: GetServerSideProps = async () => {
|
||||
return {
|
||||
redirect: {
|
||||
destination: "/experiments",
|
||||
permanent: false,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export default function Home() {
|
||||
return null;
|
||||
}
|
||||
84
app/src/pages/sentry-example-page.js
Normal file
84
app/src/pages/sentry-example-page.js
Normal file
@@ -0,0 +1,84 @@
|
||||
import Head from "next/head";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
export default function Home() {
|
||||
return (
|
||||
<div>
|
||||
<Head>
|
||||
<title>Sentry Onboarding</title>
|
||||
<meta name="description" content="Test Sentry for your Next.js app!" />
|
||||
</Head>
|
||||
|
||||
<main
|
||||
style={{
|
||||
minHeight: "100vh",
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
justifyContent: "center",
|
||||
alignItems: "center",
|
||||
}}
|
||||
>
|
||||
<h1 style={{ fontSize: "4rem", margin: "14px 0" }}>
|
||||
<svg
|
||||
style={{
|
||||
height: "1em",
|
||||
}}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 200 44"
|
||||
>
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M124.32,28.28,109.56,9.22h-3.68V34.77h3.73V15.19l15.18,19.58h3.26V9.22h-3.73ZM87.15,23.54h13.23V20.22H87.14V12.53h14.93V9.21H83.34V34.77h18.92V31.45H87.14ZM71.59,20.3h0C66.44,19.06,65,18.08,65,15.7c0-2.14,1.89-3.59,4.71-3.59a12.06,12.06,0,0,1,7.07,2.55l2-2.83a14.1,14.1,0,0,0-9-3c-5.06,0-8.59,3-8.59,7.27,0,4.6,3,6.19,8.46,7.52C74.51,24.74,76,25.78,76,28.11s-2,3.77-5.09,3.77a12.34,12.34,0,0,1-8.3-3.26l-2.25,2.69a15.94,15.94,0,0,0,10.42,3.85c5.48,0,9-2.95,9-7.51C79.75,23.79,77.47,21.72,71.59,20.3ZM195.7,9.22l-7.69,12-7.64-12h-4.46L186,24.67V34.78h3.84V24.55L200,9.22Zm-64.63,3.46h8.37v22.1h3.84V12.68h8.37V9.22H131.08ZM169.41,24.8c3.86-1.07,6-3.77,6-7.63,0-4.91-3.59-8-9.38-8H154.67V34.76h3.8V25.58h6.45l6.48,9.2h4.44l-7-9.82Zm-10.95-2.5V12.6h7.17c3.74,0,5.88,1.77,5.88,4.84s-2.29,4.86-5.84,4.86Z M29,2.26a4.67,4.67,0,0,0-8,0L14.42,13.53A32.21,32.21,0,0,1,32.17,40.19H27.55A27.68,27.68,0,0,0,12.09,17.47L6,28a15.92,15.92,0,0,1,9.23,12.17H4.62A.76.76,0,0,1,4,39.06l2.94-5a10.74,10.74,0,0,0-3.36-1.9l-2.91,5a4.54,4.54,0,0,0,1.69,6.24A4.66,4.66,0,0,0,4.62,44H19.15a19.4,19.4,0,0,0-8-17.31l2.31-4A23.87,23.87,0,0,1,23.76,44H36.07a35.88,35.88,0,0,0-16.41-31.8l4.67-8a.77.77,0,0,1,1.05-.27c.53.29,20.29,34.77,20.66,35.17a.76.76,0,0,1-.68,1.13H40.6q.09,1.91,0,3.81h4.78A4.59,4.59,0,0,0,50,39.43a4.49,4.49,0,0,0-.62-2.28Z"
|
||||
></path>
|
||||
</svg>
|
||||
</h1>
|
||||
|
||||
<p>Get started by sending us a sample error:</p>
|
||||
<button
|
||||
type="button"
|
||||
style={{
|
||||
padding: "12px",
|
||||
cursor: "pointer",
|
||||
backgroundColor: "#AD6CAA",
|
||||
borderRadius: "4px",
|
||||
border: "none",
|
||||
color: "white",
|
||||
fontSize: "14px",
|
||||
margin: "18px",
|
||||
}}
|
||||
onClick={async () => {
|
||||
const transaction = Sentry.startTransaction({
|
||||
name: "Example Frontend Transaction",
|
||||
});
|
||||
|
||||
Sentry.configureScope((scope) => {
|
||||
scope.setSpan(transaction);
|
||||
});
|
||||
|
||||
try {
|
||||
const res = await fetch("/api/sentry-example-api");
|
||||
if (!res.ok) {
|
||||
throw new Error("Sentry Example Frontend Error");
|
||||
}
|
||||
} finally {
|
||||
transaction.finish();
|
||||
}
|
||||
}}
|
||||
>
|
||||
Throw error!
|
||||
</button>
|
||||
|
||||
<p>
|
||||
Next, look for the error on the{" "}
|
||||
<a href="https://openpipe.sentry.io/issues/?project=4505642011394048">Issues Page</a>.
|
||||
</p>
|
||||
<p style={{ marginTop: "24px" }}>
|
||||
For more information, see{" "}
|
||||
<a href="https://docs.sentry.io/platforms/javascript/guides/nextjs/">
|
||||
https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
</a>
|
||||
</p>
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
15
app/src/pages/world-champs/index.tsx
Normal file
15
app/src/pages/world-champs/index.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
import { type GetServerSideProps } from "next";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/require-await
|
||||
export const getServerSideProps: GetServerSideProps = async () => {
|
||||
return {
|
||||
redirect: {
|
||||
destination: "/world-champs/signup",
|
||||
permanent: false,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export default function WorldChamps() {
|
||||
return null;
|
||||
}
|
||||
265
app/src/pages/world-champs/signup.tsx
Normal file
265
app/src/pages/world-champs/signup.tsx
Normal file
@@ -0,0 +1,265 @@
|
||||
import {
|
||||
Box,
|
||||
type BoxProps,
|
||||
Button,
|
||||
DarkMode,
|
||||
GlobalStyle,
|
||||
HStack,
|
||||
Heading,
|
||||
Icon,
|
||||
Link,
|
||||
Table,
|
||||
Tbody,
|
||||
Td,
|
||||
Text,
|
||||
type TextProps,
|
||||
Th,
|
||||
Tr,
|
||||
VStack,
|
||||
useInterval,
|
||||
Image,
|
||||
Flex,
|
||||
} from "@chakra-ui/react";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import Head from "next/head";
|
||||
import { useCallback, useState } from "react";
|
||||
import { BsGithub } from "react-icons/bs";
|
||||
import UserMenu from "~/components/nav/UserMenu";
|
||||
import { api } from "~/utils/api";
|
||||
import dayjs from "~/utils/dayjs";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import GitHubButton from "react-github-btn";
|
||||
|
||||
const TopNavbar = () => (
|
||||
<HStack px={4} py={2} align="center" justify="center">
|
||||
<HStack
|
||||
as={Link}
|
||||
href="/"
|
||||
_hover={{ textDecoration: "none" }}
|
||||
spacing={0}
|
||||
py={2}
|
||||
pr={16}
|
||||
flex={1}
|
||||
sx={{
|
||||
".widget": {
|
||||
display: "block",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
|
||||
<Heading size="md" fontFamily="inconsolata, monospace">
|
||||
OpenPipe
|
||||
</Heading>
|
||||
</HStack>
|
||||
<Box pt="6px">
|
||||
<GitHubButton
|
||||
href="https://github.com/openpipe/openpipe"
|
||||
data-color-scheme="no-preference: dark; light: dark; dark: dark;"
|
||||
data-size="large"
|
||||
aria-label="Follow @openpipe on GitHub"
|
||||
>
|
||||
Github
|
||||
</GitHubButton>
|
||||
</Box>
|
||||
</HStack>
|
||||
);
|
||||
|
||||
// Shows how long until the competition starts. Refreshes every second
|
||||
function CountdownTimer(props: { date: Date } & TextProps) {
|
||||
const [now, setNow] = useState(dayjs());
|
||||
|
||||
useInterval(() => {
|
||||
setNow(dayjs());
|
||||
}, 1000);
|
||||
|
||||
const { date, ...rest } = props;
|
||||
|
||||
const kickoff = dayjs(date);
|
||||
const diff = kickoff.diff(now, "second");
|
||||
const days = Math.floor(diff / 86400);
|
||||
const hours = Math.floor((diff % 86400) / 3600);
|
||||
const minutes = Math.floor((diff % 3600) / 60);
|
||||
const seconds = Math.floor(diff % 60);
|
||||
|
||||
return (
|
||||
<Text {...rest} suppressHydrationWarning>
|
||||
<Text as="span" fontWeight="bold">
|
||||
Kickoff in
|
||||
</Text>{" "}
|
||||
{days}d {hours}h {minutes}m {seconds}s
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
function ApplicationStatus(props: BoxProps) {
|
||||
const user = useSession().data;
|
||||
const entrant = api.worldChamps.userStatus.useQuery().data;
|
||||
const applyMutation = api.worldChamps.apply.useMutation();
|
||||
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onSignIn] = useHandledAsyncCallback(async () => {
|
||||
await signIn("github");
|
||||
}, []);
|
||||
|
||||
const [onApply] = useHandledAsyncCallback(async () => {
|
||||
await applyMutation.mutateAsync();
|
||||
await utils.worldChamps.userStatus.invalidate();
|
||||
}, []);
|
||||
|
||||
const Wrapper = useCallback(
|
||||
(wrapperProps: BoxProps) => (
|
||||
<Box {...props} {...wrapperProps} minH="120px" alignItems="center" justifyItems="center" />
|
||||
),
|
||||
[props],
|
||||
);
|
||||
|
||||
if (user === null) {
|
||||
return (
|
||||
<Wrapper>
|
||||
<Button onClick={onSignIn} colorScheme="orange" leftIcon={<Icon as={BsGithub} />}>
|
||||
Connect GitHub to apply
|
||||
</Button>
|
||||
</Wrapper>
|
||||
);
|
||||
} else if (user) {
|
||||
return (
|
||||
<Wrapper>
|
||||
<Flex flexDirection={{ base: "column", md: "row" }} alignItems="center">
|
||||
<UserMenu
|
||||
user={user}
|
||||
borderRadius={2}
|
||||
borderColor={"gray.700"}
|
||||
borderWidth={1}
|
||||
pr={6}
|
||||
mr={{ base: 0, md: 8 }}
|
||||
mb={{ base: 8, md: 0 }}
|
||||
/>
|
||||
<Box flex={1}>
|
||||
{entrant?.approved ? (
|
||||
<Text fontSize="sm">
|
||||
You're accepted! We'll send you more details before August 14th.
|
||||
</Text>
|
||||
) : entrant ? (
|
||||
<Text fontSize="sm">
|
||||
✅ Application submitted successfully. We'll notify you by email before August 14th.{" "}
|
||||
<Link
|
||||
href="https://github.com/openpipe/openpipe"
|
||||
isExternal
|
||||
textDecor="underline"
|
||||
fontWeight="bold"
|
||||
>
|
||||
Star our Github ⭐
|
||||
</Link>{" "}
|
||||
for updates while you wait!
|
||||
</Text>
|
||||
) : (
|
||||
<Button onClick={onApply} colorScheme="orange">
|
||||
Apply to compete
|
||||
</Button>
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
</Wrapper>
|
||||
);
|
||||
}
|
||||
|
||||
return <Wrapper />;
|
||||
}
|
||||
|
||||
export default function Signup() {
|
||||
return (
|
||||
<DarkMode>
|
||||
<GlobalStyle />
|
||||
|
||||
<Head>
|
||||
<title>🏆 Prompt Engineering World Championships</title>
|
||||
<meta property="og:title" content="🏆 Prompt Engineering World Championships" key="title" />
|
||||
<meta
|
||||
property="og:description"
|
||||
content="Think you have what it takes to be the best? Compete with the world's top prompt engineers and see where you rank!"
|
||||
key="description"
|
||||
/>
|
||||
</Head>
|
||||
|
||||
<Box color="gray.200" minH="100vh" w="full">
|
||||
<TopNavbar />
|
||||
<VStack mx="auto" py={24} maxW="2xl" px={4} align="center" fontSize="lg">
|
||||
<Heading size="lg" textAlign="center">
|
||||
🏆 Prompt Engineering World Championships
|
||||
</Heading>
|
||||
<CountdownTimer
|
||||
date={new Date("2023-08-14T00:00:00Z")}
|
||||
fontSize="2xl"
|
||||
alignSelf="center"
|
||||
color="gray.500"
|
||||
/>
|
||||
|
||||
<ApplicationStatus py={8} alignSelf="center" />
|
||||
|
||||
<Text fontSize="lg" textAlign="left">
|
||||
Think you have what it takes to be the best? Compete with the world's top prompt
|
||||
engineers and see where you rank!
|
||||
</Text>
|
||||
|
||||
<Heading size="lg" pt={12} alignSelf="left">
|
||||
Event Details
|
||||
</Heading>
|
||||
<Table variant="simple">
|
||||
<Tbody
|
||||
sx={{
|
||||
th: {
|
||||
base: { px: 0 },
|
||||
md: { px: 6 },
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Tr>
|
||||
<Th>Kickoff</Th>
|
||||
<Td>August 14</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Th>Prize</Th>
|
||||
<Td>$15,000 grand prize + smaller category prizes.</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Th>Events</Th>
|
||||
<Td>
|
||||
Optimize prompts for multiple tasks selected from academic benchmarks and
|
||||
real-world applications.
|
||||
</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Th>Models</Th>
|
||||
<Td>Separate "weight classes" for GPT 3.5, Claude Instant, and Llama 2.</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Th>Qualifications</Th>
|
||||
<Td>Open to entrants with any level of experience.</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Th>Certificates</Th>
|
||||
<Td>Certificate of mastery for all qualifying participants.</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Th>Cost</Th>
|
||||
<Td>
|
||||
<strong>Free</strong>. We'll cover your inference budget.
|
||||
</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Th>Questions?</Th>
|
||||
<Td>
|
||||
<Link href="mailto:world-champs@openpipe.ai" textDecor="underline">
|
||||
Email us
|
||||
</Link>{" "}
|
||||
with any follow-up questions!
|
||||
</Td>
|
||||
</Tr>
|
||||
</Tbody>
|
||||
</Table>
|
||||
</VStack>
|
||||
</Box>
|
||||
</DarkMode>
|
||||
);
|
||||
}
|
||||
10
app/src/promptConstructor/format.test.ts
Normal file
10
app/src/promptConstructor/format.test.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { expect, test } from "vitest";
|
||||
import { stripTypes } from "./format";
|
||||
|
||||
test("stripTypes", () => {
|
||||
expect(stripTypes(`const foo: string = "bar";`)).toBe(`const foo = "bar";`);
|
||||
});
|
||||
|
||||
test("stripTypes with invalid syntax", () => {
|
||||
expect(stripTypes(`asdf foo: string = "bar"`)).toBe(`asdf foo: string = "bar"`);
|
||||
});
|
||||
31
app/src/promptConstructor/format.ts
Normal file
31
app/src/promptConstructor/format.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
import prettier from "prettier/standalone";
|
||||
import parserTypescript from "prettier/plugins/typescript";
|
||||
|
||||
// @ts-expect-error for some reason missing from types
|
||||
import parserEstree from "prettier/plugins/estree";
|
||||
|
||||
import * as babel from "@babel/standalone";
|
||||
|
||||
export function stripTypes(tsCode: string): string {
|
||||
const options = {
|
||||
presets: ["typescript"],
|
||||
filename: "file.ts",
|
||||
};
|
||||
|
||||
try {
|
||||
const result = babel.transform(tsCode, options);
|
||||
return result.code ?? tsCode;
|
||||
} catch (error) {
|
||||
// console.error("Error stripping types", error);
|
||||
return tsCode;
|
||||
}
|
||||
}
|
||||
|
||||
export default async function formatPromptConstructor(code: string): Promise<string> {
|
||||
return await prettier.format(stripTypes(code), {
|
||||
parser: "typescript",
|
||||
plugins: [parserTypescript, parserEstree],
|
||||
// We're showing these in pretty narrow panes so let's keep the print width low
|
||||
printWidth: 60,
|
||||
});
|
||||
}
|
||||
56
app/src/promptConstructor/migrate.test.ts
Normal file
56
app/src/promptConstructor/migrate.test.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import "dotenv/config";
|
||||
import dedent from "dedent";
|
||||
import { expect, test } from "vitest";
|
||||
import { migrate1to2, migrate2to3 } from "./migrate";
|
||||
|
||||
test("migrate1to2", () => {
|
||||
const promptConstructor = dedent`
|
||||
// Test comment
|
||||
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "What is the capital of China?"
|
||||
}
|
||||
]
|
||||
}
|
||||
`;
|
||||
|
||||
const migrated = migrate1to2(promptConstructor);
|
||||
expect(migrated).toBe(dedent`
|
||||
// Test comment
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "What is the capital of China?"
|
||||
}
|
||||
]
|
||||
})
|
||||
`);
|
||||
});
|
||||
|
||||
test("migrate2to3", () => {
|
||||
const promptConstructor = dedent`
|
||||
// Test comment
|
||||
|
||||
definePrompt("anthropic", {
|
||||
model: "claude-2.0",
|
||||
prompt: "What is the capital of China?"
|
||||
})
|
||||
`;
|
||||
|
||||
const migrated = migrate2to3(promptConstructor);
|
||||
expect(migrated).toBe(dedent`
|
||||
// Test comment
|
||||
|
||||
definePrompt("anthropic/completion", {
|
||||
model: "claude-2.0",
|
||||
prompt: "What is the capital of China?"
|
||||
})
|
||||
`);
|
||||
});
|
||||
125
app/src/promptConstructor/migrate.ts
Normal file
125
app/src/promptConstructor/migrate.ts
Normal file
@@ -0,0 +1,125 @@
|
||||
import "dotenv/config";
|
||||
import * as recast from "recast";
|
||||
import { type ASTNode } from "ast-types";
|
||||
import { fileURLToPath } from "url";
|
||||
import parsePromptConstructor from "./parse";
|
||||
import { prisma } from "~/server/db";
|
||||
import { promptConstructorVersion } from "./version";
|
||||
const { builders: b } = recast.types;
|
||||
|
||||
export const migrate1to2 = (fnBody: string): string => {
|
||||
const ast: ASTNode = recast.parse(fnBody);
|
||||
|
||||
recast.visit(ast, {
|
||||
visitAssignmentExpression(path) {
|
||||
const node = path.node;
|
||||
if ("name" in node.left && node.left.name === "prompt") {
|
||||
const functionCall = b.callExpression(b.identifier("definePrompt"), [
|
||||
b.literal("openai/ChatCompletion"),
|
||||
node.right,
|
||||
]);
|
||||
path.replace(functionCall);
|
||||
}
|
||||
return false;
|
||||
},
|
||||
});
|
||||
|
||||
return recast.print(ast).code;
|
||||
};
|
||||
|
||||
export const migrate2to3 = (fnBody: string): string => {
|
||||
const ast: ASTNode = recast.parse(fnBody);
|
||||
|
||||
recast.visit(ast, {
|
||||
visitCallExpression(path) {
|
||||
const node = path.node;
|
||||
|
||||
// Check if the function being called is 'definePrompt'
|
||||
if (
|
||||
recast.types.namedTypes.Identifier.check(node.callee) &&
|
||||
node.callee.name === "definePrompt" &&
|
||||
node.arguments.length > 0 &&
|
||||
recast.types.namedTypes.Literal.check(node.arguments[0]) &&
|
||||
node.arguments[0].value === "anthropic"
|
||||
) {
|
||||
node.arguments[0].value = "anthropic/completion";
|
||||
}
|
||||
|
||||
return false;
|
||||
},
|
||||
});
|
||||
|
||||
return recast.print(ast).code;
|
||||
};
|
||||
|
||||
const migrations: Record<number, (fnBody: string) => string> = {
|
||||
2: migrate1to2,
|
||||
3: migrate2to3,
|
||||
};
|
||||
|
||||
const applyMigrations = (
|
||||
promptConstructor: string,
|
||||
currentVersion: number,
|
||||
targetVersion: number,
|
||||
) => {
|
||||
let migratedFn = promptConstructor;
|
||||
|
||||
for (let v = currentVersion + 1; v <= targetVersion; v++) {
|
||||
const migrationFn = migrations[v];
|
||||
if (migrationFn) {
|
||||
migratedFn = migrationFn(migratedFn);
|
||||
}
|
||||
}
|
||||
|
||||
return migratedFn;
|
||||
};
|
||||
|
||||
export default async function migrateConstructFns(targetVersion: number) {
|
||||
const prompts = await prisma.promptVariant.findMany({
|
||||
where: { promptConstructorVersion: { lt: targetVersion } },
|
||||
});
|
||||
console.log(`Migrating ${prompts.length} prompts to version ${targetVersion}`);
|
||||
await Promise.all(
|
||||
prompts.map(async (variant) => {
|
||||
const currentVersion = variant.promptConstructorVersion;
|
||||
|
||||
try {
|
||||
const migratedFn = applyMigrations(
|
||||
variant.promptConstructor,
|
||||
currentVersion,
|
||||
targetVersion,
|
||||
);
|
||||
|
||||
const parsedFn = await parsePromptConstructor(migratedFn);
|
||||
if ("error" in parsedFn) {
|
||||
throw new Error(parsedFn.error);
|
||||
}
|
||||
await prisma.promptVariant.update({
|
||||
where: {
|
||||
id: variant.id,
|
||||
},
|
||||
data: {
|
||||
promptConstructor: migratedFn,
|
||||
promptConstructorVersion: targetVersion,
|
||||
modelProvider: parsedFn.modelProvider,
|
||||
model: parsedFn.model,
|
||||
},
|
||||
});
|
||||
} catch (e) {
|
||||
console.error("Error migrating promptConstructor for variant", variant.id, e);
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// If we're running this file directly, run the migration to the latest version
|
||||
if (process.argv.at(-1) === fileURLToPath(import.meta.url)) {
|
||||
const latestVersion = Math.max(...Object.keys(migrations).map(Number));
|
||||
if (latestVersion !== promptConstructorVersion) {
|
||||
throw new Error(
|
||||
`The latest migration is ${latestVersion}, but the promptConstructorVersion is ${promptConstructorVersion}`,
|
||||
);
|
||||
}
|
||||
await migrateConstructFns(promptConstructorVersion);
|
||||
console.log("Done");
|
||||
}
|
||||
45
app/src/promptConstructor/parse.test.ts
Normal file
45
app/src/promptConstructor/parse.test.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { expect, test } from "vitest";
|
||||
import parsePromptConstructor from "./parse";
|
||||
import assert from "assert";
|
||||
|
||||
// Note: this has to be run with `vitest --no-threads` option or else
|
||||
// isolated-vm seems to throw errors
|
||||
test("parsePromptConstructor", async () => {
|
||||
const constructed = await parsePromptConstructor(
|
||||
`
|
||||
// These sometimes have a comment
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: \`What is the capital of \${scenario.country}?\`
|
||||
}
|
||||
]
|
||||
})
|
||||
`,
|
||||
{ country: "Bolivia" },
|
||||
);
|
||||
|
||||
expect(constructed).toEqual({
|
||||
modelProvider: "openai/ChatCompletion",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
modelInput: {
|
||||
messages: [
|
||||
{
|
||||
content: "What is the capital of Bolivia?",
|
||||
role: "user",
|
||||
},
|
||||
],
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test("bad syntax", async () => {
|
||||
const parsed = await parsePromptConstructor(`definePrompt("openai/ChatCompletion", {`);
|
||||
|
||||
assert("error" in parsed);
|
||||
expect(parsed.error).toContain("Unexpected end of input");
|
||||
});
|
||||
92
app/src/promptConstructor/parse.ts
Normal file
92
app/src/promptConstructor/parse.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import ivm from "isolated-vm";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import { validate } from "jsonschema";
|
||||
|
||||
export type ParsedPromptConstructor<T extends keyof typeof modelProviders> = {
|
||||
modelProvider: T;
|
||||
model: keyof (typeof modelProviders)[T]["models"];
|
||||
modelInput: Parameters<(typeof modelProviders)[T]["getModel"]>[0];
|
||||
};
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export default async function parsePromptConstructor(
|
||||
promptConstructor: string,
|
||||
scenario: JsonObject | undefined = {},
|
||||
): Promise<ParsedPromptConstructor<keyof typeof modelProviders> | { error: string }> {
|
||||
try {
|
||||
const modifiedConstructFn = promptConstructor.replace(
|
||||
"definePrompt(",
|
||||
"global.prompt = definePrompt(",
|
||||
);
|
||||
|
||||
const code = `
|
||||
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||
|
||||
const definePrompt = (modelProvider, input) => ({
|
||||
modelProvider,
|
||||
input
|
||||
})
|
||||
|
||||
${modifiedConstructFn}
|
||||
`;
|
||||
|
||||
const context = await isolate.createContext();
|
||||
const jail = context.global;
|
||||
await jail.set("global", jail.derefInto());
|
||||
|
||||
const script = await isolate.compileScript(code);
|
||||
await script.run(context);
|
||||
const promptReference = (await context.global.get("prompt")) as ivm.Reference;
|
||||
const prompt = await promptReference.copy();
|
||||
|
||||
if (!isObject(prompt)) {
|
||||
return { error: "definePrompt did not return an object" };
|
||||
}
|
||||
if (!("modelProvider" in prompt) || !isString(prompt.modelProvider)) {
|
||||
return { error: "definePrompt did not return a valid modelProvider" };
|
||||
}
|
||||
|
||||
const provider =
|
||||
prompt.modelProvider in modelProviders &&
|
||||
modelProviders[prompt.modelProvider as keyof typeof modelProviders];
|
||||
if (!provider) {
|
||||
return { error: "definePrompt did not return a known modelProvider" };
|
||||
}
|
||||
if (!("input" in prompt) || !isObject(prompt.input)) {
|
||||
return { error: "definePrompt did not return an input" };
|
||||
}
|
||||
|
||||
const validationResult = validate(prompt.input, provider.inputSchema);
|
||||
if (!validationResult.valid)
|
||||
return {
|
||||
error: `definePrompt did not return a valid input: ${validationResult.errors
|
||||
.map((e) => e.stack)
|
||||
.join(", ")}`,
|
||||
};
|
||||
|
||||
// We've validated the JSON schema so this should be safe
|
||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||
|
||||
const model = provider.getModel(input);
|
||||
if (!model) {
|
||||
return {
|
||||
error: `definePrompt did not return a known model for the provider ${prompt.modelProvider}`,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||
model,
|
||||
modelInput: input,
|
||||
};
|
||||
} catch (e) {
|
||||
const msg =
|
||||
isObject(e) && "message" in e && isString(e.message)
|
||||
? e.message
|
||||
: "unknown error parsing definePrompt script";
|
||||
return { error: msg };
|
||||
}
|
||||
}
|
||||
1
app/src/promptConstructor/version.ts
Normal file
1
app/src/promptConstructor/version.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const promptConstructorVersion = 3;
|
||||
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { isAxiosError } from "./utils";
|
||||
import { type APIResponse } from "openai/core";
|
||||
import { sleep } from "~/server/utils/sleep";
|
||||
|
||||
const MAX_AUTO_RETRIES = 50;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
const getCompletionWithBackoff = async (
|
||||
getCompletion: () => Promise<APIResponse<ChatCompletion>>,
|
||||
) => {
|
||||
let completion;
|
||||
let tries = 0;
|
||||
while (tries < MAX_AUTO_RETRIES) {
|
||||
try {
|
||||
completion = await getCompletion();
|
||||
break;
|
||||
} catch (e) {
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
await sleep(calculateDelay(tries));
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
tries++;
|
||||
}
|
||||
return completion;
|
||||
};
|
||||
// TODO: Add seeds to ensure batches don't contain duplicate data
|
||||
const MAX_BATCH_SIZE = 5;
|
||||
|
||||
export const autogenerateDatasetEntries = async (
|
||||
numToGenerate: number,
|
||||
inputDescription: string,
|
||||
outputDescription: string,
|
||||
): Promise<{ input: string; output: string }[]> => {
|
||||
const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) =>
|
||||
i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE
|
||||
? numToGenerate % MAX_BATCH_SIZE
|
||||
: MAX_BATCH_SIZE,
|
||||
);
|
||||
|
||||
const getCompletion = (batchSize: number) =>
|
||||
openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "add_list_of_data",
|
||||
description: "Add a list of data to the database",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
rows: {
|
||||
type: "array",
|
||||
description: "The rows of data that match the description",
|
||||
items: {
|
||||
type: "object",
|
||||
properties: {
|
||||
input: {
|
||||
type: "string",
|
||||
description: "The input for this row",
|
||||
},
|
||||
output: {
|
||||
type: "string",
|
||||
description: "The output for this row",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_list_of_data" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const completionCallbacks = batchSizes.map((batchSize) =>
|
||||
getCompletionWithBackoff(() => getCompletion(batchSize)),
|
||||
);
|
||||
|
||||
const completions = await Promise.all(completionCallbacks);
|
||||
|
||||
const rows = completions.flatMap((completion) => {
|
||||
const parsed = JSON.parse(
|
||||
completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}",
|
||||
) as { rows: { input: string; output: string }[] };
|
||||
return parsed.rows;
|
||||
});
|
||||
|
||||
return rows;
|
||||
};
|
||||
118
app/src/server/api/autogenerate/autogenerateScenarioValues.ts
Normal file
118
app/src/server/api/autogenerate/autogenerateScenarioValues.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { prisma } from "../../db";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { pick } from "lodash-es";
|
||||
import { isAxiosError } from "./utils";
|
||||
|
||||
export const autogenerateScenarioValues = async (
|
||||
experimentId: string,
|
||||
): Promise<Record<string, string>> => {
|
||||
const [experiment, variables, existingScenarios, prompt] = await Promise.all([
|
||||
prisma.experiment.findUnique({
|
||||
where: {
|
||||
id: experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
take: 10,
|
||||
}),
|
||||
prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!experiment || !(variables?.length > 0) || !prompt) return {};
|
||||
|
||||
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming["messages"] = [
|
||||
{
|
||||
role: "system",
|
||||
content:
|
||||
"The user is testing multiple scenarios against the same prompt. Attempt to generate a new scenario that is different from the others.",
|
||||
},
|
||||
];
|
||||
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Prompt constructor function:\n---\n${prompt.promptConstructor}`,
|
||||
});
|
||||
|
||||
existingScenarios
|
||||
.map(
|
||||
(scenario) =>
|
||||
pick(
|
||||
scenario.variableValues,
|
||||
variables.map((variable) => variable.label),
|
||||
) as Record<string, string>,
|
||||
)
|
||||
.filter((vals) => Object.keys(vals ?? {}).length > 0)
|
||||
.forEach((vals) => {
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
content: null,
|
||||
function_call: {
|
||||
name: "add_scenario",
|
||||
arguments: JSON.stringify(vals),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const variableProperties = variables.reduce(
|
||||
(acc, variable) => {
|
||||
acc[variable.label] = { type: "string" };
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, { type: "string" }>,
|
||||
);
|
||||
|
||||
try {
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages,
|
||||
functions: [
|
||||
{
|
||||
name: "add_scenario",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: variableProperties,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_scenario" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(
|
||||
completion.choices[0]?.message?.function_call?.arguments ?? "{}",
|
||||
) as Record<string, string>;
|
||||
return parsed;
|
||||
} catch (e) {
|
||||
// If it's an axios error, try to get the error message
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
console.error(e);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
return {};
|
||||
};
|
||||
18
app/src/server/api/autogenerate/utils.ts
Normal file
18
app/src/server/api/autogenerate/utils.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
type AxiosError = {
|
||||
response?: {
|
||||
data?: {
|
||||
error?: {
|
||||
message?: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export function isAxiosError(error: unknown): error is AxiosError {
|
||||
if (typeof error === "object" && error !== null) {
|
||||
// Initial check
|
||||
const err = error as AxiosError;
|
||||
return err.response?.data?.error?.message !== undefined; // Check structure
|
||||
}
|
||||
return false;
|
||||
}
|
||||
30
app/src/server/api/root.router.ts
Normal file
30
app/src/server/api/root.router.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router";
|
||||
import { createTRPCRouter } from "~/server/api/trpc";
|
||||
import { experimentsRouter } from "./routers/experiments.router";
|
||||
import { scenariosRouter } from "./routers/scenarios.router";
|
||||
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||
import { worldChampsRouter } from "./routers/worldChamps.router";
|
||||
import { datasetsRouter } from "./routers/datasets.router";
|
||||
import { datasetEntries } from "./routers/datasetEntries.router";
|
||||
|
||||
/**
|
||||
* This is the primary router for your server.
|
||||
*
|
||||
* All routers added in /api/routers should be manually added here.
|
||||
*/
|
||||
export const appRouter = createTRPCRouter({
|
||||
promptVariants: promptVariantsRouter,
|
||||
experiments: experimentsRouter,
|
||||
scenarios: scenariosRouter,
|
||||
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||
templateVars: templateVarsRouter,
|
||||
evaluations: evaluationsRouter,
|
||||
worldChamps: worldChampsRouter,
|
||||
datasets: datasetsRouter,
|
||||
datasetEntries: datasetEntries,
|
||||
});
|
||||
|
||||
// export type definition of API
|
||||
export type AppRouter = typeof appRouter;
|
||||
149
app/src/server/api/routers/datasetEntries.router.ts
Normal file
149
app/src/server/api/routers/datasetEntries.router.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyDataset, requireCanViewDataset } from "~/utils/accessControl";
|
||||
import { autogenerateDatasetEntries } from "../autogenerate/autogenerateDatasetEntries";
|
||||
|
||||
const PAGE_SIZE = 10;
|
||||
|
||||
export const datasetEntries = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
.input(z.object({ datasetId: z.string(), page: z.number() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.datasetId, ctx);
|
||||
|
||||
const { datasetId, page } = input;
|
||||
|
||||
const entries = await prisma.datasetEntry.findMany({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
orderBy: { createdAt: "desc" },
|
||||
skip: (page - 1) * PAGE_SIZE,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
const count = await prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
entries,
|
||||
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||
count,
|
||||
};
|
||||
}),
|
||||
createOne: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.create({
|
||||
data: {
|
||||
datasetId: input.datasetId,
|
||||
input: input.input,
|
||||
output: input.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
autogenerateEntries: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
numToGenerate: z.number(),
|
||||
inputDescription: z.string(),
|
||||
outputDescription: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
const dataset = await prisma.dataset.findUnique({
|
||||
where: {
|
||||
id: input.datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dataset) {
|
||||
throw new Error(`Dataset with id ${input.datasetId} does not exist`);
|
||||
}
|
||||
|
||||
const entries = await autogenerateDatasetEntries(
|
||||
input.numToGenerate,
|
||||
input.inputDescription,
|
||||
input.outputDescription,
|
||||
);
|
||||
|
||||
const createdEntries = await prisma.datasetEntry.createMany({
|
||||
data: entries.map((entry) => ({
|
||||
datasetId: input.datasetId,
|
||||
input: entry.input,
|
||||
output: entry.output,
|
||||
})),
|
||||
});
|
||||
|
||||
return createdEntries;
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const datasetId = (
|
||||
await prisma.datasetEntry.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
})
|
||||
).datasetId;
|
||||
|
||||
await requireCanModifyDataset(datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.datasetEntry.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`dataEntry with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyDataset(existing.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
input: input.updates.input,
|
||||
output: input.updates.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
91
app/src/server/api/routers/datasets.router.ts
Normal file
91
app/src/server/api/routers/datasets.router.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import {
|
||||
requireCanModifyDataset,
|
||||
requireCanViewDataset,
|
||||
requireNothing,
|
||||
} from "~/utils/accessControl";
|
||||
import userOrg from "~/server/utils/userOrg";
|
||||
|
||||
export const datasetsRouter = createTRPCRouter({
|
||||
list: protectedProcedure.query(async ({ ctx }) => {
|
||||
// Anyone can list experiments
|
||||
requireNothing(ctx);
|
||||
|
||||
const datasets = await prisma.dataset.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: "desc",
|
||||
},
|
||||
include: {
|
||||
_count: {
|
||||
select: { datasetEntries: true },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return datasets;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.id, ctx);
|
||||
return await prisma.dataset.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
}),
|
||||
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
|
||||
const numDatasets = await prisma.dataset.count({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return await prisma.dataset.create({
|
||||
data: {
|
||||
name: `Dataset ${numDatasets + 1}`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
return await prisma.dataset.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
name: input.updates.name,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
|
||||
await prisma.dataset.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
94
app/src/server/api/routers/evaluations.router.ts
Normal file
94
app/src/server/api/routers/evaluations.router.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { EvalType } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { queueRunNewEval } from "~/server/tasks/runNewEval.task";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const evaluationsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
orderBy: { createdAt: "asc" },
|
||||
});
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
evalType: z.nativeEnum(EvalType),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: input.label,
|
||||
value: input.value,
|
||||
evalType: input.evalType,
|
||||
},
|
||||
});
|
||||
|
||||
await queueRunNewEval(input.experimentId);
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
label: z.string().optional(),
|
||||
value: z.string().optional(),
|
||||
evalType: z.nativeEnum(EvalType).optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const evaluation = await prisma.evaluation.update({
|
||||
where: { id: input.id },
|
||||
data: {
|
||||
label: input.updates.label,
|
||||
value: input.updates.value,
|
||||
evalType: input.updates.evalType,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.outputEvaluation.deleteMany({
|
||||
where: {
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
});
|
||||
// Re-run all evals. Other eval results will already be cached, so this
|
||||
// should only re-run the updated one.
|
||||
await queueRunNewEval(experimentId);
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.delete({
|
||||
where: { id: input.id },
|
||||
});
|
||||
}),
|
||||
});
|
||||
419
app/src/server/api/routers/experiments.router.ts
Normal file
419
app/src/server/api/routers/experiments.router.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { z } from "zod";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import {
|
||||
canModifyExperiment,
|
||||
requireCanModifyExperiment,
|
||||
requireCanViewExperiment,
|
||||
requireNothing,
|
||||
} from "~/utils/accessControl";
|
||||
import userOrg from "~/server/utils/userOrg";
|
||||
import generateTypes from "~/modelProviders/generateTypes";
|
||||
import { promptConstructorVersion } from "~/promptConstructor/version";
|
||||
|
||||
export const experimentsRouter = createTRPCRouter({
|
||||
stats: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
|
||||
const [experiment, promptVariantCount, testScenarioCount] = await prisma.$transaction([
|
||||
prisma.experiment.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
}),
|
||||
prisma.promptVariant.count({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
return {
|
||||
experimentLabel: experiment.label,
|
||||
promptVariantCount,
|
||||
testScenarioCount,
|
||||
};
|
||||
}),
|
||||
list: protectedProcedure.query(async ({ ctx }) => {
|
||||
// Anyone can list experiments
|
||||
requireNothing(ctx);
|
||||
|
||||
const experiments = await prisma.experiment.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: look for cleaner way to do this. Maybe aggregate?
|
||||
const experimentsWithCounts = await Promise.all(
|
||||
experiments.map(async (experiment) => {
|
||||
const visibleTestScenarioCount = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: experiment.id,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
const visiblePromptVariantCount = await prisma.promptVariant.count({
|
||||
where: {
|
||||
experimentId: experiment.id,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
...experiment,
|
||||
testScenarioCount: visibleTestScenarioCount,
|
||||
promptVariantCount: visiblePromptVariantCount,
|
||||
};
|
||||
}),
|
||||
);
|
||||
|
||||
return experimentsWithCounts;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
const experiment = await prisma.experiment.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
const canModify = ctx.session?.user.id
|
||||
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
|
||||
: false;
|
||||
|
||||
return {
|
||||
...experiment,
|
||||
access: {
|
||||
canView: true,
|
||||
canModify,
|
||||
},
|
||||
};
|
||||
}),
|
||||
|
||||
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
|
||||
const [
|
||||
existingExp,
|
||||
existingVariants,
|
||||
existingScenarios,
|
||||
existingCells,
|
||||
evaluations,
|
||||
templateVariables,
|
||||
] = await prisma.$transaction([
|
||||
prisma.experiment.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
promptVariant: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
include: {
|
||||
outputEvaluations: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
const newExperimentId = uuidv4();
|
||||
|
||||
const existingToNewVariantIds = new Map<string, string>();
|
||||
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
|
||||
for (const variant of existingVariants) {
|
||||
const newVariantId = uuidv4();
|
||||
existingToNewVariantIds.set(variant.id, newVariantId);
|
||||
variantsToCreate.push({
|
||||
...variant,
|
||||
id: newVariantId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewScenarioIds = new Map<string, string>();
|
||||
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
|
||||
for (const scenario of existingScenarios) {
|
||||
const newScenarioId = uuidv4();
|
||||
existingToNewScenarioIds.set(scenario.id, newScenarioId);
|
||||
scenariosToCreate.push({
|
||||
...scenario,
|
||||
id: newScenarioId,
|
||||
experimentId: newExperimentId,
|
||||
variableValues: scenario.variableValues as Prisma.InputJsonValue,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewEvaluationIds = new Map<string, string>();
|
||||
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
|
||||
for (const evaluation of evaluations) {
|
||||
const newEvaluationId = uuidv4();
|
||||
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
|
||||
evaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: newEvaluationId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
|
||||
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
|
||||
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
|
||||
for (const cell of existingCells) {
|
||||
const newCellId = uuidv4();
|
||||
const { modelResponses, ...cellData } = cell;
|
||||
cellsToCreate.push({
|
||||
...cellData,
|
||||
id: newCellId,
|
||||
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
|
||||
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
|
||||
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const modelResponse of modelResponses) {
|
||||
const newModelResponseId = uuidv4();
|
||||
const { outputEvaluations, ...modelResponseData } = modelResponse;
|
||||
modelResponsesToCreate.push({
|
||||
...modelResponseData,
|
||||
id: newModelResponseId,
|
||||
scenarioVariantCellId: newCellId,
|
||||
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const evaluation of outputEvaluations) {
|
||||
outputEvaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: uuidv4(),
|
||||
modelResponseId: newModelResponseId,
|
||||
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
|
||||
for (const templateVariable of templateVariables) {
|
||||
templateVariablesToCreate.push({
|
||||
...templateVariable,
|
||||
id: uuidv4(),
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
await prisma.$transaction([
|
||||
prisma.experiment.create({
|
||||
data: {
|
||||
id: newExperimentId,
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `${existingExp.label} (forked)`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.createMany({
|
||||
data: variantsToCreate,
|
||||
}),
|
||||
prisma.testScenario.createMany({
|
||||
data: scenariosToCreate,
|
||||
}),
|
||||
prisma.scenarioVariantCell.createMany({
|
||||
data: cellsToCreate,
|
||||
}),
|
||||
prisma.modelResponse.createMany({
|
||||
data: modelResponsesToCreate,
|
||||
}),
|
||||
prisma.evaluation.createMany({
|
||||
data: evaluationsToCreate,
|
||||
}),
|
||||
prisma.outputEvaluation.createMany({
|
||||
data: outputEvaluationsToCreate,
|
||||
}),
|
||||
prisma.templateVariable.createMany({
|
||||
data: templateVariablesToCreate,
|
||||
}),
|
||||
]);
|
||||
|
||||
return newExperimentId;
|
||||
}),
|
||||
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
|
||||
const organizationId = (await userOrg(ctx.session.user.id)).id;
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
where: { organizationId },
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
const exp = await prisma.experiment.create({
|
||||
data: {
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `Experiment ${maxSortIndex + 1}`,
|
||||
organizationId,
|
||||
},
|
||||
});
|
||||
|
||||
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||
prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "Prompt Variant 1",
|
||||
sortIndex: 0,
|
||||
// The interpolated $ is necessary until dedent incorporates
|
||||
// https://github.com/dmnd/dedent/pull/46
|
||||
promptConstructor: dedent`
|
||||
/**
|
||||
* Use Javascript to define an OpenAI chat completion
|
||||
* (https://platform.openai.com/docs/api-reference/chat/create).
|
||||
*
|
||||
* You have access to the current scenario in the \`scenario\`
|
||||
* variable.
|
||||
*/
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});`,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
modelProvider: "openai/ChatCompletion",
|
||||
promptConstructorVersion,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "language",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "English",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "Spanish",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "German",
|
||||
},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
await generateNewCell(variant.id, scenario1.id);
|
||||
await generateNewCell(variant.id, scenario2.id);
|
||||
await generateNewCell(variant.id, scenario3.id);
|
||||
|
||||
return exp;
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
return await prisma.experiment.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
label: input.updates.label,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
|
||||
await prisma.experiment.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
// Keeping these on `experiment` for now because we might want to limit the
|
||||
// providers based on your account/experiment
|
||||
promptTypes: publicProcedure.query(async () => {
|
||||
return await generateTypes();
|
||||
}),
|
||||
});
|
||||
419
app/src/server/api/routers/promptVariants.router.ts
Normal file
419
app/src/server/api/routers/promptVariants.router.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { ZodSupportedProvider } from "~/modelProviders/types";
|
||||
import parsePromptConstructor from "~/promptConstructor/parse";
|
||||
import { promptConstructorVersion } from "~/promptConstructor/version";
|
||||
|
||||
export const promptVariantsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: { sortIndex: "asc" },
|
||||
});
|
||||
}),
|
||||
|
||||
stats: publicProcedure
|
||||
.input(z.object({ variantId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant) {
|
||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanViewExperiment(variant.experimentId, ctx);
|
||||
|
||||
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||
by: ["evaluationId"],
|
||||
_sum: {
|
||||
result: true,
|
||||
},
|
||||
_count: {
|
||||
id: true,
|
||||
},
|
||||
where: {
|
||||
modelResponse: {
|
||||
outdated: false,
|
||||
output: { not: Prisma.AnyNull },
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
id: input.variantId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
const evalResults = evals.map((evalItem) => {
|
||||
const evalResult = outputEvals.find(
|
||||
(outputEval) => outputEval.evaluationId === evalItem.id,
|
||||
);
|
||||
return {
|
||||
id: evalItem.id,
|
||||
label: evalItem.label,
|
||||
passCount: evalResult?._sum?.result ?? 0,
|
||||
totalCount: evalResult?._count?.id ?? 1,
|
||||
};
|
||||
});
|
||||
|
||||
const scenarioCount = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
const outputCount = await prisma.scenarioVariantCell.count({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelResponses: {
|
||||
some: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const overallTokens = await prisma.modelResponse.aggregate({
|
||||
where: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
scenarioVariantCell: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
_sum: {
|
||||
cost: true,
|
||||
promptTokens: true,
|
||||
completionTokens: true,
|
||||
},
|
||||
});
|
||||
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
|
||||
const awaitingEvals = !!evalResults.find(
|
||||
(result) => result.totalCount < scenarioCount * evals.length,
|
||||
);
|
||||
|
||||
return {
|
||||
evalResults,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
overallCost: overallTokens._sum?.cost ?? 0,
|
||||
scenarioCount,
|
||||
outputCount,
|
||||
awaitingEvals,
|
||||
};
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
variantId: z.string().optional(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
let originalVariant: PromptVariant | null = null;
|
||||
if (input.variantId) {
|
||||
originalVariant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
originalVariant = await prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const largestSortIndex =
|
||||
(
|
||||
await prisma.promptVariant.aggregate({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
const newVariantLabel =
|
||||
input.variantId && originalVariant
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: newVariantLabel,
|
||||
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||
promptConstructor: newConstructFn,
|
||||
promptConstructorVersion:
|
||||
originalVariant?.promptConstructorVersion ?? promptConstructorVersion,
|
||||
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||
modelProvider: originalVariant?.modelProvider ?? "openai/ChatCompletion",
|
||||
},
|
||||
});
|
||||
|
||||
const [newVariant] = await prisma.$transaction([
|
||||
createNewVariantAction,
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
if (originalVariant) {
|
||||
// Insert new variant to right of original variant
|
||||
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
|
||||
}
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return newVariant;
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
label: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const updatePromptVariantAction = prisma.promptVariant.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: input.updates,
|
||||
});
|
||||
|
||||
const [updatedPromptVariant] = await prisma.$transaction([
|
||||
updatePromptVariantAction,
|
||||
recordExperimentUpdated(existing.experimentId),
|
||||
]);
|
||||
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
hide: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const updatedPromptVariant = await prisma.promptVariant.update({
|
||||
where: { id: input.id },
|
||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||
});
|
||||
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
getModifiedPromptFn: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
newModel: z
|
||||
.object({
|
||||
provider: ZodSupportedProvider,
|
||||
model: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const constructedPrompt = await parsePromptConstructor(existing.promptConstructor);
|
||||
|
||||
if ("error" in constructedPrompt) {
|
||||
return userError(constructedPrompt.error);
|
||||
}
|
||||
|
||||
const model = input.newModel
|
||||
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||
: undefined;
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||
|
||||
// TODO: Validate promptConstructionFn
|
||||
// TODO: Record in some sort of history
|
||||
|
||||
return promptConstructionFn;
|
||||
}),
|
||||
|
||||
replaceVariant: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
promptConstructor: z.string(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
const parsedPrompt = await parsePromptConstructor(input.promptConstructor);
|
||||
|
||||
if ("error" in parsedPrompt) {
|
||||
return userError(parsedPrompt.error);
|
||||
}
|
||||
|
||||
// Create a duplicate with only the config changed
|
||||
const newVariant = await prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: existing.experimentId,
|
||||
label: existing.label,
|
||||
sortIndex: existing.sortIndex,
|
||||
uiId: existing.uiId,
|
||||
promptConstructor: input.promptConstructor,
|
||||
promptConstructorVersion: existing.promptConstructorVersion,
|
||||
modelProvider: parsedPrompt.modelProvider,
|
||||
model: parsedPrompt.model,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide anything with the same uiId besides the new one
|
||||
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
not: newVariant.id,
|
||||
},
|
||||
},
|
||||
data: {
|
||||
visible: false,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: newVariant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return { status: "ok" } as const;
|
||||
}),
|
||||
|
||||
reorder: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
draggedId: z.string(),
|
||||
droppedId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.draggedId },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await reorderPromptVariants(input.draggedId, input.droppedId);
|
||||
}),
|
||||
});
|
||||
99
app/src/server/api/routers/scenarioVariantCells.router.ts
Normal file
99
app/src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { queueQueryModel } from "~/server/tasks/queryModel.task";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
await requireCanViewExperiment(experimentId, ctx);
|
||||
|
||||
const [cell, numTotalEvals] = await prisma.$transaction([
|
||||
prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
where: {
|
||||
outdated: false,
|
||||
},
|
||||
include: {
|
||||
outputEvaluations: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.evaluation.count({
|
||||
where: { experimentId },
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!cell) return null;
|
||||
|
||||
const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1];
|
||||
const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals;
|
||||
|
||||
return {
|
||||
...cell,
|
||||
evalsComplete,
|
||||
};
|
||||
}),
|
||||
forceRefetch: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!cell) {
|
||||
await generateNewCell(input.variantId, input.scenarioId, { stream: true });
|
||||
return;
|
||||
}
|
||||
|
||||
await prisma.modelResponse.updateMany({
|
||||
where: { scenarioVariantCellId: cell.id },
|
||||
data: {
|
||||
outdated: true,
|
||||
},
|
||||
});
|
||||
|
||||
await queueQueryModel(cell.id, true);
|
||||
}),
|
||||
});
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user