Add upload data modal
This commit is contained in:
@@ -0,0 +1,5 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "DatasetEntry" ALTER COLUMN "loggedCallId" DROP NOT NULL,
|
||||||
|
ALTER COLUMN "inputTokens" DROP DEFAULT,
|
||||||
|
ALTER COLUMN "outputTokens" DROP DEFAULT,
|
||||||
|
ALTER COLUMN "type" DROP DEFAULT;
|
||||||
@@ -199,8 +199,8 @@ enum DatasetEntryType {
|
|||||||
model DatasetEntry {
|
model DatasetEntry {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
loggedCallId String @db.Uuid
|
loggedCallId String? @db.Uuid
|
||||||
loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
|
loggedCall LoggedCall? @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
input Json @default("[]")
|
input Json @default("[]")
|
||||||
output Json?
|
output Json?
|
||||||
|
|||||||
@@ -7,12 +7,14 @@ import { BetaModal } from "./BetaModal";
|
|||||||
|
|
||||||
const ActionButton = ({
|
const ActionButton = ({
|
||||||
icon,
|
icon,
|
||||||
|
iconBoxSize = 3.5,
|
||||||
label,
|
label,
|
||||||
requireBeta = false,
|
requireBeta = false,
|
||||||
onClick,
|
onClick,
|
||||||
...buttonProps
|
...buttonProps
|
||||||
}: {
|
}: {
|
||||||
icon: IconType;
|
icon: IconType;
|
||||||
|
iconBoxSize?: number;
|
||||||
label: string;
|
label: string;
|
||||||
requireBeta?: boolean;
|
requireBeta?: boolean;
|
||||||
onClick?: () => void;
|
onClick?: () => void;
|
||||||
@@ -39,7 +41,9 @@ const ActionButton = ({
|
|||||||
{...buttonProps}
|
{...buttonProps}
|
||||||
>
|
>
|
||||||
<HStack spacing={1}>
|
<HStack spacing={1}>
|
||||||
{icon && <Icon as={icon} color={requireBeta ? "orange.400" : undefined} />}
|
{icon && (
|
||||||
|
<Icon as={icon} boxSize={iconBoxSize} color={requireBeta ? "orange.400" : undefined} />
|
||||||
|
)}
|
||||||
<Text display={{ base: "none", md: "flex" }}>{label}</Text>
|
<Text display={{ base: "none", md: "flex" }}>{label}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import { useRouter } from "next/router";
|
|||||||
|
|
||||||
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useAppStore } from "~/state/store";
|
|
||||||
import ActionButton from "../ActionButton";
|
import ActionButton from "../ActionButton";
|
||||||
import InputDropdown from "../InputDropdown";
|
import InputDropdown from "../InputDropdown";
|
||||||
import { FiChevronDown } from "react-icons/fi";
|
import { FiChevronDown } from "react-icons/fi";
|
||||||
@@ -53,7 +52,6 @@ const FineTuneButton = () => {
|
|||||||
export default FineTuneButton;
|
export default FineTuneButton;
|
||||||
|
|
||||||
const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
||||||
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
|
|
||||||
const dataset = useDataset().data;
|
const dataset = useDataset().data;
|
||||||
const datasetEntries = useDatasetEntries().data;
|
const datasetEntries = useDatasetEntries().data;
|
||||||
|
|
||||||
@@ -73,7 +71,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
|||||||
const createFineTuneMutation = api.fineTunes.create.useMutation();
|
const createFineTuneMutation = api.fineTunes.create.useMutation();
|
||||||
|
|
||||||
const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => {
|
const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
if (!selectedProjectId || !modelSlug || !selectedBaseModel || !dataset) return;
|
if (!modelSlug || !selectedBaseModel || !dataset) return;
|
||||||
await createFineTuneMutation.mutateAsync({
|
await createFineTuneMutation.mutateAsync({
|
||||||
slug: modelSlug,
|
slug: modelSlug,
|
||||||
baseModel: selectedBaseModel,
|
baseModel: selectedBaseModel,
|
||||||
@@ -83,7 +81,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
|||||||
await utils.fineTunes.list.invalidate();
|
await utils.fineTunes.list.invalidate();
|
||||||
await router.push({ pathname: "/fine-tunes" });
|
await router.push({ pathname: "/fine-tunes" });
|
||||||
disclosure.onClose();
|
disclosure.onClose();
|
||||||
}, [createFineTuneMutation, selectedProjectId, modelSlug, selectedBaseModel]);
|
}, [createFineTuneMutation, modelSlug, selectedBaseModel]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
||||||
|
|||||||
249
app/src/components/datasets/ImportDataButton.tsx
Normal file
249
app/src/components/datasets/ImportDataButton.tsx
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
import { useState, useEffect, useRef } from "react";
|
||||||
|
import {
|
||||||
|
Modal,
|
||||||
|
ModalOverlay,
|
||||||
|
ModalContent,
|
||||||
|
ModalHeader,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalBody,
|
||||||
|
ModalFooter,
|
||||||
|
HStack,
|
||||||
|
VStack,
|
||||||
|
Icon,
|
||||||
|
Text,
|
||||||
|
Button,
|
||||||
|
Box,
|
||||||
|
useDisclosure,
|
||||||
|
type UseDisclosureReturn,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
|
||||||
|
|
||||||
|
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import ActionButton from "../ActionButton";
|
||||||
|
import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
|
||||||
|
import pluralize from "pluralize";
|
||||||
|
|
||||||
|
const ImportDataButton = () => {
|
||||||
|
const datasetEntries = useDatasetEntries().data;
|
||||||
|
|
||||||
|
const numEntries = datasetEntries?.matchingEntryIds.length || 0;
|
||||||
|
|
||||||
|
const disclosure = useDisclosure();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<ActionButton
|
||||||
|
onClick={disclosure.onOpen}
|
||||||
|
label="Import Data"
|
||||||
|
icon={AiOutlineCloudUpload}
|
||||||
|
iconBoxSize={4}
|
||||||
|
isDisabled={numEntries === 0}
|
||||||
|
requireBeta
|
||||||
|
/>
|
||||||
|
<ImportDataModal disclosure={disclosure} />
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImportDataButton;
|
||||||
|
|
||||||
|
const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
||||||
|
const dataset = useDataset().data;
|
||||||
|
|
||||||
|
const [validationError, setValidationError] = useState<string | null>(null);
|
||||||
|
const [trainingRows, setTrainingRows] = useState<TrainingRow[] | null>(null);
|
||||||
|
|
||||||
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
const handleFileDrop = (e: React.DragEvent<HTMLDivElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
const files = e.dataTransfer.files;
|
||||||
|
if (files.length > 0) {
|
||||||
|
processFile(files[0] as File);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
const files = e.target.files;
|
||||||
|
if (files && files.length > 0) {
|
||||||
|
processFile(files[0] as File);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const processFile = (file: File) => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = (e: ProgressEvent<FileReader>) => {
|
||||||
|
const content = e.target?.result as string;
|
||||||
|
// Process the content, e.g., set to state
|
||||||
|
let parsedJSONL;
|
||||||
|
try {
|
||||||
|
parsedJSONL = parseJSONL(content) as TrainingRow[];
|
||||||
|
const validationError = validateTrainingRows(parsedJSONL);
|
||||||
|
if (validationError) {
|
||||||
|
setValidationError(validationError);
|
||||||
|
setTrainingRows(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setTrainingRows(parsedJSONL);
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
} catch (e: any) {
|
||||||
|
console.log("e is", e);
|
||||||
|
setValidationError("Unable to parse JSONL file: " + (e.message as string));
|
||||||
|
setTrainingRows(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
reader.readAsText(file);
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (disclosure.isOpen) {
|
||||||
|
setTrainingRows(null);
|
||||||
|
setValidationError(null);
|
||||||
|
}
|
||||||
|
}, [disclosure.isOpen]);
|
||||||
|
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const sendJSONLMutation = api.datasetEntries.create.useMutation();
|
||||||
|
|
||||||
|
const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!dataset || !trainingRows) return;
|
||||||
|
await sendJSONLMutation.mutateAsync({
|
||||||
|
datasetId: dataset.id,
|
||||||
|
jsonl: JSON.stringify(trainingRows),
|
||||||
|
});
|
||||||
|
|
||||||
|
await utils.datasetEntries.list.invalidate();
|
||||||
|
disclosure.onClose();
|
||||||
|
}, [dataset, trainingRows, sendJSONLMutation]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w={1200}>
|
||||||
|
<ModalHeader>
|
||||||
|
<HStack>
|
||||||
|
<Text>Upload Training Logs</Text>
|
||||||
|
</HStack>
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody maxW="unset" p={8}>
|
||||||
|
<Box w="full" aspectRatio={1.5}>
|
||||||
|
{!trainingRows && !validationError && (
|
||||||
|
<VStack
|
||||||
|
w="full"
|
||||||
|
h="full"
|
||||||
|
stroke="gray.300"
|
||||||
|
justifyContent="center"
|
||||||
|
borderRadius={8}
|
||||||
|
sx={{
|
||||||
|
"background-image": `url("data:image/svg+xml,%3csvg width='100%25' height='100%25' xmlns='http://www.w3.org/2000/svg'%3e%3crect x='2%25' y='2%25' width='96%25' height='96%25' fill='none' stroke='%23eee' stroke-width='4' stroke-dasharray='6%2c 14' stroke-dashoffset='0' stroke-linecap='square' rx='8' ry='8'/%3e%3c/svg%3e")`,
|
||||||
|
}}
|
||||||
|
onDragOver={(e) => e.preventDefault()}
|
||||||
|
onDrop={handleFileDrop}
|
||||||
|
>
|
||||||
|
<JsonFileIcon />
|
||||||
|
<Icon as={AiOutlineCloudUpload} boxSize={24} color="gray.300" />
|
||||||
|
|
||||||
|
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||||
|
Drag & Drop
|
||||||
|
</Text>
|
||||||
|
<Text color="gray.500">
|
||||||
|
your .jsonl file here, or{" "}
|
||||||
|
<input
|
||||||
|
type="file"
|
||||||
|
ref={fileInputRef}
|
||||||
|
onChange={handleFileChange}
|
||||||
|
style={{ display: "none" }}
|
||||||
|
accept=".jsonl"
|
||||||
|
/>
|
||||||
|
<Text
|
||||||
|
as="span"
|
||||||
|
textDecor="underline"
|
||||||
|
_hover={{ color: "orange.400" }}
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={() => fileInputRef.current?.click()}
|
||||||
|
>
|
||||||
|
browse
|
||||||
|
</Text>
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
)}
|
||||||
|
{validationError && (
|
||||||
|
<VStack w="full" h="full" justifyContent="center" spacing={8}>
|
||||||
|
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
|
||||||
|
<VStack w="full">
|
||||||
|
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||||
|
Error
|
||||||
|
</Text>
|
||||||
|
<Text color="gray.500">{validationError}</Text>
|
||||||
|
</VStack>
|
||||||
|
<Text
|
||||||
|
as="span"
|
||||||
|
textDecor="underline"
|
||||||
|
color="gray.500"
|
||||||
|
_hover={{ color: "orange.400" }}
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={() => setValidationError(null)}
|
||||||
|
>
|
||||||
|
Try again
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
)}
|
||||||
|
{trainingRows && !validationError && (
|
||||||
|
<VStack w="full" h="full" justifyContent="center" spacing={8}>
|
||||||
|
<JsonFileIcon />
|
||||||
|
<VStack w="full">
|
||||||
|
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||||
|
Success
|
||||||
|
</Text>
|
||||||
|
<Text color="gray.500">
|
||||||
|
We'll upload <b>{trainingRows.length}</b>{" "}
|
||||||
|
{pluralize("row", trainingRows.length)} into <b>{dataset?.name}</b>.{" "}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
<Text
|
||||||
|
as="span"
|
||||||
|
textDecor="underline"
|
||||||
|
color="gray.500"
|
||||||
|
_hover={{ color: "orange.400" }}
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={() => setTrainingRows(null)}
|
||||||
|
>
|
||||||
|
Change file
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</ModalBody>
|
||||||
|
<ModalFooter>
|
||||||
|
<HStack>
|
||||||
|
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
colorScheme="orange"
|
||||||
|
onClick={sendJSONL}
|
||||||
|
isLoading={sendingInProgress}
|
||||||
|
minW={24}
|
||||||
|
isDisabled={!trainingRows || !!validationError}
|
||||||
|
>
|
||||||
|
Upload
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
</ModalFooter>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const JsonFileIcon = () => (
|
||||||
|
<Box position="relative" display="flex" alignItems="center" justifyContent="center">
|
||||||
|
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
|
||||||
|
<Text position="absolute" color="orange.400" fontWeight="bold" fontSize={12} pt={4}>
|
||||||
|
JSONL
|
||||||
|
</Text>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
53
app/src/components/datasets/validateTrainingRows.ts
Normal file
53
app/src/components/datasets/validateTrainingRows.ts
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
|
||||||
|
|
||||||
|
export type TrainingRow = {
|
||||||
|
input: CreateChatCompletionRequestMessage[];
|
||||||
|
output?: CreateChatCompletionRequestMessage;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const parseJSONL = (jsonlString: string): unknown[] =>
|
||||||
|
jsonlString
|
||||||
|
.trim()
|
||||||
|
.split("\n")
|
||||||
|
.map((line) => JSON.parse(line) as unknown);
|
||||||
|
|
||||||
|
export const validateTrainingRows = (rows: unknown): string | null => {
|
||||||
|
if (!Array.isArray(rows)) return "training data is not an array";
|
||||||
|
for (let i = 0; i < rows.length; i++) {
|
||||||
|
const row = rows[i] as TrainingRow;
|
||||||
|
const error = validateTrainingRow(row);
|
||||||
|
if (error) return `row ${i}: ${error}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const validateTrainingRow = (row: TrainingRow): string | null => {
|
||||||
|
if (!row) return "empty row";
|
||||||
|
if (!row.input) return "missing input";
|
||||||
|
|
||||||
|
// Validate input
|
||||||
|
if (!Array.isArray(row.input)) return "input is not an array";
|
||||||
|
if ((row.input as unknown[]).some((x) => typeof x !== "object"))
|
||||||
|
return "input contains invalid item";
|
||||||
|
if (row.input.some((x) => !x)) return "input contains empty item";
|
||||||
|
if (row.input.some((x) => !x.content && !x.function_call))
|
||||||
|
return "input contains item with no content or function_call";
|
||||||
|
if (row.input.some((x) => x.function_call && !x.function_call.arguments))
|
||||||
|
return "input contains item with function_call but no arguments";
|
||||||
|
if (row.input.some((x) => x.function_call && !x.function_call.name))
|
||||||
|
return "input contains item with function_call but no name";
|
||||||
|
|
||||||
|
// Validate output
|
||||||
|
if (row.output !== undefined) {
|
||||||
|
if (typeof row.output !== "object") return "output is not an object";
|
||||||
|
if (!row.output.content && !row.output.function_call)
|
||||||
|
return "output contains no content or function_call";
|
||||||
|
if (row.output.function_call && !row.output.function_call.arguments)
|
||||||
|
return "output contains function_call but no arguments";
|
||||||
|
if (row.output.function_call && !row.output.function_call.name)
|
||||||
|
return "output contains function_call but no name";
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
@@ -25,6 +25,7 @@ import DatasetEntryPaginator from "~/components/datasets/DatasetEntryPaginator";
|
|||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
import FineTuneButton from "~/components/datasets/FineTuneButton";
|
import FineTuneButton from "~/components/datasets/FineTuneButton";
|
||||||
import ExperimentButton from "~/components/datasets/ExperimentButton";
|
import ExperimentButton from "~/components/datasets/ExperimentButton";
|
||||||
|
import ImportDataButton from "~/components/datasets/ImportDataButton";
|
||||||
|
|
||||||
export default function Dataset() {
|
export default function Dataset() {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
@@ -101,6 +102,7 @@ export default function Dataset() {
|
|||||||
<HStack w="full" justifyContent="flex-end">
|
<HStack w="full" justifyContent="flex-end">
|
||||||
<FineTuneButton />
|
<FineTuneButton />
|
||||||
<ExperimentButton />
|
<ExperimentButton />
|
||||||
|
<ImportDataButton />
|
||||||
</HStack>
|
</HStack>
|
||||||
<DatasetEntriesTable />
|
<DatasetEntriesTable />
|
||||||
<DatasetEntryPaginator />
|
<DatasetEntryPaginator />
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import { prisma } from "~/server/db";
|
|||||||
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
|
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
|
||||||
import { error, success } from "~/utils/errorHandling/standardResponses";
|
import { error, success } from "~/utils/errorHandling/standardResponses";
|
||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
|
import { type TrainingRow, validateTrainingRows } from "~/components/datasets/validateTrainingRows";
|
||||||
|
|
||||||
export const datasetEntriesRouter = createTRPCRouter({
|
export const datasetEntriesRouter = createTRPCRouter({
|
||||||
list: protectedProcedure
|
list: protectedProcedure
|
||||||
@@ -94,7 +95,8 @@ export const datasetEntriesRouter = createTRPCRouter({
|
|||||||
name: z.string(),
|
name: z.string(),
|
||||||
})
|
})
|
||||||
.optional(),
|
.optional(),
|
||||||
loggedCallIds: z.string().array(),
|
loggedCallIds: z.string().array().optional(),
|
||||||
|
jsonl: z.string().optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -115,8 +117,14 @@ export const datasetEntriesRouter = createTRPCRouter({
|
|||||||
return error("No datasetId or newDatasetParams provided");
|
return error("No datasetId or newDatasetParams provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
const [loggedCalls, existingTrainingCount, existingTestingCount] = await prisma.$transaction([
|
if (!input.loggedCallIds && !input.jsonl) {
|
||||||
prisma.loggedCall.findMany({
|
return error("No loggedCallIds or jsonl provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
let trainingRows: TrainingRow[];
|
||||||
|
|
||||||
|
if (input.loggedCallIds) {
|
||||||
|
const loggedCalls = await prisma.loggedCall.findMany({
|
||||||
where: {
|
where: {
|
||||||
id: {
|
id: {
|
||||||
in: input.loggedCallIds,
|
in: input.loggedCallIds,
|
||||||
@@ -135,7 +143,39 @@ export const datasetEntriesRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}),
|
orderBy: { createdAt: "desc" },
|
||||||
|
});
|
||||||
|
|
||||||
|
trainingRows = loggedCalls.map((loggedCall) => {
|
||||||
|
const inputMessages = (
|
||||||
|
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
|
||||||
|
).messages;
|
||||||
|
let output: ChatCompletion.Choice.Message | undefined = undefined;
|
||||||
|
const resp = loggedCall.modelResponse?.respPayload as unknown as
|
||||||
|
| ChatCompletion
|
||||||
|
| undefined;
|
||||||
|
if (resp && resp.choices?.[0]) {
|
||||||
|
output = resp.choices[0].message;
|
||||||
|
} else {
|
||||||
|
output = {
|
||||||
|
role: "assistant",
|
||||||
|
content: "",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
|
||||||
|
output: output as unknown as CreateChatCompletionRequestMessage,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
trainingRows = JSON.parse(input.jsonl as string) as TrainingRow[];
|
||||||
|
const validationError = validateTrainingRows(trainingRows);
|
||||||
|
if (validationError) {
|
||||||
|
return error(`Invalid JSONL: ${validationError}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const [existingTrainingCount, existingTestingCount] = await prisma.$transaction([
|
||||||
prisma.datasetEntry.count({
|
prisma.datasetEntry.count({
|
||||||
where: {
|
where: {
|
||||||
datasetId,
|
datasetId,
|
||||||
@@ -150,39 +190,32 @@ export const datasetEntriesRouter = createTRPCRouter({
|
|||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const shuffledLoggedCalls = shuffle(loggedCalls);
|
const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
|
||||||
|
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
|
||||||
const totalEntries = existingTrainingCount + existingTestingCount + loggedCalls.length;
|
const numTestingToAdd = trainingRows.length - numTrainingToAdd;
|
||||||
const numTrainingToAdd = Math.floor(trainingRatio * totalEntries) - existingTrainingCount;
|
const typesToAssign = shuffle([
|
||||||
|
...Array(numTrainingToAdd).fill("TRAIN"),
|
||||||
|
...Array(numTestingToAdd).fill("TEST"),
|
||||||
|
]);
|
||||||
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
|
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
|
||||||
|
for (const row of trainingRows) {
|
||||||
let i = 0;
|
let outputTokens = 0;
|
||||||
for (const loggedCall of shuffledLoggedCalls) {
|
if (row.output) {
|
||||||
const inputMessages = (
|
outputTokens = countOpenAIChatTokens("gpt-4-0613", [
|
||||||
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
|
row.output as unknown as ChatCompletion.Choice.Message,
|
||||||
).messages;
|
]);
|
||||||
let output: ChatCompletion.Choice.Message | undefined = undefined;
|
|
||||||
const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined;
|
|
||||||
if (resp && resp.choices?.[0]) {
|
|
||||||
output = resp.choices[0].message;
|
|
||||||
} else {
|
|
||||||
output = {
|
|
||||||
role: "assistant",
|
|
||||||
content: "",
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
datasetEntriesToCreate.push({
|
datasetEntriesToCreate.push({
|
||||||
datasetId,
|
datasetId: datasetId,
|
||||||
loggedCallId: loggedCall.id,
|
input: row.input as unknown as Prisma.InputJsonValue,
|
||||||
input: inputMessages as unknown as Prisma.InputJsonValue,
|
output: row.output as unknown as Prisma.InputJsonValue,
|
||||||
output: output as unknown as Prisma.InputJsonValue,
|
inputTokens: countOpenAIChatTokens(
|
||||||
inputTokens: loggedCall.modelResponse?.inputTokens || 0,
|
"gpt-4-0613",
|
||||||
outputTokens: loggedCall.modelResponse?.outputTokens || 0,
|
row.input as unknown as CreateChatCompletionRequestMessage[],
|
||||||
type: i < numTrainingToAdd ? "TRAIN" : "TEST",
|
),
|
||||||
|
outputTokens,
|
||||||
|
type: typesToAssign.pop() as "TRAIN" | "TEST",
|
||||||
});
|
});
|
||||||
i++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure dataset and dataset entries are created atomically
|
// Ensure dataset and dataset entries are created atomically
|
||||||
@@ -198,7 +231,7 @@ export const datasetEntriesRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
prisma.datasetEntry.createMany({
|
prisma.datasetEntry.createMany({
|
||||||
data: shuffle(datasetEntriesToCreate),
|
data: datasetEntriesToCreate,
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
@@ -242,7 +275,8 @@ export const datasetEntriesRouter = createTRPCRouter({
|
|||||||
|
|
||||||
let parsedOutput = undefined;
|
let parsedOutput = undefined;
|
||||||
let outputTokens = undefined;
|
let outputTokens = undefined;
|
||||||
if (input.updates.output) {
|
// The client might send "null" as a string, so we need to check for that
|
||||||
|
if (input.updates.output && input.updates.output !== "null") {
|
||||||
parsedOutput = JSON.parse(input.updates.output);
|
parsedOutput = JSON.parse(input.updates.output);
|
||||||
outputTokens = countOpenAIChatTokens("gpt-4-0613", [
|
outputTokens = countOpenAIChatTokens("gpt-4-0613", [
|
||||||
parsedOutput as unknown as ChatCompletion.Choice.Message,
|
parsedOutput as unknown as ChatCompletion.Choice.Message,
|
||||||
|
|||||||
Reference in New Issue
Block a user