Upload training data through Azure Blob Storage
This commit is contained in:
@@ -40,3 +40,8 @@ SMTP_HOST="placeholder"
|
||||
SMTP_PORT="placeholder"
|
||||
SMTP_LOGIN="placeholder"
|
||||
SMTP_PASSWORD="placeholder"
|
||||
|
||||
# Azure credentials are necessary for uploading large training data files
|
||||
AZURE_STORAGE_ACCOUNT_NAME="placeholder"
|
||||
AZURE_STORAGE_ACCOUNT_KEY="placeholder"
|
||||
AZURE_STORAGE_CONTAINER_NAME="placeholder"
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.5.8",
|
||||
"@apidevtools/json-schema-ref-parser": "^10.1.0",
|
||||
"@azure/identity": "^3.3.0",
|
||||
"@azure/storage-blob": "12.15.0",
|
||||
"@babel/standalone": "^7.22.9",
|
||||
"@chakra-ui/anatomy": "^2.2.0",
|
||||
"@chakra-ui/next-js": "^2.1.4",
|
||||
@@ -69,6 +71,7 @@
|
||||
"jsonschema": "^1.4.1",
|
||||
"kysely": "^0.26.1",
|
||||
"kysely-codegen": "^0.10.1",
|
||||
"llama-tokenizer-js": "^1.1.3",
|
||||
"lodash-es": "^4.17.21",
|
||||
"lucide-react": "^0.265.0",
|
||||
"marked": "^7.0.3",
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "DatasetFileUploadStatus" AS ENUM ('PENDING', 'DOWNLOADING', 'PROCESSING', 'SAVING', 'COMPLETE', 'ERROR');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "DatasetFileUpload" (
|
||||
"id" UUID NOT NULL,
|
||||
"datasetId" UUID NOT NULL,
|
||||
"blobName" TEXT NOT NULL,
|
||||
"fileName" TEXT NOT NULL,
|
||||
"fileSize" INTEGER NOT NULL,
|
||||
"progress" INTEGER NOT NULL DEFAULT 0,
|
||||
"status" "DatasetFileUploadStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"uploadedAt" TIMESTAMP(3) NOT NULL,
|
||||
"visible" BOOLEAN NOT NULL DEFAULT true,
|
||||
"errorMessage" TEXT,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "DatasetFileUpload_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "DatasetFileUpload" ADD CONSTRAINT "DatasetFileUpload_datasetId_fkey" FOREIGN KEY ("datasetId") REFERENCES "Dataset"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -176,12 +176,41 @@ model OutputEvaluation {
|
||||
@@unique([modelResponseId, evaluationId])
|
||||
}
|
||||
|
||||
|
||||
enum DatasetFileUploadStatus {
|
||||
PENDING
|
||||
DOWNLOADING
|
||||
PROCESSING
|
||||
SAVING
|
||||
COMPLETE
|
||||
ERROR
|
||||
}
|
||||
|
||||
model DatasetFileUpload {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
datasetId String @db.Uuid
|
||||
dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade)
|
||||
blobName String
|
||||
fileName String
|
||||
fileSize Int
|
||||
progress Int @default(0) // Percentage
|
||||
status DatasetFileUploadStatus @default(PENDING)
|
||||
uploadedAt DateTime
|
||||
visible Boolean @default(true)
|
||||
errorMessage String?
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
}
|
||||
|
||||
model Dataset {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
name String
|
||||
datasetEntries DatasetEntry[]
|
||||
fineTunes FineTune[]
|
||||
datasetFileUploads DatasetFileUpload[]
|
||||
trainingRatio Float @default(0.8)
|
||||
|
||||
projectId String @db.Uuid
|
||||
|
||||
61
app/src/components/datasets/FileUploadCard.tsx
Normal file
61
app/src/components/datasets/FileUploadCard.tsx
Normal file
@@ -0,0 +1,61 @@
|
||||
import { VStack, HStack, Button, Text, Card, Progress, IconButton } from "@chakra-ui/react";
|
||||
import { BsX } from "react-icons/bs";
|
||||
|
||||
import { type RouterOutputs, api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { formatFileSize } from "~/utils/utils";
|
||||
|
||||
type FileUpload = RouterOutputs["datasets"]["listFileUploads"][0];
|
||||
|
||||
const FileUploadCard = ({ fileUpload }: { fileUpload: FileUpload }) => {
|
||||
const { id, fileName, fileSize, progress, status, errorMessage } = fileUpload;
|
||||
|
||||
const utils = api.useContext();
|
||||
|
||||
const hideFileUploadMutation = api.datasets.hideFileUpload.useMutation();
|
||||
const [hideFileUpload, hidingInProgress] = useHandledAsyncCallback(async () => {
|
||||
await hideFileUploadMutation.mutateAsync({ fileUploadId: id });
|
||||
await utils.datasets.listFileUploads.invalidate();
|
||||
}, [id, hideFileUploadMutation, utils]);
|
||||
|
||||
const [refreshDatasetEntries] = useHandledAsyncCallback(async () => {
|
||||
await utils.datasetEntries.list.invalidate();
|
||||
}, [utils]);
|
||||
|
||||
return (
|
||||
<Card w="full">
|
||||
<VStack w="full" alignItems="flex-start" p={4}>
|
||||
<HStack w="full" justifyContent="space-between">
|
||||
<Text fontWeight="bold">
|
||||
Uploading {fileName} ({formatFileSize(fileSize, 2)})
|
||||
</Text>
|
||||
<HStack spacing={0}>
|
||||
{status === "COMPLETE" && (
|
||||
<Button variant="ghost" onClick={refreshDatasetEntries} color="orange.400" size="xs">
|
||||
Refresh Table
|
||||
</Button>
|
||||
)}
|
||||
<IconButton
|
||||
aria-label="Hide file upload"
|
||||
as={BsX}
|
||||
boxSize={6}
|
||||
minW={0}
|
||||
variant="ghost"
|
||||
isLoading={hidingInProgress}
|
||||
onClick={hideFileUpload}
|
||||
cursor="pointer"
|
||||
/>
|
||||
</HStack>
|
||||
</HStack>
|
||||
|
||||
<Text alignSelf="center" fontSize="xs">
|
||||
{errorMessage ? errorMessage : `${status} (${progress}%)`}
|
||||
</Text>
|
||||
|
||||
<Progress w="full" value={progress} borderRadius={2} />
|
||||
</VStack>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default FileUploadCard;
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { useState, useEffect, useRef, useCallback } from "react";
|
||||
import {
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
@@ -16,13 +16,15 @@ import {
|
||||
useDisclosure,
|
||||
type UseDisclosureReturn,
|
||||
} from "@chakra-ui/react";
|
||||
import pluralize from "pluralize";
|
||||
import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
|
||||
|
||||
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import ActionButton from "../ActionButton";
|
||||
import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
|
||||
import pluralize from "pluralize";
|
||||
import { uploadDatasetEntryFile } from "~/utils/azure/website";
|
||||
import { formatFileSize } from "~/utils/utils";
|
||||
|
||||
const ImportDataButton = () => {
|
||||
const disclosure = useDisclosure();
|
||||
@@ -48,6 +50,7 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
|
||||
const [validationError, setValidationError] = useState<string | null>(null);
|
||||
const [trainingRows, setTrainingRows] = useState<TrainingRow[] | null>(null);
|
||||
const [file, setFile] = useState<File | null>(null);
|
||||
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
@@ -67,6 +70,14 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
};
|
||||
|
||||
const processFile = (file: File) => {
|
||||
setFile(file);
|
||||
|
||||
// skip reading if file is larger than 10MB
|
||||
if (file.size > 10000000) {
|
||||
setTrainingRows(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e: ProgressEvent<FileReader>) => {
|
||||
const content = e.target?.result as string;
|
||||
@@ -83,7 +94,6 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
setTrainingRows(parsedJSONL);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} catch (e: any) {
|
||||
console.log("e is", e);
|
||||
setValidationError("Unable to parse JSONL file: " + (e.message as string));
|
||||
setTrainingRows(null);
|
||||
return;
|
||||
@@ -92,28 +102,38 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
const resetState = useCallback(() => {
|
||||
setValidationError(null);
|
||||
setTrainingRows(null);
|
||||
setFile(null);
|
||||
}, [setValidationError, setTrainingRows, setFile]);
|
||||
|
||||
useEffect(() => {
|
||||
if (disclosure.isOpen) {
|
||||
setTrainingRows(null);
|
||||
setValidationError(null);
|
||||
resetState();
|
||||
}
|
||||
}, [disclosure.isOpen]);
|
||||
}, [disclosure.isOpen, resetState]);
|
||||
|
||||
const triggerFileDownloadMutation = api.datasets.triggerFileDownload.useMutation();
|
||||
|
||||
const utils = api.useContext();
|
||||
|
||||
const sendJSONLMutation = api.datasetEntries.create.useMutation();
|
||||
|
||||
const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!dataset || !trainingRows) return;
|
||||
if (!dataset || !file) return;
|
||||
|
||||
await sendJSONLMutation.mutateAsync({
|
||||
const blobName = await uploadDatasetEntryFile(file);
|
||||
|
||||
await triggerFileDownloadMutation.mutateAsync({
|
||||
datasetId: dataset.id,
|
||||
jsonl: JSON.stringify(trainingRows),
|
||||
blobName,
|
||||
fileName: file.name,
|
||||
fileSize: file.size,
|
||||
});
|
||||
|
||||
await utils.datasetEntries.list.invalidate();
|
||||
await utils.datasets.listFileUploads.invalidate();
|
||||
|
||||
disclosure.onClose();
|
||||
}, [dataset, trainingRows, sendJSONLMutation]);
|
||||
}, [dataset, trainingRows, triggerFileDownloadMutation, file, utils]);
|
||||
|
||||
return (
|
||||
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
||||
@@ -127,7 +147,28 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset" p={8}>
|
||||
<Box w="full" aspectRatio={1.5}>
|
||||
{!trainingRows && !validationError && (
|
||||
{validationError && (
|
||||
<VStack w="full" h="full" justifyContent="center" spacing={8}>
|
||||
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
|
||||
<VStack w="full">
|
||||
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||
Error
|
||||
</Text>
|
||||
<Text color="gray.500">{validationError}</Text>
|
||||
</VStack>
|
||||
<Text
|
||||
as="span"
|
||||
textDecor="underline"
|
||||
color="gray.500"
|
||||
_hover={{ color: "orange.400" }}
|
||||
cursor="pointer"
|
||||
onClick={resetState}
|
||||
>
|
||||
Try again
|
||||
</Text>
|
||||
</VStack>
|
||||
)}
|
||||
{!validationError && !file && (
|
||||
<VStack
|
||||
w="full"
|
||||
h="full"
|
||||
@@ -167,38 +208,28 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
</Text>
|
||||
</VStack>
|
||||
)}
|
||||
{validationError && (
|
||||
<VStack w="full" h="full" justifyContent="center" spacing={8}>
|
||||
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
|
||||
<VStack w="full">
|
||||
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||
Error
|
||||
</Text>
|
||||
<Text color="gray.500">{validationError}</Text>
|
||||
</VStack>
|
||||
<Text
|
||||
as="span"
|
||||
textDecor="underline"
|
||||
color="gray.500"
|
||||
_hover={{ color: "orange.400" }}
|
||||
cursor="pointer"
|
||||
onClick={() => setValidationError(null)}
|
||||
>
|
||||
Try again
|
||||
</Text>
|
||||
</VStack>
|
||||
)}
|
||||
{trainingRows && !validationError && (
|
||||
{!validationError && file && (
|
||||
<VStack w="full" h="full" justifyContent="center" spacing={8}>
|
||||
<JsonFileIcon />
|
||||
<VStack w="full">
|
||||
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||
Success
|
||||
</Text>
|
||||
<Text color="gray.500">
|
||||
We'll upload <b>{trainingRows.length}</b>{" "}
|
||||
{pluralize("row", trainingRows.length)} into <b>{dataset?.name}</b>.{" "}
|
||||
</Text>
|
||||
{trainingRows ? (
|
||||
<>
|
||||
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||
Success
|
||||
</Text>
|
||||
<Text color="gray.500">
|
||||
We'll upload <b>{trainingRows.length}</b>{" "}
|
||||
{pluralize("row", trainingRows.length)} into <b>{dataset?.name}</b>.{" "}
|
||||
</Text>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Text fontSize={32} color="gray.500" fontWeight="bold">
|
||||
{file.name}
|
||||
</Text>
|
||||
<Text color="gray.500">{formatFileSize(file.size)}</Text>
|
||||
</>
|
||||
)}
|
||||
</VStack>
|
||||
<Text
|
||||
as="span"
|
||||
@@ -206,7 +237,7 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
color="gray.500"
|
||||
_hover={{ color: "orange.400" }}
|
||||
cursor="pointer"
|
||||
onClick={() => setTrainingRows(null)}
|
||||
onClick={resetState}
|
||||
>
|
||||
Change file
|
||||
</Text>
|
||||
@@ -224,7 +255,7 @@ const ImportDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
|
||||
onClick={sendJSONL}
|
||||
isLoading={sendingInProgress}
|
||||
minW={24}
|
||||
isDisabled={!trainingRows || !!validationError}
|
||||
isDisabled={!file || !!validationError}
|
||||
>
|
||||
Upload
|
||||
</Button>
|
||||
|
||||
@@ -26,6 +26,9 @@ export const env = createEnv({
|
||||
SMTP_PORT: z.string().default("placeholder"),
|
||||
SMTP_LOGIN: z.string().default("placeholder"),
|
||||
SMTP_PASSWORD: z.string().default("placeholder"),
|
||||
AZURE_STORAGE_ACCOUNT_NAME: z.string().default("placeholder"),
|
||||
AZURE_STORAGE_ACCOUNT_KEY: z.string().default("placeholder"),
|
||||
AZURE_STORAGE_CONTAINER_NAME: z.string().default("placeholder"),
|
||||
WORKER_CONCURRENCY: z
|
||||
.string()
|
||||
.default("10")
|
||||
@@ -72,6 +75,9 @@ export const env = createEnv({
|
||||
SMTP_PORT: process.env.SMTP_PORT,
|
||||
SMTP_LOGIN: process.env.SMTP_LOGIN,
|
||||
SMTP_PASSWORD: process.env.SMTP_PASSWORD,
|
||||
AZURE_STORAGE_ACCOUNT_NAME: process.env.AZURE_STORAGE_ACCOUNT_NAME,
|
||||
AZURE_STORAGE_ACCOUNT_KEY: process.env.AZURE_STORAGE_ACCOUNT_KEY,
|
||||
AZURE_STORAGE_CONTAINER_NAME: process.env.AZURE_STORAGE_CONTAINER_NAME,
|
||||
WORKER_CONCURRENCY: process.env.WORKER_CONCURRENCY,
|
||||
WORKER_MAX_POOL_SIZE: process.env.WORKER_MAX_POOL_SIZE,
|
||||
},
|
||||
|
||||
@@ -28,6 +28,7 @@ import ExperimentButton from "~/components/datasets/ExperimentButton";
|
||||
import ImportDataButton from "~/components/datasets/ImportDataButton";
|
||||
import DownloadButton from "~/components/datasets/ExportButton";
|
||||
import DeleteButton from "~/components/datasets/DeleteButton";
|
||||
import FileUploadCard from "~/components/datasets/FileUploadCard";
|
||||
|
||||
export default function Dataset() {
|
||||
const utils = api.useContext();
|
||||
@@ -40,6 +41,19 @@ export default function Dataset() {
|
||||
setName(dataset.data?.name || "");
|
||||
}, [dataset.data?.name]);
|
||||
|
||||
const [fileUploadsRefetchInterval, setFileUploadsRefetchInterval] = useState<number>(500);
|
||||
const fileUploads = api.datasets.listFileUploads.useQuery(
|
||||
{ datasetId: dataset.data?.id as string },
|
||||
{ enabled: !!dataset.data?.id, refetchInterval: fileUploadsRefetchInterval },
|
||||
);
|
||||
useEffect(() => {
|
||||
if (fileUploads?.data?.some((fu) => fu.status !== "COMPLETE" && fu.status !== "ERROR")) {
|
||||
setFileUploadsRefetchInterval(500);
|
||||
} else {
|
||||
setFileUploadsRefetchInterval(0);
|
||||
}
|
||||
}, [fileUploads]);
|
||||
|
||||
useEffect(() => {
|
||||
useAppStore.getState().sharedArgumentsEditor.loadMonaco().catch(console.error);
|
||||
}, []);
|
||||
@@ -101,6 +115,13 @@ export default function Dataset() {
|
||||
<DatasetHeaderButtons openDrawer={drawerDisclosure.onOpen} />
|
||||
</PageHeaderContainer>
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
|
||||
<HStack w="full">
|
||||
<VStack w="full">
|
||||
{fileUploads?.data?.map((upload) => (
|
||||
<FileUploadCard key={upload.id} fileUpload={upload} />
|
||||
))}
|
||||
</VStack>
|
||||
</HStack>
|
||||
<HStack w="full" justifyContent="flex-end">
|
||||
<FineTuneButton />
|
||||
<ImportDataButton />
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import {
|
||||
@@ -7,18 +6,18 @@ import {
|
||||
type CreateChatCompletionRequestMessage,
|
||||
} from "openai/resources/chat";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
import { shuffle } from "lodash-es";
|
||||
import archiver from "archiver";
|
||||
import { WritableStreamBuffer } from "stream-buffers";
|
||||
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
|
||||
import { error, success } from "~/utils/errorHandling/standardResponses";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { type TrainingRow, validateTrainingRows } from "~/components/datasets/validateTrainingRows";
|
||||
import { type TrainingRow } from "~/components/datasets/validateTrainingRows";
|
||||
import hashObject from "~/server/utils/hashObject";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { WritableStreamBuffer } from "stream-buffers";
|
||||
import { formatEntriesFromTrainingRows } from "~/server/utils/createEntriesFromTrainingRows";
|
||||
|
||||
export const datasetEntriesRouter = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
@@ -100,7 +99,6 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
})
|
||||
.optional(),
|
||||
loggedCallIds: z.string().array().optional(),
|
||||
jsonl: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
@@ -121,104 +119,48 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
return error("No datasetId or newDatasetParams provided");
|
||||
}
|
||||
|
||||
if (!input.loggedCallIds && !input.jsonl) {
|
||||
return error("No loggedCallIds or jsonl provided");
|
||||
if (!input.loggedCallIds) {
|
||||
return error("No loggedCallIds provided");
|
||||
}
|
||||
|
||||
let trainingRows: TrainingRow[];
|
||||
|
||||
if (input.loggedCallIds) {
|
||||
const loggedCalls = await prisma.loggedCall.findMany({
|
||||
where: {
|
||||
id: {
|
||||
in: input.loggedCallIds,
|
||||
},
|
||||
modelResponse: {
|
||||
isNot: null,
|
||||
const loggedCalls = await prisma.loggedCall.findMany({
|
||||
where: {
|
||||
id: {
|
||||
in: input.loggedCallIds,
|
||||
},
|
||||
modelResponse: {
|
||||
isNot: null,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponse: {
|
||||
select: {
|
||||
reqPayload: true,
|
||||
respPayload: true,
|
||||
inputTokens: true,
|
||||
outputTokens: true,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponse: {
|
||||
select: {
|
||||
reqPayload: true,
|
||||
respPayload: true,
|
||||
inputTokens: true,
|
||||
outputTokens: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: { createdAt: "desc" },
|
||||
});
|
||||
},
|
||||
orderBy: { createdAt: "desc" },
|
||||
});
|
||||
|
||||
trainingRows = loggedCalls.map((loggedCall) => {
|
||||
const inputMessages = (
|
||||
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
|
||||
).messages;
|
||||
let output: ChatCompletion.Choice.Message | undefined = undefined;
|
||||
const resp = loggedCall.modelResponse?.respPayload as unknown as
|
||||
| ChatCompletion
|
||||
| undefined;
|
||||
if (resp && resp.choices?.[0]) {
|
||||
output = resp.choices[0].message;
|
||||
}
|
||||
return {
|
||||
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
|
||||
output: output as unknown as CreateChatCompletionRequestMessage,
|
||||
};
|
||||
});
|
||||
} else {
|
||||
trainingRows = JSON.parse(input.jsonl as string) as TrainingRow[];
|
||||
const validationError = validateTrainingRows(trainingRows);
|
||||
if (validationError) {
|
||||
return error(`Invalid JSONL: ${validationError}`);
|
||||
const trainingRows = loggedCalls.map((loggedCall) => {
|
||||
const inputMessages = (
|
||||
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
|
||||
).messages;
|
||||
let output: ChatCompletion.Choice.Message | undefined = undefined;
|
||||
const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined;
|
||||
if (resp && resp.choices?.[0]) {
|
||||
output = resp.choices[0].message;
|
||||
}
|
||||
}
|
||||
return {
|
||||
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
|
||||
output: output as unknown as CreateChatCompletionRequestMessage,
|
||||
};
|
||||
});
|
||||
|
||||
const [existingTrainingCount, existingTestingCount] = await prisma.$transaction([
|
||||
prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
type: "TRAIN",
|
||||
},
|
||||
}),
|
||||
prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
type: "TEST",
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
|
||||
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
|
||||
const numTestingToAdd = trainingRows.length - numTrainingToAdd;
|
||||
const typesToAssign = shuffle([
|
||||
...Array(numTrainingToAdd).fill("TRAIN"),
|
||||
...Array(numTestingToAdd).fill("TEST"),
|
||||
]);
|
||||
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
|
||||
for (const row of trainingRows) {
|
||||
let outputTokens = 0;
|
||||
if (row.output) {
|
||||
outputTokens = countOpenAIChatTokens("gpt-4-0613", [
|
||||
row.output as unknown as ChatCompletion.Choice.Message,
|
||||
]);
|
||||
}
|
||||
datasetEntriesToCreate.push({
|
||||
datasetId: datasetId,
|
||||
input: row.input as unknown as Prisma.InputJsonValue,
|
||||
output: (row.output as unknown as Prisma.InputJsonValue) ?? {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
},
|
||||
inputTokens: countOpenAIChatTokens(
|
||||
"gpt-4-0613",
|
||||
row.input as unknown as CreateChatCompletionRequestMessage[],
|
||||
),
|
||||
outputTokens,
|
||||
type: typesToAssign.pop() as "TRAIN" | "TEST",
|
||||
});
|
||||
}
|
||||
const datasetEntriesToCreate = await formatEntriesFromTrainingRows(datasetId, trainingRows);
|
||||
|
||||
// Ensure dataset and dataset entries are created atomically
|
||||
await prisma.$transaction([
|
||||
@@ -239,7 +181,6 @@ export const datasetEntriesRouter = createTRPCRouter({
|
||||
|
||||
return success(datasetId);
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { z } from "zod";
|
||||
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
|
||||
import { success } from "~/utils/errorHandling/standardResponses";
|
||||
import { generateServiceClientUrl } from "~/utils/azure/server";
|
||||
import { queueImportDatasetEntries } from "~/server/tasks/importDatasetEntries.task";
|
||||
|
||||
export const datasetsRouter = createTRPCRouter({
|
||||
get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
@@ -94,4 +97,73 @@ export const datasetsRouter = createTRPCRouter({
|
||||
|
||||
return success("Dataset deleted");
|
||||
}),
|
||||
getServiceClientUrl: protectedProcedure
|
||||
.input(z.object({ projectId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
// The user must at least be authenticated to get a SAS token
|
||||
await requireCanModifyProject(input.projectId, ctx);
|
||||
return generateServiceClientUrl();
|
||||
}),
|
||||
triggerFileDownload: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
blobName: z.string(),
|
||||
fileName: z.string(),
|
||||
fileSize: z.number(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { projectId } = await prisma.dataset.findUniqueOrThrow({
|
||||
where: { id: input.datasetId },
|
||||
});
|
||||
await requireCanViewProject(projectId, ctx);
|
||||
|
||||
const { id } = await prisma.datasetFileUpload.create({
|
||||
data: {
|
||||
datasetId: input.datasetId,
|
||||
blobName: input.blobName,
|
||||
status: "PENDING",
|
||||
fileName: input.fileName,
|
||||
fileSize: input.fileSize,
|
||||
uploadedAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
await queueImportDatasetEntries(id);
|
||||
}),
|
||||
listFileUploads: protectedProcedure
|
||||
.input(z.object({ datasetId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
const { projectId } = await prisma.dataset.findUniqueOrThrow({
|
||||
where: { id: input.datasetId },
|
||||
});
|
||||
await requireCanViewProject(projectId, ctx);
|
||||
|
||||
return await prisma.datasetFileUpload.findMany({
|
||||
where: {
|
||||
datasetId: input.datasetId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: { createdAt: "desc" },
|
||||
});
|
||||
}),
|
||||
hideFileUpload: protectedProcedure
|
||||
.input(z.object({ fileUploadId: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { datasetId } = await prisma.datasetFileUpload.findUniqueOrThrow({
|
||||
where: { id: input.fileUploadId },
|
||||
});
|
||||
const { projectId } = await prisma.dataset.findUniqueOrThrow({
|
||||
where: { id: datasetId },
|
||||
});
|
||||
await requireCanModifyProject(projectId, ctx);
|
||||
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: input.fileUploadId },
|
||||
data: {
|
||||
visible: false,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
|
||||
132
app/src/server/tasks/importDatasetEntries.task.ts
Normal file
132
app/src/server/tasks/importDatasetEntries.task.ts
Normal file
@@ -0,0 +1,132 @@
|
||||
import { type DatasetFileUpload } from "@prisma/client";
|
||||
import { prisma } from "~/server/db";
|
||||
import defineTask from "./defineTask";
|
||||
import { downloadBlobToString } from "~/utils/azure/server";
|
||||
import {
|
||||
type TrainingRow,
|
||||
validateTrainingRows,
|
||||
parseJSONL,
|
||||
} from "~/components/datasets/validateTrainingRows";
|
||||
import { formatEntriesFromTrainingRows } from "~/server/utils/createEntriesFromTrainingRows";
|
||||
|
||||
export type ImportDatasetEntriesJob = {
|
||||
datasetFileUploadId: string;
|
||||
};
|
||||
|
||||
export const importDatasetEntries = defineTask<ImportDatasetEntriesJob>(
|
||||
"importDatasetEntries",
|
||||
async (task) => {
|
||||
const { datasetFileUploadId } = task;
|
||||
const datasetFileUpload = await prisma.datasetFileUpload.findUnique({
|
||||
where: { id: datasetFileUploadId },
|
||||
});
|
||||
if (!datasetFileUpload) {
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
errorMessage: "Dataset File Upload not found",
|
||||
status: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
status: "DOWNLOADING",
|
||||
progress: 5,
|
||||
},
|
||||
});
|
||||
|
||||
const jsonlStr = await downloadBlobToString(datasetFileUpload.blobName);
|
||||
const trainingRows = parseJSONL(jsonlStr) as TrainingRow[];
|
||||
const validationError = validateTrainingRows(trainingRows);
|
||||
if (validationError) {
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
errorMessage: `Invalid JSONL: ${validationError}`,
|
||||
status: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
status: "PROCESSING",
|
||||
progress: 30,
|
||||
},
|
||||
});
|
||||
|
||||
const updatePromises: Promise<DatasetFileUpload>[] = [];
|
||||
|
||||
const updateCallback = async (progress: number) => {
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
progress: 30 + Math.floor((progress / trainingRows.length) * 69),
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
let datasetEntriesToCreate;
|
||||
try {
|
||||
datasetEntriesToCreate = await formatEntriesFromTrainingRows(
|
||||
datasetFileUpload.datasetId,
|
||||
trainingRows,
|
||||
updateCallback,
|
||||
500,
|
||||
);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} catch (e: any) {
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
errorMessage: `Error formatting rows: ${e.message as string}`,
|
||||
status: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
await Promise.all(updatePromises);
|
||||
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
status: "SAVING",
|
||||
progress: 99,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.datasetEntry.createMany({
|
||||
data: datasetEntriesToCreate,
|
||||
});
|
||||
|
||||
await prisma.datasetFileUpload.update({
|
||||
where: { id: datasetFileUploadId },
|
||||
data: {
|
||||
status: "COMPLETE",
|
||||
progress: 100,
|
||||
},
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
export const queueImportDatasetEntries = async (datasetFileUploadId: string) => {
|
||||
await Promise.all([
|
||||
prisma.datasetFileUpload.update({
|
||||
where: {
|
||||
id: datasetFileUploadId,
|
||||
},
|
||||
data: {
|
||||
errorMessage: null,
|
||||
status: "PENDING",
|
||||
},
|
||||
}),
|
||||
|
||||
importDatasetEntries.enqueue({ datasetFileUploadId }),
|
||||
]);
|
||||
};
|
||||
@@ -5,10 +5,11 @@ import "../../../sentry.server.config";
|
||||
import { env } from "~/env.mjs";
|
||||
import { queryModel } from "./queryModel.task";
|
||||
import { runNewEval } from "./runNewEval.task";
|
||||
import { importDatasetEntries } from "./importDatasetEntries.task";
|
||||
|
||||
console.log("Starting worker");
|
||||
|
||||
const registeredTasks = [queryModel, runNewEval];
|
||||
const registeredTasks = [queryModel, runNewEval, importDatasetEntries];
|
||||
|
||||
const taskList = registeredTasks.reduce((acc, task) => {
|
||||
acc[task.task.identifier] = task.task.handler;
|
||||
|
||||
70
app/src/server/utils/createEntriesFromTrainingRows.ts
Normal file
70
app/src/server/utils/createEntriesFromTrainingRows.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { shuffle } from "lodash-es";
|
||||
import {
|
||||
type CreateChatCompletionRequestMessage,
|
||||
type ChatCompletion,
|
||||
} from "openai/resources/chat";
|
||||
|
||||
import { prisma } from "~/server/db";
|
||||
import { type TrainingRow } from "~/components/datasets/validateTrainingRows";
|
||||
import { countLlamaChatTokens } from "~/utils/countTokens";
|
||||
|
||||
export const formatEntriesFromTrainingRows = async (
|
||||
datasetId: string,
|
||||
trainingRows: TrainingRow[],
|
||||
updateCallback?: (progress: number) => Promise<void>,
|
||||
updateFrequency = 1000,
|
||||
) => {
|
||||
const [dataset, existingTrainingCount, existingTestingCount] = await prisma.$transaction([
|
||||
prisma.dataset.findUnique({ where: { id: datasetId } }),
|
||||
prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
type: "TRAIN",
|
||||
},
|
||||
}),
|
||||
prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
type: "TEST",
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
const trainingRatio = dataset?.trainingRatio ?? 0.8;
|
||||
|
||||
const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
|
||||
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
|
||||
const numTestingToAdd = trainingRows.length - numTrainingToAdd;
|
||||
const typesToAssign = shuffle([
|
||||
...Array(numTrainingToAdd).fill("TRAIN"),
|
||||
...Array(numTestingToAdd).fill("TEST"),
|
||||
]);
|
||||
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
|
||||
let i = 0;
|
||||
for (const row of trainingRows) {
|
||||
// console.log(row);
|
||||
if (updateCallback && i % updateFrequency === 0) await updateCallback(i);
|
||||
let outputTokens = 0;
|
||||
if (row.output) {
|
||||
outputTokens = countLlamaChatTokens([row.output as unknown as ChatCompletion.Choice.Message]);
|
||||
}
|
||||
// console.log("outputTokens", outputTokens);
|
||||
datasetEntriesToCreate.push({
|
||||
datasetId: datasetId,
|
||||
input: row.input as unknown as Prisma.InputJsonValue,
|
||||
output: (row.output as unknown as Prisma.InputJsonValue) ?? {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
},
|
||||
inputTokens: countLlamaChatTokens(
|
||||
row.input as unknown as CreateChatCompletionRequestMessage[],
|
||||
),
|
||||
outputTokens,
|
||||
type: typesToAssign.pop() as "TRAIN" | "TEST",
|
||||
});
|
||||
i++;
|
||||
}
|
||||
|
||||
return datasetEntriesToCreate;
|
||||
};
|
||||
71
app/src/utils/azure/server.ts
Normal file
71
app/src/utils/azure/server.ts
Normal file
@@ -0,0 +1,71 @@
|
||||
import {
|
||||
BlobServiceClient,
|
||||
generateAccountSASQueryParameters,
|
||||
AccountSASPermissions,
|
||||
AccountSASServices,
|
||||
AccountSASResourceTypes,
|
||||
StorageSharedKeyCredential,
|
||||
SASProtocol,
|
||||
} from "@azure/storage-blob";
|
||||
import { DefaultAzureCredential } from "@azure/identity";
|
||||
|
||||
const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME;
|
||||
if (!accountName) throw Error("Azure Storage accountName not found");
|
||||
const accountKey = process.env.AZURE_STORAGE_ACCOUNT_KEY;
|
||||
if (!accountKey) throw Error("Azure Storage accountKey not found");
|
||||
const containerName = process.env.AZURE_STORAGE_CONTAINER_NAME;
|
||||
if (!containerName) throw Error("Azure Storage containerName not found");
|
||||
|
||||
const sharedKeyCredential = new StorageSharedKeyCredential(accountName, accountKey);
|
||||
|
||||
const blobServiceClient = new BlobServiceClient(
|
||||
`https://${accountName}.blob.core.windows.net`,
|
||||
new DefaultAzureCredential(),
|
||||
);
|
||||
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName);
|
||||
|
||||
export const generateServiceClientUrl = () => {
|
||||
const sasOptions = {
|
||||
services: AccountSASServices.parse("b").toString(), // blobs
|
||||
resourceTypes: AccountSASResourceTypes.parse("sco").toString(), // service, container, object
|
||||
permissions: AccountSASPermissions.parse("w"), // write permissions
|
||||
protocol: SASProtocol.Https,
|
||||
startsOn: new Date(),
|
||||
expiresOn: new Date(new Date().valueOf() + 10 * 60 * 1000), // 10 minutes
|
||||
};
|
||||
let sasToken = generateAccountSASQueryParameters(sasOptions, sharedKeyCredential).toString();
|
||||
|
||||
// remove leading "?"
|
||||
sasToken = sasToken[0] === "?" ? sasToken.substring(1) : sasToken;
|
||||
return {
|
||||
serviceClientUrl: `https://${accountName}.blob.core.windows.net?${sasToken}`,
|
||||
containerName,
|
||||
};
|
||||
};
|
||||
|
||||
export async function downloadBlobToString(blobName: string) {
|
||||
const blobClient = containerClient.getBlobClient(blobName);
|
||||
|
||||
const downloadResponse = await blobClient.download();
|
||||
|
||||
if (!downloadResponse) throw Error("error downloading blob");
|
||||
if (!downloadResponse.readableStreamBody)
|
||||
throw Error("downloadResponse.readableStreamBody not found");
|
||||
|
||||
const downloaded = await streamToBuffer(downloadResponse.readableStreamBody);
|
||||
return downloaded.toString();
|
||||
}
|
||||
|
||||
async function streamToBuffer(readableStream: NodeJS.ReadableStream): Promise<Buffer> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const chunks: Uint8Array[] = [];
|
||||
readableStream.on("data", (data: ArrayBuffer) => {
|
||||
chunks.push(data instanceof Buffer ? data : Buffer.from(data));
|
||||
});
|
||||
readableStream.on("end", () => {
|
||||
resolve(Buffer.concat(chunks));
|
||||
});
|
||||
readableStream.on("error", reject);
|
||||
});
|
||||
}
|
||||
30
app/src/utils/azure/website.ts
Normal file
30
app/src/utils/azure/website.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { BlobServiceClient } from "@azure/storage-blob";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
export const uploadDatasetEntryFile = async (file: File) => {
|
||||
const { selectedProjectId: projectId, api } = useAppStore.getState();
|
||||
if (!projectId) throw Error("projectId not found");
|
||||
if (!api) throw Error("api not initialized");
|
||||
const { serviceClientUrl, containerName } = await api.client.datasets.getServiceClientUrl.query({
|
||||
projectId,
|
||||
});
|
||||
|
||||
const blobServiceClient = new BlobServiceClient(serviceClientUrl);
|
||||
// create container client
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName);
|
||||
|
||||
// base name without extension
|
||||
const basename = file.name.split("/").pop()?.split(".").shift();
|
||||
if (!basename) throw Error("basename not found");
|
||||
|
||||
const blobName = `${basename}-${uuidv4()}.jsonl`;
|
||||
// create blob client
|
||||
const blobClient = containerClient.getBlockBlobClient(blobName);
|
||||
|
||||
// upload file
|
||||
await blobClient.uploadData(file);
|
||||
|
||||
return blobName;
|
||||
};
|
||||
@@ -1,5 +1,7 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { GPTTokens } from "gpt-tokens";
|
||||
import llamaTokenizer from "llama-tokenizer-js";
|
||||
|
||||
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
|
||||
|
||||
interface GPTTokensMessageItem {
|
||||
@@ -22,3 +24,11 @@ export const countOpenAIChatTokens = (
|
||||
messages: reformattedMessages as unknown as GPTTokensMessageItem[],
|
||||
}).usedTokens;
|
||||
};
|
||||
|
||||
export const countLlamaChatTokens = (messages: ChatCompletion.Choice.Message[]) => {
|
||||
const stringToTokenize = messages
|
||||
.map((message) => message.content || JSON.stringify(message.function_call))
|
||||
.join("\n");
|
||||
const tokens = llamaTokenizer.encode(stringToTokenize);
|
||||
return tokens.length;
|
||||
};
|
||||
|
||||
@@ -52,3 +52,18 @@ export const parseableToFunctionCall = (str: string) => {
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
export const formatFileSize = (bytes: number, decimals = 2) => {
|
||||
if (bytes === 0) return "0 Bytes";
|
||||
|
||||
const k = 1024;
|
||||
const dm = decimals < 0 ? 0 : decimals;
|
||||
const sizes = ["Bytes", "KB", "MB", "GB", "TB"];
|
||||
|
||||
for (const size of sizes) {
|
||||
if (bytes < k) return `${parseFloat(bytes.toFixed(dm))} ${size}`;
|
||||
bytes /= k;
|
||||
}
|
||||
|
||||
return "> 1024 TB";
|
||||
};
|
||||
|
||||
@@ -19,7 +19,9 @@
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"~/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"typeRoots": ["./types", "./node_modules/@types"],
|
||||
"types": ["llama-tokenizer-js", "node"]
|
||||
},
|
||||
"include": [
|
||||
".eslintrc.cjs",
|
||||
|
||||
4
app/types/llama-tokenizer-js/index.d.ts
vendored
Normal file
4
app/types/llama-tokenizer-js/index.d.ts
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
declare module "llama-tokenizer-js" {
|
||||
export function encode(input: string): number[];
|
||||
export function decode(input: number[]): string;
|
||||
}
|
||||
Reference in New Issue
Block a user