Add upload data modal

This commit is contained in:
David Corbitt
2023-09-06 09:24:41 -07:00
parent c5c8dbf65e
commit 7c4ab151a4
8 changed files with 387 additions and 42 deletions

View File

@@ -0,0 +1,5 @@
-- AlterTable
ALTER TABLE "DatasetEntry" ALTER COLUMN "loggedCallId" DROP NOT NULL,
ALTER COLUMN "inputTokens" DROP DEFAULT,
ALTER COLUMN "outputTokens" DROP DEFAULT,
ALTER COLUMN "type" DROP DEFAULT;

View File

@@ -199,8 +199,8 @@ enum DatasetEntryType {
model DatasetEntry { model DatasetEntry {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
loggedCallId String @db.Uuid loggedCallId String? @db.Uuid
loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade) loggedCall LoggedCall? @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
input Json @default("[]") input Json @default("[]")
output Json? output Json?

View File

@@ -7,12 +7,14 @@ import { BetaModal } from "./BetaModal";
const ActionButton = ({ const ActionButton = ({
icon, icon,
iconBoxSize = 3.5,
label, label,
requireBeta = false, requireBeta = false,
onClick, onClick,
...buttonProps ...buttonProps
}: { }: {
icon: IconType; icon: IconType;
iconBoxSize?: number;
label: string; label: string;
requireBeta?: boolean; requireBeta?: boolean;
onClick?: () => void; onClick?: () => void;
@@ -39,7 +41,9 @@ const ActionButton = ({
{...buttonProps} {...buttonProps}
> >
<HStack spacing={1}> <HStack spacing={1}>
{icon && <Icon as={icon} color={requireBeta ? "orange.400" : undefined} />} {icon && (
<Icon as={icon} boxSize={iconBoxSize} color={requireBeta ? "orange.400" : undefined} />
)}
<Text display={{ base: "none", md: "flex" }}>{label}</Text> <Text display={{ base: "none", md: "flex" }}>{label}</Text>
</HStack> </HStack>
</Button> </Button>

View File

@@ -22,7 +22,6 @@ import { useRouter } from "next/router";
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks"; import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton"; import ActionButton from "../ActionButton";
import InputDropdown from "../InputDropdown"; import InputDropdown from "../InputDropdown";
import { FiChevronDown } from "react-icons/fi"; import { FiChevronDown } from "react-icons/fi";
@@ -53,7 +52,6 @@ const FineTuneButton = () => {
export default FineTuneButton; export default FineTuneButton;
const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const dataset = useDataset().data; const dataset = useDataset().data;
const datasetEntries = useDatasetEntries().data; const datasetEntries = useDatasetEntries().data;
@@ -73,7 +71,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const createFineTuneMutation = api.fineTunes.create.useMutation(); const createFineTuneMutation = api.fineTunes.create.useMutation();
const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => { const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => {
if (!selectedProjectId || !modelSlug || !selectedBaseModel || !dataset) return; if (!modelSlug || !selectedBaseModel || !dataset) return;
await createFineTuneMutation.mutateAsync({ await createFineTuneMutation.mutateAsync({
slug: modelSlug, slug: modelSlug,
baseModel: selectedBaseModel, baseModel: selectedBaseModel,
@@ -83,7 +81,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
await utils.fineTunes.list.invalidate(); await utils.fineTunes.list.invalidate();
await router.push({ pathname: "/fine-tunes" }); await router.push({ pathname: "/fine-tunes" });
disclosure.onClose(); disclosure.onClose();
}, [createFineTuneMutation, selectedProjectId, modelSlug, selectedBaseModel]); }, [createFineTuneMutation, modelSlug, selectedBaseModel]);
return ( return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}> <Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>

View File

@@ -0,0 +1,249 @@
import { useState, useEffect, useRef } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Box,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import ActionButton from "../ActionButton";
import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
import pluralize from "pluralize";
const ImportDataButton = () => {
const datasetEntries = useDatasetEntries().data;
const numEntries = datasetEntries?.matchingEntryIds.length || 0;
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Import Data"
icon={AiOutlineCloudUpload}
iconBoxSize={4}
isDisabled={numEntries === 0}
requireBeta
/>
<ImportDataModal disclosure={disclosure} />
</>
);
};
export default ImportDataButton;
const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const [validationError, setValidationError] = useState<string | null>(null);
const [trainingRows, setTrainingRows] = useState<TrainingRow[] | null>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const handleFileDrop = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault();
const files = e.dataTransfer.files;
if (files.length > 0) {
processFile(files[0] as File);
}
};
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const files = e.target.files;
if (files && files.length > 0) {
processFile(files[0] as File);
}
};
const processFile = (file: File) => {
const reader = new FileReader();
reader.onload = (e: ProgressEvent<FileReader>) => {
const content = e.target?.result as string;
// Process the content, e.g., set to state
let parsedJSONL;
try {
parsedJSONL = parseJSONL(content) as TrainingRow[];
const validationError = validateTrainingRows(parsedJSONL);
if (validationError) {
setValidationError(validationError);
setTrainingRows(null);
return;
}
setTrainingRows(parsedJSONL);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
console.log("e is", e);
setValidationError("Unable to parse JSONL file: " + (e.message as string));
setTrainingRows(null);
return;
}
};
reader.readAsText(file);
};
useEffect(() => {
if (disclosure.isOpen) {
setTrainingRows(null);
setValidationError(null);
}
}, [disclosure.isOpen]);
const utils = api.useContext();
const sendJSONLMutation = api.datasetEntries.create.useMutation();
const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
if (!dataset || !trainingRows) return;
await sendJSONLMutation.mutateAsync({
datasetId: dataset.id,
jsonl: JSON.stringify(trainingRows),
});
await utils.datasetEntries.list.invalidate();
disclosure.onClose();
}, [dataset, trainingRows, sendJSONLMutation]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Text>Upload Training Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset" p={8}>
<Box w="full" aspectRatio={1.5}>
{!trainingRows && !validationError && (
<VStack
w="full"
h="full"
stroke="gray.300"
justifyContent="center"
borderRadius={8}
sx={{
"background-image": `url("data:image/svg+xml,%3csvg width='100%25' height='100%25' xmlns='http://www.w3.org/2000/svg'%3e%3crect x='2%25' y='2%25' width='96%25' height='96%25' fill='none' stroke='%23eee' stroke-width='4' stroke-dasharray='6%2c 14' stroke-dashoffset='0' stroke-linecap='square' rx='8' ry='8'/%3e%3c/svg%3e")`,
}}
onDragOver={(e) => e.preventDefault()}
onDrop={handleFileDrop}
>
<JsonFileIcon />
<Icon as={AiOutlineCloudUpload} boxSize={24} color="gray.300" />
<Text fontSize={32} color="gray.500" fontWeight="bold">
Drag & Drop
</Text>
<Text color="gray.500">
your .jsonl file here, or{" "}
<input
type="file"
ref={fileInputRef}
onChange={handleFileChange}
style={{ display: "none" }}
accept=".jsonl"
/>
<Text
as="span"
textDecor="underline"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={() => fileInputRef.current?.click()}
>
browse
</Text>
</Text>
</VStack>
)}
{validationError && (
<VStack w="full" h="full" justifyContent="center" spacing={8}>
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
<VStack w="full">
<Text fontSize={32} color="gray.500" fontWeight="bold">
Error
</Text>
<Text color="gray.500">{validationError}</Text>
</VStack>
<Text
as="span"
textDecor="underline"
color="gray.500"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={() => setValidationError(null)}
>
Try again
</Text>
</VStack>
)}
{trainingRows && !validationError && (
<VStack w="full" h="full" justifyContent="center" spacing={8}>
<JsonFileIcon />
<VStack w="full">
<Text fontSize={32} color="gray.500" fontWeight="bold">
Success
</Text>
<Text color="gray.500">
We'll upload <b>{trainingRows.length}</b>{" "}
{pluralize("row", trainingRows.length)} into <b>{dataset?.name}</b>.{" "}
</Text>
</VStack>
<Text
as="span"
textDecor="underline"
color="gray.500"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={() => setTrainingRows(null)}
>
Change file
</Text>
</VStack>
)}
</Box>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button
colorScheme="orange"
onClick={sendJSONL}
isLoading={sendingInProgress}
minW={24}
isDisabled={!trainingRows || !!validationError}
>
Upload
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};
const JsonFileIcon = () => (
<Box position="relative" display="flex" alignItems="center" justifyContent="center">
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
<Text position="absolute" color="orange.400" fontWeight="bold" fontSize={12} pt={4}>
JSONL
</Text>
</Box>
);

View File

@@ -0,0 +1,53 @@
import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
export type TrainingRow = {
input: CreateChatCompletionRequestMessage[];
output?: CreateChatCompletionRequestMessage;
};
export const parseJSONL = (jsonlString: string): unknown[] =>
jsonlString
.trim()
.split("\n")
.map((line) => JSON.parse(line) as unknown);
export const validateTrainingRows = (rows: unknown): string | null => {
if (!Array.isArray(rows)) return "training data is not an array";
for (let i = 0; i < rows.length; i++) {
const row = rows[i] as TrainingRow;
const error = validateTrainingRow(row);
if (error) return `row ${i}: ${error}`;
}
return null;
};
const validateTrainingRow = (row: TrainingRow): string | null => {
if (!row) return "empty row";
if (!row.input) return "missing input";
// Validate input
if (!Array.isArray(row.input)) return "input is not an array";
if ((row.input as unknown[]).some((x) => typeof x !== "object"))
return "input contains invalid item";
if (row.input.some((x) => !x)) return "input contains empty item";
if (row.input.some((x) => !x.content && !x.function_call))
return "input contains item with no content or function_call";
if (row.input.some((x) => x.function_call && !x.function_call.arguments))
return "input contains item with function_call but no arguments";
if (row.input.some((x) => x.function_call && !x.function_call.name))
return "input contains item with function_call but no name";
// Validate output
if (row.output !== undefined) {
if (typeof row.output !== "object") return "output is not an object";
if (!row.output.content && !row.output.function_call)
return "output contains no content or function_call";
if (row.output.function_call && !row.output.function_call.arguments)
return "output contains function_call but no arguments";
if (row.output.function_call && !row.output.function_call.name)
return "output contains function_call but no name";
}
return null;
};

View File

@@ -25,6 +25,7 @@ import DatasetEntryPaginator from "~/components/datasets/DatasetEntryPaginator";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
import FineTuneButton from "~/components/datasets/FineTuneButton"; import FineTuneButton from "~/components/datasets/FineTuneButton";
import ExperimentButton from "~/components/datasets/ExperimentButton"; import ExperimentButton from "~/components/datasets/ExperimentButton";
import ImportDataButton from "~/components/datasets/ImportDataButton";
export default function Dataset() { export default function Dataset() {
const utils = api.useContext(); const utils = api.useContext();
@@ -101,6 +102,7 @@ export default function Dataset() {
<HStack w="full" justifyContent="flex-end"> <HStack w="full" justifyContent="flex-end">
<FineTuneButton /> <FineTuneButton />
<ExperimentButton /> <ExperimentButton />
<ImportDataButton />
</HStack> </HStack>
<DatasetEntriesTable /> <DatasetEntriesTable />
<DatasetEntryPaginator /> <DatasetEntryPaginator />

View File

@@ -14,6 +14,7 @@ import { prisma } from "~/server/db";
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl"; import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
import { error, success } from "~/utils/errorHandling/standardResponses"; import { error, success } from "~/utils/errorHandling/standardResponses";
import { countOpenAIChatTokens } from "~/utils/countTokens"; import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type TrainingRow, validateTrainingRows } from "~/components/datasets/validateTrainingRows";
export const datasetEntriesRouter = createTRPCRouter({ export const datasetEntriesRouter = createTRPCRouter({
list: protectedProcedure list: protectedProcedure
@@ -94,7 +95,8 @@ export const datasetEntriesRouter = createTRPCRouter({
name: z.string(), name: z.string(),
}) })
.optional(), .optional(),
loggedCallIds: z.string().array(), loggedCallIds: z.string().array().optional(),
jsonl: z.string().optional(),
}), }),
) )
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
@@ -115,8 +117,14 @@ export const datasetEntriesRouter = createTRPCRouter({
return error("No datasetId or newDatasetParams provided"); return error("No datasetId or newDatasetParams provided");
} }
const [loggedCalls, existingTrainingCount, existingTestingCount] = await prisma.$transaction([ if (!input.loggedCallIds && !input.jsonl) {
prisma.loggedCall.findMany({ return error("No loggedCallIds or jsonl provided");
}
let trainingRows: TrainingRow[];
if (input.loggedCallIds) {
const loggedCalls = await prisma.loggedCall.findMany({
where: { where: {
id: { id: {
in: input.loggedCallIds, in: input.loggedCallIds,
@@ -135,7 +143,39 @@ export const datasetEntriesRouter = createTRPCRouter({
}, },
}, },
}, },
}), orderBy: { createdAt: "desc" },
});
trainingRows = loggedCalls.map((loggedCall) => {
const inputMessages = (
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
).messages;
let output: ChatCompletion.Choice.Message | undefined = undefined;
const resp = loggedCall.modelResponse?.respPayload as unknown as
| ChatCompletion
| undefined;
if (resp && resp.choices?.[0]) {
output = resp.choices[0].message;
} else {
output = {
role: "assistant",
content: "",
};
}
return {
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
output: output as unknown as CreateChatCompletionRequestMessage,
};
});
} else {
trainingRows = JSON.parse(input.jsonl as string) as TrainingRow[];
const validationError = validateTrainingRows(trainingRows);
if (validationError) {
return error(`Invalid JSONL: ${validationError}`);
}
}
const [existingTrainingCount, existingTestingCount] = await prisma.$transaction([
prisma.datasetEntry.count({ prisma.datasetEntry.count({
where: { where: {
datasetId, datasetId,
@@ -150,39 +190,32 @@ export const datasetEntriesRouter = createTRPCRouter({
}), }),
]); ]);
const shuffledLoggedCalls = shuffle(loggedCalls); const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
const totalEntries = existingTrainingCount + existingTestingCount + loggedCalls.length; const numTestingToAdd = trainingRows.length - numTrainingToAdd;
const numTrainingToAdd = Math.floor(trainingRatio * totalEntries) - existingTrainingCount; const typesToAssign = shuffle([
...Array(numTrainingToAdd).fill("TRAIN"),
...Array(numTestingToAdd).fill("TEST"),
]);
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = []; const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
for (const row of trainingRows) {
let i = 0; let outputTokens = 0;
for (const loggedCall of shuffledLoggedCalls) { if (row.output) {
const inputMessages = ( outputTokens = countOpenAIChatTokens("gpt-4-0613", [
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams row.output as unknown as ChatCompletion.Choice.Message,
).messages; ]);
let output: ChatCompletion.Choice.Message | undefined = undefined;
const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined;
if (resp && resp.choices?.[0]) {
output = resp.choices[0].message;
} else {
output = {
role: "assistant",
content: "",
};
} }
datasetEntriesToCreate.push({ datasetEntriesToCreate.push({
datasetId, datasetId: datasetId,
loggedCallId: loggedCall.id, input: row.input as unknown as Prisma.InputJsonValue,
input: inputMessages as unknown as Prisma.InputJsonValue, output: row.output as unknown as Prisma.InputJsonValue,
output: output as unknown as Prisma.InputJsonValue, inputTokens: countOpenAIChatTokens(
inputTokens: loggedCall.modelResponse?.inputTokens || 0, "gpt-4-0613",
outputTokens: loggedCall.modelResponse?.outputTokens || 0, row.input as unknown as CreateChatCompletionRequestMessage[],
type: i < numTrainingToAdd ? "TRAIN" : "TEST", ),
outputTokens,
type: typesToAssign.pop() as "TRAIN" | "TEST",
}); });
i++;
} }
// Ensure dataset and dataset entries are created atomically // Ensure dataset and dataset entries are created atomically
@@ -198,7 +231,7 @@ export const datasetEntriesRouter = createTRPCRouter({
}, },
}), }),
prisma.datasetEntry.createMany({ prisma.datasetEntry.createMany({
data: shuffle(datasetEntriesToCreate), data: datasetEntriesToCreate,
}), }),
]); ]);
@@ -242,7 +275,8 @@ export const datasetEntriesRouter = createTRPCRouter({
let parsedOutput = undefined; let parsedOutput = undefined;
let outputTokens = undefined; let outputTokens = undefined;
if (input.updates.output) { // The client might send "null" as a string, so we need to check for that
if (input.updates.output && input.updates.output !== "null") {
parsedOutput = JSON.parse(input.updates.output); parsedOutput = JSON.parse(input.updates.output);
outputTokens = countOpenAIChatTokens("gpt-4-0613", [ outputTokens = countOpenAIChatTokens("gpt-4-0613", [
parsedOutput as unknown as ChatCompletion.Choice.Message, parsedOutput as unknown as ChatCompletion.Choice.Message,