From 7c4ab151a4ff592ec93084f80a837716dfea80b5 Mon Sep 17 00:00:00 2001 From: David Corbitt Date: Wed, 6 Sep 2023 09:24:41 -0700 Subject: [PATCH] Add upload data modal --- .../migration.sql | 5 + app/prisma/schema.prisma | 4 +- app/src/components/ActionButton.tsx | 6 +- .../components/datasets/FineTuneButton.tsx | 6 +- .../components/datasets/ImportDataButton.tsx | 249 ++++++++++++++++++ .../datasets/validateTrainingRows.ts | 53 ++++ app/src/pages/datasets/[id].tsx | 2 + .../api/routers/datasetEntries.router.ts | 104 +++++--- 8 files changed, 387 insertions(+), 42 deletions(-) create mode 100644 app/prisma/migrations/20230906052922_make_dataset_entry_logged_call_optional/migration.sql create mode 100644 app/src/components/datasets/ImportDataButton.tsx create mode 100644 app/src/components/datasets/validateTrainingRows.ts diff --git a/app/prisma/migrations/20230906052922_make_dataset_entry_logged_call_optional/migration.sql b/app/prisma/migrations/20230906052922_make_dataset_entry_logged_call_optional/migration.sql new file mode 100644 index 0000000..bbccc3b --- /dev/null +++ b/app/prisma/migrations/20230906052922_make_dataset_entry_logged_call_optional/migration.sql @@ -0,0 +1,5 @@ +-- AlterTable +ALTER TABLE "DatasetEntry" ALTER COLUMN "loggedCallId" DROP NOT NULL, +ALTER COLUMN "inputTokens" DROP DEFAULT, +ALTER COLUMN "outputTokens" DROP DEFAULT, +ALTER COLUMN "type" DROP DEFAULT; diff --git a/app/prisma/schema.prisma b/app/prisma/schema.prisma index acf3046..feaf327 100644 --- a/app/prisma/schema.prisma +++ b/app/prisma/schema.prisma @@ -199,8 +199,8 @@ enum DatasetEntryType { model DatasetEntry { id String @id @default(uuid()) @db.Uuid - loggedCallId String @db.Uuid - loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade) + loggedCallId String? @db.Uuid + loggedCall LoggedCall? @relation(fields: [loggedCallId], references: [id], onDelete: Cascade) input Json @default("[]") output Json? diff --git a/app/src/components/ActionButton.tsx b/app/src/components/ActionButton.tsx index 6822e02..d966461 100644 --- a/app/src/components/ActionButton.tsx +++ b/app/src/components/ActionButton.tsx @@ -7,12 +7,14 @@ import { BetaModal } from "./BetaModal"; const ActionButton = ({ icon, + iconBoxSize = 3.5, label, requireBeta = false, onClick, ...buttonProps }: { icon: IconType; + iconBoxSize?: number; label: string; requireBeta?: boolean; onClick?: () => void; @@ -39,7 +41,9 @@ const ActionButton = ({ {...buttonProps} > - {icon && } + {icon && ( + + )} {label} diff --git a/app/src/components/datasets/FineTuneButton.tsx b/app/src/components/datasets/FineTuneButton.tsx index 80bcfd1..dfd1b30 100644 --- a/app/src/components/datasets/FineTuneButton.tsx +++ b/app/src/components/datasets/FineTuneButton.tsx @@ -22,7 +22,6 @@ import { useRouter } from "next/router"; import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks"; import { api } from "~/utils/api"; -import { useAppStore } from "~/state/store"; import ActionButton from "../ActionButton"; import InputDropdown from "../InputDropdown"; import { FiChevronDown } from "react-icons/fi"; @@ -53,7 +52,6 @@ const FineTuneButton = () => { export default FineTuneButton; const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { - const selectedProjectId = useAppStore((s) => s.selectedProjectId); const dataset = useDataset().data; const datasetEntries = useDatasetEntries().data; @@ -73,7 +71,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { const createFineTuneMutation = api.fineTunes.create.useMutation(); const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => { - if (!selectedProjectId || !modelSlug || !selectedBaseModel || !dataset) return; + if (!modelSlug || !selectedBaseModel || !dataset) return; await createFineTuneMutation.mutateAsync({ slug: modelSlug, baseModel: selectedBaseModel, @@ -83,7 +81,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { await utils.fineTunes.list.invalidate(); await router.push({ pathname: "/fine-tunes" }); disclosure.onClose(); - }, [createFineTuneMutation, selectedProjectId, modelSlug, selectedBaseModel]); + }, [createFineTuneMutation, modelSlug, selectedBaseModel]); return ( diff --git a/app/src/components/datasets/ImportDataButton.tsx b/app/src/components/datasets/ImportDataButton.tsx new file mode 100644 index 0000000..14c8e10 --- /dev/null +++ b/app/src/components/datasets/ImportDataButton.tsx @@ -0,0 +1,249 @@ +import { useState, useEffect, useRef } from "react"; +import { + Modal, + ModalOverlay, + ModalContent, + ModalHeader, + ModalCloseButton, + ModalBody, + ModalFooter, + HStack, + VStack, + Icon, + Text, + Button, + Box, + useDisclosure, + type UseDisclosureReturn, +} from "@chakra-ui/react"; +import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai"; + +import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks"; +import { api } from "~/utils/api"; +import ActionButton from "../ActionButton"; +import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows"; +import pluralize from "pluralize"; + +const ImportDataButton = () => { + const datasetEntries = useDatasetEntries().data; + + const numEntries = datasetEntries?.matchingEntryIds.length || 0; + + const disclosure = useDisclosure(); + + return ( + <> + + + + ); +}; + +export default ImportDataButton; + +const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { + const dataset = useDataset().data; + + const [validationError, setValidationError] = useState(null); + const [trainingRows, setTrainingRows] = useState(null); + + const fileInputRef = useRef(null); + + const handleFileDrop = (e: React.DragEvent) => { + e.preventDefault(); + const files = e.dataTransfer.files; + if (files.length > 0) { + processFile(files[0] as File); + } + }; + + const handleFileChange = (e: React.ChangeEvent) => { + const files = e.target.files; + if (files && files.length > 0) { + processFile(files[0] as File); + } + }; + + const processFile = (file: File) => { + const reader = new FileReader(); + reader.onload = (e: ProgressEvent) => { + const content = e.target?.result as string; + // Process the content, e.g., set to state + let parsedJSONL; + try { + parsedJSONL = parseJSONL(content) as TrainingRow[]; + const validationError = validateTrainingRows(parsedJSONL); + if (validationError) { + setValidationError(validationError); + setTrainingRows(null); + return; + } + setTrainingRows(parsedJSONL); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + console.log("e is", e); + setValidationError("Unable to parse JSONL file: " + (e.message as string)); + setTrainingRows(null); + return; + } + }; + reader.readAsText(file); + }; + + useEffect(() => { + if (disclosure.isOpen) { + setTrainingRows(null); + setValidationError(null); + } + }, [disclosure.isOpen]); + + const utils = api.useContext(); + + const sendJSONLMutation = api.datasetEntries.create.useMutation(); + + const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => { + if (!dataset || !trainingRows) return; + await sendJSONLMutation.mutateAsync({ + datasetId: dataset.id, + jsonl: JSON.stringify(trainingRows), + }); + + await utils.datasetEntries.list.invalidate(); + disclosure.onClose(); + }, [dataset, trainingRows, sendJSONLMutation]); + + return ( + + + + + + Upload Training Logs + + + + + + {!trainingRows && !validationError && ( + e.preventDefault()} + onDrop={handleFileDrop} + > + + + + + Drag & Drop + + + your .jsonl file here, or{" "} + + fileInputRef.current?.click()} + > + browse + + + + )} + {validationError && ( + + + + + Error + + {validationError} + + setValidationError(null)} + > + Try again + + + )} + {trainingRows && !validationError && ( + + + + + Success + + + We'll upload {trainingRows.length}{" "} + {pluralize("row", trainingRows.length)} into {dataset?.name}.{" "} + + + setTrainingRows(null)} + > + Change file + + + )} + + + + + + + + + + + ); +}; + +const JsonFileIcon = () => ( + + + + JSONL + + +); diff --git a/app/src/components/datasets/validateTrainingRows.ts b/app/src/components/datasets/validateTrainingRows.ts new file mode 100644 index 0000000..c78718d --- /dev/null +++ b/app/src/components/datasets/validateTrainingRows.ts @@ -0,0 +1,53 @@ +import { type CreateChatCompletionRequestMessage } from "openai/resources/chat"; + +export type TrainingRow = { + input: CreateChatCompletionRequestMessage[]; + output?: CreateChatCompletionRequestMessage; +}; + +export const parseJSONL = (jsonlString: string): unknown[] => + jsonlString + .trim() + .split("\n") + .map((line) => JSON.parse(line) as unknown); + +export const validateTrainingRows = (rows: unknown): string | null => { + if (!Array.isArray(rows)) return "training data is not an array"; + for (let i = 0; i < rows.length; i++) { + const row = rows[i] as TrainingRow; + const error = validateTrainingRow(row); + if (error) return `row ${i}: ${error}`; + } + + return null; +}; + +const validateTrainingRow = (row: TrainingRow): string | null => { + if (!row) return "empty row"; + if (!row.input) return "missing input"; + + // Validate input + if (!Array.isArray(row.input)) return "input is not an array"; + if ((row.input as unknown[]).some((x) => typeof x !== "object")) + return "input contains invalid item"; + if (row.input.some((x) => !x)) return "input contains empty item"; + if (row.input.some((x) => !x.content && !x.function_call)) + return "input contains item with no content or function_call"; + if (row.input.some((x) => x.function_call && !x.function_call.arguments)) + return "input contains item with function_call but no arguments"; + if (row.input.some((x) => x.function_call && !x.function_call.name)) + return "input contains item with function_call but no name"; + + // Validate output + if (row.output !== undefined) { + if (typeof row.output !== "object") return "output is not an object"; + if (!row.output.content && !row.output.function_call) + return "output contains no content or function_call"; + if (row.output.function_call && !row.output.function_call.arguments) + return "output contains function_call but no arguments"; + if (row.output.function_call && !row.output.function_call.name) + return "output contains function_call but no name"; + } + + return null; +}; diff --git a/app/src/pages/datasets/[id].tsx b/app/src/pages/datasets/[id].tsx index 86780a6..541c160 100644 --- a/app/src/pages/datasets/[id].tsx +++ b/app/src/pages/datasets/[id].tsx @@ -25,6 +25,7 @@ import DatasetEntryPaginator from "~/components/datasets/DatasetEntryPaginator"; import { useAppStore } from "~/state/store"; import FineTuneButton from "~/components/datasets/FineTuneButton"; import ExperimentButton from "~/components/datasets/ExperimentButton"; +import ImportDataButton from "~/components/datasets/ImportDataButton"; export default function Dataset() { const utils = api.useContext(); @@ -101,6 +102,7 @@ export default function Dataset() { + diff --git a/app/src/server/api/routers/datasetEntries.router.ts b/app/src/server/api/routers/datasetEntries.router.ts index abfd972..fee852e 100644 --- a/app/src/server/api/routers/datasetEntries.router.ts +++ b/app/src/server/api/routers/datasetEntries.router.ts @@ -14,6 +14,7 @@ import { prisma } from "~/server/db"; import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl"; import { error, success } from "~/utils/errorHandling/standardResponses"; import { countOpenAIChatTokens } from "~/utils/countTokens"; +import { type TrainingRow, validateTrainingRows } from "~/components/datasets/validateTrainingRows"; export const datasetEntriesRouter = createTRPCRouter({ list: protectedProcedure @@ -94,7 +95,8 @@ export const datasetEntriesRouter = createTRPCRouter({ name: z.string(), }) .optional(), - loggedCallIds: z.string().array(), + loggedCallIds: z.string().array().optional(), + jsonl: z.string().optional(), }), ) .mutation(async ({ input, ctx }) => { @@ -115,8 +117,14 @@ export const datasetEntriesRouter = createTRPCRouter({ return error("No datasetId or newDatasetParams provided"); } - const [loggedCalls, existingTrainingCount, existingTestingCount] = await prisma.$transaction([ - prisma.loggedCall.findMany({ + if (!input.loggedCallIds && !input.jsonl) { + return error("No loggedCallIds or jsonl provided"); + } + + let trainingRows: TrainingRow[]; + + if (input.loggedCallIds) { + const loggedCalls = await prisma.loggedCall.findMany({ where: { id: { in: input.loggedCallIds, @@ -135,7 +143,39 @@ export const datasetEntriesRouter = createTRPCRouter({ }, }, }, - }), + orderBy: { createdAt: "desc" }, + }); + + trainingRows = loggedCalls.map((loggedCall) => { + const inputMessages = ( + loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams + ).messages; + let output: ChatCompletion.Choice.Message | undefined = undefined; + const resp = loggedCall.modelResponse?.respPayload as unknown as + | ChatCompletion + | undefined; + if (resp && resp.choices?.[0]) { + output = resp.choices[0].message; + } else { + output = { + role: "assistant", + content: "", + }; + } + return { + input: inputMessages as unknown as CreateChatCompletionRequestMessage[], + output: output as unknown as CreateChatCompletionRequestMessage, + }; + }); + } else { + trainingRows = JSON.parse(input.jsonl as string) as TrainingRow[]; + const validationError = validateTrainingRows(trainingRows); + if (validationError) { + return error(`Invalid JSONL: ${validationError}`); + } + } + + const [existingTrainingCount, existingTestingCount] = await prisma.$transaction([ prisma.datasetEntry.count({ where: { datasetId, @@ -150,39 +190,32 @@ export const datasetEntriesRouter = createTRPCRouter({ }), ]); - const shuffledLoggedCalls = shuffle(loggedCalls); - - const totalEntries = existingTrainingCount + existingTestingCount + loggedCalls.length; - const numTrainingToAdd = Math.floor(trainingRatio * totalEntries) - existingTrainingCount; - + const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length; + const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount; + const numTestingToAdd = trainingRows.length - numTrainingToAdd; + const typesToAssign = shuffle([ + ...Array(numTrainingToAdd).fill("TRAIN"), + ...Array(numTestingToAdd).fill("TEST"), + ]); const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = []; - - let i = 0; - for (const loggedCall of shuffledLoggedCalls) { - const inputMessages = ( - loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams - ).messages; - let output: ChatCompletion.Choice.Message | undefined = undefined; - const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined; - if (resp && resp.choices?.[0]) { - output = resp.choices[0].message; - } else { - output = { - role: "assistant", - content: "", - }; + for (const row of trainingRows) { + let outputTokens = 0; + if (row.output) { + outputTokens = countOpenAIChatTokens("gpt-4-0613", [ + row.output as unknown as ChatCompletion.Choice.Message, + ]); } - datasetEntriesToCreate.push({ - datasetId, - loggedCallId: loggedCall.id, - input: inputMessages as unknown as Prisma.InputJsonValue, - output: output as unknown as Prisma.InputJsonValue, - inputTokens: loggedCall.modelResponse?.inputTokens || 0, - outputTokens: loggedCall.modelResponse?.outputTokens || 0, - type: i < numTrainingToAdd ? "TRAIN" : "TEST", + datasetId: datasetId, + input: row.input as unknown as Prisma.InputJsonValue, + output: row.output as unknown as Prisma.InputJsonValue, + inputTokens: countOpenAIChatTokens( + "gpt-4-0613", + row.input as unknown as CreateChatCompletionRequestMessage[], + ), + outputTokens, + type: typesToAssign.pop() as "TRAIN" | "TEST", }); - i++; } // Ensure dataset and dataset entries are created atomically @@ -198,7 +231,7 @@ export const datasetEntriesRouter = createTRPCRouter({ }, }), prisma.datasetEntry.createMany({ - data: shuffle(datasetEntriesToCreate), + data: datasetEntriesToCreate, }), ]); @@ -242,7 +275,8 @@ export const datasetEntriesRouter = createTRPCRouter({ let parsedOutput = undefined; let outputTokens = undefined; - if (input.updates.output) { + // The client might send "null" as a string, so we need to check for that + if (input.updates.output && input.updates.output !== "null") { parsedOutput = JSON.parse(input.updates.output); outputTokens = countOpenAIChatTokens("gpt-4-0613", [ parsedOutput as unknown as ChatCompletion.Choice.Message,