Export formatted logged calls (#187)
* Export formatted data * Properly update inputMessageHashMap * Hide remove duplicates checkbox in advanced options * Remove unused import
This commit is contained in:
@@ -48,6 +48,7 @@
|
||||
"@trpc/react-query": "^10.26.0",
|
||||
"@trpc/server": "^10.26.0",
|
||||
"@vercel/og": "^0.5.9",
|
||||
"archiver": "^6.0.0",
|
||||
"ast-types": "^0.14.2",
|
||||
"chroma-js": "^2.4.2",
|
||||
"concurrently": "^8.2.0",
|
||||
@@ -99,6 +100,7 @@
|
||||
"replicate": "^0.12.3",
|
||||
"socket.io": "^4.7.1",
|
||||
"socket.io-client": "^4.7.1",
|
||||
"stream-buffers": "^3.0.2",
|
||||
"superjson": "1.12.2",
|
||||
"trpc-openapi": "^1.2.0",
|
||||
"tsx": "^3.12.7",
|
||||
@@ -111,6 +113,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
||||
"@types/archiver": "^5.3.2",
|
||||
"@types/babel__core": "^7.20.1",
|
||||
"@types/babel__standalone": "^7.1.4",
|
||||
"@types/chroma-js": "^2.4.0",
|
||||
@@ -127,6 +130,7 @@
|
||||
"@types/react": "^18.2.6",
|
||||
"@types/react-dom": "^18.2.4",
|
||||
"@types/react-syntax-highlighter": "^15.5.7",
|
||||
"@types/stream-buffers": "^3.0.4",
|
||||
"@types/uuid": "^9.0.2",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.6",
|
||||
"@typescript-eslint/parser": "^5.59.6",
|
||||
|
||||
197
app/src/components/requestLogs/ExportButton.tsx
Normal file
197
app/src/components/requestLogs/ExportButton.tsx
Normal file
@@ -0,0 +1,197 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import {
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalCloseButton,
|
||||
ModalBody,
|
||||
ModalFooter,
|
||||
HStack,
|
||||
VStack,
|
||||
Icon,
|
||||
Text,
|
||||
Button,
|
||||
Checkbox,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
NumberInputStepper,
|
||||
NumberIncrementStepper,
|
||||
NumberDecrementStepper,
|
||||
Collapse,
|
||||
useDisclosure,
|
||||
type UseDisclosureReturn,
|
||||
} from "@chakra-ui/react";
|
||||
import { BiExport } from "react-icons/bi";
|
||||
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import ActionButton from "./ActionButton";
|
||||
import InputDropdown from "../InputDropdown";
|
||||
import { FiChevronUp, FiChevronDown } from "react-icons/fi";
|
||||
|
||||
const SUPPORTED_EXPORT_FORMATS = ["alpaca-finetune", "openai-fine-tune", "unformatted"];
|
||||
|
||||
const ExportButton = () => {
|
||||
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
|
||||
|
||||
const disclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
<>
|
||||
<ActionButton
|
||||
onClick={disclosure.onOpen}
|
||||
label="Export"
|
||||
icon={BiExport}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
/>
|
||||
<ExportLogsModal disclosure={disclosure} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ExportButton;
|
||||
|
||||
const ExportLogsModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
||||
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
|
||||
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
|
||||
const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds);
|
||||
|
||||
const [selectedExportFormat, setSelectedExportFormat] = useState(SUPPORTED_EXPORT_FORMATS[0]);
|
||||
const [testingSplit, setTestingSplit] = useState(10);
|
||||
const [removeDuplicates, setRemoveDuplicates] = useState(true);
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (disclosure.isOpen) {
|
||||
setSelectedExportFormat(SUPPORTED_EXPORT_FORMATS[0]);
|
||||
setTestingSplit(10);
|
||||
setRemoveDuplicates(true);
|
||||
}
|
||||
}, [disclosure.isOpen]);
|
||||
|
||||
const exportLogsMutation = api.loggedCalls.export.useMutation();
|
||||
|
||||
const [exportLogs, exportInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!selectedProjectId || !selectedLogIds.size || !testingSplit || !selectedExportFormat)
|
||||
return;
|
||||
const response = await exportLogsMutation.mutateAsync({
|
||||
projectId: selectedProjectId,
|
||||
selectedLogIds: Array.from(selectedLogIds),
|
||||
testingSplit,
|
||||
selectedExportFormat,
|
||||
removeDuplicates,
|
||||
});
|
||||
|
||||
const dataUrl = `data:application/pdf;base64,${response}`;
|
||||
const blob = await fetch(dataUrl).then((res) => res.blob());
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
|
||||
a.href = url;
|
||||
a.download = `data.zip`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
|
||||
disclosure.onClose();
|
||||
clearSelectedLogIds();
|
||||
}, [
|
||||
exportLogsMutation,
|
||||
selectedProjectId,
|
||||
selectedLogIds,
|
||||
testingSplit,
|
||||
selectedExportFormat,
|
||||
removeDuplicates,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={BiExport} />
|
||||
<Text>Export Logs</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
|
||||
<Text>
|
||||
We'll export the <b>{selectedLogIds.size}</b> logs you have selected in the format of
|
||||
your choice.
|
||||
</Text>
|
||||
<VStack alignItems="flex-start" spacing={4}>
|
||||
<HStack spacing={2}>
|
||||
<Text fontWeight="bold" w={48}>
|
||||
Format:
|
||||
</Text>
|
||||
<InputDropdown
|
||||
options={SUPPORTED_EXPORT_FORMATS}
|
||||
selectedOption={selectedExportFormat}
|
||||
onSelect={(option) => setSelectedExportFormat(option)}
|
||||
inputGroupProps={{ w: 48 }}
|
||||
/>
|
||||
</HStack>
|
||||
<HStack spacing={2}>
|
||||
<Text fontWeight="bold" w={48}>
|
||||
Testing Split:
|
||||
</Text>
|
||||
<HStack>
|
||||
<NumberInput
|
||||
defaultValue={10}
|
||||
onChange={(_, num) => setTestingSplit(num)}
|
||||
min={1}
|
||||
max={100}
|
||||
w={48}
|
||||
>
|
||||
<NumberInputField />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
</HStack>
|
||||
</HStack>
|
||||
</VStack>
|
||||
<VStack alignItems="flex-start" spacing={0}>
|
||||
<Button
|
||||
variant="unstyled"
|
||||
color="blue.600"
|
||||
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
||||
>
|
||||
<HStack>
|
||||
<Text>Advanced Options</Text>
|
||||
<Icon as={showAdvancedOptions ? FiChevronUp : FiChevronDown} />
|
||||
</HStack>
|
||||
</Button>
|
||||
<Collapse in={showAdvancedOptions} unmountOnExit={true}>
|
||||
<VStack align="stretch" pt={4}>
|
||||
<Checkbox
|
||||
colorScheme="blue"
|
||||
isChecked={removeDuplicates}
|
||||
onChange={(e) => setRemoveDuplicates(e.target.checked)}
|
||||
>
|
||||
<Text>Remove duplicates? (recommended)</Text>
|
||||
</Checkbox>
|
||||
</VStack>
|
||||
</Collapse>
|
||||
</VStack>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<HStack>
|
||||
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="blue" onClick={exportLogs} isLoading={exportInProgress} minW={24}>
|
||||
Export
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
@@ -11,6 +11,7 @@ import { FiFilter } from "react-icons/fi";
|
||||
import LogFilters from "~/components/requestLogs/LogFilters/LogFilters";
|
||||
import ColumnVisiblityDropdown from "~/components/requestLogs/ColumnVisiblityDropdown";
|
||||
import FineTuneButton from "~/components/requestLogs/FineTuneButton";
|
||||
import ExportButton from "~/components/requestLogs/ExportButton";
|
||||
|
||||
export default function LoggedCalls() {
|
||||
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
|
||||
@@ -35,6 +36,7 @@ export default function LoggedCalls() {
|
||||
icon={RiFlaskLine}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
/>
|
||||
<ExportButton />
|
||||
<ColumnVisiblityDropdown />
|
||||
<ActionButton
|
||||
onClick={() => {
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import { z } from "zod";
|
||||
import { type Expression, type SqlBool, sql, type RawBuilder } from "kysely";
|
||||
import { jsonArrayFrom } from "kysely/helpers/postgres";
|
||||
import archiver from "archiver";
|
||||
import { WritableStreamBuffer } from "stream-buffers";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { shuffle } from "lodash-es";
|
||||
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { kysely, prisma } from "~/server/db";
|
||||
import { comparators, defaultFilterableFields } from "~/state/logFiltersSlice";
|
||||
import { requireCanViewProject } from "~/utils/accessControl";
|
||||
import hashObject from "~/server/utils/hashObject";
|
||||
|
||||
// create comparator type based off of comparators
|
||||
const comparatorToSqlExpression = (comparator: (typeof comparators)[number], value: string) => {
|
||||
@@ -180,4 +185,101 @@ export const loggedCallsRouter = createTRPCRouter({
|
||||
|
||||
return tags.map((tag) => tag.name);
|
||||
}),
|
||||
export: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
projectId: z.string(),
|
||||
selectedLogIds: z.string().array(),
|
||||
testingSplit: z.number(),
|
||||
selectedExportFormat: z.string(),
|
||||
removeDuplicates: z.boolean(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewProject(input.projectId, ctx);
|
||||
|
||||
// Fetch the real data using Prisma
|
||||
const loggedCallsFromDb = await ctx.prisma.loggedCallModelResponse.findMany({
|
||||
where: {
|
||||
originalLoggedCall: {
|
||||
projectId: input.projectId,
|
||||
id: { in: input.selectedLogIds },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Convert the database data into the desired format
|
||||
let formattedLoggedCalls: { input: JsonValue[]; output: JsonValue }[] = loggedCallsFromDb.map(
|
||||
(call) => ({
|
||||
input: (call.reqPayload as unknown as Record<string, unknown>).messages as JsonValue[],
|
||||
output: (call.respPayload as unknown as { choices: { message: unknown }[] }).choices[0]
|
||||
?.message as JsonValue,
|
||||
}),
|
||||
);
|
||||
|
||||
if (input.removeDuplicates) {
|
||||
const deduplicatedLoggedCalls = [];
|
||||
const loggedCallHashSet = new Set<string>();
|
||||
for (const loggedCall of formattedLoggedCalls) {
|
||||
const loggedCallHash = hashObject(loggedCall);
|
||||
if (!loggedCallHashSet.has(loggedCallHash)) {
|
||||
loggedCallHashSet.add(loggedCallHash);
|
||||
deduplicatedLoggedCalls.push(loggedCall);
|
||||
}
|
||||
}
|
||||
formattedLoggedCalls = deduplicatedLoggedCalls;
|
||||
}
|
||||
|
||||
// Remove duplicate messages from input
|
||||
const inputMessageHashMap = new Map<string, number>();
|
||||
for (const loggedCall of formattedLoggedCalls) {
|
||||
for (const message of loggedCall.input) {
|
||||
const hash = hashObject(message);
|
||||
if (inputMessageHashMap.has(hash)) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
inputMessageHashMap.set(hash, inputMessageHashMap.get(hash)! + 1);
|
||||
} else {
|
||||
inputMessageHashMap.set(hash, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const loggedCall of formattedLoggedCalls) {
|
||||
loggedCall.input = loggedCall.input.filter((message) => {
|
||||
const hash = hashObject(message);
|
||||
// If the same message appears in a single input multiple times, there is some danger of
|
||||
// it being removed from all logged calls. This is enough of an edge case that we don't
|
||||
// need to worry about it for now.
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
return inputMessageHashMap.get(hash)! < formattedLoggedCalls.length;
|
||||
});
|
||||
}
|
||||
|
||||
// Stringify inputs and outputs
|
||||
const stringifiedLoggedCalls = shuffle(formattedLoggedCalls).map((loggedCall) => ({
|
||||
input: JSON.stringify(loggedCall.input),
|
||||
output: JSON.stringify(loggedCall.output),
|
||||
}));
|
||||
|
||||
const splitIndex = Math.floor((stringifiedLoggedCalls.length * input.testingSplit) / 100);
|
||||
|
||||
const testingData = stringifiedLoggedCalls.slice(0, splitIndex);
|
||||
const trainingData = stringifiedLoggedCalls.slice(splitIndex);
|
||||
|
||||
// Convert arrays to JSONL format
|
||||
const trainingDataJSONL = trainingData.map((item) => JSON.stringify(item)).join("\n");
|
||||
const testingDataJSONL = testingData.map((item) => JSON.stringify(item)).join("\n");
|
||||
|
||||
const output = new WritableStreamBuffer();
|
||||
const archive = archiver("zip");
|
||||
|
||||
archive.pipe(output);
|
||||
archive.append(trainingDataJSONL, { name: "train.jsonl" });
|
||||
archive.append(testingDataJSONL, { name: "test.jsonl" });
|
||||
await archive.finalize();
|
||||
|
||||
// Convert buffer to base64
|
||||
const base64 = output.getContents().toString("base64");
|
||||
|
||||
return base64;
|
||||
}),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user