Add Datasets (#118)
* Add dataset (without entries) * Fix dataset hook * Add dataset rows * Add buttons to import/generate data * Add GenerateDataModal * Autogenerate and save data * Fix prettier * Fix types * Add dataset pagination * Fix prettier * Use useDisclosure * Allow generate data modal fadeaway * hide/show data in env var * Fix prettier
This commit is contained in:
2
@types/nextjs-routes.d.ts
vendored
2
@types/nextjs-routes.d.ts
vendored
@@ -16,6 +16,8 @@ declare module "nextjs-routes" {
|
|||||||
| StaticRoute<"/api/experiments/og-image">
|
| StaticRoute<"/api/experiments/og-image">
|
||||||
| StaticRoute<"/api/sentry-example-api">
|
| StaticRoute<"/api/sentry-example-api">
|
||||||
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
|
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
|
||||||
|
| DynamicRoute<"/data/[id]", { "id": string }>
|
||||||
|
| StaticRoute<"/data">
|
||||||
| DynamicRoute<"/experiments/[id]", { "id": string }>
|
| DynamicRoute<"/experiments/[id]", { "id": string }>
|
||||||
| StaticRoute<"/experiments">
|
| StaticRoute<"/experiments">
|
||||||
| StaticRoute<"/">
|
| StaticRoute<"/">
|
||||||
|
|||||||
28
prisma/migrations/20230804042305_add_datasets/migration.sql
Normal file
28
prisma/migrations/20230804042305_add_datasets/migration.sql
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Dataset" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"organizationId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "Dataset_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "DatasetEntry" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"input" TEXT NOT NULL,
|
||||||
|
"output" TEXT,
|
||||||
|
"datasetId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "DatasetEntry_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Dataset" ADD CONSTRAINT "Dataset_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "DatasetEntry" ADD CONSTRAINT "DatasetEntry_datasetId_fkey" FOREIGN KEY ("datasetId") REFERENCES "Dataset"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -174,6 +174,32 @@ model OutputEvaluation {
|
|||||||
@@unique([modelResponseId, evaluationId])
|
@@unique([modelResponseId, evaluationId])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model Dataset {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
name String
|
||||||
|
datasetEntries DatasetEntry[]
|
||||||
|
|
||||||
|
organizationId String @db.Uuid
|
||||||
|
organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
model DatasetEntry {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
input String
|
||||||
|
output String?
|
||||||
|
|
||||||
|
datasetId String @db.Uuid
|
||||||
|
dataset Dataset? @relation(fields: [datasetId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
}
|
||||||
|
|
||||||
model Organization {
|
model Organization {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
personalOrgUserId String? @unique @db.Uuid
|
personalOrgUserId String? @unique @db.Uuid
|
||||||
@@ -183,6 +209,7 @@ model Organization {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
organizationUsers OrganizationUser[]
|
organizationUsers OrganizationUser[]
|
||||||
experiments Experiment[]
|
experiments Experiment[]
|
||||||
|
datasets Dataset[]
|
||||||
}
|
}
|
||||||
|
|
||||||
enum OrganizationUserRole {
|
enum OrganizationUserRole {
|
||||||
|
|||||||
@@ -1,18 +1,29 @@
|
|||||||
import { Button, Spinner, InputGroup, InputRightElement, Icon, HStack } from "@chakra-ui/react";
|
import {
|
||||||
|
Button,
|
||||||
|
Spinner,
|
||||||
|
InputGroup,
|
||||||
|
InputRightElement,
|
||||||
|
Icon,
|
||||||
|
HStack,
|
||||||
|
type InputGroupProps,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { IoMdSend } from "react-icons/io";
|
import { IoMdSend } from "react-icons/io";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import AutoResizeTextArea from "./AutoResizeTextArea";
|
||||||
|
|
||||||
export const CustomInstructionsInput = ({
|
export const CustomInstructionsInput = ({
|
||||||
instructions,
|
instructions,
|
||||||
setInstructions,
|
setInstructions,
|
||||||
loading,
|
loading,
|
||||||
onSubmit,
|
onSubmit,
|
||||||
|
placeholder = "Send custom instructions",
|
||||||
|
...props
|
||||||
}: {
|
}: {
|
||||||
instructions: string;
|
instructions: string;
|
||||||
setInstructions: (instructions: string) => void;
|
setInstructions: (instructions: string) => void;
|
||||||
loading: boolean;
|
loading: boolean;
|
||||||
onSubmit: () => void;
|
onSubmit: () => void;
|
||||||
}) => {
|
placeholder?: string;
|
||||||
|
} & InputGroupProps) => {
|
||||||
return (
|
return (
|
||||||
<InputGroup
|
<InputGroup
|
||||||
size="md"
|
size="md"
|
||||||
@@ -22,6 +33,7 @@ export const CustomInstructionsInput = ({
|
|||||||
borderRadius={8}
|
borderRadius={8}
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
colorScheme="orange"
|
colorScheme="orange"
|
||||||
|
{...props}
|
||||||
>
|
>
|
||||||
<AutoResizeTextArea
|
<AutoResizeTextArea
|
||||||
value={instructions}
|
value={instructions}
|
||||||
@@ -33,7 +45,7 @@ export const CustomInstructionsInput = ({
|
|||||||
onSubmit();
|
onSubmit();
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
placeholder="Send custom instructions"
|
placeholder={placeholder}
|
||||||
py={4}
|
py={4}
|
||||||
pl={4}
|
pl={4}
|
||||||
pr={12}
|
pr={12}
|
||||||
@@ -1,73 +1,20 @@
|
|||||||
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
import { useScenarios } from "~/utils/hooks";
|
||||||
import {
|
import Paginator from "../Paginator";
|
||||||
BsChevronDoubleLeft,
|
|
||||||
BsChevronDoubleRight,
|
|
||||||
BsChevronLeft,
|
|
||||||
BsChevronRight,
|
|
||||||
} from "react-icons/bs";
|
|
||||||
import { usePage, useScenarios } from "~/utils/hooks";
|
|
||||||
|
|
||||||
const ScenarioPaginator = () => {
|
const ScenarioPaginator = () => {
|
||||||
const [page, setPage] = usePage();
|
|
||||||
const { data } = useScenarios();
|
const { data } = useScenarios();
|
||||||
|
|
||||||
if (!data) return null;
|
if (!data) return null;
|
||||||
|
|
||||||
const { scenarios, startIndex, lastPage, count } = data;
|
const { scenarios, startIndex, lastPage, count } = data;
|
||||||
|
|
||||||
const nextPage = () => {
|
|
||||||
if (page < lastPage) {
|
|
||||||
setPage(page + 1, "replace");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const prevPage = () => {
|
|
||||||
if (page > 1) {
|
|
||||||
setPage(page - 1, "replace");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const goToLastPage = () => setPage(lastPage, "replace");
|
|
||||||
const goToFirstPage = () => setPage(1, "replace");
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack pt={4}>
|
<Paginator
|
||||||
<IconButton
|
numItemsLoaded={scenarios.length}
|
||||||
variant="ghost"
|
startIndex={startIndex}
|
||||||
size="sm"
|
lastPage={lastPage}
|
||||||
onClick={goToFirstPage}
|
count={count}
|
||||||
isDisabled={page === 1}
|
/>
|
||||||
aria-label="Go to first page"
|
|
||||||
icon={<BsChevronDoubleLeft />}
|
|
||||||
/>
|
|
||||||
<IconButton
|
|
||||||
variant="ghost"
|
|
||||||
size="sm"
|
|
||||||
onClick={prevPage}
|
|
||||||
isDisabled={page === 1}
|
|
||||||
aria-label="Previous page"
|
|
||||||
icon={<BsChevronLeft />}
|
|
||||||
/>
|
|
||||||
<Box>
|
|
||||||
{startIndex}-{startIndex + scenarios.length - 1} / {count}
|
|
||||||
</Box>
|
|
||||||
<IconButton
|
|
||||||
variant="ghost"
|
|
||||||
size="sm"
|
|
||||||
onClick={nextPage}
|
|
||||||
isDisabled={page === lastPage}
|
|
||||||
aria-label="Next page"
|
|
||||||
icon={<BsChevronRight />}
|
|
||||||
/>
|
|
||||||
<IconButton
|
|
||||||
variant="ghost"
|
|
||||||
size="sm"
|
|
||||||
onClick={goToLastPage}
|
|
||||||
isDisabled={page === lastPage}
|
|
||||||
aria-label="Go to last page"
|
|
||||||
icon={<BsChevronDoubleRight />}
|
|
||||||
/>
|
|
||||||
</HStack>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
79
src/components/Paginator.tsx
Normal file
79
src/components/Paginator.tsx
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import { Box, HStack, IconButton } from "@chakra-ui/react";
|
||||||
|
import {
|
||||||
|
BsChevronDoubleLeft,
|
||||||
|
BsChevronDoubleRight,
|
||||||
|
BsChevronLeft,
|
||||||
|
BsChevronRight,
|
||||||
|
} from "react-icons/bs";
|
||||||
|
import { usePage } from "~/utils/hooks";
|
||||||
|
|
||||||
|
const Paginator = ({
|
||||||
|
numItemsLoaded,
|
||||||
|
startIndex,
|
||||||
|
lastPage,
|
||||||
|
count,
|
||||||
|
}: {
|
||||||
|
numItemsLoaded: number;
|
||||||
|
startIndex: number;
|
||||||
|
lastPage: number;
|
||||||
|
count: number;
|
||||||
|
}) => {
|
||||||
|
const [page, setPage] = usePage();
|
||||||
|
|
||||||
|
const nextPage = () => {
|
||||||
|
if (page < lastPage) {
|
||||||
|
setPage(page + 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const prevPage = () => {
|
||||||
|
if (page > 1) {
|
||||||
|
setPage(page - 1, "replace");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const goToLastPage = () => setPage(lastPage, "replace");
|
||||||
|
const goToFirstPage = () => setPage(1, "replace");
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack pt={4}>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToFirstPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Go to first page"
|
||||||
|
icon={<BsChevronDoubleLeft />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={prevPage}
|
||||||
|
isDisabled={page === 1}
|
||||||
|
aria-label="Previous page"
|
||||||
|
icon={<BsChevronLeft />}
|
||||||
|
/>
|
||||||
|
<Box>
|
||||||
|
{startIndex}-{startIndex + numItemsLoaded - 1} / {count}
|
||||||
|
</Box>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={nextPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Next page"
|
||||||
|
icon={<BsChevronRight />}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={goToLastPage}
|
||||||
|
isDisabled={page === lastPage}
|
||||||
|
aria-label="Go to last page"
|
||||||
|
icon={<BsChevronDoubleRight />}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Paginator;
|
||||||
@@ -20,7 +20,7 @@ import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
|||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import CompareFunctions from "./CompareFunctions";
|
import CompareFunctions from "./CompareFunctions";
|
||||||
import { CustomInstructionsInput } from "./CustomInstructionsInput";
|
import { CustomInstructionsInput } from "../CustomInstructionsInput";
|
||||||
import { RefineAction } from "./RefineAction";
|
import { RefineAction } from "./RefineAction";
|
||||||
import { isObject, isString } from "lodash-es";
|
import { isObject, isString } from "lodash-es";
|
||||||
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
|
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
|
||||||
@@ -122,7 +122,7 @@ export const RefinePromptModal = ({
|
|||||||
instructions={instructions}
|
instructions={instructions}
|
||||||
setInstructions={setInstructions}
|
setInstructions={setInstructions}
|
||||||
loading={modificationInProgress}
|
loading={modificationInProgress}
|
||||||
onSubmit={getModifiedPromptFn}
|
onSubmit={() => getModifiedPromptFn()}
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
<CompareFunctions
|
<CompareFunctions
|
||||||
|
|||||||
110
src/components/datasets/DatasetCard.tsx
Normal file
110
src/components/datasets/DatasetCard.tsx
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import {
|
||||||
|
HStack,
|
||||||
|
Icon,
|
||||||
|
VStack,
|
||||||
|
Text,
|
||||||
|
Divider,
|
||||||
|
Spinner,
|
||||||
|
AspectRatio,
|
||||||
|
SkeletonText,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { RiDatabase2Line } from "react-icons/ri";
|
||||||
|
import { formatTimePast } from "~/utils/dayjs";
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { BsPlusSquare } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
|
type DatasetData = {
|
||||||
|
name: string;
|
||||||
|
numEntries: number;
|
||||||
|
id: string;
|
||||||
|
createdAt: Date;
|
||||||
|
updatedAt: Date;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const DatasetCard = ({ dataset }: { dataset: DatasetData }) => {
|
||||||
|
return (
|
||||||
|
<AspectRatio ratio={1.2} w="full">
|
||||||
|
<VStack
|
||||||
|
as={Link}
|
||||||
|
href={{ pathname: "/data/[id]", query: { id: dataset.id } }}
|
||||||
|
bg="gray.50"
|
||||||
|
_hover={{ bg: "gray.100" }}
|
||||||
|
transition="background 0.2s"
|
||||||
|
cursor="pointer"
|
||||||
|
borderColor="gray.200"
|
||||||
|
borderWidth={1}
|
||||||
|
p={4}
|
||||||
|
justify="space-between"
|
||||||
|
>
|
||||||
|
<HStack w="full" color="gray.700" justify="center">
|
||||||
|
<Icon as={RiDatabase2Line} boxSize={4} />
|
||||||
|
<Text fontWeight="bold">{dataset.name}</Text>
|
||||||
|
</HStack>
|
||||||
|
<HStack h="full" spacing={4} flex={1} align="center">
|
||||||
|
<CountLabel label="Rows" count={dataset.numEntries} />
|
||||||
|
</HStack>
|
||||||
|
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
|
||||||
|
<Text flex={1}>Created {formatTimePast(dataset.createdAt)}</Text>
|
||||||
|
<Divider h={4} orientation="vertical" />
|
||||||
|
<Text flex={1}>Updated {formatTimePast(dataset.updatedAt)}</Text>
|
||||||
|
</HStack>
|
||||||
|
</VStack>
|
||||||
|
</AspectRatio>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||||
|
return (
|
||||||
|
<VStack alignItems="center" flex={1}>
|
||||||
|
<Text color="gray.500" fontWeight="bold">
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
<Text fontSize="sm" color="gray.500">
|
||||||
|
{count}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const NewDatasetCard = () => {
|
||||||
|
const router = useRouter();
|
||||||
|
const createMutation = api.datasets.create.useMutation();
|
||||||
|
const [createDataset, isLoading] = useHandledAsyncCallback(async () => {
|
||||||
|
const newDataset = await createMutation.mutateAsync({ label: "New Dataset" });
|
||||||
|
await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } });
|
||||||
|
}, [createMutation, router]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AspectRatio ratio={1.2} w="full">
|
||||||
|
<VStack
|
||||||
|
align="center"
|
||||||
|
justify="center"
|
||||||
|
_hover={{ cursor: "pointer", bg: "gray.50" }}
|
||||||
|
transition="background 0.2s"
|
||||||
|
cursor="pointer"
|
||||||
|
borderColor="gray.200"
|
||||||
|
borderWidth={1}
|
||||||
|
p={4}
|
||||||
|
onClick={createDataset}
|
||||||
|
>
|
||||||
|
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||||
|
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||||
|
New Dataset
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
</AspectRatio>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const DatasetCardSkeleton = () => (
|
||||||
|
<AspectRatio ratio={1.2} w="full">
|
||||||
|
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
|
||||||
|
<SkeletonText noOfLines={1} w="80%" />
|
||||||
|
<SkeletonText noOfLines={2} w="60%" />
|
||||||
|
<SkeletonText noOfLines={1} w="80%" />
|
||||||
|
</VStack>
|
||||||
|
</AspectRatio>
|
||||||
|
);
|
||||||
21
src/components/datasets/DatasetEntriesPaginator.tsx
Normal file
21
src/components/datasets/DatasetEntriesPaginator.tsx
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import { useDatasetEntries } from "~/utils/hooks";
|
||||||
|
import Paginator from "../Paginator";
|
||||||
|
|
||||||
|
const DatasetEntriesPaginator = () => {
|
||||||
|
const { data } = useDatasetEntries();
|
||||||
|
|
||||||
|
if (!data) return null;
|
||||||
|
|
||||||
|
const { entries, startIndex, lastPage, count } = data;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Paginator
|
||||||
|
numItemsLoaded={entries.length}
|
||||||
|
startIndex={startIndex}
|
||||||
|
lastPage={lastPage}
|
||||||
|
count={count}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default DatasetEntriesPaginator;
|
||||||
43
src/components/datasets/DatasetEntriesTable.tsx
Normal file
43
src/components/datasets/DatasetEntriesTable.tsx
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import {
|
||||||
|
type StackProps,
|
||||||
|
VStack,
|
||||||
|
Table,
|
||||||
|
Th,
|
||||||
|
Tr,
|
||||||
|
Thead,
|
||||||
|
Tbody,
|
||||||
|
Text,
|
||||||
|
HStack,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { useDatasetEntries } from "~/utils/hooks";
|
||||||
|
import TableRow from "./TableRow";
|
||||||
|
import DatasetEntriesPaginator from "./DatasetEntriesPaginator";
|
||||||
|
|
||||||
|
const DatasetEntriesTable = (props: StackProps) => {
|
||||||
|
const { data } = useDatasetEntries();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack justifyContent="space-between" {...props}>
|
||||||
|
<Table variant="simple" sx={{ "table-layout": "fixed", width: "full" }}>
|
||||||
|
<Thead>
|
||||||
|
<Tr>
|
||||||
|
<Th>Input</Th>
|
||||||
|
<Th>Output</Th>
|
||||||
|
</Tr>
|
||||||
|
</Thead>
|
||||||
|
<Tbody>{data?.entries.map((entry) => <TableRow key={entry.id} entry={entry} />)}</Tbody>
|
||||||
|
</Table>
|
||||||
|
{(!data || data.entries.length) === 0 ? (
|
||||||
|
<Text alignSelf="flex-start" pl={6} color="gray.500">
|
||||||
|
No entries found
|
||||||
|
</Text>
|
||||||
|
) : (
|
||||||
|
<HStack justifyContent="flex-start">
|
||||||
|
<DatasetEntriesPaginator />
|
||||||
|
</HStack>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default DatasetEntriesTable;
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import { Button, HStack, useDisclosure } from "@chakra-ui/react";
|
||||||
|
import { BiImport } from "react-icons/bi";
|
||||||
|
import { BsStars } from "react-icons/bs";
|
||||||
|
|
||||||
|
import { GenerateDataModal } from "./GenerateDataModal";
|
||||||
|
|
||||||
|
export const DatasetHeaderButtons = () => {
|
||||||
|
const generateModalDisclosure = useDisclosure();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<HStack>
|
||||||
|
<Button leftIcon={<BiImport />} colorScheme="blue" variant="ghost">
|
||||||
|
Import Data
|
||||||
|
</Button>
|
||||||
|
<Button leftIcon={<BsStars />} colorScheme="blue" onClick={generateModalDisclosure.onOpen}>
|
||||||
|
Generate Data
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
<GenerateDataModal
|
||||||
|
isOpen={generateModalDisclosure.isOpen}
|
||||||
|
onClose={generateModalDisclosure.onClose}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
import {
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
ModalFooter,
|
||||||
|
Text,
|
||||||
|
HStack,
|
||||||
|
VStack,
|
||||||
|
Icon,
|
||||||
|
NumberInput,
|
||||||
|
NumberInputField,
|
||||||
|
NumberInputStepper,
|
||||||
|
NumberIncrementStepper,
|
||||||
|
NumberDecrementStepper,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { BsStars } from "react-icons/bs";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { CustomInstructionsInput } from "~/components/CustomInstructionsInput";
|
||||||
|
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
|
||||||
|
export const GenerateDataModal = ({
|
||||||
|
isOpen,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
isOpen: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
}) => {
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const datasetId = useDataset().data?.id;
|
||||||
|
const [instructions, setInstructions] = useState<string>(
|
||||||
|
"Each row should contain an email body. Half of the emails should contain event details, and the other half should not.",
|
||||||
|
);
|
||||||
|
const [numToGenerate, setNumToGenerate] = useState<number>(20);
|
||||||
|
|
||||||
|
const generateInputsMutation = api.datasetEntries.autogenerateInputs.useMutation();
|
||||||
|
|
||||||
|
const [generateEntries, generateEntriesInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!instructions || !numToGenerate || !datasetId) return;
|
||||||
|
await generateInputsMutation.mutateAsync({
|
||||||
|
datasetId,
|
||||||
|
instructions,
|
||||||
|
numToGenerate,
|
||||||
|
});
|
||||||
|
await utils.datasetEntries.list.invalidate();
|
||||||
|
onClose();
|
||||||
|
}, [generateInputsMutation, onClose, instructions, numToGenerate, datasetId]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal isOpen={isOpen} onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w={1200}>
|
||||||
|
<ModalHeader>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={BsStars} />
|
||||||
|
<Text>Generate Data</Text>
|
||||||
|
</HStack>
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody maxW="unset">
|
||||||
|
<VStack w="full" spacing={8} padding={8} alignItems="flex-start">
|
||||||
|
<VStack alignItems="flex-start" spacing={2}>
|
||||||
|
<Text fontWeight="bold">Number of Rows:</Text>
|
||||||
|
<NumberInput
|
||||||
|
step={5}
|
||||||
|
defaultValue={15}
|
||||||
|
min={0}
|
||||||
|
max={100}
|
||||||
|
onChange={(valueString) => setNumToGenerate(parseInt(valueString) || 0)}
|
||||||
|
value={numToGenerate}
|
||||||
|
w="24"
|
||||||
|
>
|
||||||
|
<NumberInputField />
|
||||||
|
<NumberInputStepper>
|
||||||
|
<NumberIncrementStepper />
|
||||||
|
<NumberDecrementStepper />
|
||||||
|
</NumberInputStepper>
|
||||||
|
</NumberInput>
|
||||||
|
</VStack>
|
||||||
|
<VStack alignItems="flex-start" w="full" spacing={2}>
|
||||||
|
<Text fontWeight="bold">Row Description:</Text>
|
||||||
|
<CustomInstructionsInput
|
||||||
|
instructions={instructions}
|
||||||
|
setInstructions={setInstructions}
|
||||||
|
onSubmit={generateEntries}
|
||||||
|
loading={generateEntriesInProgress}
|
||||||
|
placeholder="Each row should contain..."
|
||||||
|
/>
|
||||||
|
</VStack>
|
||||||
|
</VStack>
|
||||||
|
</ModalBody>
|
||||||
|
<ModalFooter />
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
13
src/components/datasets/TableRow.tsx
Normal file
13
src/components/datasets/TableRow.tsx
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import { Td, Tr } from "@chakra-ui/react";
|
||||||
|
import { type DatasetEntry } from "@prisma/client";
|
||||||
|
|
||||||
|
const TableRow = ({ entry }: { entry: DatasetEntry }) => {
|
||||||
|
return (
|
||||||
|
<Tr key={entry.id}>
|
||||||
|
<Td>{entry.input}</Td>
|
||||||
|
<Td>{entry.output}</Td>
|
||||||
|
</Tr>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default TableRow;
|
||||||
@@ -5,7 +5,7 @@ import { BsGearFill } from "react-icons/bs";
|
|||||||
import { TbGitFork } from "react-icons/tb";
|
import { TbGitFork } from "react-icons/tb";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
|
||||||
export const HeaderButtons = () => {
|
export const ExperimentHeaderButtons = () => {
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
|
|
||||||
const canModify = experiment.data?.access.canModify ?? false;
|
const canModify = experiment.data?.access.canModify ?? false;
|
||||||
@@ -8,42 +8,43 @@ import {
|
|||||||
Text,
|
Text,
|
||||||
Box,
|
Box,
|
||||||
type BoxProps,
|
type BoxProps,
|
||||||
type LinkProps,
|
Link as ChakraLink,
|
||||||
Link,
|
|
||||||
Flex,
|
Flex,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
|
import Link, { type LinkProps } from "next/link";
|
||||||
import { BsGithub, BsPersonCircle } from "react-icons/bs";
|
import { BsGithub, BsPersonCircle } from "react-icons/bs";
|
||||||
import { useRouter } from "next/router";
|
import { useRouter } from "next/router";
|
||||||
import { type IconType } from "react-icons";
|
import { type IconType } from "react-icons";
|
||||||
import { RiFlaskLine } from "react-icons/ri";
|
import { RiDatabase2Line, RiFlaskLine } from "react-icons/ri";
|
||||||
import { signIn, useSession } from "next-auth/react";
|
import { signIn, useSession } from "next-auth/react";
|
||||||
import UserMenu from "./UserMenu";
|
import UserMenu from "./UserMenu";
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
|
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType; href: string };
|
||||||
|
|
||||||
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
|
const IconLink = ({ icon, label, href, color, ...props }: IconLinkProps) => {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const isActive = href && router.pathname.startsWith(href);
|
const isActive = href && router.pathname.startsWith(href);
|
||||||
return (
|
return (
|
||||||
<HStack
|
<Link href={href} style={{ width: "100%" }}>
|
||||||
w="full"
|
<HStack
|
||||||
p={4}
|
w="full"
|
||||||
color={color}
|
p={4}
|
||||||
as={Link}
|
color={color}
|
||||||
href={href}
|
as={ChakraLink}
|
||||||
target={target}
|
bgColor={isActive ? "gray.200" : "transparent"}
|
||||||
bgColor={isActive ? "gray.200" : "transparent"}
|
_hover={{ bgColor: "gray.300", textDecoration: "none" }}
|
||||||
_hover={{ bgColor: "gray.200", textDecoration: "none" }}
|
justifyContent="start"
|
||||||
justifyContent="start"
|
cursor="pointer"
|
||||||
cursor="pointer"
|
{...props}
|
||||||
{...props}
|
>
|
||||||
>
|
<Icon as={icon} boxSize={6} mr={2} />
|
||||||
<Icon as={icon} boxSize={6} mr={2} />
|
<Text fontWeight="bold" fontSize="sm">
|
||||||
<Text fontWeight="bold" fontSize="sm">
|
{label}
|
||||||
{label}
|
</Text>
|
||||||
</Text>
|
</HStack>
|
||||||
</HStack>
|
</Link>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -72,16 +73,28 @@ const NavSidebar = () => {
|
|||||||
{user != null && (
|
{user != null && (
|
||||||
<>
|
<>
|
||||||
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
||||||
|
{env.NEXT_PUBLIC_SHOW_DATA && (
|
||||||
|
<IconLink icon={RiDatabase2Line} label="Data" href="/data" />
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{user === null && (
|
{user === null && (
|
||||||
<IconLink
|
<HStack
|
||||||
icon={BsPersonCircle}
|
w="full"
|
||||||
label="Sign In"
|
p={4}
|
||||||
|
as={ChakraLink}
|
||||||
|
_hover={{ bgColor: "gray.300", textDecoration: "none" }}
|
||||||
|
justifyContent="start"
|
||||||
|
cursor="pointer"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
signIn("github").catch(console.error);
|
signIn("github").catch(console.error);
|
||||||
}}
|
}}
|
||||||
/>
|
>
|
||||||
|
<Icon as={BsPersonCircle} boxSize={6} mr={2} />
|
||||||
|
<Text fontWeight="bold" fontSize="sm">
|
||||||
|
Sign In
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
)}
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
{user ? (
|
{user ? (
|
||||||
@@ -90,7 +103,7 @@ const NavSidebar = () => {
|
|||||||
<Divider />
|
<Divider />
|
||||||
)}
|
)}
|
||||||
<VStack spacing={0} align="center">
|
<VStack spacing={0} align="center">
|
||||||
<Link
|
<ChakraLink
|
||||||
href="https://github.com/openpipe/openpipe"
|
href="https://github.com/openpipe/openpipe"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
color="gray.500"
|
color="gray.500"
|
||||||
@@ -98,7 +111,7 @@ const NavSidebar = () => {
|
|||||||
p={2}
|
p={2}
|
||||||
>
|
>
|
||||||
<Icon as={BsGithub} boxSize={6} />
|
<Icon as={BsGithub} boxSize={6} />
|
||||||
</Link>
|
</ChakraLink>
|
||||||
</VStack>
|
</VStack>
|
||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ export const env = createEnv({
|
|||||||
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
||||||
NEXT_PUBLIC_HOST: z.string().url().default("http://localhost:3000"),
|
NEXT_PUBLIC_HOST: z.string().url().default("http://localhost:3000"),
|
||||||
NEXT_PUBLIC_SENTRY_DSN: z.string().optional(),
|
NEXT_PUBLIC_SENTRY_DSN: z.string().optional(),
|
||||||
|
NEXT_PUBLIC_SHOW_DATA: z.string().optional(),
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -46,6 +47,7 @@ export const env = createEnv({
|
|||||||
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
||||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||||
NEXT_PUBLIC_HOST: process.env.NEXT_PUBLIC_HOST,
|
NEXT_PUBLIC_HOST: process.env.NEXT_PUBLIC_HOST,
|
||||||
|
NEXT_PUBLIC_SHOW_DATA: process.env.NEXT_PUBLIC_SHOW_DATA,
|
||||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||||
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||||
|
|||||||
99
src/pages/data/[id].tsx
Normal file
99
src/pages/data/[id].tsx
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Breadcrumb,
|
||||||
|
BreadcrumbItem,
|
||||||
|
Center,
|
||||||
|
Flex,
|
||||||
|
Icon,
|
||||||
|
Input,
|
||||||
|
VStack,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import Link from "next/link";
|
||||||
|
|
||||||
|
import { useRouter } from "next/router";
|
||||||
|
import { useState, useEffect } from "react";
|
||||||
|
import { RiDatabase2Line } from "react-icons/ri";
|
||||||
|
import AppShell from "~/components/nav/AppShell";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable";
|
||||||
|
import { DatasetHeaderButtons } from "~/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons";
|
||||||
|
|
||||||
|
export default function Dataset() {
|
||||||
|
const router = useRouter();
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const dataset = useDataset();
|
||||||
|
const datasetId = router.query.id as string;
|
||||||
|
|
||||||
|
const [name, setName] = useState(dataset.data?.name || "");
|
||||||
|
useEffect(() => {
|
||||||
|
setName(dataset.data?.name || "");
|
||||||
|
}, [dataset.data?.name]);
|
||||||
|
|
||||||
|
const updateMutation = api.datasets.update.useMutation();
|
||||||
|
const [onSaveName] = useHandledAsyncCallback(async () => {
|
||||||
|
if (name && name !== dataset.data?.name && dataset.data?.id) {
|
||||||
|
await updateMutation.mutateAsync({
|
||||||
|
id: dataset.data.id,
|
||||||
|
updates: { name: name },
|
||||||
|
});
|
||||||
|
await Promise.all([utils.datasets.list.invalidate(), utils.datasets.get.invalidate()]);
|
||||||
|
}
|
||||||
|
}, [updateMutation, dataset.data?.id, dataset.data?.name, name]);
|
||||||
|
|
||||||
|
if (!dataset.isLoading && !dataset.data) {
|
||||||
|
return (
|
||||||
|
<AppShell title="Dataset not found">
|
||||||
|
<Center h="100%">
|
||||||
|
<div>Dataset not found 😕</div>
|
||||||
|
</Center>
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AppShell title={dataset.data?.name}>
|
||||||
|
<VStack h="full">
|
||||||
|
<Flex
|
||||||
|
pl={4}
|
||||||
|
pr={8}
|
||||||
|
py={2}
|
||||||
|
w="full"
|
||||||
|
direction={{ base: "column", sm: "row" }}
|
||||||
|
alignItems={{ base: "flex-start", sm: "center" }}
|
||||||
|
>
|
||||||
|
<Breadcrumb flex={1} mt={1}>
|
||||||
|
<BreadcrumbItem>
|
||||||
|
<Link href="/data">
|
||||||
|
<Flex alignItems="center" _hover={{ textDecoration: "underline" }}>
|
||||||
|
<Icon as={RiDatabase2Line} boxSize={4} mr={2} /> Datasets
|
||||||
|
</Flex>
|
||||||
|
</Link>
|
||||||
|
</BreadcrumbItem>
|
||||||
|
<BreadcrumbItem isCurrentPage>
|
||||||
|
<Input
|
||||||
|
size="sm"
|
||||||
|
value={name}
|
||||||
|
onChange={(e) => setName(e.target.value)}
|
||||||
|
onBlur={onSaveName}
|
||||||
|
borderWidth={1}
|
||||||
|
borderColor="transparent"
|
||||||
|
fontSize={16}
|
||||||
|
px={0}
|
||||||
|
minW={{ base: 100, lg: 300 }}
|
||||||
|
flex={1}
|
||||||
|
_hover={{ borderColor: "gray.300" }}
|
||||||
|
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||||
|
/>
|
||||||
|
</BreadcrumbItem>
|
||||||
|
</Breadcrumb>
|
||||||
|
<DatasetHeaderButtons />
|
||||||
|
</Flex>
|
||||||
|
<Box w="full" overflowX="auto" flex={1} pl={4} pr={8} pt={8}>
|
||||||
|
{datasetId && <DatasetEntriesTable />}
|
||||||
|
</Box>
|
||||||
|
</VStack>
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
83
src/pages/data/index.tsx
Normal file
83
src/pages/data/index.tsx
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import {
|
||||||
|
SimpleGrid,
|
||||||
|
Icon,
|
||||||
|
VStack,
|
||||||
|
Breadcrumb,
|
||||||
|
BreadcrumbItem,
|
||||||
|
Flex,
|
||||||
|
Center,
|
||||||
|
Text,
|
||||||
|
Link,
|
||||||
|
HStack,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import AppShell from "~/components/nav/AppShell";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { signIn, useSession } from "next-auth/react";
|
||||||
|
import { RiDatabase2Line } from "react-icons/ri";
|
||||||
|
import {
|
||||||
|
DatasetCard,
|
||||||
|
DatasetCardSkeleton,
|
||||||
|
NewDatasetCard,
|
||||||
|
} from "~/components/datasets/DatasetCard";
|
||||||
|
|
||||||
|
export default function DatasetsPage() {
|
||||||
|
const datasets = api.datasets.list.useQuery();
|
||||||
|
|
||||||
|
const user = useSession().data;
|
||||||
|
const authLoading = useSession().status === "loading";
|
||||||
|
|
||||||
|
if (user === null || authLoading) {
|
||||||
|
return (
|
||||||
|
<AppShell title="Data">
|
||||||
|
<Center h="100%">
|
||||||
|
{!authLoading && (
|
||||||
|
<Text>
|
||||||
|
<Link
|
||||||
|
onClick={() => {
|
||||||
|
signIn("github").catch(console.error);
|
||||||
|
}}
|
||||||
|
textDecor="underline"
|
||||||
|
>
|
||||||
|
Sign in
|
||||||
|
</Link>{" "}
|
||||||
|
to view or create new datasets!
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Center>
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AppShell title="Data">
|
||||||
|
<VStack alignItems={"flex-start"} px={4} py={2}>
|
||||||
|
<HStack minH={8} align="center" pt={2}>
|
||||||
|
<Breadcrumb flex={1}>
|
||||||
|
<BreadcrumbItem>
|
||||||
|
<Flex alignItems="center">
|
||||||
|
<Icon as={RiDatabase2Line} boxSize={4} mr={2} /> Datasets
|
||||||
|
</Flex>
|
||||||
|
</BreadcrumbItem>
|
||||||
|
</Breadcrumb>
|
||||||
|
</HStack>
|
||||||
|
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
|
||||||
|
<NewDatasetCard />
|
||||||
|
{datasets.data && !datasets.isLoading ? (
|
||||||
|
datasets?.data?.map((dataset) => (
|
||||||
|
<DatasetCard
|
||||||
|
key={dataset.id}
|
||||||
|
dataset={{ ...dataset, numEntries: dataset._count.datasetEntries }}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<DatasetCardSkeleton />
|
||||||
|
<DatasetCardSkeleton />
|
||||||
|
<DatasetCardSkeleton />
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</SimpleGrid>
|
||||||
|
</VStack>
|
||||||
|
</AppShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ import { api } from "~/utils/api";
|
|||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
import { useSyncVariantEditor } from "~/state/sync";
|
import { useSyncVariantEditor } from "~/state/sync";
|
||||||
import { HeaderButtons } from "~/components/experiments/HeaderButtons/HeaderButtons";
|
import { ExperimentHeaderButtons } from "~/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons";
|
||||||
import Head from "next/head";
|
import Head from "next/head";
|
||||||
|
|
||||||
// TODO: import less to fix deployment with server side props
|
// TODO: import less to fix deployment with server side props
|
||||||
@@ -142,7 +142,7 @@ export default function Experiment() {
|
|||||||
)}
|
)}
|
||||||
</BreadcrumbItem>
|
</BreadcrumbItem>
|
||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
<HeaderButtons />
|
<ExperimentHeaderButtons />
|
||||||
</Flex>
|
</Flex>
|
||||||
<ExperimentSettingsDrawer />
|
<ExperimentSettingsDrawer />
|
||||||
<Box w="100%" overflowX="auto" flex={1}>
|
<Box w="100%" overflowX="auto" flex={1}>
|
||||||
|
|||||||
97
src/server/api/autogenerate/autogenerateDatasetInputs.ts
Normal file
97
src/server/api/autogenerate/autogenerateDatasetInputs.ts
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
import { openai } from "../../utils/openai";
|
||||||
|
import { isAxiosError } from "./utils";
|
||||||
|
import { type APIResponse } from "openai/core";
|
||||||
|
import { sleep } from "~/server/utils/sleep";
|
||||||
|
|
||||||
|
const MAX_AUTO_RETRIES = 50;
|
||||||
|
const MIN_DELAY = 500; // milliseconds
|
||||||
|
const MAX_DELAY = 15000; // milliseconds
|
||||||
|
|
||||||
|
function calculateDelay(numPreviousTries: number): number {
|
||||||
|
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||||
|
const jitter = Math.random() * baseDelay;
|
||||||
|
return baseDelay + jitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
const getCompletionWithBackoff = async (
|
||||||
|
getCompletion: () => Promise<APIResponse<ChatCompletion>>,
|
||||||
|
) => {
|
||||||
|
let completion;
|
||||||
|
let tries = 0;
|
||||||
|
while (tries < MAX_AUTO_RETRIES) {
|
||||||
|
try {
|
||||||
|
completion = await getCompletion();
|
||||||
|
break;
|
||||||
|
} catch (e) {
|
||||||
|
if (isAxiosError(e)) {
|
||||||
|
console.error(e?.response?.data?.error?.message);
|
||||||
|
} else {
|
||||||
|
await sleep(calculateDelay(tries));
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tries++;
|
||||||
|
}
|
||||||
|
return completion;
|
||||||
|
};
|
||||||
|
|
||||||
|
const MAX_BATCH_SIZE = 5;
|
||||||
|
|
||||||
|
export const autogenerateDatasetInputs = async (
|
||||||
|
numToGenerate: number,
|
||||||
|
customInstructions: string,
|
||||||
|
): Promise<string[]> => {
|
||||||
|
const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) =>
|
||||||
|
i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1
|
||||||
|
? numToGenerate % MAX_BATCH_SIZE
|
||||||
|
: MAX_BATCH_SIZE,
|
||||||
|
);
|
||||||
|
|
||||||
|
const getCompletion = (batchSize: number) =>
|
||||||
|
openai.chat.completions.create({
|
||||||
|
model: "gpt-4",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: `The user needs ${batchSize} rows of data that match the following instructions:\n---\n" + ${customInstructions}`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functions: [
|
||||||
|
{
|
||||||
|
name: "add_list_of_data",
|
||||||
|
description: "Add a list of data to the database",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
rows: {
|
||||||
|
type: "array",
|
||||||
|
description: "The rows of data that match the instructions",
|
||||||
|
items: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
|
||||||
|
function_call: { name: "add_list_of_data" },
|
||||||
|
temperature: 0.5,
|
||||||
|
});
|
||||||
|
|
||||||
|
const completionCallbacks = batchSizes.map((batchSize) =>
|
||||||
|
getCompletionWithBackoff(() => getCompletion(batchSize)),
|
||||||
|
);
|
||||||
|
|
||||||
|
const completions = await Promise.all(completionCallbacks);
|
||||||
|
|
||||||
|
const rows = completions.flatMap((completion) => {
|
||||||
|
const parsed = JSON.parse(
|
||||||
|
completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}",
|
||||||
|
) as { rows: string[] };
|
||||||
|
return parsed.rows;
|
||||||
|
});
|
||||||
|
|
||||||
|
return rows;
|
||||||
|
};
|
||||||
@@ -1,26 +1,9 @@
|
|||||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../../db";
|
||||||
import { openai } from "../utils/openai";
|
import { openai } from "../../utils/openai";
|
||||||
import { pick } from "lodash-es";
|
import { pick } from "lodash-es";
|
||||||
|
import { isAxiosError } from "./utils";
|
||||||
|
|
||||||
type AxiosError = {
|
|
||||||
response?: {
|
|
||||||
data?: {
|
|
||||||
error?: {
|
|
||||||
message?: string;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
function isAxiosError(error: unknown): error is AxiosError {
|
|
||||||
if (typeof error === "object" && error !== null) {
|
|
||||||
// Initial check
|
|
||||||
const err = error as AxiosError;
|
|
||||||
return err.response?.data?.error?.message !== undefined; // Check structure
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
export const autogenerateScenarioValues = async (
|
export const autogenerateScenarioValues = async (
|
||||||
experimentId: string,
|
experimentId: string,
|
||||||
): Promise<Record<string, string>> => {
|
): Promise<Record<string, string>> => {
|
||||||
18
src/server/api/autogenerate/utils.ts
Normal file
18
src/server/api/autogenerate/utils.ts
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
type AxiosError = {
|
||||||
|
response?: {
|
||||||
|
data?: {
|
||||||
|
error?: {
|
||||||
|
message?: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export function isAxiosError(error: unknown): error is AxiosError {
|
||||||
|
if (typeof error === "object" && error !== null) {
|
||||||
|
// Initial check
|
||||||
|
const err = error as AxiosError;
|
||||||
|
return err.response?.data?.error?.message !== undefined; // Check structure
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
@@ -6,6 +6,8 @@ import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.route
|
|||||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||||
import { worldChampsRouter } from "./routers/worldChamps.router";
|
import { worldChampsRouter } from "./routers/worldChamps.router";
|
||||||
|
import { datasetsRouter } from "./routers/datasets.router";
|
||||||
|
import { datasetEntries } from "./routers/datasetEntries.router";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is the primary router for your server.
|
* This is the primary router for your server.
|
||||||
@@ -20,6 +22,8 @@ export const appRouter = createTRPCRouter({
|
|||||||
templateVars: templateVarsRouter,
|
templateVars: templateVarsRouter,
|
||||||
evaluations: evaluationsRouter,
|
evaluations: evaluationsRouter,
|
||||||
worldChamps: worldChampsRouter,
|
worldChamps: worldChampsRouter,
|
||||||
|
datasets: datasetsRouter,
|
||||||
|
datasetEntries: datasetEntries,
|
||||||
});
|
});
|
||||||
|
|
||||||
// export type definition of API
|
// export type definition of API
|
||||||
|
|||||||
143
src/server/api/routers/datasetEntries.router.ts
Normal file
143
src/server/api/routers/datasetEntries.router.ts
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import { requireCanModifyDataset, requireCanViewDataset } from "~/utils/accessControl";
|
||||||
|
import { autogenerateDatasetInputs } from "../autogenerate/autogenerateDatasetInputs";
|
||||||
|
|
||||||
|
const PAGE_SIZE = 10;
|
||||||
|
|
||||||
|
export const datasetEntries = createTRPCRouter({
|
||||||
|
list: protectedProcedure
|
||||||
|
.input(z.object({ datasetId: z.string(), page: z.number() }))
|
||||||
|
.query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewDataset(input.datasetId, ctx);
|
||||||
|
|
||||||
|
const { datasetId, page } = input;
|
||||||
|
|
||||||
|
const entries = await prisma.datasetEntry.findMany({
|
||||||
|
where: {
|
||||||
|
datasetId,
|
||||||
|
},
|
||||||
|
orderBy: { createdAt: "asc" },
|
||||||
|
skip: (page - 1) * PAGE_SIZE,
|
||||||
|
take: PAGE_SIZE,
|
||||||
|
});
|
||||||
|
|
||||||
|
const count = await prisma.datasetEntry.count({
|
||||||
|
where: {
|
||||||
|
datasetId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
entries,
|
||||||
|
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||||
|
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||||
|
count,
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
createOne: protectedProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
datasetId: z.string(),
|
||||||
|
input: z.string(),
|
||||||
|
output: z.string().optional(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyDataset(input.datasetId, ctx);
|
||||||
|
|
||||||
|
return await prisma.datasetEntry.create({
|
||||||
|
data: {
|
||||||
|
datasetId: input.datasetId,
|
||||||
|
input: input.input,
|
||||||
|
output: input.output,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
|
||||||
|
autogenerateInputs: protectedProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
datasetId: z.string(),
|
||||||
|
numToGenerate: z.number(),
|
||||||
|
instructions: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyDataset(input.datasetId, ctx);
|
||||||
|
|
||||||
|
const dataset = await prisma.dataset.findUnique({
|
||||||
|
where: {
|
||||||
|
id: input.datasetId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!dataset) {
|
||||||
|
throw new Error(`Dataset with id ${input.datasetId} does not exist`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const entryInputs = await autogenerateDatasetInputs(input.numToGenerate, input.instructions);
|
||||||
|
|
||||||
|
const createdEntries = await prisma.datasetEntry.createMany({
|
||||||
|
data: entryInputs.map((entryInput) => ({
|
||||||
|
datasetId: input.datasetId,
|
||||||
|
input: entryInput,
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
|
||||||
|
return createdEntries;
|
||||||
|
}),
|
||||||
|
|
||||||
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const datasetId = (
|
||||||
|
await prisma.datasetEntry.findUniqueOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
})
|
||||||
|
).datasetId;
|
||||||
|
|
||||||
|
await requireCanModifyDataset(datasetId, ctx);
|
||||||
|
|
||||||
|
return await prisma.datasetEntry.delete({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
|
||||||
|
update: protectedProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
id: z.string(),
|
||||||
|
updates: z.object({
|
||||||
|
input: z.string(),
|
||||||
|
output: z.string().optional(),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
const existing = await prisma.datasetEntry.findUnique({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!existing) {
|
||||||
|
throw new Error(`dataEntry with id ${input.id} does not exist`);
|
||||||
|
}
|
||||||
|
|
||||||
|
await requireCanModifyDataset(existing.datasetId, ctx);
|
||||||
|
|
||||||
|
return await prisma.datasetEntry.update({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
input: input.updates.input,
|
||||||
|
output: input.updates.output,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
});
|
||||||
91
src/server/api/routers/datasets.router.ts
Normal file
91
src/server/api/routers/datasets.router.ts
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import {
|
||||||
|
requireCanModifyDataset,
|
||||||
|
requireCanViewDataset,
|
||||||
|
requireNothing,
|
||||||
|
} from "~/utils/accessControl";
|
||||||
|
import userOrg from "~/server/utils/userOrg";
|
||||||
|
|
||||||
|
export const datasetsRouter = createTRPCRouter({
|
||||||
|
list: protectedProcedure.query(async ({ ctx }) => {
|
||||||
|
// Anyone can list experiments
|
||||||
|
requireNothing(ctx);
|
||||||
|
|
||||||
|
const datasets = await prisma.dataset.findMany({
|
||||||
|
where: {
|
||||||
|
organization: {
|
||||||
|
organizationUsers: {
|
||||||
|
some: { userId: ctx.session.user.id },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
orderBy: {
|
||||||
|
createdAt: "desc",
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
_count: {
|
||||||
|
select: { datasetEntries: true },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return datasets;
|
||||||
|
}),
|
||||||
|
|
||||||
|
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||||
|
await requireCanViewDataset(input.id, ctx);
|
||||||
|
return await prisma.dataset.findFirstOrThrow({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
|
||||||
|
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||||
|
// Anyone can create an experiment
|
||||||
|
requireNothing(ctx);
|
||||||
|
|
||||||
|
const numDatasets = await prisma.dataset.count({
|
||||||
|
where: {
|
||||||
|
organization: {
|
||||||
|
organizationUsers: {
|
||||||
|
some: { userId: ctx.session.user.id },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return await prisma.dataset.create({
|
||||||
|
data: {
|
||||||
|
name: `Dataset ${numDatasets + 1}`,
|
||||||
|
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
|
||||||
|
update: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyDataset(input.id, ctx);
|
||||||
|
return await prisma.dataset.update({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
name: input.updates.name,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
|
||||||
|
delete: protectedProcedure
|
||||||
|
.input(z.object({ id: z.string() }))
|
||||||
|
.mutation(async ({ input, ctx }) => {
|
||||||
|
await requireCanModifyDataset(input.id, ctx);
|
||||||
|
|
||||||
|
await prisma.dataset.delete({
|
||||||
|
where: {
|
||||||
|
id: input.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
});
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { autogenerateScenarioValues } from "../autogen";
|
import { autogenerateScenarioValues } from "../autogenerate/autogenerateScenarioValues";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { runAllEvals } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|||||||
@@ -16,6 +16,33 @@ export const requireNothing = (ctx: TRPCContext) => {
|
|||||||
ctx.markAccessControlRun();
|
ctx.markAccessControlRun();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext) => {
|
||||||
|
const dataset = await prisma.dataset.findFirst({
|
||||||
|
where: {
|
||||||
|
id: datasetId,
|
||||||
|
organization: {
|
||||||
|
organizationUsers: {
|
||||||
|
some: {
|
||||||
|
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
|
||||||
|
userId: ctx.session?.user.id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!dataset) {
|
||||||
|
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.markAccessControlRun();
|
||||||
|
};
|
||||||
|
|
||||||
|
export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContext) => {
|
||||||
|
// Right now all users who can view a dataset can also modify it
|
||||||
|
await requireCanViewDataset(datasetId, ctx);
|
||||||
|
};
|
||||||
|
|
||||||
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
|
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
|
||||||
await prisma.experiment.findFirst({
|
await prisma.experiment.findFirst({
|
||||||
where: { id: experimentId },
|
where: { id: experimentId },
|
||||||
|
|||||||
@@ -17,6 +17,26 @@ export const useExperimentAccess = () => {
|
|||||||
return useExperiment().data?.access ?? { canView: false, canModify: false };
|
return useExperiment().data?.access ?? { canView: false, canModify: false };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const useDataset = () => {
|
||||||
|
const router = useRouter();
|
||||||
|
const dataset = api.datasets.get.useQuery(
|
||||||
|
{ id: router.query.id as string },
|
||||||
|
{ enabled: !!router.query.id },
|
||||||
|
);
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useDatasetEntries = () => {
|
||||||
|
const dataset = useDataset();
|
||||||
|
const [page] = usePage();
|
||||||
|
|
||||||
|
return api.datasetEntries.list.useQuery(
|
||||||
|
{ datasetId: dataset.data?.id ?? "", page },
|
||||||
|
{ enabled: dataset.data?.id != null },
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>;
|
type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>;
|
||||||
|
|
||||||
export function useHandledAsyncCallback<T extends unknown[], U>(
|
export function useHandledAsyncCallback<T extends unknown[], U>(
|
||||||
|
|||||||
Reference in New Issue
Block a user