diff --git a/app/prisma/migrations/20230906052922_make_dataset_entry_logged_call_optional/migration.sql b/app/prisma/migrations/20230906052922_make_dataset_entry_logged_call_optional/migration.sql
new file mode 100644
index 0000000..bbccc3b
--- /dev/null
+++ b/app/prisma/migrations/20230906052922_make_dataset_entry_logged_call_optional/migration.sql
@@ -0,0 +1,5 @@
+-- AlterTable
+ALTER TABLE "DatasetEntry" ALTER COLUMN "loggedCallId" DROP NOT NULL,
+ALTER COLUMN "inputTokens" DROP DEFAULT,
+ALTER COLUMN "outputTokens" DROP DEFAULT,
+ALTER COLUMN "type" DROP DEFAULT;
diff --git a/app/prisma/schema.prisma b/app/prisma/schema.prisma
index acf3046..feaf327 100644
--- a/app/prisma/schema.prisma
+++ b/app/prisma/schema.prisma
@@ -199,8 +199,8 @@ enum DatasetEntryType {
model DatasetEntry {
id String @id @default(uuid()) @db.Uuid
- loggedCallId String @db.Uuid
- loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
+ loggedCallId String? @db.Uuid
+ loggedCall LoggedCall? @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
input Json @default("[]")
output Json?
diff --git a/app/src/components/ActionButton.tsx b/app/src/components/ActionButton.tsx
index 6822e02..d966461 100644
--- a/app/src/components/ActionButton.tsx
+++ b/app/src/components/ActionButton.tsx
@@ -7,12 +7,14 @@ import { BetaModal } from "./BetaModal";
const ActionButton = ({
icon,
+ iconBoxSize = 3.5,
label,
requireBeta = false,
onClick,
...buttonProps
}: {
icon: IconType;
+ iconBoxSize?: number;
label: string;
requireBeta?: boolean;
onClick?: () => void;
@@ -39,7 +41,9 @@ const ActionButton = ({
{...buttonProps}
>
- {icon && }
+ {icon && (
+
+ )}
{label}
diff --git a/app/src/components/datasets/FineTuneButton.tsx b/app/src/components/datasets/FineTuneButton.tsx
index 80bcfd1..dfd1b30 100644
--- a/app/src/components/datasets/FineTuneButton.tsx
+++ b/app/src/components/datasets/FineTuneButton.tsx
@@ -22,7 +22,6 @@ import { useRouter } from "next/router";
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
-import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import InputDropdown from "../InputDropdown";
import { FiChevronDown } from "react-icons/fi";
@@ -53,7 +52,6 @@ const FineTuneButton = () => {
export default FineTuneButton;
const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
- const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const dataset = useDataset().data;
const datasetEntries = useDatasetEntries().data;
@@ -73,7 +71,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const createFineTuneMutation = api.fineTunes.create.useMutation();
const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => {
- if (!selectedProjectId || !modelSlug || !selectedBaseModel || !dataset) return;
+ if (!modelSlug || !selectedBaseModel || !dataset) return;
await createFineTuneMutation.mutateAsync({
slug: modelSlug,
baseModel: selectedBaseModel,
@@ -83,7 +81,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
await utils.fineTunes.list.invalidate();
await router.push({ pathname: "/fine-tunes" });
disclosure.onClose();
- }, [createFineTuneMutation, selectedProjectId, modelSlug, selectedBaseModel]);
+ }, [createFineTuneMutation, modelSlug, selectedBaseModel]);
return (
diff --git a/app/src/components/datasets/ImportDataButton.tsx b/app/src/components/datasets/ImportDataButton.tsx
new file mode 100644
index 0000000..14c8e10
--- /dev/null
+++ b/app/src/components/datasets/ImportDataButton.tsx
@@ -0,0 +1,249 @@
+import { useState, useEffect, useRef } from "react";
+import {
+ Modal,
+ ModalOverlay,
+ ModalContent,
+ ModalHeader,
+ ModalCloseButton,
+ ModalBody,
+ ModalFooter,
+ HStack,
+ VStack,
+ Icon,
+ Text,
+ Button,
+ Box,
+ useDisclosure,
+ type UseDisclosureReturn,
+} from "@chakra-ui/react";
+import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
+
+import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
+import { api } from "~/utils/api";
+import ActionButton from "../ActionButton";
+import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
+import pluralize from "pluralize";
+
+const ImportDataButton = () => {
+ const datasetEntries = useDatasetEntries().data;
+
+ const numEntries = datasetEntries?.matchingEntryIds.length || 0;
+
+ const disclosure = useDisclosure();
+
+ return (
+ <>
+
+
+ >
+ );
+};
+
+export default ImportDataButton;
+
+const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
+ const dataset = useDataset().data;
+
+ const [validationError, setValidationError] = useState(null);
+ const [trainingRows, setTrainingRows] = useState(null);
+
+ const fileInputRef = useRef(null);
+
+ const handleFileDrop = (e: React.DragEvent) => {
+ e.preventDefault();
+ const files = e.dataTransfer.files;
+ if (files.length > 0) {
+ processFile(files[0] as File);
+ }
+ };
+
+ const handleFileChange = (e: React.ChangeEvent) => {
+ const files = e.target.files;
+ if (files && files.length > 0) {
+ processFile(files[0] as File);
+ }
+ };
+
+ const processFile = (file: File) => {
+ const reader = new FileReader();
+ reader.onload = (e: ProgressEvent) => {
+ const content = e.target?.result as string;
+ // Process the content, e.g., set to state
+ let parsedJSONL;
+ try {
+ parsedJSONL = parseJSONL(content) as TrainingRow[];
+ const validationError = validateTrainingRows(parsedJSONL);
+ if (validationError) {
+ setValidationError(validationError);
+ setTrainingRows(null);
+ return;
+ }
+ setTrainingRows(parsedJSONL);
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ } catch (e: any) {
+ console.log("e is", e);
+ setValidationError("Unable to parse JSONL file: " + (e.message as string));
+ setTrainingRows(null);
+ return;
+ }
+ };
+ reader.readAsText(file);
+ };
+
+ useEffect(() => {
+ if (disclosure.isOpen) {
+ setTrainingRows(null);
+ setValidationError(null);
+ }
+ }, [disclosure.isOpen]);
+
+ const utils = api.useContext();
+
+ const sendJSONLMutation = api.datasetEntries.create.useMutation();
+
+ const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
+ if (!dataset || !trainingRows) return;
+ await sendJSONLMutation.mutateAsync({
+ datasetId: dataset.id,
+ jsonl: JSON.stringify(trainingRows),
+ });
+
+ await utils.datasetEntries.list.invalidate();
+ disclosure.onClose();
+ }, [dataset, trainingRows, sendJSONLMutation]);
+
+ return (
+
+
+
+
+
+ Upload Training Logs
+
+
+
+
+
+ {!trainingRows && !validationError && (
+ e.preventDefault()}
+ onDrop={handleFileDrop}
+ >
+
+
+
+
+ Drag & Drop
+
+
+ your .jsonl file here, or{" "}
+
+ fileInputRef.current?.click()}
+ >
+ browse
+
+
+
+ )}
+ {validationError && (
+
+
+
+
+ Error
+
+ {validationError}
+
+ setValidationError(null)}
+ >
+ Try again
+
+
+ )}
+ {trainingRows && !validationError && (
+
+
+
+
+ Success
+
+
+ We'll upload {trainingRows.length}{" "}
+ {pluralize("row", trainingRows.length)} into {dataset?.name}.{" "}
+
+
+ setTrainingRows(null)}
+ >
+ Change file
+
+
+ )}
+
+
+
+
+
+
+
+
+
+
+ );
+};
+
+const JsonFileIcon = () => (
+
+
+
+ JSONL
+
+
+);
diff --git a/app/src/components/datasets/validateTrainingRows.ts b/app/src/components/datasets/validateTrainingRows.ts
new file mode 100644
index 0000000..c78718d
--- /dev/null
+++ b/app/src/components/datasets/validateTrainingRows.ts
@@ -0,0 +1,53 @@
+import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
+
+export type TrainingRow = {
+ input: CreateChatCompletionRequestMessage[];
+ output?: CreateChatCompletionRequestMessage;
+};
+
+export const parseJSONL = (jsonlString: string): unknown[] =>
+ jsonlString
+ .trim()
+ .split("\n")
+ .map((line) => JSON.parse(line) as unknown);
+
+export const validateTrainingRows = (rows: unknown): string | null => {
+ if (!Array.isArray(rows)) return "training data is not an array";
+ for (let i = 0; i < rows.length; i++) {
+ const row = rows[i] as TrainingRow;
+ const error = validateTrainingRow(row);
+ if (error) return `row ${i}: ${error}`;
+ }
+
+ return null;
+};
+
+const validateTrainingRow = (row: TrainingRow): string | null => {
+ if (!row) return "empty row";
+ if (!row.input) return "missing input";
+
+ // Validate input
+ if (!Array.isArray(row.input)) return "input is not an array";
+ if ((row.input as unknown[]).some((x) => typeof x !== "object"))
+ return "input contains invalid item";
+ if (row.input.some((x) => !x)) return "input contains empty item";
+ if (row.input.some((x) => !x.content && !x.function_call))
+ return "input contains item with no content or function_call";
+ if (row.input.some((x) => x.function_call && !x.function_call.arguments))
+ return "input contains item with function_call but no arguments";
+ if (row.input.some((x) => x.function_call && !x.function_call.name))
+ return "input contains item with function_call but no name";
+
+ // Validate output
+ if (row.output !== undefined) {
+ if (typeof row.output !== "object") return "output is not an object";
+ if (!row.output.content && !row.output.function_call)
+ return "output contains no content or function_call";
+ if (row.output.function_call && !row.output.function_call.arguments)
+ return "output contains function_call but no arguments";
+ if (row.output.function_call && !row.output.function_call.name)
+ return "output contains function_call but no name";
+ }
+
+ return null;
+};
diff --git a/app/src/pages/datasets/[id].tsx b/app/src/pages/datasets/[id].tsx
index 86780a6..541c160 100644
--- a/app/src/pages/datasets/[id].tsx
+++ b/app/src/pages/datasets/[id].tsx
@@ -25,6 +25,7 @@ import DatasetEntryPaginator from "~/components/datasets/DatasetEntryPaginator";
import { useAppStore } from "~/state/store";
import FineTuneButton from "~/components/datasets/FineTuneButton";
import ExperimentButton from "~/components/datasets/ExperimentButton";
+import ImportDataButton from "~/components/datasets/ImportDataButton";
export default function Dataset() {
const utils = api.useContext();
@@ -101,6 +102,7 @@ export default function Dataset() {
+
diff --git a/app/src/server/api/routers/datasetEntries.router.ts b/app/src/server/api/routers/datasetEntries.router.ts
index abfd972..fee852e 100644
--- a/app/src/server/api/routers/datasetEntries.router.ts
+++ b/app/src/server/api/routers/datasetEntries.router.ts
@@ -14,6 +14,7 @@ import { prisma } from "~/server/db";
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
import { error, success } from "~/utils/errorHandling/standardResponses";
import { countOpenAIChatTokens } from "~/utils/countTokens";
+import { type TrainingRow, validateTrainingRows } from "~/components/datasets/validateTrainingRows";
export const datasetEntriesRouter = createTRPCRouter({
list: protectedProcedure
@@ -94,7 +95,8 @@ export const datasetEntriesRouter = createTRPCRouter({
name: z.string(),
})
.optional(),
- loggedCallIds: z.string().array(),
+ loggedCallIds: z.string().array().optional(),
+ jsonl: z.string().optional(),
}),
)
.mutation(async ({ input, ctx }) => {
@@ -115,8 +117,14 @@ export const datasetEntriesRouter = createTRPCRouter({
return error("No datasetId or newDatasetParams provided");
}
- const [loggedCalls, existingTrainingCount, existingTestingCount] = await prisma.$transaction([
- prisma.loggedCall.findMany({
+ if (!input.loggedCallIds && !input.jsonl) {
+ return error("No loggedCallIds or jsonl provided");
+ }
+
+ let trainingRows: TrainingRow[];
+
+ if (input.loggedCallIds) {
+ const loggedCalls = await prisma.loggedCall.findMany({
where: {
id: {
in: input.loggedCallIds,
@@ -135,7 +143,39 @@ export const datasetEntriesRouter = createTRPCRouter({
},
},
},
- }),
+ orderBy: { createdAt: "desc" },
+ });
+
+ trainingRows = loggedCalls.map((loggedCall) => {
+ const inputMessages = (
+ loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
+ ).messages;
+ let output: ChatCompletion.Choice.Message | undefined = undefined;
+ const resp = loggedCall.modelResponse?.respPayload as unknown as
+ | ChatCompletion
+ | undefined;
+ if (resp && resp.choices?.[0]) {
+ output = resp.choices[0].message;
+ } else {
+ output = {
+ role: "assistant",
+ content: "",
+ };
+ }
+ return {
+ input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
+ output: output as unknown as CreateChatCompletionRequestMessage,
+ };
+ });
+ } else {
+ trainingRows = JSON.parse(input.jsonl as string) as TrainingRow[];
+ const validationError = validateTrainingRows(trainingRows);
+ if (validationError) {
+ return error(`Invalid JSONL: ${validationError}`);
+ }
+ }
+
+ const [existingTrainingCount, existingTestingCount] = await prisma.$transaction([
prisma.datasetEntry.count({
where: {
datasetId,
@@ -150,39 +190,32 @@ export const datasetEntriesRouter = createTRPCRouter({
}),
]);
- const shuffledLoggedCalls = shuffle(loggedCalls);
-
- const totalEntries = existingTrainingCount + existingTestingCount + loggedCalls.length;
- const numTrainingToAdd = Math.floor(trainingRatio * totalEntries) - existingTrainingCount;
-
+ const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
+ const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
+ const numTestingToAdd = trainingRows.length - numTrainingToAdd;
+ const typesToAssign = shuffle([
+ ...Array(numTrainingToAdd).fill("TRAIN"),
+ ...Array(numTestingToAdd).fill("TEST"),
+ ]);
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
-
- let i = 0;
- for (const loggedCall of shuffledLoggedCalls) {
- const inputMessages = (
- loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
- ).messages;
- let output: ChatCompletion.Choice.Message | undefined = undefined;
- const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined;
- if (resp && resp.choices?.[0]) {
- output = resp.choices[0].message;
- } else {
- output = {
- role: "assistant",
- content: "",
- };
+ for (const row of trainingRows) {
+ let outputTokens = 0;
+ if (row.output) {
+ outputTokens = countOpenAIChatTokens("gpt-4-0613", [
+ row.output as unknown as ChatCompletion.Choice.Message,
+ ]);
}
-
datasetEntriesToCreate.push({
- datasetId,
- loggedCallId: loggedCall.id,
- input: inputMessages as unknown as Prisma.InputJsonValue,
- output: output as unknown as Prisma.InputJsonValue,
- inputTokens: loggedCall.modelResponse?.inputTokens || 0,
- outputTokens: loggedCall.modelResponse?.outputTokens || 0,
- type: i < numTrainingToAdd ? "TRAIN" : "TEST",
+ datasetId: datasetId,
+ input: row.input as unknown as Prisma.InputJsonValue,
+ output: row.output as unknown as Prisma.InputJsonValue,
+ inputTokens: countOpenAIChatTokens(
+ "gpt-4-0613",
+ row.input as unknown as CreateChatCompletionRequestMessage[],
+ ),
+ outputTokens,
+ type: typesToAssign.pop() as "TRAIN" | "TEST",
});
- i++;
}
// Ensure dataset and dataset entries are created atomically
@@ -198,7 +231,7 @@ export const datasetEntriesRouter = createTRPCRouter({
},
}),
prisma.datasetEntry.createMany({
- data: shuffle(datasetEntriesToCreate),
+ data: datasetEntriesToCreate,
}),
]);
@@ -242,7 +275,8 @@ export const datasetEntriesRouter = createTRPCRouter({
let parsedOutput = undefined;
let outputTokens = undefined;
- if (input.updates.output) {
+ // The client might send "null" as a string, so we need to check for that
+ if (input.updates.output && input.updates.output !== "null") {
parsedOutput = JSON.parse(input.updates.output);
outputTokens = countOpenAIChatTokens("gpt-4-0613", [
parsedOutput as unknown as ChatCompletion.Choice.Message,