Add dataset entry deletion and export
This commit is contained in:
101
app/src/components/datasets/DeleteButton.tsx
Normal file
101
app/src/components/datasets/DeleteButton.tsx
Normal file
@@ -0,0 +1,101 @@
|
||||
import {
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalCloseButton,
|
||||
ModalBody,
|
||||
ModalFooter,
|
||||
HStack,
|
||||
VStack,
|
||||
Icon,
|
||||
Text,
|
||||
Button,
|
||||
useDisclosure,
|
||||
type UseDisclosureReturn,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsTrash } from "react-icons/bs";
|
||||
|
||||
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import ActionButton from "../ActionButton";
|
||||
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
|
||||
import pluralize from "pluralize";
|
||||
|
||||
const DeleteButton = () => {
|
||||
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
|
||||
|
||||
const disclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
<>
|
||||
<ActionButton
|
||||
onClick={disclosure.onOpen}
|
||||
label="Delete"
|
||||
icon={BsTrash}
|
||||
isDisabled={selectedIds.size === 0}
|
||||
requireBeta
|
||||
/>
|
||||
<DeleteDatasetEntriesModal disclosure={disclosure} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default DeleteButton;
|
||||
|
||||
const DeleteDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
||||
const dataset = useDataset().data;
|
||||
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
|
||||
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
|
||||
|
||||
const deleteRowsMutation = api.datasetEntries.delete.useMutation();
|
||||
|
||||
const utils = api.useContext();
|
||||
|
||||
const [deleteRows, deletionInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!dataset?.id || !selectedIds.size) return;
|
||||
const response = await deleteRowsMutation.mutateAsync({
|
||||
ids: Array.from(selectedIds),
|
||||
});
|
||||
|
||||
if (maybeReportError(response)) return;
|
||||
|
||||
await utils.datasetEntries.list.invalidate();
|
||||
disclosure.onClose();
|
||||
clearSelectedIds();
|
||||
}, [deleteRowsMutation, dataset, selectedIds, utils]);
|
||||
|
||||
return (
|
||||
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={BsTrash} />
|
||||
<Text>Delete Logs</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
|
||||
<Text>
|
||||
Are you sure you want to delete the <b>{selectedIds.size}</b>{" "}
|
||||
{pluralize("row", selectedIds.size)} rows you've selected?
|
||||
</Text>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<HStack>
|
||||
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={deleteRows} isLoading={deletionInProgress} minW={24}>
|
||||
Delete
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
182
app/src/components/datasets/ExportButton.tsx
Normal file
182
app/src/components/datasets/ExportButton.tsx
Normal file
@@ -0,0 +1,182 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import {
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalCloseButton,
|
||||
ModalBody,
|
||||
ModalFooter,
|
||||
HStack,
|
||||
VStack,
|
||||
Icon,
|
||||
Text,
|
||||
Button,
|
||||
Checkbox,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
NumberInputStepper,
|
||||
NumberIncrementStepper,
|
||||
NumberDecrementStepper,
|
||||
Collapse,
|
||||
Flex,
|
||||
useDisclosure,
|
||||
type UseDisclosureReturn,
|
||||
} from "@chakra-ui/react";
|
||||
import { AiOutlineDownload } from "react-icons/ai";
|
||||
|
||||
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import ActionButton from "../ActionButton";
|
||||
import { FiChevronUp, FiChevronDown } from "react-icons/fi";
|
||||
import InfoCircle from "../InfoCircle";
|
||||
|
||||
const ExportButton = () => {
|
||||
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
|
||||
|
||||
const disclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
<>
|
||||
<ActionButton
|
||||
onClick={disclosure.onOpen}
|
||||
label="Download"
|
||||
icon={AiOutlineDownload}
|
||||
isDisabled={selectedIds.size === 0}
|
||||
requireBeta
|
||||
/>
|
||||
<ExportDatasetEntriesModal disclosure={disclosure} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ExportButton;
|
||||
|
||||
const ExportDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
||||
const dataset = useDataset().data;
|
||||
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
|
||||
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
|
||||
|
||||
const [testingSplit, setTestingSplit] = useState(10);
|
||||
const [removeDuplicates, setRemoveDuplicates] = useState(false);
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (disclosure.isOpen) {
|
||||
setTestingSplit(10);
|
||||
setRemoveDuplicates(false);
|
||||
}
|
||||
}, [disclosure.isOpen]);
|
||||
|
||||
const exportDataMutation = api.datasetEntries.export.useMutation();
|
||||
|
||||
const [exportData, exportInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!dataset?.id || !selectedIds.size || !testingSplit) return;
|
||||
const response = await exportDataMutation.mutateAsync({
|
||||
datasetId: dataset.id,
|
||||
datasetEntryIds: Array.from(selectedIds),
|
||||
testingSplit,
|
||||
removeDuplicates,
|
||||
});
|
||||
|
||||
const dataUrl = `data:application/pdf;base64,${response}`;
|
||||
const blob = await fetch(dataUrl).then((res) => res.blob());
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
|
||||
a.href = url;
|
||||
a.download = `data.zip`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
|
||||
disclosure.onClose();
|
||||
clearSelectedIds();
|
||||
}, [exportDataMutation, dataset, selectedIds, testingSplit, removeDuplicates]);
|
||||
|
||||
return (
|
||||
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={AiOutlineDownload} />
|
||||
<Text>Export Logs</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
|
||||
<Text>
|
||||
We'll export the <b>{selectedIds.size}</b> rows you have selected in the OpenAI
|
||||
training format.
|
||||
</Text>
|
||||
<VStack alignItems="flex-start" spacing={4}>
|
||||
<Flex
|
||||
flexDir={{ base: "column", md: "row" }}
|
||||
alignItems={{ base: "flex-start", md: "center" }}
|
||||
>
|
||||
<HStack w={48} alignItems="center" spacing={1}>
|
||||
<Text fontWeight="bold">Testing Split:</Text>
|
||||
<InfoCircle tooltipText="The percent of your logs that will be reserved for testing and saved in another file. Logs are split randomly." />
|
||||
</HStack>
|
||||
<HStack>
|
||||
<NumberInput
|
||||
defaultValue={10}
|
||||
onChange={(_, num) => setTestingSplit(num)}
|
||||
min={1}
|
||||
max={100}
|
||||
w={48}
|
||||
>
|
||||
<NumberInputField />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
</HStack>
|
||||
</Flex>
|
||||
</VStack>
|
||||
<VStack alignItems="flex-start" spacing={0}>
|
||||
<Button
|
||||
variant="unstyled"
|
||||
color="blue.600"
|
||||
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
||||
>
|
||||
<HStack>
|
||||
<Text>Advanced Options</Text>
|
||||
<Icon as={showAdvancedOptions ? FiChevronUp : FiChevronDown} />
|
||||
</HStack>
|
||||
</Button>
|
||||
<Collapse in={showAdvancedOptions} unmountOnExit={true}>
|
||||
<VStack align="stretch" pt={4}>
|
||||
<HStack>
|
||||
<Checkbox
|
||||
colorScheme="blue"
|
||||
isChecked={removeDuplicates}
|
||||
onChange={(e) => setRemoveDuplicates(e.target.checked)}
|
||||
>
|
||||
<Text>Remove duplicates</Text>
|
||||
</Checkbox>
|
||||
<InfoCircle tooltipText="To avoid overfitting and speed up training, automatically deduplicate logs with matching input and output." />
|
||||
</HStack>
|
||||
</VStack>
|
||||
</Collapse>
|
||||
</VStack>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<HStack>
|
||||
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="blue" onClick={exportData} isLoading={exportInProgress} minW={24}>
|
||||
Download
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
@@ -18,17 +18,13 @@ import {
|
||||
} from "@chakra-ui/react";
|
||||
import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
|
||||
|
||||
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import ActionButton from "../ActionButton";
|
||||
import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
|
||||
import pluralize from "pluralize";
|
||||
|
||||
const ImportDataButton = () => {
|
||||
const datasetEntries = useDatasetEntries().data;
|
||||
|
||||
const numEntries = datasetEntries?.matchingEntryIds.length || 0;
|
||||
|
||||
const disclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
@@ -38,7 +34,6 @@ const ImportDataButton = () => {
|
||||
label="Import Data"
|
||||
icon={AiOutlineCloudUpload}
|
||||
iconBoxSize={4}
|
||||
isDisabled={numEntries === 0}
|
||||
requireBeta
|
||||
/>
|
||||
<ImportDataModal disclosure={disclosure} />
|
||||
@@ -110,6 +105,7 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
|
||||
const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!dataset || !trainingRows) return;
|
||||
|
||||
await sendJSONLMutation.mutateAsync({
|
||||
datasetId: dataset.id,
|
||||
jsonl: JSON.stringify(trainingRows),
|
||||
|
||||
@@ -15,8 +15,14 @@ export const validateTrainingRows = (rows: unknown): string | null => {
|
||||
if (!Array.isArray(rows)) return "training data is not an array";
|
||||
for (let i = 0; i < rows.length; i++) {
|
||||
const row = rows[i] as TrainingRow;
|
||||
const error = validateTrainingRow(row);
|
||||
if (error) return `row ${i}: ${error}`;
|
||||
let errorMessage: string | null = null;
|
||||
try {
|
||||
errorMessage = validateTrainingRow(row);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} catch (error: any) {
|
||||
errorMessage = error.message;
|
||||
}
|
||||
if (errorMessage) return `row ${i + 1}: ${errorMessage}`;
|
||||
}
|
||||
|
||||
return null;
|
||||
@@ -39,7 +45,7 @@ const validateTrainingRow = (row: TrainingRow): string | null => {
|
||||
return "input contains item with function_call but no name";
|
||||
|
||||
// Validate output
|
||||
if (row.output !== undefined) {
|
||||
if (row.output) {
|
||||
if (typeof row.output !== "object") return "output is not an object";
|
||||
if (!row.output.content && !row.output.function_call)
|
||||
return "output contains no content or function_call";
|
||||
|
||||
@@ -26,6 +26,8 @@ import { useAppStore } from "~/state/store";
|
||||
import FineTuneButton from "~/components/datasets/FineTuneButton";
|
||||
import ExperimentButton from "~/components/datasets/ExperimentButton";
|
||||
import ImportDataButton from "~/components/datasets/ImportDataButton";
|
||||
import DownloadButton from "~/components/datasets/ExportButton";
|
||||
import DeleteButton from "~/components/datasets/DeleteButton";
|
||||
|
||||
export default function Dataset() {
|
||||
const utils = api.useContext();
|
||||
@@ -101,8 +103,10 @@ export default function Dataset() {
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
|
||||
<HStack w="full" justifyContent="flex-end">
|
||||
<FineTuneButton />
|
||||
<ExperimentButton />
|
||||
<ImportDataButton />
|
||||
<ExperimentButton />
|
||||
<DownloadButton />
|
||||
<DeleteButton />
|
||||
</HStack>
|
||||
<DatasetEntriesTable />
|
||||
<DatasetEntryPaginator />
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
} from "openai/resources/chat";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { shuffle } from "lodash-es";
|
||||
import archiver from "archiver";
|
||||
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
@@ -15,6 +16,9 @@ import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessCo
|
||||
import { error, success } from "~/utils/errorHandling/standardResponses";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { type TrainingRow, validateTrainingRows } from "~/components/datasets/validateTrainingRows";
|
||||
import hashObject from "~/server/utils/hashObject";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { WritableStreamBuffer } from "stream-buffers";
|
||||
|
||||
export const datasetEntriesRouter = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
@@ -121,6 +125,10 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
return error("No loggedCallIds or jsonl provided");
|
||||
}
|
||||
|
||||
const startingTime = Date.now();
|
||||
|
||||
console.log("1", 0);
|
||||
|
||||
let trainingRows: TrainingRow[];
|
||||
|
||||
if (input.loggedCallIds) {
|
||||
@@ -156,11 +164,6 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
| undefined;
|
||||
if (resp && resp.choices?.[0]) {
|
||||
output = resp.choices[0].message;
|
||||
} else {
|
||||
output = {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
};
|
||||
}
|
||||
return {
|
||||
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
|
||||
@@ -168,6 +171,7 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
};
|
||||
});
|
||||
} else {
|
||||
console.log("2", Date.now() - startingTime);
|
||||
trainingRows = JSON.parse(input.jsonl as string) as TrainingRow[];
|
||||
const validationError = validateTrainingRows(trainingRows);
|
||||
if (validationError) {
|
||||
@@ -175,6 +179,8 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
}
|
||||
}
|
||||
|
||||
console.log("3", Date.now() - startingTime);
|
||||
|
||||
const [existingTrainingCount, existingTestingCount] = await prisma.$transaction([
|
||||
prisma.datasetEntry.count({
|
||||
where: {
|
||||
@@ -190,6 +196,8 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
}),
|
||||
]);
|
||||
|
||||
console.log("4", Date.now() - startingTime);
|
||||
|
||||
const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
|
||||
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
|
||||
const numTestingToAdd = trainingRows.length - numTrainingToAdd;
|
||||
@@ -208,7 +216,10 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
datasetEntriesToCreate.push({
|
||||
datasetId: datasetId,
|
||||
input: row.input as unknown as Prisma.InputJsonValue,
|
||||
output: row.output as unknown as Prisma.InputJsonValue,
|
||||
output: (row.output as unknown as Prisma.InputJsonValue) ?? {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
},
|
||||
inputTokens: countOpenAIChatTokens(
|
||||
"gpt-4-0613",
|
||||
row.input as unknown as CreateChatCompletionRequestMessage[],
|
||||
@@ -218,6 +229,8 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
});
|
||||
}
|
||||
|
||||
console.log("5", Date.now() - startingTime);
|
||||
|
||||
// Ensure dataset and dataset entries are created atomically
|
||||
await prisma.$transaction([
|
||||
prisma.dataset.upsert({
|
||||
@@ -235,6 +248,8 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
}),
|
||||
]);
|
||||
|
||||
console.log("6", Date.now() - startingTime);
|
||||
|
||||
return success(datasetId);
|
||||
}),
|
||||
|
||||
@@ -327,4 +342,68 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
|
||||
return success("Dataset entries deleted");
|
||||
}),
|
||||
|
||||
export: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
datasetEntryIds: z.string().array(),
|
||||
testingSplit: z.number(),
|
||||
removeDuplicates: z.boolean(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { projectId } = await prisma.dataset.findUniqueOrThrow({
|
||||
where: { id: input.datasetId },
|
||||
});
|
||||
await requireCanViewProject(projectId, ctx);
|
||||
|
||||
const datasetEntries = await ctx.prisma.datasetEntry.findMany({
|
||||
where: {
|
||||
id: {
|
||||
in: input.datasetEntryIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
let rows: TrainingRow[] = datasetEntries.map((entry) => ({
|
||||
input: entry.input as unknown as CreateChatCompletionRequestMessage[],
|
||||
output: entry.output as unknown as CreateChatCompletionRequestMessage,
|
||||
}));
|
||||
|
||||
if (input.removeDuplicates) {
|
||||
const deduplicatedRows = [];
|
||||
const rowHashSet = new Set<string>();
|
||||
for (const row of rows) {
|
||||
const rowHash = hashObject(row as unknown as JsonValue);
|
||||
if (!rowHashSet.has(rowHash)) {
|
||||
rowHashSet.add(rowHash);
|
||||
deduplicatedRows.push(row);
|
||||
}
|
||||
}
|
||||
rows = deduplicatedRows;
|
||||
}
|
||||
|
||||
const splitIndex = Math.floor((rows.length * input.testingSplit) / 100);
|
||||
|
||||
const testingData = rows.slice(0, splitIndex);
|
||||
const trainingData = rows.slice(splitIndex);
|
||||
|
||||
// Convert arrays to JSONL format
|
||||
const trainingDataJSONL = trainingData.map((item) => JSON.stringify(item)).join("\n");
|
||||
const testingDataJSONL = testingData.map((item) => JSON.stringify(item)).join("\n");
|
||||
|
||||
const output = new WritableStreamBuffer();
|
||||
const archive = archiver("zip");
|
||||
|
||||
archive.pipe(output);
|
||||
archive.append(trainingDataJSONL, { name: "train.jsonl" });
|
||||
archive.append(testingDataJSONL, { name: "test.jsonl" });
|
||||
await archive.finalize();
|
||||
|
||||
// Convert buffer to base64
|
||||
const base64 = output.getContents().toString("base64");
|
||||
|
||||
return base64;
|
||||
}),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user