Add dataset entry deletion and export

This commit is contained in:
David Corbitt
2023-09-06 22:01:39 -07:00
parent 7c4ab151a4
commit 5aadf3c2ba
6 changed files with 384 additions and 16 deletions

View File

@@ -0,0 +1,101 @@
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { BsTrash } from "react-icons/bs";
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
import pluralize from "pluralize";
const DeleteButton = () => {
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Delete"
icon={BsTrash}
isDisabled={selectedIds.size === 0}
requireBeta
/>
<DeleteDatasetEntriesModal disclosure={disclosure} />
</>
);
};
export default DeleteButton;
const DeleteDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
const deleteRowsMutation = api.datasetEntries.delete.useMutation();
const utils = api.useContext();
const [deleteRows, deletionInProgress] = useHandledAsyncCallback(async () => {
if (!dataset?.id || !selectedIds.size) return;
const response = await deleteRowsMutation.mutateAsync({
ids: Array.from(selectedIds),
});
if (maybeReportError(response)) return;
await utils.datasetEntries.list.invalidate();
disclosure.onClose();
clearSelectedIds();
}, [deleteRowsMutation, dataset, selectedIds, utils]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={BsTrash} />
<Text>Delete Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
Are you sure you want to delete the <b>{selectedIds.size}</b>{" "}
{pluralize("row", selectedIds.size)} rows you've selected?
</Text>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button colorScheme="red" onClick={deleteRows} isLoading={deletionInProgress} minW={24}>
Delete
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,182 @@
import { useState, useEffect } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Checkbox,
NumberInput,
NumberInputField,
NumberInputStepper,
NumberIncrementStepper,
NumberDecrementStepper,
Collapse,
Flex,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { AiOutlineDownload } from "react-icons/ai";
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import { FiChevronUp, FiChevronDown } from "react-icons/fi";
import InfoCircle from "../InfoCircle";
const ExportButton = () => {
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Download"
icon={AiOutlineDownload}
isDisabled={selectedIds.size === 0}
requireBeta
/>
<ExportDatasetEntriesModal disclosure={disclosure} />
</>
);
};
export default ExportButton;
const ExportDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
const [testingSplit, setTestingSplit] = useState(10);
const [removeDuplicates, setRemoveDuplicates] = useState(false);
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
useEffect(() => {
if (disclosure.isOpen) {
setTestingSplit(10);
setRemoveDuplicates(false);
}
}, [disclosure.isOpen]);
const exportDataMutation = api.datasetEntries.export.useMutation();
const [exportData, exportInProgress] = useHandledAsyncCallback(async () => {
if (!dataset?.id || !selectedIds.size || !testingSplit) return;
const response = await exportDataMutation.mutateAsync({
datasetId: dataset.id,
datasetEntryIds: Array.from(selectedIds),
testingSplit,
removeDuplicates,
});
const dataUrl = `data:application/pdf;base64,${response}`;
const blob = await fetch(dataUrl).then((res) => res.blob());
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `data.zip`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
disclosure.onClose();
clearSelectedIds();
}, [exportDataMutation, dataset, selectedIds, testingSplit, removeDuplicates]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={AiOutlineDownload} />
<Text>Export Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
We'll export the <b>{selectedIds.size}</b> rows you have selected in the OpenAI
training format.
</Text>
<VStack alignItems="flex-start" spacing={4}>
<Flex
flexDir={{ base: "column", md: "row" }}
alignItems={{ base: "flex-start", md: "center" }}
>
<HStack w={48} alignItems="center" spacing={1}>
<Text fontWeight="bold">Testing Split:</Text>
<InfoCircle tooltipText="The percent of your logs that will be reserved for testing and saved in another file. Logs are split randomly." />
</HStack>
<HStack>
<NumberInput
defaultValue={10}
onChange={(_, num) => setTestingSplit(num)}
min={1}
max={100}
w={48}
>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
</HStack>
</Flex>
</VStack>
<VStack alignItems="flex-start" spacing={0}>
<Button
variant="unstyled"
color="blue.600"
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
>
<HStack>
<Text>Advanced Options</Text>
<Icon as={showAdvancedOptions ? FiChevronUp : FiChevronDown} />
</HStack>
</Button>
<Collapse in={showAdvancedOptions} unmountOnExit={true}>
<VStack align="stretch" pt={4}>
<HStack>
<Checkbox
colorScheme="blue"
isChecked={removeDuplicates}
onChange={(e) => setRemoveDuplicates(e.target.checked)}
>
<Text>Remove duplicates</Text>
</Checkbox>
<InfoCircle tooltipText="To avoid overfitting and speed up training, automatically deduplicate logs with matching input and output." />
</HStack>
</VStack>
</Collapse>
</VStack>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button colorScheme="blue" onClick={exportData} isLoading={exportInProgress} minW={24}>
Download
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -18,17 +18,13 @@ import {
} from "@chakra-ui/react";
import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import ActionButton from "../ActionButton";
import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
import pluralize from "pluralize";
const ImportDataButton = () => {
const datasetEntries = useDatasetEntries().data;
const numEntries = datasetEntries?.matchingEntryIds.length || 0;
const disclosure = useDisclosure();
return (
@@ -38,7 +34,6 @@ const ImportDataButton = () => {
label="Import Data"
icon={AiOutlineCloudUpload}
iconBoxSize={4}
isDisabled={numEntries === 0}
requireBeta
/>
<ImportDataModal disclosure={disclosure} />
@@ -110,6 +105,7 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
if (!dataset || !trainingRows) return;
await sendJSONLMutation.mutateAsync({
datasetId: dataset.id,
jsonl: JSON.stringify(trainingRows),

View File

@@ -15,8 +15,14 @@ export const validateTrainingRows = (rows: unknown): string | null => {
if (!Array.isArray(rows)) return "training data is not an array";
for (let i = 0; i < rows.length; i++) {
const row = rows[i] as TrainingRow;
const error = validateTrainingRow(row);
if (error) return `row ${i}: ${error}`;
let errorMessage: string | null = null;
try {
errorMessage = validateTrainingRow(row);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (error: any) {
errorMessage = error.message;
}
if (errorMessage) return `row ${i + 1}: ${errorMessage}`;
}
return null;
@@ -39,7 +45,7 @@ const validateTrainingRow = (row: TrainingRow): string | null => {
return "input contains item with function_call but no name";
// Validate output
if (row.output !== undefined) {
if (row.output) {
if (typeof row.output !== "object") return "output is not an object";
if (!row.output.content && !row.output.function_call)
return "output contains no content or function_call";

View File

@@ -26,6 +26,8 @@ import { useAppStore } from "~/state/store";
import FineTuneButton from "~/components/datasets/FineTuneButton";
import ExperimentButton from "~/components/datasets/ExperimentButton";
import ImportDataButton from "~/components/datasets/ImportDataButton";
import DownloadButton from "~/components/datasets/ExportButton";
import DeleteButton from "~/components/datasets/DeleteButton";
export default function Dataset() {
const utils = api.useContext();
@@ -101,8 +103,10 @@ export default function Dataset() {
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
<HStack w="full" justifyContent="flex-end">
<FineTuneButton />
<ExperimentButton />
<ImportDataButton />
<ExperimentButton />
<DownloadButton />
<DeleteButton />
</HStack>
<DatasetEntriesTable />
<DatasetEntryPaginator />

View File

@@ -8,6 +8,7 @@ import {
} from "openai/resources/chat";
import { TRPCError } from "@trpc/server";
import { shuffle } from "lodash-es";
import archiver from "archiver";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
@@ -15,6 +16,9 @@ import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessCo
import { error, success } from "~/utils/errorHandling/standardResponses";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type TrainingRow, validateTrainingRows } from "~/components/datasets/validateTrainingRows";
import hashObject from "~/server/utils/hashObject";
import { type JsonValue } from "type-fest";
import { WritableStreamBuffer } from "stream-buffers";
export const datasetEntriesRouter = createTRPCRouter({
list: protectedProcedure
@@ -121,6 +125,10 @@ export const datasetEntriesRouter = createTRPCRouter({
return error("No loggedCallIds or jsonl provided");
}
const startingTime = Date.now();
console.log("1", 0);
let trainingRows: TrainingRow[];
if (input.loggedCallIds) {
@@ -156,11 +164,6 @@ export const datasetEntriesRouter = createTRPCRouter({
| undefined;
if (resp && resp.choices?.[0]) {
output = resp.choices[0].message;
} else {
output = {
role: "assistant",
content: "",
};
}
return {
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
@@ -168,6 +171,7 @@ export const datasetEntriesRouter = createTRPCRouter({
};
});
} else {
console.log("2", Date.now() - startingTime);
trainingRows = JSON.parse(input.jsonl as string) as TrainingRow[];
const validationError = validateTrainingRows(trainingRows);
if (validationError) {
@@ -175,6 +179,8 @@ export const datasetEntriesRouter = createTRPCRouter({
}
}
console.log("3", Date.now() - startingTime);
const [existingTrainingCount, existingTestingCount] = await prisma.$transaction([
prisma.datasetEntry.count({
where: {
@@ -190,6 +196,8 @@ export const datasetEntriesRouter = createTRPCRouter({
}),
]);
console.log("4", Date.now() - startingTime);
const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
const numTestingToAdd = trainingRows.length - numTrainingToAdd;
@@ -208,7 +216,10 @@ export const datasetEntriesRouter = createTRPCRouter({
datasetEntriesToCreate.push({
datasetId: datasetId,
input: row.input as unknown as Prisma.InputJsonValue,
output: row.output as unknown as Prisma.InputJsonValue,
output: (row.output as unknown as Prisma.InputJsonValue) ?? {
role: "assistant",
content: "",
},
inputTokens: countOpenAIChatTokens(
"gpt-4-0613",
row.input as unknown as CreateChatCompletionRequestMessage[],
@@ -218,6 +229,8 @@ export const datasetEntriesRouter = createTRPCRouter({
});
}
console.log("5", Date.now() - startingTime);
// Ensure dataset and dataset entries are created atomically
await prisma.$transaction([
prisma.dataset.upsert({
@@ -235,6 +248,8 @@ export const datasetEntriesRouter = createTRPCRouter({
}),
]);
console.log("6", Date.now() - startingTime);
return success(datasetId);
}),
@@ -327,4 +342,68 @@ export const datasetEntriesRouter = createTRPCRouter({
return success("Dataset entries deleted");
}),
export: protectedProcedure
.input(
z.object({
datasetId: z.string(),
datasetEntryIds: z.string().array(),
testingSplit: z.number(),
removeDuplicates: z.boolean(),
}),
)
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
await requireCanViewProject(projectId, ctx);
const datasetEntries = await ctx.prisma.datasetEntry.findMany({
where: {
id: {
in: input.datasetEntryIds,
},
},
});
let rows: TrainingRow[] = datasetEntries.map((entry) => ({
input: entry.input as unknown as CreateChatCompletionRequestMessage[],
output: entry.output as unknown as CreateChatCompletionRequestMessage,
}));
if (input.removeDuplicates) {
const deduplicatedRows = [];
const rowHashSet = new Set<string>();
for (const row of rows) {
const rowHash = hashObject(row as unknown as JsonValue);
if (!rowHashSet.has(rowHash)) {
rowHashSet.add(rowHash);
deduplicatedRows.push(row);
}
}
rows = deduplicatedRows;
}
const splitIndex = Math.floor((rows.length * input.testingSplit) / 100);
const testingData = rows.slice(0, splitIndex);
const trainingData = rows.slice(splitIndex);
// Convert arrays to JSONL format
const trainingDataJSONL = trainingData.map((item) => JSON.stringify(item)).join("\n");
const testingDataJSONL = testingData.map((item) => JSON.stringify(item)).join("\n");
const output = new WritableStreamBuffer();
const archive = archiver("zip");
archive.pipe(output);
archive.append(trainingDataJSONL, { name: "train.jsonl" });
archive.append(testingDataJSONL, { name: "test.jsonl" });
await archive.finalize();
// Convert buffer to base64
const base64 = output.getContents().toString("base64");
return base64;
}),
});