Compare commits
1 Commits
publish-py
...
arcticfly-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
092b48552d |
14
.github/ISSUE_TEMPLATE/sweep-fast-template.yml
vendored
@@ -1,14 +0,0 @@
|
||||
name: Sweep Fast Issue
|
||||
title: 'Sweep (fast): '
|
||||
description: For few-line fixes to be handled by Sweep, an AI-powered junior developer. Sweep will use GPT-3.5 to quickly create a PR for very small changes.
|
||||
labels: sweep
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Details
|
||||
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
|
||||
placeholder: |
|
||||
Bugs: The bug might be in ... file. Here are the logs: ...
|
||||
Features: the new endpoint should use the ... class from ... file because it contains ... logic.
|
||||
Refactors: We are migrating this function to ... version because ...
|
||||
14
.github/ISSUE_TEMPLATE/sweep-slow-template.yml
vendored
@@ -1,14 +0,0 @@
|
||||
name: Sweep Slow Issue
|
||||
title: 'Sweep (slow): '
|
||||
description: For larger bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer. Sweep will perform a deeper search and more self-reviews but will take longer.
|
||||
labels: sweep
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Details
|
||||
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
|
||||
placeholder: |
|
||||
Bugs: The bug might be in ... file. Here are the logs: ...
|
||||
Features: the new endpoint should use the ... class from ... file because it contains ... logic.
|
||||
Refactors: We are migrating this function to ... version because ...
|
||||
14
.github/ISSUE_TEMPLATE/sweep-template.yml
vendored
@@ -1,14 +0,0 @@
|
||||
name: Sweep Issue
|
||||
title: 'Sweep: '
|
||||
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
|
||||
labels: sweep
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Details
|
||||
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
|
||||
placeholder: |
|
||||
Bugs: The bug might be in ... file. Here are the logs: ...
|
||||
Features: the new endpoint should use the ... class from ... file because it contains ... logic.
|
||||
Refactors: We are migrating this function to ... version because ...
|
||||
15
README.md
@@ -16,7 +16,6 @@
|
||||
<a href='http://makeapullrequest.com'><img alt='PRs Welcome' src='https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square'/></a>
|
||||
<a href="https://github.com/openpipe/openpipe/graphs/commit-activity"><img alt="GitHub commit activity" src="https://img.shields.io/github/commit-activity/m/openpipe/openpipe?style=flat-square"/></a>
|
||||
<a href="https://github.com/openpipe/openpipe/issues"><img alt="GitHub closed issues" src="https://img.shields.io/github/issues-closed/openpipe/openpipe?style=flat-square"/></a>
|
||||
<img src="https://img.shields.io/badge/Y%20Combinator-S23-orange?style=flat-square" alt="Y Combinator S23">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -24,21 +23,21 @@
|
||||
</p>
|
||||
|
||||
<br>
|
||||
Use powerful but expensive LLMs to fine-tune smaller and cheaper models suited to your exact needs. Evaluate model and prompt combinations in the playground. Query your past requests and export optimized training data. Try it out at https://app.openpipe.ai or <a href="#running-locally">run it locally</a>.
|
||||
Use powerful but expensive LLMs to fine-tune smaller and cheaper models suited to your exact needs. Evaluate model and prompt combinations in the playground. Query your past requests and export optimized training data.
|
||||
<br>
|
||||
|
||||
|
||||
## Features
|
||||
## 🪛 Features
|
||||
|
||||
* <b>Fine-Tune</b>
|
||||
* Easy integration with OpenPipe's SDK in both Python and JS.
|
||||
* Swiftly query logs using intuitive built-in filters.
|
||||
* Export data in multiple training formats, including Alpaca and ChatGPT, with deduplication.
|
||||
|
||||
* <b>Experiment</b>
|
||||
* Bulk-test wide-reaching scenarios using code templating.
|
||||
* Seamlessly translate prompts across different model APIs.
|
||||
* Tap into autogenerated scenarios for fresh test perspectives.
|
||||
|
||||
* <b>Fine-Tune (Beta)</b>
|
||||
* Easy integration with OpenPipe's SDK in both Python and JS.
|
||||
* Swiftly query logs using intuitive built-in filters.
|
||||
* Export data in multiple training formats, including Alpaca and ChatGPT, with deduplication.
|
||||
|
||||
<img src="https://github.com/openpipe/openpipe/assets/41524992/eaa8b92d-4536-4f63-bbef-4b0b1a60f6b5" alt="fine-tune demo">
|
||||
|
||||
|
||||
@@ -79,8 +79,7 @@
|
||||
"nextjs-routes": "^2.0.1",
|
||||
"nodemailer": "^6.9.4",
|
||||
"openai": "4.0.0-beta.7",
|
||||
"openpipe": "^0.3.0",
|
||||
"openpipe-dev": "workspace:^",
|
||||
"openpipe": "workspace:*",
|
||||
"pg": "^8.11.2",
|
||||
"pluralize": "^8.0.0",
|
||||
"posthog-js": "^1.75.3",
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import { prisma } from "~/server/db";
|
||||
|
||||
// delete most recent fineTune
|
||||
const mostRecentFineTune = await prisma.fineTune.findFirst({
|
||||
orderBy: { createdAt: "desc" },
|
||||
});
|
||||
|
||||
if (mostRecentFineTune) {
|
||||
await prisma.fineTune.delete({
|
||||
where: { id: mostRecentFineTune.id },
|
||||
});
|
||||
}
|
||||
@@ -80,7 +80,7 @@ const MODEL_RESPONSE_TEMPLATES: {
|
||||
},
|
||||
respStatus: 200,
|
||||
respPayload: {
|
||||
id: "chatcmpl-7",
|
||||
id: "chatcmpl-7lNspqePJWVyXwXebupxb1eMozo6Q",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
usage: {
|
||||
total_tokens: 241,
|
||||
@@ -108,7 +108,7 @@ const MODEL_RESPONSE_TEMPLATES: {
|
||||
inputTokens: 236,
|
||||
outputTokens: 5,
|
||||
finishReason: "stop",
|
||||
tags: [{ name: "prompt_id", value: "define_func" }],
|
||||
tags: [],
|
||||
},
|
||||
{
|
||||
reqPayload: {
|
||||
@@ -167,7 +167,7 @@ const MODEL_RESPONSE_TEMPLATES: {
|
||||
},
|
||||
respStatus: 200,
|
||||
respPayload: {
|
||||
id: "chatcmpl-7",
|
||||
id: "chatcmpl-7lNifmc5AncyAvleZRDBhAcLFYBIT",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
usage: {
|
||||
total_tokens: 227,
|
||||
@@ -210,7 +210,7 @@ const MODEL_RESPONSE_TEMPLATES: {
|
||||
},
|
||||
respStatus: 200,
|
||||
respPayload: {
|
||||
id: "chatcmpl-7",
|
||||
id: "chatcmpl-7lNh1TtrsJVgz3Nj70bKkZZk7xPi7",
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
usage: {
|
||||
total_tokens: 21,
|
||||
@@ -234,7 +234,7 @@ const MODEL_RESPONSE_TEMPLATES: {
|
||||
inputTokens: 14,
|
||||
outputTokens: 7,
|
||||
finishReason: "stop",
|
||||
tags: [{ name: "prompt_id", value: "translate_text" }],
|
||||
tags: [{ name: "prompt_id", value: "id2" }],
|
||||
},
|
||||
{
|
||||
reqPayload: {
|
||||
@@ -281,7 +281,7 @@ const MODEL_RESPONSE_TEMPLATES: {
|
||||
},
|
||||
respStatus: 200,
|
||||
respPayload: {
|
||||
id: "chatcmpl-7",
|
||||
id: "chatcmpl-7lQS3MktOT8BTgNEytl9dkyssCQqL",
|
||||
model: "gpt-4-0613",
|
||||
usage: {
|
||||
total_tokens: 2910,
|
||||
@@ -311,7 +311,7 @@ const MODEL_RESPONSE_TEMPLATES: {
|
||||
outputTokens: 108,
|
||||
finishReason: "stop",
|
||||
tags: [
|
||||
{ name: "prompt_id", value: "chatcmpl-7" },
|
||||
{ name: "prompt_id", value: "chatcmpl-7lQS3MktOT8BTgNEytl9dkyssCQqL" },
|
||||
{ name: "some_other_tag", value: "some_other_value" },
|
||||
],
|
||||
},
|
||||
@@ -339,7 +339,7 @@ const loggedCallsToCreate: Prisma.LoggedCallCreateManyInput[] = [];
|
||||
const loggedCallModelResponsesToCreate: Prisma.LoggedCallModelResponseCreateManyInput[] = [];
|
||||
const loggedCallsToUpdate: Prisma.LoggedCallUpdateArgs[] = [];
|
||||
const loggedCallTagsToCreate: Prisma.LoggedCallTagCreateManyInput[] = [];
|
||||
for (let i = 0; i < 11437; i++) {
|
||||
for (let i = 0; i < 1437; i++) {
|
||||
const loggedCallId = uuidv4();
|
||||
const loggedCallModelResponseId = uuidv4();
|
||||
const template =
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useState, useMemo, useCallback } from "react";
|
||||
import {
|
||||
Button,
|
||||
HStack,
|
||||
@@ -15,18 +14,16 @@ import {
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { isString } from "lodash-es";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { useState } from "react";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
|
||||
import { type ProviderModel } from "~/modelProviders/types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import { lookupModel, modelLabel } from "~/utils/utils";
|
||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||
import { ModelSearch } from "./ModelSearch";
|
||||
import { ModelStatsCard } from "./ModelStatsCard";
|
||||
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
export const ChangeModelModal = ({
|
||||
variant,
|
||||
@@ -35,43 +32,48 @@ export const ChangeModelModal = ({
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const editorOptionsMap = useAppStore((s) => s.sharedVariantEditor.editorOptionsMap);
|
||||
const originalPromptFn = useMemo(
|
||||
() => editorOptionsMap[variant.uiId]?.getContent() || "",
|
||||
[editorOptionsMap, variant.uiId],
|
||||
);
|
||||
|
||||
const originalModel = lookupModel(variant.modelProvider, variant.model);
|
||||
const [selectedModel, setSelectedModel] = useState({
|
||||
provider: variant.modelProvider,
|
||||
model: variant.model,
|
||||
} as ProviderModel);
|
||||
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
||||
const [modifiedPromptFn, setModifiedPromptFn] = useState<string>();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const utils = api.useContext();
|
||||
|
||||
const experiment = useExperiment();
|
||||
|
||||
const { mutateAsync: getModifiedPromptMutateAsync } =
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment) return;
|
||||
|
||||
const resp = await getModifiedPromptMutateAsync({
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
originalPromptFn,
|
||||
newModel: selectedModel,
|
||||
});
|
||||
if (maybeReportError(resp)) return;
|
||||
setModifiedPromptFn(resp.payload);
|
||||
setConvertedModel(selectedModel);
|
||||
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
|
||||
|
||||
const replaceVariant = useCallback(() => {
|
||||
if (!modifiedPromptFn) return;
|
||||
editorOptionsMap[variant.uiId]?.setContent(modifiedPromptFn);
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
|
||||
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (
|
||||
!variant.experimentId ||
|
||||
!modifiedPromptFn ||
|
||||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
|
||||
)
|
||||
return;
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
promptConstructor: modifiedPromptFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [variant.uiId, editorOptionsMap, onClose, modifiedPromptFn]);
|
||||
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||
|
||||
const originalLabel = modelLabel(variant.modelProvider, variant.model);
|
||||
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
|
||||
@@ -128,9 +130,9 @@ export const ChangeModelModal = ({
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
isDisabled={!convertedModel || modificationInProgress}
|
||||
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
|
||||
>
|
||||
Accept
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
|
||||
@@ -1,41 +1,74 @@
|
||||
import { Button, Icon, useDisclosure, Text } from "@chakra-ui/react";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsTrash } from "react-icons/bs";
|
||||
import {
|
||||
Button,
|
||||
Icon,
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogContent,
|
||||
AlertDialogOverlay,
|
||||
useDisclosure,
|
||||
Text,
|
||||
} from "@chakra-ui/react";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useRef } from "react";
|
||||
import { BsTrash } from "react-icons/bs";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { api } from "~/utils/api";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import DeleteExperimentDialog from "../experiments/DeleteExperimentDialog";
|
||||
|
||||
export const DeleteButton = () => {
|
||||
const experiment = useExperiment();
|
||||
const mutation = api.experiments.delete.useMutation();
|
||||
const utils = api.useContext();
|
||||
const router = useRouter();
|
||||
|
||||
const disclosure = useDisclosure();
|
||||
|
||||
const closeDrawer = useAppStore((s) => s.closeDrawer);
|
||||
const [onDelete] = useHandledAsyncCallback(async () => {
|
||||
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||
if (!experiment.data?.id) return;
|
||||
await mutation.mutateAsync({ id: experiment.data.id });
|
||||
await utils.experiments.list.invalidate();
|
||||
await router.push({ pathname: "/experiments" });
|
||||
closeDrawer();
|
||||
}, [router, closeDrawer]);
|
||||
|
||||
onClose();
|
||||
}, [mutation, experiment.data?.id, router]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
colorScheme="red"
|
||||
fontWeight="normal"
|
||||
onClick={disclosure.onOpen}
|
||||
>
|
||||
<Button size="sm" variant="ghost" colorScheme="red" fontWeight="normal" onClick={onOpen}>
|
||||
<Icon as={BsTrash} boxSize={4} />
|
||||
<Text ml={2}>Delete Experiment</Text>
|
||||
</Button>
|
||||
|
||||
<DeleteExperimentDialog
|
||||
experimentId={experiment.data?.id}
|
||||
onDelete={onDelete}
|
||||
disclosure={disclosure}
|
||||
/>
|
||||
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete Experiment
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||
as well. Are you sure?
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -147,10 +147,9 @@ export default function OutputCell({
|
||||
<ResponseLog
|
||||
time={response.receivedAt}
|
||||
title="Response received from API"
|
||||
message={[
|
||||
response.statusCode ? `Status: ${response.statusCode}\n` : "",
|
||||
response.errorMessage ?? "",
|
||||
].join("")}
|
||||
message={`statusCode: ${response.statusCode ?? ""}\n ${
|
||||
response.errorMessage ?? ""
|
||||
}`}
|
||||
/>
|
||||
)}
|
||||
</Fragment>
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
} from "@chakra-ui/react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||
import { type CreatedEditor, editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { api } from "~/utils/api";
|
||||
import {
|
||||
@@ -24,10 +24,8 @@ import { type PromptVariant } from "./types";
|
||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const { canModify } = useExperimentAccess();
|
||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||
const updateOptionsForEditor = useAppStore.use.sharedVariantEditor.updateOptionsForEditor();
|
||||
const editorRef = useRef<CreatedEditor | null>(null);
|
||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||
const lastSavedFnRef = useRef(props.variant.promptConstructor);
|
||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||
const [isChanged, setIsChanged] = useState(false);
|
||||
|
||||
@@ -50,18 +48,22 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
}, [isFullscreen, toggleFullscreen]);
|
||||
|
||||
const lastSavedFn = props.variant.promptConstructor;
|
||||
useEffect(() => {
|
||||
// Store in ref so that we can access it dynamically
|
||||
lastSavedFnRef.current = lastSavedFn;
|
||||
}, [lastSavedFn]);
|
||||
|
||||
const modifierKey = useModifierKeyLabel();
|
||||
|
||||
const checkForChanges = useCallback(() => {
|
||||
if (!editorRef.current) return;
|
||||
const currentFn = editorRef.current.getValue();
|
||||
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFnRef.current);
|
||||
}, [editorRef]);
|
||||
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||
}, [lastSavedFn]);
|
||||
|
||||
const matchUpdatedSavedFn = useCallback(() => {
|
||||
if (!editorRef.current) return;
|
||||
editorRef.current.setValue(lastSavedFn);
|
||||
setIsChanged(false);
|
||||
}, [lastSavedFn]);
|
||||
|
||||
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
|
||||
|
||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||
const utils = api.useContext();
|
||||
@@ -134,11 +136,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
readOnly: !canModify,
|
||||
});
|
||||
|
||||
updateOptionsForEditor(props.variant.uiId, {
|
||||
getContent: () => editorRef.current?.getValue() || "",
|
||||
setContent: (content) => editorRef.current?.setValue(content),
|
||||
});
|
||||
|
||||
// Workaround because otherwise the commands only work on whatever
|
||||
// editor was loaded on the page last.
|
||||
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
|
||||
@@ -158,7 +155,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
});
|
||||
});
|
||||
|
||||
const checkForChangesListener = editorRef.current.onDidChangeModelContent(checkForChanges);
|
||||
editorRef.current.onDidChangeModelContent(checkForChanges);
|
||||
|
||||
const resizeObserver = new ResizeObserver(() => {
|
||||
editorRef.current?.layout();
|
||||
@@ -167,7 +164,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
checkForChangesListener.dispose();
|
||||
editorRef.current?.dispose();
|
||||
};
|
||||
}
|
||||
@@ -175,7 +171,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
// We intentionally skip the onSave and props.savedConfig dependencies here because
|
||||
// we don't want to re-render the editor from scratch
|
||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||
}, [monaco, editorId, updateOptionsForEditor]);
|
||||
}, [monaco, editorId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!editorRef.current) return;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useState, useMemo, useCallback } from "react";
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
@@ -10,23 +9,22 @@ import {
|
||||
ModalOverlay,
|
||||
VStack,
|
||||
Text,
|
||||
Spinner,
|
||||
HStack,
|
||||
Icon,
|
||||
SimpleGrid,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
|
||||
import { useState } from "react";
|
||||
import CompareFunctions from "./CompareFunctions";
|
||||
import { CustomInstructionsInput } from "../CustomInstructionsInput";
|
||||
import { RefineAction } from "./RefineAction";
|
||||
import { isString } from "lodash-es";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
|
||||
|
||||
export const RefinePromptModal = ({
|
||||
variant,
|
||||
@@ -35,23 +33,19 @@ export const RefinePromptModal = ({
|
||||
variant: PromptVariant;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const editorOptionsMap = useAppStore((s) => s.sharedVariantEditor.editorOptionsMap);
|
||||
const originalPromptFn = useMemo(
|
||||
() => editorOptionsMap[variant.uiId]?.getContent() || "",
|
||||
[editorOptionsMap, variant.uiId],
|
||||
);
|
||||
const utils = api.useContext();
|
||||
const visibleScenarios = useVisibleScenarioIds();
|
||||
|
||||
const refinementActions =
|
||||
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
|
||||
|
||||
const { mutateAsync: getModifiedPromptMutateAsync } =
|
||||
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||
const [instructions, setInstructions] = useState<string>("");
|
||||
|
||||
const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>(
|
||||
undefined,
|
||||
);
|
||||
const [refinedPromptFn, setRefinedPromptFn] = useState<string>();
|
||||
|
||||
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||
async (label?: string) => {
|
||||
@@ -60,22 +54,31 @@ export const RefinePromptModal = ({
|
||||
? (refinementActions[label] as RefinementAction).instructions
|
||||
: instructions;
|
||||
setActiveRefineActionLabel(label);
|
||||
const resp = await getModifiedPromptMutateAsync({
|
||||
await getModifiedPromptMutateAsync({
|
||||
id: variant.id,
|
||||
originalPromptFn,
|
||||
instructions: updatedInstructions,
|
||||
});
|
||||
if (maybeReportError(resp)) return;
|
||||
setRefinedPromptFn(resp.payload);
|
||||
},
|
||||
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel],
|
||||
);
|
||||
|
||||
const replaceVariant = useCallback(() => {
|
||||
if (!refinedPromptFn) return;
|
||||
editorOptionsMap[variant.uiId]?.setContent(refinedPromptFn);
|
||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||
|
||||
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (
|
||||
!variant.experimentId ||
|
||||
!refinedPromptFn ||
|
||||
(isObject(refinedPromptFn) && "status" in refinedPromptFn)
|
||||
)
|
||||
return;
|
||||
await replaceVariantMutation.mutateAsync({
|
||||
id: variant.id,
|
||||
promptConstructor: refinedPromptFn,
|
||||
streamScenarios: visibleScenarios,
|
||||
});
|
||||
await utils.promptVariants.list.invalidate();
|
||||
onClose();
|
||||
}, [variant.uiId, editorOptionsMap, onClose, refinedPromptFn]);
|
||||
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
@@ -123,7 +126,7 @@ export const RefinePromptModal = ({
|
||||
/>
|
||||
</VStack>
|
||||
<CompareFunctions
|
||||
originalFunction={originalPromptFn}
|
||||
originalFunction={variant.promptConstructor}
|
||||
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||
maxH="40vh"
|
||||
/>
|
||||
@@ -136,9 +139,9 @@ export const RefinePromptModal = ({
|
||||
colorScheme="blue"
|
||||
onClick={replaceVariant}
|
||||
minW={24}
|
||||
isDisabled={!refinedPromptFn}
|
||||
isDisabled={replacementInProgress || !refinedPromptFn}
|
||||
>
|
||||
Accept
|
||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
|
||||
26
app/src/components/StatsCard.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import { VStack, HStack, type StackProps, Text, Divider } from "@chakra-ui/react";
|
||||
import Link, { type LinkProps } from "next/link";
|
||||
|
||||
const StatsCard = ({
|
||||
title,
|
||||
href,
|
||||
children,
|
||||
...rest
|
||||
}: { title: string; href: string } & StackProps & LinkProps) => {
|
||||
return (
|
||||
<VStack flex={1} borderWidth={1} padding={4} borderRadius={4} borderColor="gray.300" {...rest}>
|
||||
<HStack w="full" justifyContent="space-between">
|
||||
<Text fontSize="md" fontWeight="bold">
|
||||
{title}
|
||||
</Text>
|
||||
<Link href={href}>
|
||||
<Text color="blue">View all</Text>
|
||||
</Link>
|
||||
</HStack>
|
||||
<Divider />
|
||||
{children}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default StatsCard;
|
||||
@@ -2,12 +2,11 @@ import { Card, CardHeader, Heading, Table, Tbody, HStack, Button, Text } from "@
|
||||
import { useState } from "react";
|
||||
import Link from "next/link";
|
||||
import { useLoggedCalls } from "~/utils/hooks";
|
||||
import { EmptyTableRow, TableHeader, TableRow } from "../requestLogs/TableRow";
|
||||
import { TableHeader, TableRow } from "../requestLogs/TableRow";
|
||||
|
||||
export default function LoggedCallsTable() {
|
||||
const { data: loggedCalls } = useLoggedCalls(false);
|
||||
|
||||
const [expandedRow, setExpandedRow] = useState<string | null>(null);
|
||||
const { data: loggedCalls } = useLoggedCalls();
|
||||
|
||||
return (
|
||||
<Card width="100%" overflow="hidden">
|
||||
@@ -24,26 +23,22 @@ export default function LoggedCallsTable() {
|
||||
<Table>
|
||||
<TableHeader />
|
||||
<Tbody>
|
||||
{loggedCalls?.calls.length ? (
|
||||
loggedCalls?.calls.map((loggedCall) => {
|
||||
return (
|
||||
<TableRow
|
||||
key={loggedCall.id}
|
||||
loggedCall={loggedCall}
|
||||
isExpanded={loggedCall.id === expandedRow}
|
||||
onToggle={() => {
|
||||
if (loggedCall.id === expandedRow) {
|
||||
setExpandedRow(null);
|
||||
} else {
|
||||
setExpandedRow(loggedCall.id);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})
|
||||
) : (
|
||||
<EmptyTableRow filtersApplied={false} />
|
||||
)}
|
||||
{loggedCalls?.calls.map((loggedCall) => {
|
||||
return (
|
||||
<TableRow
|
||||
key={loggedCall.id}
|
||||
loggedCall={loggedCall}
|
||||
isExpanded={loggedCall.id === expandedRow}
|
||||
onToggle={() => {
|
||||
if (loggedCall.id === expandedRow) {
|
||||
setExpandedRow(null);
|
||||
} else {
|
||||
setExpandedRow(loggedCall.id);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</Card>
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import { useRef } from "react";
|
||||
import {
|
||||
type UseDisclosureReturn,
|
||||
AlertDialog,
|
||||
AlertDialogOverlay,
|
||||
AlertDialogContent,
|
||||
AlertDialogHeader,
|
||||
AlertDialogBody,
|
||||
AlertDialogFooter,
|
||||
Button,
|
||||
} from "@chakra-ui/react";
|
||||
import { api } from "~/utils/api";
|
||||
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
|
||||
const DeleteExperimentDialog = ({
|
||||
experimentId,
|
||||
onDelete,
|
||||
disclosure,
|
||||
}: {
|
||||
experimentId?: string;
|
||||
onDelete?: () => void;
|
||||
disclosure: UseDisclosureReturn;
|
||||
}) => {
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const mutation = api.experiments.delete.useMutation();
|
||||
const utils = api.useContext();
|
||||
|
||||
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
|
||||
if (!experimentId) return;
|
||||
await mutation.mutateAsync({ id: experimentId });
|
||||
await utils.experiments.list.invalidate();
|
||||
onDelete?.();
|
||||
|
||||
disclosure.onClose();
|
||||
}, [mutation, experimentId, disclosure.onClose]);
|
||||
|
||||
return (
|
||||
<AlertDialog leastDestructiveRef={cancelRef} {...disclosure}>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
Delete Experiment
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
If you delete this experiment all the associated prompts and scenarios will be deleted
|
||||
as well. Are you sure?
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
<Button ref={cancelRef} onClick={disclosure.onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
|
||||
Delete
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
);
|
||||
};
|
||||
|
||||
export default DeleteExperimentDialog;
|
||||
@@ -1,4 +1,3 @@
|
||||
import { type MouseEvent, useState } from "react";
|
||||
import {
|
||||
HStack,
|
||||
Icon,
|
||||
@@ -9,29 +8,17 @@ import {
|
||||
AspectRatio,
|
||||
SkeletonText,
|
||||
Card,
|
||||
useDisclosure,
|
||||
Box,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
useToast,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiFlaskLine } from "react-icons/ri";
|
||||
import { formatTimePast } from "~/utils/dayjs";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsPlusSquare, BsThreeDotsVertical, BsLink45Deg, BsTrash } from "react-icons/bs";
|
||||
|
||||
import { formatTimePast } from "~/utils/dayjs";
|
||||
import { type RouterOutputs, api } from "~/utils/api";
|
||||
import { BsPlusSquare } from "react-icons/bs";
|
||||
import { RouterOutputs, api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import DeleteExperimentDialog from "./DeleteExperimentDialog";
|
||||
|
||||
export const ExperimentCard = ({ exp }: { exp: RouterOutputs["experiments"]["list"][0] }) => {
|
||||
const [isMenuHovered, setIsMenuHovered] = useState(false);
|
||||
|
||||
return (
|
||||
<Card
|
||||
w="full"
|
||||
@@ -40,7 +27,7 @@ export const ExperimentCard = ({ exp }: { exp: RouterOutputs["experiments"]["lis
|
||||
p={4}
|
||||
bg="white"
|
||||
borderRadius={4}
|
||||
_hover={{ bg: isMenuHovered ? undefined : "gray.100" }}
|
||||
_hover={{ bg: "gray.100" }}
|
||||
transition="background 0.2s"
|
||||
aspectRatio={1.2}
|
||||
>
|
||||
@@ -51,17 +38,9 @@ export const ExperimentCard = ({ exp }: { exp: RouterOutputs["experiments"]["lis
|
||||
href={{ pathname: "/experiments/[experimentSlug]", query: { experimentSlug: exp.slug } }}
|
||||
justify="space-between"
|
||||
>
|
||||
<HStack w="full" justify="space-between" spacing={0}>
|
||||
<Box w={6} />
|
||||
<HStack color="gray.700" justify="center">
|
||||
<Icon as={RiFlaskLine} boxSize={4} />
|
||||
<Text fontWeight="bold">{exp.label}</Text>
|
||||
</HStack>
|
||||
<CardMenu
|
||||
experimentId={exp.id}
|
||||
experimentSlug={exp.slug}
|
||||
setIsMenuHovered={setIsMenuHovered}
|
||||
/>
|
||||
<HStack w="full" color="gray.700" justify="center">
|
||||
<Icon as={RiFlaskLine} boxSize={4} />
|
||||
<Text fontWeight="bold">{exp.label}</Text>
|
||||
</HStack>
|
||||
<HStack h="full" spacing={4} flex={1} align="center">
|
||||
<CountLabel label="Variants" count={exp.promptVariantCount} />
|
||||
@@ -78,75 +57,6 @@ export const ExperimentCard = ({ exp }: { exp: RouterOutputs["experiments"]["lis
|
||||
);
|
||||
};
|
||||
|
||||
const CardMenu = ({
|
||||
experimentId,
|
||||
experimentSlug,
|
||||
setIsMenuHovered,
|
||||
}: {
|
||||
experimentId: string;
|
||||
experimentSlug: string;
|
||||
setIsMenuHovered: (isHovered: boolean) => void;
|
||||
}) => {
|
||||
const deleteDisclosure = useDisclosure();
|
||||
const menuDisclosure = useDisclosure();
|
||||
const toast = useToast();
|
||||
const [copyShareLink] = useHandledAsyncCallback(
|
||||
async (e: MouseEvent<HTMLButtonElement>) => {
|
||||
if (typeof window === "undefined") return;
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
const shareLink = `${window.location.origin}/experiments/${experimentSlug}`;
|
||||
await navigator.clipboard.writeText(shareLink);
|
||||
toast({
|
||||
title: "Share link copied to clipboard",
|
||||
status: "success",
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
menuDisclosure.onClose();
|
||||
},
|
||||
[toast, menuDisclosure.onClose, experimentSlug],
|
||||
);
|
||||
return (
|
||||
<>
|
||||
<Menu isLazy {...menuDisclosure}>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label="Options"
|
||||
icon={<BsThreeDotsVertical />}
|
||||
variant="ghost"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
menuDisclosure.onOpen();
|
||||
}}
|
||||
onMouseEnter={() => setIsMenuHovered(true)}
|
||||
onMouseLeave={() => setIsMenuHovered(false)}
|
||||
boxSize={6}
|
||||
minW={0}
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuItem icon={<Icon as={BsLink45Deg} boxSize={5} />} onClick={copyShareLink}>
|
||||
Copy Link
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<Icon as={BsTrash} boxSize={5} />}
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
deleteDisclosure.onOpen();
|
||||
}}
|
||||
color="red.500"
|
||||
>
|
||||
Delete
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
<DeleteExperimentDialog experimentId={experimentId} disclosure={deleteDisclosure} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||
return (
|
||||
<VStack alignItems="center" flex={1}>
|
||||
@@ -188,7 +98,9 @@ export const NewExperimentCard = () => {
|
||||
>
|
||||
<VStack align="center" justify="center" w="full" h="full" p={4} onClick={createExperiment}>
|
||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||
<Text ml={2}>New Experiment</Text>
|
||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||
New Experiment
|
||||
</Text>
|
||||
</VStack>
|
||||
</Card>
|
||||
);
|
||||
|
||||
@@ -13,18 +13,15 @@ import {
|
||||
} from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsGearFill, BsGithub, BsPersonCircle } from "react-icons/bs";
|
||||
import { IoStatsChartOutline } from "react-icons/io5";
|
||||
import { RiHome3Line, RiFlaskLine } from "react-icons/ri";
|
||||
import { AiOutlineThunderbolt } from "react-icons/ai";
|
||||
import { FaReadme } from "react-icons/fa";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
|
||||
import ProjectMenu from "./ProjectMenu";
|
||||
import NavSidebarOption from "./NavSidebarOption";
|
||||
import IconLink from "./IconLink";
|
||||
import { BetaModal } from "../BetaModal";
|
||||
import { BetaModal } from "./BetaModal";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
const Divider = () => <Box h="1px" bgColor="gray.300" w="full" />;
|
||||
@@ -76,9 +73,9 @@ const NavSidebar = () => {
|
||||
<ProjectMenu />
|
||||
<Divider />
|
||||
|
||||
<IconLink icon={RiHome3Line} label="Dashboard" href="/dashboard" />
|
||||
<IconLink icon={IoStatsChartOutline} label="Request Logs" href="/request-logs" />
|
||||
<IconLink icon={AiOutlineThunderbolt} label="Fine Tunes" href="/fine-tunes" beta />
|
||||
<IconLink icon={RiHome3Line} label="Dashboard" href="/dashboard" beta />
|
||||
<IconLink icon={IoStatsChartOutline} label="Request Logs" href="/request-logs" beta />
|
||||
<IconLink icon={FaRobot} label="Fine Tunes" href="/fine-tunes" beta />
|
||||
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
||||
<VStack w="full" alignItems="flex-start" spacing={0} pt={8}>
|
||||
<Text
|
||||
@@ -114,22 +111,7 @@ const NavSidebar = () => {
|
||||
</NavSidebarOption>
|
||||
)}
|
||||
</VStack>
|
||||
<HStack
|
||||
w="full"
|
||||
px={{ base: 2, md: 4 }}
|
||||
py={{ base: 1, md: 2 }}
|
||||
as={ChakraLink}
|
||||
justifyContent="start"
|
||||
href="https://docs.openpipe.ai"
|
||||
target="_blank"
|
||||
color="gray.500"
|
||||
spacing={1}
|
||||
>
|
||||
<Icon as={FaReadme} boxSize={4} mr={2} />
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
Read the Docs
|
||||
</Text>
|
||||
</HStack>
|
||||
|
||||
<Divider />
|
||||
<VStack spacing={0} align="center">
|
||||
<ChakraLink
|
||||
@@ -158,7 +140,6 @@ export default function AppShell({
|
||||
requireBeta?: boolean;
|
||||
}) {
|
||||
const [vh, setVh] = useState("100vh"); // Default height to prevent flicker on initial render
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
const setHeight = () => {
|
||||
@@ -200,7 +181,7 @@ export default function AppShell({
|
||||
{children}
|
||||
</Box>
|
||||
</Flex>
|
||||
<BetaModal isOpen={!!requireBeta && flagsLoaded && !flags.betaAccess} onClose={router.back} />
|
||||
{requireBeta && flagsLoaded && !flags.betaAccess && <BetaModal />}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -13,17 +13,19 @@ import {
|
||||
Link,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
|
||||
export const BetaModal = ({ isOpen, onClose }: { isOpen: boolean; onClose: () => void }) => {
|
||||
export const BetaModal = () => {
|
||||
const router = useRouter();
|
||||
const session = useSession();
|
||||
|
||||
const email = session.data?.user.email ?? "";
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
isOpen
|
||||
onClose={router.back}
|
||||
closeOnOverlayClick={false}
|
||||
size={{ base: "xl", md: "2xl" }}
|
||||
>
|
||||
@@ -54,7 +56,7 @@ export const BetaModal = ({ isOpen, onClose }: { isOpen: boolean; onClose: () =>
|
||||
>
|
||||
Join Waitlist
|
||||
</Button>
|
||||
<Button colorScheme="blue" onClick={onClose}>
|
||||
<Button colorScheme="blue" onClick={router.back}>
|
||||
Done
|
||||
</Button>
|
||||
</HStack>
|
||||
@@ -57,7 +57,6 @@ export default function ProjectMenu() {
|
||||
await utils.projects.list.invalidate();
|
||||
setSelectedProjectId(newProj.id);
|
||||
await router.push({ pathname: "/project/settings" });
|
||||
popover.onClose();
|
||||
}, [createMutation, router]);
|
||||
|
||||
const user = useSession().data;
|
||||
|
||||
@@ -1,50 +1,29 @@
|
||||
import { useState } from "react";
|
||||
|
||||
import { Button, HStack, type ButtonProps, Icon, Text } from "@chakra-ui/react";
|
||||
import { type IconType } from "react-icons";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import { BetaModal } from "../BetaModal";
|
||||
|
||||
const ActionButton = ({
|
||||
icon,
|
||||
label,
|
||||
requireBeta = false,
|
||||
onClick,
|
||||
...buttonProps
|
||||
}: {
|
||||
icon: IconType;
|
||||
label: string;
|
||||
requireBeta?: boolean;
|
||||
onClick?: () => void;
|
||||
} & ButtonProps) => {
|
||||
const flags = useAppStore((s) => s.featureFlags.featureFlags);
|
||||
const flagsLoaded = useAppStore((s) => s.featureFlags.flagsLoaded);
|
||||
|
||||
const [betaModalOpen, setBetaModalOpen] = useState(false);
|
||||
|
||||
const isBetaBlocked = requireBeta && flagsLoaded && !flags.betaAccess;
|
||||
}: { icon: IconType; label: string } & ButtonProps) => {
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
color="black"
|
||||
bgColor="white"
|
||||
borderColor="gray.300"
|
||||
borderRadius={4}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
fontSize="sm"
|
||||
fontWeight="normal"
|
||||
onClick={isBetaBlocked ? () => setBetaModalOpen(true) : onClick}
|
||||
{...buttonProps}
|
||||
>
|
||||
<HStack spacing={1}>
|
||||
{icon && <Icon as={icon} color={requireBeta ? "orange.400" : undefined} />}
|
||||
<Text display={{ base: "none", md: "flex" }}>{label}</Text>
|
||||
</HStack>
|
||||
</Button>
|
||||
<BetaModal isOpen={betaModalOpen} onClose={() => setBetaModalOpen(false)} />
|
||||
</>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
color="black"
|
||||
bgColor="white"
|
||||
borderColor="gray.300"
|
||||
borderRadius={4}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
fontSize="sm"
|
||||
fontWeight="normal"
|
||||
{...buttonProps}
|
||||
>
|
||||
<HStack spacing={1}>
|
||||
{icon && <Icon as={icon} />}
|
||||
<Text display={{ base: "none", md: "flex" }}>{label}</Text>
|
||||
</HStack>
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ const ExportButton = () => {
|
||||
label="Export"
|
||||
icon={BiExport}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
requireBeta
|
||||
/>
|
||||
<ExportLogsModal disclosure={disclosure} />
|
||||
</>
|
||||
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
type UseDisclosureReturn,
|
||||
Input,
|
||||
} from "@chakra-ui/react";
|
||||
import { AiTwotoneThunderbolt } from "react-icons/ai";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import humanId from "human-id";
|
||||
import { useRouter } from "next/router";
|
||||
|
||||
@@ -39,9 +39,8 @@ const FineTuneButton = () => {
|
||||
<ActionButton
|
||||
onClick={disclosure.onOpen}
|
||||
label="Fine Tune"
|
||||
icon={AiTwotoneThunderbolt}
|
||||
icon={FaRobot}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
requireBeta
|
||||
/>
|
||||
<FineTuneModal disclosure={disclosure} />
|
||||
</>
|
||||
@@ -91,7 +90,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={AiTwotoneThunderbolt} />
|
||||
<Icon as={FaRobot} />
|
||||
<Text>Fine Tune</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Card, Table, Tbody } from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { useLoggedCalls } from "~/utils/hooks";
|
||||
import { TableHeader, TableRow, EmptyTableRow } from "./TableRow";
|
||||
import { TableHeader, TableRow } from "./TableRow";
|
||||
|
||||
export default function LoggedCallsTable() {
|
||||
const [expandedRow, setExpandedRow] = useState<string | null>(null);
|
||||
@@ -12,27 +12,23 @@ export default function LoggedCallsTable() {
|
||||
<Table>
|
||||
<TableHeader showOptions />
|
||||
<Tbody>
|
||||
{loggedCalls?.calls.length ? (
|
||||
loggedCalls?.calls?.map((loggedCall) => {
|
||||
return (
|
||||
<TableRow
|
||||
key={loggedCall.id}
|
||||
loggedCall={loggedCall}
|
||||
isExpanded={loggedCall.id === expandedRow}
|
||||
onToggle={() => {
|
||||
if (loggedCall.id === expandedRow) {
|
||||
setExpandedRow(null);
|
||||
} else {
|
||||
setExpandedRow(loggedCall.id);
|
||||
}
|
||||
}}
|
||||
showOptions
|
||||
/>
|
||||
);
|
||||
})
|
||||
) : (
|
||||
<EmptyTableRow />
|
||||
)}
|
||||
{loggedCalls?.calls?.map((loggedCall) => {
|
||||
return (
|
||||
<TableRow
|
||||
key={loggedCall.id}
|
||||
loggedCall={loggedCall}
|
||||
isExpanded={loggedCall.id === expandedRow}
|
||||
onToggle={() => {
|
||||
if (loggedCall.id === expandedRow) {
|
||||
setExpandedRow(null);
|
||||
} else {
|
||||
setExpandedRow(loggedCall.id);
|
||||
}
|
||||
}}
|
||||
showOptions
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</Card>
|
||||
|
||||
@@ -13,7 +13,6 @@ import {
|
||||
ButtonGroup,
|
||||
Text,
|
||||
Checkbox,
|
||||
Link as ChakraLink,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
|
||||
@@ -199,41 +198,3 @@ export const TableRow = ({
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export const EmptyTableRow = ({ filtersApplied = true }: { filtersApplied?: boolean }) => {
|
||||
const visibleColumns = useAppStore((s) => s.columnVisibility.visibleColumns);
|
||||
const filters = useAppStore((state) => state.logFilters.filters);
|
||||
const { isLoading } = useLoggedCalls();
|
||||
|
||||
if (isLoading) return null;
|
||||
|
||||
if (filters.length && filtersApplied) {
|
||||
return (
|
||||
<Tr>
|
||||
<Td w="full" colSpan={visibleColumns.size + 1}>
|
||||
<Text color="gray.500" textAlign="center" w="full" p={4}>
|
||||
No matching request logs found. Try removing some filters.
|
||||
</Text>
|
||||
</Td>
|
||||
</Tr>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Tr>
|
||||
<Td w="full" colSpan={visibleColumns.size + 1}>
|
||||
<Text color="gray.500" textAlign="center" w="full" p={4}>
|
||||
This project has no request logs. Learn how to add request logs to your project in our{" "}
|
||||
<ChakraLink
|
||||
href="https://docs.openpipe.ai/getting-started/quick-start"
|
||||
target="_blank"
|
||||
color="blue.600"
|
||||
>
|
||||
Quick Start
|
||||
</ChakraLink>{" "}
|
||||
guide.
|
||||
</Text>
|
||||
</Td>
|
||||
</Tr>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import { isArray, isString } from "lodash-es";
|
||||
import { APIError } from "openai";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import mergeChunks from "openpipe/openai/mergeChunks";
|
||||
import mergeChunks from "openpipe/src/openai/mergeChunks";
|
||||
import { openai } from "~/server/utils/openai";
|
||||
import { type CompletionResponse } from "../types";
|
||||
|
||||
|
||||
@@ -17,23 +17,10 @@ const modelEndpoints: Record<OpenpipeChatInput["model"], string> = {
|
||||
"NousResearch/Nous-Hermes-llama-2-7b": "https://ua1bpc6kv3dgge-8000.proxy.runpod.net/v1",
|
||||
};
|
||||
|
||||
const CUSTOM_MODELS_ENABLED = false;
|
||||
|
||||
export async function getCompletion(
|
||||
input: OpenpipeChatInput,
|
||||
onStream: ((partialOutput: OpenpipeChatOutput) => void) | null,
|
||||
): Promise<CompletionResponse<OpenpipeChatOutput>> {
|
||||
// Temporarily disable these models because of GPU constraints
|
||||
|
||||
if (!CUSTOM_MODELS_ENABLED) {
|
||||
return {
|
||||
type: "error",
|
||||
message:
|
||||
"We've disabled this model temporarily because of GPU capacity constraints. Check back later.",
|
||||
autoRetry: false,
|
||||
};
|
||||
}
|
||||
|
||||
const { model, messages, ...rest } = input;
|
||||
|
||||
const templatedPrompt = frontendModelProvider.models[model].templatePrompt?.(messages);
|
||||
|
||||
@@ -8,8 +8,8 @@ const replicate = new Replicate({
|
||||
});
|
||||
|
||||
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
||||
"7b-chat": "d24902e3fa9b698cc208b5e63136c4e26e828659a9f09827ca6ec5bb83014381",
|
||||
"13b-chat": "9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3",
|
||||
"7b-chat": "7b0bfc9aff140d5b75bacbed23e91fd3c34b01a1e958d32132de6e0a19796e2c",
|
||||
"13b-chat": "2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
|
||||
"70b-chat": "2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1",
|
||||
};
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ export default function Dashboard() {
|
||||
);
|
||||
|
||||
return (
|
||||
<AppShell title="Dashboard" requireAuth>
|
||||
<AppShell title="Dashboard" requireAuth requireBeta>
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4}>
|
||||
<Text fontSize="2xl" fontWeight="bold">
|
||||
Dashboard
|
||||
|
||||
@@ -19,7 +19,7 @@ export default function LoggedCalls() {
|
||||
const [filtersShown, setFiltersShown] = useState(true);
|
||||
|
||||
return (
|
||||
<AppShell title="Request Logs" requireAuth>
|
||||
<AppShell title="Request Logs" requireAuth requireBeta>
|
||||
<Box h="100vh" overflowY="scroll">
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
|
||||
<Text fontSize="2xl" fontWeight="bold">
|
||||
@@ -35,7 +35,6 @@ export default function LoggedCalls() {
|
||||
label="Experiment"
|
||||
icon={RiFlaskLine}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
requireBeta
|
||||
/>
|
||||
<ExportButton />
|
||||
<ColumnVisiblityDropdown />
|
||||
|
||||
@@ -196,10 +196,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(
|
||||
originalVariant,
|
||||
originalVariant?.promptConstructor,
|
||||
);
|
||||
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
@@ -301,7 +298,6 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
originalPromptFn: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
newModel: z
|
||||
.object({
|
||||
@@ -319,21 +315,22 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const constructedPrompt = await parsePromptConstructor(existing.promptConstructor);
|
||||
|
||||
if ("error" in constructedPrompt) {
|
||||
return error(constructedPrompt.error);
|
||||
}
|
||||
|
||||
const model = input.newModel
|
||||
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||
: undefined;
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(
|
||||
existing,
|
||||
input.originalPromptFn,
|
||||
model,
|
||||
input.instructions,
|
||||
);
|
||||
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||
|
||||
// TODO: Validate promptConstructionFn
|
||||
// TODO: Record in some sort of history
|
||||
|
||||
return success(promptConstructionFn);
|
||||
return promptConstructionFn;
|
||||
}),
|
||||
|
||||
replaceVariant: protectedProcedure
|
||||
|
||||
@@ -12,37 +12,30 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function deriveNewConstructFn(
|
||||
originalVariant: PromptVariant | null,
|
||||
originalPromptFn?: string,
|
||||
newModel?: Model,
|
||||
instructions?: string,
|
||||
) {
|
||||
if (originalPromptFn && !newModel && !instructions) {
|
||||
return originalPromptFn;
|
||||
if (originalVariant && !newModel && !instructions) {
|
||||
return originalVariant.promptConstructor;
|
||||
}
|
||||
if (originalVariant && originalPromptFn && (newModel || instructions)) {
|
||||
return await requestUpdatedPromptFunction(
|
||||
originalVariant,
|
||||
originalPromptFn,
|
||||
newModel,
|
||||
instructions,
|
||||
);
|
||||
if (originalVariant && (newModel || instructions)) {
|
||||
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
|
||||
}
|
||||
return dedent`
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Hello, world!\`,
|
||||
},
|
||||
],
|
||||
});`;
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Return 'Hello, world!'",
|
||||
}
|
||||
]
|
||||
}`;
|
||||
}
|
||||
|
||||
const NUM_RETRIES = 5;
|
||||
const requestUpdatedPromptFunction = async (
|
||||
originalVariant: PromptVariant,
|
||||
originalPromptFn: string,
|
||||
newModel?: Model,
|
||||
instructions?: string,
|
||||
) => {
|
||||
@@ -62,7 +55,7 @@ const requestUpdatedPromptFunction = async (
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `This is the current prompt constructor function:\n---\n${originalPromptFn}`,
|
||||
content: `This is the current prompt constructor function:\n---\n${originalVariant.promptConstructor}`,
|
||||
},
|
||||
];
|
||||
if (newModel) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import OpenAI, { type ClientOptions } from "openpipe/openai";
|
||||
import OpenAI, { type ClientOptions } from "openpipe/src/openai";
|
||||
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
|
||||
@@ -1,26 +1,16 @@
|
||||
import loader, { type Monaco } from "@monaco-editor/loader";
|
||||
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { type SliceCreator } from "./store";
|
||||
import loader from "@monaco-editor/loader";
|
||||
import formatPromptConstructor from "~/promptConstructor/format";
|
||||
|
||||
export const editorBackground = "#fafafa";
|
||||
|
||||
export type CreatedEditor = ReturnType<Monaco["editor"]["create"]>;
|
||||
|
||||
type EditorOptions = {
|
||||
getContent: () => string;
|
||||
setContent: (content: string) => void;
|
||||
};
|
||||
|
||||
export type SharedVariantEditorSlice = {
|
||||
monaco: null | Monaco;
|
||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||
loadMonaco: () => Promise<void>;
|
||||
scenarioVars: RouterOutputs["scenarioVars"]["list"];
|
||||
updateScenariosModel: () => void;
|
||||
setScenarioVars: (scenarioVars: RouterOutputs["scenarioVars"]["list"]) => void;
|
||||
editorOptionsMap: Record<string, EditorOptions>;
|
||||
updateOptionsForEditor: (uiId: string, { getContent, setContent }: EditorOptions) => void;
|
||||
};
|
||||
|
||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||
@@ -103,10 +93,4 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
);
|
||||
}
|
||||
},
|
||||
editorOptionsMap: {},
|
||||
updateOptionsForEditor: (uiId, options) => {
|
||||
set((state) => {
|
||||
state.sharedVariantEditor.editorOptionsMap[uiId] = options;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -148,13 +148,13 @@ export const useScenarioVars = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export const useLoggedCalls = (applyFilters = true) => {
|
||||
export const useLoggedCalls = () => {
|
||||
const selectedProjectId = useAppStore((state) => state.selectedProjectId);
|
||||
const { page, pageSize } = usePageParams();
|
||||
const filters = useAppStore((state) => state.logFilters.filters);
|
||||
|
||||
const { data, isLoading, ...rest } = api.loggedCalls.list.useQuery(
|
||||
{ projectId: selectedProjectId ?? "", page, pageSize, filters: applyFilters ? filters : [] },
|
||||
{ projectId: selectedProjectId ?? "", page, pageSize, filters },
|
||||
{ enabled: !!selectedProjectId },
|
||||
);
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
This client allows you automatically report your OpenAI calls to [OpenPipe](https://openpipe.ai/). OpenPipe
|
||||
|
||||
## Installation
|
||||
|
||||
`pip install openpipe`
|
||||
|
||||
## Usage
|
||||
@@ -16,7 +15,7 @@ This client allows you automatically report your OpenAI calls to [OpenPipe](http
|
||||
from openpipe import openai, configure_openpipe
|
||||
import os
|
||||
|
||||
# Set the OpenPipe API key you got in step (2) above.
|
||||
# Set the OpenPipe API key you got in step (3) above.
|
||||
# If you have the `OPENPIPE_API_KEY` environment variable set we'll read from it by default.
|
||||
configure_openpipe(api_key=os.getenv("OPENPIPE_API_KEY"))
|
||||
|
||||
@@ -24,7 +23,7 @@ configure_openpipe(api_key=os.getenv("OPENPIPE_API_KEY"))
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
```
|
||||
|
||||
You can now use your new OpenAI client, which functions identically to the generic OpenAI client while also reporting calls to your OpenPipe instance.
|
||||
You can use the OpenPipe client for normal
|
||||
|
||||
## Special Features
|
||||
|
||||
@@ -38,4 +37,4 @@ completion = openai.ChatCompletion.create(
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
openpipe={"tags": {"prompt_id": "counting"}},
|
||||
)
|
||||
```
|
||||
```
|
||||
@@ -6,9 +6,11 @@ from openpipe.api_client.client import AuthenticatedClient
|
||||
from openpipe.api_client.models.report_json_body_tags import (
|
||||
ReportJsonBodyTags,
|
||||
)
|
||||
import toml
|
||||
import time
|
||||
import os
|
||||
import pkg_resources
|
||||
|
||||
version = toml.load("pyproject.toml")["tool"]["poetry"]["version"]
|
||||
|
||||
configured_client = AuthenticatedClient(
|
||||
base_url="https://app.openpipe.ai/api/v1", token=""
|
||||
@@ -21,7 +23,7 @@ if os.environ.get("OPENPIPE_API_KEY"):
|
||||
def _get_tags(openpipe_options):
|
||||
tags = openpipe_options.get("tags") or {}
|
||||
tags["$sdk"] = "python"
|
||||
tags["$sdk.version"] = pkg_resources.get_distribution('openpipe').version
|
||||
tags["$sdk.version"] = version
|
||||
|
||||
return ReportJsonBodyTags.from_dict(tags)
|
||||
|
||||
|
||||
23
client-libs/python/poetry.lock
generated
@@ -1056,7 +1056,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||
@@ -1064,15 +1063,8 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||
@@ -1089,7 +1081,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||
@@ -1097,7 +1088,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||
@@ -1157,6 +1147,17 @@ files = [
|
||||
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
description = "Python Library for Tom's Obvious, Minimal Language"
|
||||
optional = false
|
||||
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
files = [
|
||||
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
||||
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
@@ -1366,4 +1367,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "f50c3ee43ebb9510bf42b9a16d8d6a92d561bec40e8f3c11fb2614e92a5b756f"
|
||||
content-hash = "e93c2ecac1b81a4fc1f9ad3dcedf03b1126cc6815e084ae233da7d3ece313ade"
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Check if PYPI_OPENPIPE_TOKEN is set
|
||||
if [[ -z "${PYPI_OPENPIPE_TOKEN}" ]]; then
|
||||
echo "Error: PYPI_OPENPIPE_TOKEN is not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# If the token is set, proceed with publishing
|
||||
poetry publish --build --username=__token__ --password=$PYPI_OPENPIPE_TOKEN
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "openpipe"
|
||||
version = "3.1.2"
|
||||
version = "3.0.1"
|
||||
description = "Python client library for the OpenPipe service"
|
||||
authors = ["Kyle Corbitt <kyle@openpipe.ai>"]
|
||||
license = "Apache-2.0"
|
||||
@@ -14,6 +14,7 @@ openai = "^0.27.8"
|
||||
httpx = "^0.24.1"
|
||||
attrs = "^23.1.0"
|
||||
python-dateutil = "^2.8.2"
|
||||
toml = "^0.10.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
# OpenPipe Node API Library
|
||||
|
||||
[](https://npmjs.org/package/openpipe)
|
||||
|
||||
This library wraps TypeScript or Javascript OpenAI API calls and logs additional data to the configured `OPENPIPE_BASE_URL` for further processing.
|
||||
|
||||
It is fully compatible with OpenAI's sdk and logs both streaming and non-streaming requests and responses.
|
||||
|
||||
<!-- To learn more about using OpenPipe, check out our [Documentation](https://docs.openpipe.ai/docs/api). -->
|
||||
|
||||
## Installation
|
||||
|
||||
```sh
|
||||
npm install --save openpipe
|
||||
# or
|
||||
yarn add openpipe
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
1. Create a project at https://app.openpipe.ai
|
||||
2. Find your project's API key at https://app.openpipe.ai/project/settings
|
||||
3. Configure the OpenPipe client as shown below.
|
||||
|
||||
```js
|
||||
// import OpenAI from 'openai'
|
||||
import OpenAI from "openpipe/openai";
|
||||
|
||||
// Fully compatible with original OpenAI initialization
|
||||
const openai = new OpenAI({
|
||||
apiKey: "my api key", // defaults to process.env["OPENAI_API_KEY"]
|
||||
// openpipe key is optional
|
||||
openpipe: {
|
||||
apiKey: "my api key", // defaults to process.env["OPENPIPE_API_KEY"]
|
||||
baseUrl: "my url", // defaults to process.env["OPENPIPE_BASE_URL"] or https://app.openpipe.ai/api/v1 if not set
|
||||
},
|
||||
});
|
||||
|
||||
async function main() {
|
||||
// Allows optional openpipe object
|
||||
const completion = await openai.chat.completions.create({
|
||||
messages: [{ role: "user", content: "Say this is a test" }],
|
||||
model: "gpt-3.5-turbo",
|
||||
// optional
|
||||
openpipe: {
|
||||
// Add custom searchable tags
|
||||
tags: {
|
||||
prompt_id: "getCompletion",
|
||||
any_key: "any_value",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
console.log(completion.choices);
|
||||
}
|
||||
|
||||
main();
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
<b><i>How do I report calls to my self-hosted instance?</i></b>
|
||||
|
||||
Start an instance by following the instructions on [Running Locally](https://github.com/OpenPipe/OpenPipe#running-locally). Once it's running, point your `OPENPIPE_BASE_URL` to your self-hosted instance.
|
||||
|
||||
<b><i>What if my `OPENPIPE_BASE_URL` is misconfigured or my instance goes down? Will my OpenAI calls stop working?</i></b>
|
||||
|
||||
Your OpenAI calls will continue to function as expected no matter what. The sdk handles logging errors gracefully without affecting OpenAI inference.
|
||||
|
||||
See the [GitHub repo](https://github.com/OpenPipe/OpenPipe) for more details.
|
||||
@@ -1,27 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Adapted from https://github.com/openai/openai-node/blob/master/build
|
||||
|
||||
set -exuo pipefail
|
||||
|
||||
rm -rf dist /tmp/openpipe-build-dist
|
||||
|
||||
mkdir /tmp/openpipe-build-dist
|
||||
|
||||
cp -rp * /tmp/openpipe-build-dist
|
||||
|
||||
# Rename package name in package.json
|
||||
python3 -c "
|
||||
import json
|
||||
with open('/tmp/openpipe-build-dist/package.json', 'r') as f:
|
||||
data = json.load(f)
|
||||
data['name'] = 'openpipe'
|
||||
with open('/tmp/openpipe-build-dist/package.json', 'w') as f:
|
||||
json.dump(data, f, indent=4)
|
||||
"
|
||||
|
||||
rm -rf /tmp/openpipe-build-dist/node_modules
|
||||
mv /tmp/openpipe-build-dist dist
|
||||
|
||||
# build to .js files
|
||||
(cd dist && npm exec tsc -- --noEmit false)
|
||||
@@ -1 +1,3 @@
|
||||
export * as openai from "./openai";
|
||||
// main.ts or index.ts at the root level
|
||||
export * as OpenAI from "./src/openai";
|
||||
export * as OpenAILegacy from "./src/openai-legacy";
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
{
|
||||
"name": "openpipe-dev",
|
||||
"version": "0.3.5",
|
||||
"name": "openpipe",
|
||||
"version": "0.1.0",
|
||||
"type": "module",
|
||||
"description": "Metrics and auto-evaluation for LLM calls",
|
||||
"scripts": {
|
||||
"build": "./build.sh",
|
||||
"build": "tsc",
|
||||
"test": "vitest"
|
||||
},
|
||||
"main": "./index.ts",
|
||||
"publishConfig": {
|
||||
"access": "public",
|
||||
"main": "./index.js"
|
||||
},
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "Apache-2.0",
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Adapted from https://github.com/openai/openai-node/blob/master/build
|
||||
|
||||
set -exuo pipefail
|
||||
|
||||
./build.sh
|
||||
|
||||
(cd dist && pnpm publish --access public)
|
||||
85
client-libs/typescript/src/openai-legacy/index.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
import * as openPipeClient from "../codegen";
|
||||
import * as openai from "openai-legacy";
|
||||
import { version } from "../../package.json";
|
||||
|
||||
// Anything we don't override we want to pass through to openai directly
|
||||
export * as openAILegacy from "openai-legacy";
|
||||
|
||||
type OPConfigurationParameters = {
|
||||
apiKey?: string;
|
||||
basePath?: string;
|
||||
};
|
||||
|
||||
export class Configuration extends openai.Configuration {
|
||||
public qkConfig?: openPipeClient.Configuration;
|
||||
|
||||
constructor(
|
||||
config: openai.ConfigurationParameters & {
|
||||
opParameters?: OPConfigurationParameters;
|
||||
}
|
||||
) {
|
||||
super(config);
|
||||
if (config.opParameters) {
|
||||
this.qkConfig = new openPipeClient.Configuration(config.opParameters);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type CreateChatCompletion = InstanceType<typeof openai.OpenAIApi>["createChatCompletion"];
|
||||
|
||||
export class OpenAIApi extends openai.OpenAIApi {
|
||||
public openPipeApi?: openPipeClient.DefaultApi;
|
||||
|
||||
constructor(config: Configuration) {
|
||||
super(config);
|
||||
if (config.qkConfig) {
|
||||
this.openPipeApi = new openPipeClient.DefaultApi(config.qkConfig);
|
||||
}
|
||||
}
|
||||
|
||||
public async createChatCompletion(
|
||||
createChatCompletionRequest: Parameters<CreateChatCompletion>[0],
|
||||
options?: Parameters<CreateChatCompletion>[1]
|
||||
): ReturnType<CreateChatCompletion> {
|
||||
const requestedAt = Date.now();
|
||||
let resp: Awaited<ReturnType<CreateChatCompletion>> | null = null;
|
||||
let respPayload: openai.CreateChatCompletionResponse | null = null;
|
||||
let statusCode: number | undefined = undefined;
|
||||
let errorMessage: string | undefined;
|
||||
try {
|
||||
resp = await super.createChatCompletion(createChatCompletionRequest, options);
|
||||
respPayload = resp.data;
|
||||
statusCode = resp.status;
|
||||
} catch (err) {
|
||||
console.error("Error in createChatCompletion");
|
||||
if ("isAxiosError" in err && err.isAxiosError) {
|
||||
errorMessage = err.response?.data?.error?.message;
|
||||
respPayload = err.response?.data;
|
||||
statusCode = err.response?.status;
|
||||
} else if ("message" in err) {
|
||||
errorMessage = err.message.toString();
|
||||
}
|
||||
throw err;
|
||||
} finally {
|
||||
this.openPipeApi
|
||||
?.externalApiReport({
|
||||
requestedAt,
|
||||
receivedAt: Date.now(),
|
||||
reqPayload: createChatCompletionRequest,
|
||||
respPayload: respPayload,
|
||||
statusCode: statusCode,
|
||||
errorMessage,
|
||||
tags: {
|
||||
client: "openai-js",
|
||||
clientVersion: version,
|
||||
},
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error("Error reporting to OP", err);
|
||||
});
|
||||
}
|
||||
|
||||
console.log("done");
|
||||
return resp;
|
||||
}
|
||||
}
|
||||
@@ -80,7 +80,6 @@ test("bad call streaming", async () => {
|
||||
stream: true,
|
||||
});
|
||||
} catch (e) {
|
||||
// @ts-expect-error need to check for error type
|
||||
await e.openpipe.reportingFinished;
|
||||
const lastLogged = await lastLoggedCall();
|
||||
expect(lastLogged?.modelResponse?.errorMessage).toEqual(
|
||||
@@ -97,9 +96,7 @@ test("bad call", async () => {
|
||||
messages: [{ role: "system", content: "count to 10" }],
|
||||
});
|
||||
} catch (e) {
|
||||
// @ts-expect-error need to check for error type
|
||||
assert("openpipe" in e);
|
||||
// @ts-expect-error need to check for error type
|
||||
await e.openpipe.reportingFinished;
|
||||
const lastLogged = await lastLoggedCall();
|
||||
expect(lastLogged?.modelResponse?.errorMessage).toEqual(
|
||||
@@ -123,8 +120,7 @@ test("caching", async () => {
|
||||
|
||||
await completion.openpipe.reportingFinished;
|
||||
const firstLogged = await lastLoggedCall();
|
||||
|
||||
expect(completion.choices[0]?.message.content).toEqual(
|
||||
expect(completion.choices[0].message.content).toEqual(
|
||||
firstLogged?.modelResponse?.respPayload.choices[0].message.content,
|
||||
);
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pkg from "./package.json";
|
||||
|
||||
import pkg from "../package.json";
|
||||
import { DefaultService } from "./codegen";
|
||||
|
||||
export type OpenPipeConfig = {
|
||||
@@ -14,12 +14,9 @@
|
||||
"isolatedModules": true,
|
||||
"incremental": true,
|
||||
"noUncheckedIndexedAccess": true,
|
||||
"noEmit": true,
|
||||
"sourceMap": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"rootDir": "."
|
||||
"baseUrl": ".",
|
||||
"outDir": "dist"
|
||||
},
|
||||
"include": ["**/*.ts"],
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules"]
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
---
|
||||
title: "How reporting works"
|
||||
description: "Our SDK wraps calls and forwards requests"
|
||||
---
|
||||
|
||||
### Does reporting calls add latency to streamed requests?
|
||||
|
||||
Streamed requests won't have any added latency. The SDK forwards each streamed token as it's received from the server while
|
||||
simultaneously collecting it in the response it will report to your OpenPipe instance once the entire response has been received.
|
||||
|
||||
#### Your OpenAI key never leaves your machine.
|
||||
|
||||
Calls to OpenAI are carried out by our SDK **on your machine**, meaning that your API key is secure, and you'll
|
||||
continue getting uninterrupted inference even if your OpenPipe instance goes down.
|
||||
|
||||
## <br />
|
||||
|
||||
### Want to dig deeper? Take a peek in our open-source code.
|
||||
|
||||
We benefit from a growing community of developers and customers who are
|
||||
dedicated to improving the OpenPipe experience. Our [open source repo](https://github.com/openpipe/openpipe)
|
||||
is an opportunity for developers to confirm the quality of our offering
|
||||
and to make improvements when they can.
|
||||
|
Before Width: | Height: | Size: 490 B |
@@ -1,8 +0,0 @@
|
||||
---
|
||||
title: "Experiments"
|
||||
description: "
|
||||
Template multiple scenarios into combinations of prompts and models to compare their output. Use flexible regex and GPT-4 evaluations to assess completion quality.
|
||||
Quickly iterate and spot model shortcomings before deployment."
|
||||
---
|
||||
|
||||
<Frame></Frame>
|
||||
@@ -1,8 +0,0 @@
|
||||
---
|
||||
title: "Export Data - Beta"
|
||||
sidebarTitle: "Export Data"
|
||||
description: "
|
||||
Export your past requests as a JSONL file in an Alpaca or OpenAI fine-tuning format or in their raw form."
|
||||
---
|
||||
|
||||
<Frame></Frame>
|
||||
@@ -1,8 +0,0 @@
|
||||
---
|
||||
title: "Fine Tuning - Beta"
|
||||
sidebarTitle: "Fine Tuning"
|
||||
description: "
|
||||
Fine tune your data on specific logs. Filter by prompt id and exclude requests with an undesirable output."
|
||||
---
|
||||
|
||||
<Frame></Frame>
|
||||
@@ -1,7 +0,0 @@
|
||||
---
|
||||
title: "Log Filters"
|
||||
description: "
|
||||
Search and filter your past LLM requests to inspect your responses and build a training dataset."
|
||||
---
|
||||
|
||||
<Frame></Frame>
|
||||
@@ -1,114 +0,0 @@
|
||||
---
|
||||
title: "Installing the SDK"
|
||||
---
|
||||
|
||||
Use the OpenPipe SDK as a drop-in replacement for the generic OpenAI package. We currently support logging OpenAI calls and support for more LLM providers will be added soon.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Python">
|
||||
|
||||
Find the SDK at https://pypi.org/project/openpipe/
|
||||
|
||||
## Simple Integration
|
||||
|
||||
Add `OPENPIPE_API_KEY` to your environment variables.
|
||||
|
||||
```bash
|
||||
export OPENPIPE_API_KEY=opk-<your-api-key>
|
||||
# Or you can set it in your code, as shown in the example below
|
||||
```
|
||||
|
||||
Replace this line
|
||||
|
||||
```python
|
||||
from openai import openai
|
||||
```
|
||||
|
||||
with this one
|
||||
|
||||
```python
|
||||
from openpipe import openai
|
||||
```
|
||||
|
||||
## Adding Searchable Tags
|
||||
|
||||
OpenPipe has a concept of "tagging." This is very useful for grouping a certain set of completions together.
|
||||
When you're using a dataset for fine-tuning, you can select all the prompts that match a certain set of tags. Here's how you can use the tagging feature:
|
||||
|
||||
```python
|
||||
from openpipe import openai, configure_openpipe
|
||||
import os
|
||||
|
||||
# If you have the `OPENPIPE_API_KEY` environment variable set
|
||||
# we'll read from it by default.
|
||||
configure_openpipe(api_key=os.getenv("OPENPIPE_API_KEY"))
|
||||
|
||||
# Configure OpenAI the same way you would normally
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
openpipe={"tags": {"prompt_id": "counting", "any_key": "any_value"}},
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="NodeJS">
|
||||
|
||||
Find the SDK at https://www.npmjs.com/package/openpipe
|
||||
|
||||
## Simple Integration
|
||||
|
||||
Add `OPENPIPE_API_KEY` to your environment variables.
|
||||
|
||||
```bash
|
||||
export OPENPIPE_API_KEY=opk-<your-api-key>
|
||||
# Or you can set it in your code, as shown in the example below
|
||||
```
|
||||
|
||||
Replace this line
|
||||
|
||||
```typescript
|
||||
import OpenAI from "openai";
|
||||
```
|
||||
|
||||
with this one
|
||||
|
||||
```typescript
|
||||
import OpenAI from "openpipe/openai";
|
||||
```
|
||||
|
||||
## Adding Searchable Tags
|
||||
|
||||
OpenPipe has a concept of "tagging." This is very useful for grouping a certain set of completions together.
|
||||
When you're using a dataset for fine-tuning, you can select all the prompts that match a certain set of tags. Here's how you can use the tagging feature:
|
||||
|
||||
```typescript
|
||||
// Fully compatible with original OpenAI initialization
|
||||
const openai = new OpenAI({
|
||||
apiKey: "my api key", // defaults to process.env["OPENAI_API_KEY"]
|
||||
// openpipe key is optional
|
||||
openpipe: {
|
||||
apiKey: "my api key", // defaults to process.env["OPENPIPE_API_KEY"]
|
||||
baseUrl: "my url", // defaults to process.env["OPENPIPE_BASE_URL"] or https://app.openpipe.ai/api/v1 if not set
|
||||
},
|
||||
});
|
||||
|
||||
const completion = await openai.chat.completions.create({
|
||||
messages: [{ role: "user", content: "Count to 10" }],
|
||||
model: "gpt-3.5-turbo",
|
||||
// optional
|
||||
openpipe: {
|
||||
// Add custom searchable tags
|
||||
tags: {
|
||||
prompt_id: "counting",
|
||||
any_key: "any_value",
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
@@ -1,35 +0,0 @@
|
||||
---
|
||||
title: "Quick Start"
|
||||
description: "Get started with OpenPipe in a few quick steps."
|
||||
---
|
||||
|
||||
## Step 1: Create your OpenPipe Account
|
||||
|
||||
If you don't already have one, create an account with OpenPipe at https://app.openpipe.ai/. You can sign up with GitHub, so you don't need to remember an extra password.
|
||||
|
||||
## Step 2: Find your Project API key
|
||||
|
||||
In order to capture your calls and fine-tune a model on them, we need an API key to authenticate you and determine which project to store your logs under.
|
||||
|
||||
<Note>
|
||||
When you created your account, a project was automatically configured for you as well. Find its
|
||||
API key at https://app.openpipe.ai/project/settings.
|
||||
</Note>
|
||||
|
||||
## Step 3: Integrate the OpenPipe SDK
|
||||
|
||||
You're done with the hard part! Learn how to integrate the OpenPipe SDK on the next page.
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="OpenPipe SDK"
|
||||
icon={
|
||||
<svg role="img" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
|
||||
<title>OpenPipe</title>
|
||||
<path d="M22.2819 9.8211a5.9847 5.9847 0 0 0-.5157-4.9108 6.0462 6.0462 0 0 0-6.5098-2.9A6.0651 6.0651 0 0 0 4.9807 4.1818a5.9847 5.9847 0 0 0-3.9977 2.9 6.0462 6.0462 0 0 0 .7427 7.0966 5.98 5.98 0 0 0 .511 4.9107 6.051 6.051 0 0 0 6.5146 2.9001A5.9847 5.9847 0 0 0 13.2599 24a6.0557 6.0557 0 0 0 5.7718-4.2058 5.9894 5.9894 0 0 0 3.9977-2.9001 6.0557 6.0557 0 0 0-.7475-7.0729zm-9.022 12.6081a4.4755 4.4755 0 0 1-2.8764-1.0408l.1419-.0804 4.7783-2.7582a.7948.7948 0 0 0 .3927-.6813v-6.7369l2.02 1.1686a.071.071 0 0 1 .038.052v5.5826a4.504 4.504 0 0 1-4.4945 4.4944zm-9.6607-4.1254a4.4708 4.4708 0 0 1-.5346-3.0137l.142.0852 4.783 2.7582a.7712.7712 0 0 0 .7806 0l5.8428-3.3685v2.3324a.0804.0804 0 0 1-.0332.0615L9.74 19.9502a4.4992 4.4992 0 0 1-6.1408-1.6464zM2.3408 7.8956a4.485 4.485 0 0 1 2.3655-1.9728V11.6a.7664.7664 0 0 0 .3879.6765l5.8144 3.3543-2.0201 1.1685a.0757.0757 0 0 1-.071 0l-4.8303-2.7865A4.504 4.504 0 0 1 2.3408 7.872zm16.5963 3.8558L13.1038 8.364 15.1192 7.2a.0757.0757 0 0 1 .071 0l4.8303 2.7913a4.4944 4.4944 0 0 1-.6765 8.1042v-5.6772a.79.79 0 0 0-.407-.667zm2.0107-3.0231l-.142-.0852-4.7735-2.7818a.7759.7759 0 0 0-.7854 0L9.409 9.2297V6.8974a.0662.0662 0 0 1 .0284-.0615l4.8303-2.7866a4.4992 4.4992 0 0 1 6.6802 4.66zM8.3065 12.863l-2.02-1.1638a.0804.0804 0 0 1-.038-.0567V6.0742a4.4992 4.4992 0 0 1 7.3757-3.4537l-.142.0805L8.704 5.459a.7948.7948 0 0 0-.3927.6813zm1.0976-2.3654l2.602-1.4998 2.6069 1.4998v2.9994l-2.5974 1.4997-2.6067-1.4997Z" />
|
||||
</svg>
|
||||
}
|
||||
iconType="duotone"
|
||||
href="/getting-started/openpipe-sdk"
|
||||
></Card>
|
||||
</CardGroup>
|
||||
|
Before Width: | Height: | Size: 416 KiB |
|
Before Width: | Height: | Size: 414 KiB |
|
Before Width: | Height: | Size: 404 KiB |
|
Before Width: | Height: | Size: 321 KiB |
|
Before Width: | Height: | Size: 390 KiB |
@@ -1,18 +0,0 @@
|
||||
---
|
||||
title: "OpenPipe Documentation"
|
||||
sidebarTitle: "Introduction"
|
||||
description: "
|
||||
Product-focused teams use OpenPipe's seamless fine-tuning and monitoring services to decrease the cost and latency of their LLM operations.
|
||||
You can use OpenPipe to collect and analyze LLM logs, create fine-tuned models, and compare output from multiple models given the same input."
|
||||
---
|
||||
|
||||
<Frame></Frame>
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Get Started" icon="code">
|
||||
Quickly integrate the OpenPipe SDK into your application and start collecting data.
|
||||
</Card>
|
||||
<Card title="Features" icon="lightbulb">
|
||||
View the platform features OpenPipe provides and learn how to use them.
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
Before Width: | Height: | Size: 8.3 KiB |
|
Before Width: | Height: | Size: 8.3 KiB |
@@ -1,65 +0,0 @@
|
||||
{
|
||||
"name": "OpenPipe",
|
||||
"logo": {
|
||||
"light": "/logo/light.svg",
|
||||
"dark": "/logo/dark.svg"
|
||||
},
|
||||
"favicon": "/favicon.webp",
|
||||
"colors": {
|
||||
"primary": "#FF5733",
|
||||
"light": "#FF5733",
|
||||
"dark": "#FF5733"
|
||||
},
|
||||
"modeToggle": {
|
||||
"default": "light"
|
||||
},
|
||||
"topbarCtaButton": {
|
||||
"name": "Sign In",
|
||||
"url": "https://app.openpipe.ai"
|
||||
},
|
||||
"anchors": [
|
||||
{
|
||||
"name": "GitHub",
|
||||
"icon": "github",
|
||||
"url": "https://github.com/openpipe/openpipe"
|
||||
}
|
||||
],
|
||||
"feedback": {
|
||||
"suggestEdit": true,
|
||||
"raiseIssue": true
|
||||
},
|
||||
"navigation": [
|
||||
{
|
||||
"group": "Welcome",
|
||||
"pages": ["introduction", "overview"]
|
||||
},
|
||||
{
|
||||
"group": "Getting Started",
|
||||
"pages": ["getting-started/quick-start", "getting-started/openpipe-sdk"]
|
||||
},
|
||||
{
|
||||
"group": "Features",
|
||||
"pages": [
|
||||
"features/log-filters",
|
||||
"features/exporting-data",
|
||||
"features/fine-tuning",
|
||||
"features/experiments"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "FAQ",
|
||||
"pages": ["faq/how-reporting-works"]
|
||||
}
|
||||
],
|
||||
"topbarLinks": [
|
||||
{
|
||||
"name": "Github",
|
||||
"url": "https://github.com/OpenPipe/OpenPipe"
|
||||
}
|
||||
],
|
||||
"footerSocials": {
|
||||
"twitter": "https://twitter.com/OpenPipeAI",
|
||||
"linkedin": "https://www.linkedin.com/company/openpipe/about/",
|
||||
"github": "https://github.com/OpenPipe/OpenPipe"
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
---
|
||||
title: "Overview"
|
||||
description: "OpenPipe is a streamlined platform designed to help product-focused teams train specialized LLM models as replacements for slow and expensive prompts."
|
||||
---
|
||||
|
||||
## Who We Are
|
||||
|
||||
We're a team of full-stack engineers and machine learning researchers working to streamline the process of integrating fine-tuned models into any application. Our goal is to make the fine-tuning process accessible to everyone.
|
||||
|
||||
## What We Provide
|
||||
|
||||
Here are a few of the features we offer:
|
||||
|
||||
- **Data Capture**: OpenPipe automatically captures every request and response sent through our drop-in replacement sdk and stores it for your future use.
|
||||
|
||||
- **Monitoring**: OpenPipe provides intuitive tools to view the frequency and cost of your LLM requests, and provides a special tool for viewing requests with error status codes.
|
||||
|
||||
- **Searchable Logs**: We enable you to search your past requests, and provide a simple protocol for tagging them by prompt id for easy filtering.
|
||||
|
||||
- **Fine-Tuning**: With all your LLM requests and responses in one place, it's easy to select the data you want to fine-tune on and kick off a job.
|
||||
|
||||
- **Model Hosting**: After we've trained your model, OpenPipe will automatically begin hosting it. Accessing your model will require an API key from your project.
|
||||
|
||||
- **Unified SDK**: Switching requests from your previous LLM provider to your new model is as simple as changing the model name. All our models implement the OpenAI inference format, so you won't have to change how you parse its response.
|
||||
|
||||
- **Data Export**: OpenPipe allows you to download your request logs or the fine-tuned models you've trained at any time for easy self-hosting.
|
||||
|
||||
- **Experimentation**: The fine-tunes you've created on OpenPipe are immediately available for you to run inference on in our experimentation playground.
|
||||
|
||||
Welcome to the OpenPipe community!
|
||||
4
examples/.gitignore
vendored
@@ -1,4 +0,0 @@
|
||||
axolotl/
|
||||
models/
|
||||
data/
|
||||
wandb/
|
||||
@@ -1,7 +0,0 @@
|
||||
OPENAI_API_KEY="[your OpenAI API key]"
|
||||
OPENPIPE_API_KEY="[your OpenPipe API key from https://app.openpipe.ai/project/settings]"
|
||||
|
||||
# You'll need this to download the Llama 2 weights from Hugging Face
|
||||
HUGGING_FACE_HUB_TOKEN="[Your Hugging Face Hub token]"
|
||||
|
||||
WANDB_API_KEY="[Optionally, you can set a Weights & Biases API key to track your training run. Create it at https://wandb.ai/settings]"
|
||||
@@ -1,473 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Current Time: 2023-08-24 21:25:06\n",
|
||||
"Current Time: 2023-08-24 21:25:36\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
"while True:\n",
|
||||
" current_time = time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime())\n",
|
||||
" print(f\"Current Time: {current_time}\")\n",
|
||||
" time.sleep(30)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
|
||||
"\n",
|
||||
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
|
||||
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
|
||||
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
|
||||
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
|
||||
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
|
||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
|
||||
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
|
||||
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
|
||||
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
|
||||
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
|
||||
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
|
||||
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
|
||||
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
|
||||
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
|
||||
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
|
||||
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
|
||||
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
|
||||
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
|
||||
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
|
||||
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
|
||||
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
|
||||
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
|
||||
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
|
||||
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
|
||||
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
|
||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
|
||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
|
||||
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
|
||||
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
|
||||
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
|
||||
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
|
||||
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
|
||||
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
|
||||
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
|
||||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
|
||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
|
||||
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
|
||||
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
|
||||
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
|
||||
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
|
||||
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
|
||||
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
|
||||
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
|
||||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
|
||||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
|
||||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
|
||||
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
|
||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install datasets==2.14.4 vllm==0.1.3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of recipes: 2,147,248\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
|
||||
"\n",
|
||||
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
|
||||
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from vllm import LLM, SamplingParams\n",
|
||||
"\n",
|
||||
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
|
||||
"\n",
|
||||
"sampling_params = SamplingParams(\n",
|
||||
" # 120 should be fine for the work we're doing here.\n",
|
||||
" max_tokens=120,\n",
|
||||
" # This is a deterministic task so temperature=0 is best.\n",
|
||||
" temperature=0,\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Start time: 1692906050.3340027\n",
|
||||
"Processing recipes 0 to 10,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 10,000 to 20,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 20,000 to 30,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 30,000 to 40,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 40,000 to 50,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 48%|████▊ | 4796/10000 [02:21<03:18, 26.22it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(all_recipes), BATCH_SIZE):\n\u001b[1;32m 11\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mProcessing recipes \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m+\u001b[39mBATCH_SIZE\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[39m=\u001b[39m llm\u001b[39m.\u001b[39;49mgenerate(all_recipes[i:i\u001b[39m+\u001b[39;49mBATCH_SIZE], sampling_params\u001b[39m=\u001b[39;49msampling_params)\n\u001b[1;32m 14\u001b[0m all_outputs\u001b[39m.\u001b[39mextend([o\u001b[39m.\u001b[39moutputs[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mtext \u001b[39mfor\u001b[39;00m o \u001b[39min\u001b[39;00m outputs])\n\u001b[1;32m 16\u001b[0m end_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:130\u001b[0m, in \u001b[0;36mLLM.generate\u001b[0;34m(self, prompts, sampling_params, prompt_token_ids, use_tqdm)\u001b[0m\n\u001b[1;32m 128\u001b[0m token_ids \u001b[39m=\u001b[39m prompt_token_ids[i]\n\u001b[1;32m 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_add_request(prompt, sampling_params, token_ids)\n\u001b[0;32m--> 130\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_engine(use_tqdm)\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:150\u001b[0m, in \u001b[0;36mLLM._run_engine\u001b[0;34m(self, use_tqdm)\u001b[0m\n\u001b[1;32m 148\u001b[0m outputs: List[RequestOutput] \u001b[39m=\u001b[39m []\n\u001b[1;32m 149\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mllm_engine\u001b[39m.\u001b[39mhas_unfinished_requests():\n\u001b[0;32m--> 150\u001b[0m step_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mllm_engine\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 151\u001b[0m \u001b[39mfor\u001b[39;00m output \u001b[39min\u001b[39;00m step_outputs:\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m output\u001b[39m.\u001b[39mfinished:\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:313\u001b[0m, in \u001b[0;36mLLMEngine.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 308\u001b[0m RequestOutput\u001b[39m.\u001b[39mfrom_seq_group(seq_group)\n\u001b[1;32m 309\u001b[0m \u001b[39mfor\u001b[39;00m seq_group \u001b[39min\u001b[39;00m scheduler_outputs\u001b[39m.\u001b[39mignored_seq_groups\n\u001b[1;32m 310\u001b[0m ]\n\u001b[1;32m 312\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_workers(\n\u001b[1;32m 314\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mexecute_model\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 315\u001b[0m seq_group_metadata_list\u001b[39m=\u001b[39;49mseq_group_metadata_list,\n\u001b[1;32m 316\u001b[0m blocks_to_swap_in\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_in,\n\u001b[1;32m 317\u001b[0m blocks_to_swap_out\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_out,\n\u001b[1;32m 318\u001b[0m blocks_to_copy\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_copy,\n\u001b[1;32m 319\u001b[0m )\n\u001b[1;32m 320\u001b[0m \u001b[39m# Update the scheduler with the model outputs.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m seq_groups \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscheduler\u001b[39m.\u001b[39mupdate(output)\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:470\u001b[0m, in \u001b[0;36mLLMEngine._run_workers\u001b[0;34m(self, method, get_all_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 468\u001b[0m executor \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(worker, method)\n\u001b[0;32m--> 470\u001b[0m output \u001b[39m=\u001b[39m executor(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 471\u001b[0m all_outputs\u001b[39m.\u001b[39mappend(output)\n\u001b[1;32m 473\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparallel_config\u001b[39m.\u001b[39mworker_use_ray:\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py:293\u001b[0m, in \u001b[0;36mWorker.execute_model\u001b[0;34m(self, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)\u001b[0m\n\u001b[1;32m 289\u001b[0m input_tokens, input_positions, input_metadata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_inputs(\n\u001b[1;32m 290\u001b[0m seq_group_metadata_list)\n\u001b[1;32m 292\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 293\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 294\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_tokens,\n\u001b[1;32m 295\u001b[0m positions\u001b[39m=\u001b[39;49minput_positions,\n\u001b[1;32m 296\u001b[0m kv_caches\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgpu_cache,\n\u001b[1;32m 297\u001b[0m input_metadata\u001b[39m=\u001b[39;49minput_metadata,\n\u001b[1;32m 298\u001b[0m cache_events\u001b[39m=\u001b[39;49mcache_events,\n\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py:255\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, positions, kv_caches, input_metadata, cache_events)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 247\u001b[0m input_ids: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 251\u001b[0m cache_events: Optional[List[torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mEvent]],\n\u001b[1;32m 252\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 253\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(input_ids, positions, kv_caches,\n\u001b[1;32m 254\u001b[0m input_metadata, cache_events)\n\u001b[0;32m--> 255\u001b[0m next_tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msampler(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlm_head\u001b[39m.\u001b[39;49mweight, hidden_states,\n\u001b[1;32m 256\u001b[0m input_metadata)\n\u001b[1;32m 257\u001b[0m \u001b[39mreturn\u001b[39;00m next_tokens\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py:44\u001b[0m, in \u001b[0;36mSampler.forward\u001b[0;34m(self, embedding, hidden_states, input_metadata, embedding_bias)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 37\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 38\u001b[0m embedding: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 43\u001b[0m \u001b[39m# Get the hidden states that we use for sampling.\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m hidden_states \u001b[39m=\u001b[39m _prune_hidden_states(hidden_states, input_metadata)\n\u001b[1;32m 46\u001b[0m \u001b[39m# Get the logits for the next tokens.\u001b[39;00m\n\u001b[1;32m 47\u001b[0m logits \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmatmul(hidden_states, embedding\u001b[39m.\u001b[39mt())\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# We'll process our recipes in batches of 10,000.\n",
|
||||
"\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"BATCH_SIZE = 10000\n",
|
||||
"all_outputs = []\n",
|
||||
"\n",
|
||||
"start_time = time.time()\n",
|
||||
"print(f\"Start time: {start_time}\")\n",
|
||||
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
|
||||
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
|
||||
" outputs = llm.generate(\n",
|
||||
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
|
||||
"\n",
|
||||
"end_time = time.time()\n",
|
||||
"print(f\"End time: {end_time}\")\n",
|
||||
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Model</th>\n",
|
||||
" <th>Cost to Classify One Recipe</th>\n",
|
||||
" <th>Cost to Classify Entire Dataset</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>Llama 2 7B (finetuned)</td>\n",
|
||||
" <td>0.000009</td>\n",
|
||||
" <td>18.86</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>GPT-3.5</td>\n",
|
||||
" <td>0.000481</td>\n",
|
||||
" <td>1,033.26</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>GPT-3.5 (finetuned)</td>\n",
|
||||
" <td>0.004044</td>\n",
|
||||
" <td>8,683.47</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>GPT-4</td>\n",
|
||||
" <td>0.010800</td>\n",
|
||||
" <td>23,190.28</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Model Cost to Classify One Recipe \\\n",
|
||||
"0 Llama 2 7B (finetuned) 0.000009 \n",
|
||||
"1 GPT-3.5 0.000481 \n",
|
||||
"2 GPT-3.5 (finetuned) 0.004044 \n",
|
||||
"3 GPT-4 0.010800 \n",
|
||||
"\n",
|
||||
" Cost to Classify Entire Dataset \n",
|
||||
"0 18.86 \n",
|
||||
"1 1,033.26 \n",
|
||||
"2 8,683.47 \n",
|
||||
"3 23,190.28 "
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.\n",
|
||||
"finetuned_hourly_cost = 1.14\n",
|
||||
"\n",
|
||||
"finetuned_total_hours = 16.54\n",
|
||||
"\n",
|
||||
"finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)\n",
|
||||
"\n",
|
||||
"# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them\n",
|
||||
"avg_input_tokens = 276\n",
|
||||
"avg_output_tokens = 42\n",
|
||||
"\n",
|
||||
"# Token pricing from https://openai.com/pricing\n",
|
||||
"gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000\n",
|
||||
"\n",
|
||||
"gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000\n",
|
||||
"\n",
|
||||
"gpt_35_finetuned_avg_cost = (\n",
|
||||
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Multiply the number of recipes\n",
|
||||
"# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost\n",
|
||||
"# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost\n",
|
||||
"# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost\n",
|
||||
"\n",
|
||||
"# Let's put this in a dataframe for easier comparison.\n",
|
||||
"\n",
|
||||
"costs = pd.DataFrame(\n",
|
||||
" {\n",
|
||||
" \"Model\": [\n",
|
||||
" \"Llama 2 7B (finetuned)\",\n",
|
||||
" \"GPT-3.5\",\n",
|
||||
" \"GPT-3.5 (finetuned)\",\n",
|
||||
" \"GPT-4\",\n",
|
||||
" ],\n",
|
||||
" \"Cost to Classify One Recipe\": [\n",
|
||||
" finetuned_avg_cost,\n",
|
||||
" gpt_35_avg_cost,\n",
|
||||
" gpt_35_finetuned_avg_cost,\n",
|
||||
" gpt_4_avg_cost,\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
|
||||
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
|
||||
").map(lambda x: f\"{x:,.2f}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"costs\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
# OpenPipe demo: fine-tuning your own model
|
||||
|
||||
Hi there! This repository should give you a brief overview of how to fine-tune a competitive model from start to finish. You should review the notebooks in this directory in the following order:
|
||||
|
||||
1. [./generate-data.ipynb](./generate-data.ipynb): Demonstrates how to generate a sample dataset of GPT-4 completions, store it using OpenPipe, and then export it in a format suitable for training a model.
|
||||
2. [./train.ipynb](./train.ipynb): Trains a Llama 2 7B model on the dataset from step (1).
|
||||
3. [./evaluate.ipynb](./evaluate.ipynb): Evaluates the model we trained using a special test set that we set aside in step (1).
|
||||
4. [./benchmark.ipynb](./benchmark.ipynb): A script to compare costs and completion latencies between our fine-tuned model, GPT-3.5, and GPT-4.
|
||||
|
||||
If you want to follow along yourself, I recommend using [RunPod](https://www.runpod.io/). The training scripts we use will run on any of their GPUs with 24GB of vRAM or more.
|
||||
@@ -1,432 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
|
||||
"\n",
|
||||
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
|
||||
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
|
||||
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
|
||||
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
|
||||
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
|
||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
|
||||
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
|
||||
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
|
||||
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
|
||||
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
|
||||
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
|
||||
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
|
||||
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
|
||||
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
|
||||
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
|
||||
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
|
||||
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
|
||||
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
|
||||
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
|
||||
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
|
||||
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
|
||||
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
|
||||
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
|
||||
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
|
||||
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
|
||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
|
||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
|
||||
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
|
||||
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
|
||||
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
|
||||
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
|
||||
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
|
||||
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
|
||||
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
|
||||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
|
||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
|
||||
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
|
||||
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
|
||||
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
|
||||
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
|
||||
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
|
||||
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
|
||||
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
|
||||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
|
||||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
|
||||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
|
||||
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
|
||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install datasets==2.14.4 vllm==0.1.3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of recipes: 2,147,248\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
|
||||
"\n",
|
||||
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
|
||||
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from vllm import LLM, SamplingParams\n",
|
||||
"\n",
|
||||
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
|
||||
"\n",
|
||||
"sampling_params = SamplingParams(\n",
|
||||
" # 120 should be fine for the work we're doing here.\n",
|
||||
" max_tokens=120,\n",
|
||||
" # This is a deterministic task so temperature=0 is best.\n",
|
||||
" temperature=0,\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Start time: 1692906050.3340027\n",
|
||||
"Processing recipes 0 to 10,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 10,000 to 20,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 20,000 to 30,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 30,000 to 40,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processing recipes 40,000 to 50,000...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 48%|████▊ | 4796/10000 [02:21<03:18, 26.22it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(all_recipes), BATCH_SIZE):\n\u001b[1;32m 11\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mProcessing recipes \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m+\u001b[39mBATCH_SIZE\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[39m=\u001b[39m llm\u001b[39m.\u001b[39;49mgenerate(all_recipes[i:i\u001b[39m+\u001b[39;49mBATCH_SIZE], sampling_params\u001b[39m=\u001b[39;49msampling_params)\n\u001b[1;32m 14\u001b[0m all_outputs\u001b[39m.\u001b[39mextend([o\u001b[39m.\u001b[39moutputs[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mtext \u001b[39mfor\u001b[39;00m o \u001b[39min\u001b[39;00m outputs])\n\u001b[1;32m 16\u001b[0m end_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:130\u001b[0m, in \u001b[0;36mLLM.generate\u001b[0;34m(self, prompts, sampling_params, prompt_token_ids, use_tqdm)\u001b[0m\n\u001b[1;32m 128\u001b[0m token_ids \u001b[39m=\u001b[39m prompt_token_ids[i]\n\u001b[1;32m 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_add_request(prompt, sampling_params, token_ids)\n\u001b[0;32m--> 130\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_engine(use_tqdm)\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:150\u001b[0m, in \u001b[0;36mLLM._run_engine\u001b[0;34m(self, use_tqdm)\u001b[0m\n\u001b[1;32m 148\u001b[0m outputs: List[RequestOutput] \u001b[39m=\u001b[39m []\n\u001b[1;32m 149\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mllm_engine\u001b[39m.\u001b[39mhas_unfinished_requests():\n\u001b[0;32m--> 150\u001b[0m step_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mllm_engine\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 151\u001b[0m \u001b[39mfor\u001b[39;00m output \u001b[39min\u001b[39;00m step_outputs:\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m output\u001b[39m.\u001b[39mfinished:\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:313\u001b[0m, in \u001b[0;36mLLMEngine.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 308\u001b[0m RequestOutput\u001b[39m.\u001b[39mfrom_seq_group(seq_group)\n\u001b[1;32m 309\u001b[0m \u001b[39mfor\u001b[39;00m seq_group \u001b[39min\u001b[39;00m scheduler_outputs\u001b[39m.\u001b[39mignored_seq_groups\n\u001b[1;32m 310\u001b[0m ]\n\u001b[1;32m 312\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_workers(\n\u001b[1;32m 314\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mexecute_model\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 315\u001b[0m seq_group_metadata_list\u001b[39m=\u001b[39;49mseq_group_metadata_list,\n\u001b[1;32m 316\u001b[0m blocks_to_swap_in\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_in,\n\u001b[1;32m 317\u001b[0m blocks_to_swap_out\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_out,\n\u001b[1;32m 318\u001b[0m blocks_to_copy\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_copy,\n\u001b[1;32m 319\u001b[0m )\n\u001b[1;32m 320\u001b[0m \u001b[39m# Update the scheduler with the model outputs.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m seq_groups \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscheduler\u001b[39m.\u001b[39mupdate(output)\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:470\u001b[0m, in \u001b[0;36mLLMEngine._run_workers\u001b[0;34m(self, method, get_all_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 468\u001b[0m executor \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(worker, method)\n\u001b[0;32m--> 470\u001b[0m output \u001b[39m=\u001b[39m executor(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 471\u001b[0m all_outputs\u001b[39m.\u001b[39mappend(output)\n\u001b[1;32m 473\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparallel_config\u001b[39m.\u001b[39mworker_use_ray:\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py:293\u001b[0m, in \u001b[0;36mWorker.execute_model\u001b[0;34m(self, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)\u001b[0m\n\u001b[1;32m 289\u001b[0m input_tokens, input_positions, input_metadata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_inputs(\n\u001b[1;32m 290\u001b[0m seq_group_metadata_list)\n\u001b[1;32m 292\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 293\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 294\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_tokens,\n\u001b[1;32m 295\u001b[0m positions\u001b[39m=\u001b[39;49minput_positions,\n\u001b[1;32m 296\u001b[0m kv_caches\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgpu_cache,\n\u001b[1;32m 297\u001b[0m input_metadata\u001b[39m=\u001b[39;49minput_metadata,\n\u001b[1;32m 298\u001b[0m cache_events\u001b[39m=\u001b[39;49mcache_events,\n\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py:255\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, positions, kv_caches, input_metadata, cache_events)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 247\u001b[0m input_ids: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 251\u001b[0m cache_events: Optional[List[torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mEvent]],\n\u001b[1;32m 252\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 253\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(input_ids, positions, kv_caches,\n\u001b[1;32m 254\u001b[0m input_metadata, cache_events)\n\u001b[0;32m--> 255\u001b[0m next_tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msampler(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlm_head\u001b[39m.\u001b[39;49mweight, hidden_states,\n\u001b[1;32m 256\u001b[0m input_metadata)\n\u001b[1;32m 257\u001b[0m \u001b[39mreturn\u001b[39;00m next_tokens\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py:44\u001b[0m, in \u001b[0;36mSampler.forward\u001b[0;34m(self, embedding, hidden_states, input_metadata, embedding_bias)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 37\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 38\u001b[0m embedding: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 43\u001b[0m \u001b[39m# Get the hidden states that we use for sampling.\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m hidden_states \u001b[39m=\u001b[39m _prune_hidden_states(hidden_states, input_metadata)\n\u001b[1;32m 46\u001b[0m \u001b[39m# Get the logits for the next tokens.\u001b[39;00m\n\u001b[1;32m 47\u001b[0m logits \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmatmul(hidden_states, embedding\u001b[39m.\u001b[39mt())\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# We'll process our recipes in batches of 10,000.\n",
|
||||
"\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"BATCH_SIZE = 10000\n",
|
||||
"all_outputs = []\n",
|
||||
"\n",
|
||||
"start_time = time.time()\n",
|
||||
"print(f\"Start time: {start_time}\")\n",
|
||||
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
|
||||
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
|
||||
" outputs = llm.generate(\n",
|
||||
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
|
||||
"\n",
|
||||
"end_time = time.time()\n",
|
||||
"print(f\"End time: {end_time}\")\n",
|
||||
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n",
|
||||
"\n",
|
||||
"# Ended up running this in a separate script to leave it running in the background.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Model</th>\n",
|
||||
" <th>Cost to Classify One Recipe</th>\n",
|
||||
" <th>Cost to Classify Entire Dataset</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>Llama 2 7B (finetuned)</td>\n",
|
||||
" <td>0.000009</td>\n",
|
||||
" <td>18.86</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>GPT-3.5</td>\n",
|
||||
" <td>0.000481</td>\n",
|
||||
" <td>1,033.26</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>GPT-3.5 (finetuned)</td>\n",
|
||||
" <td>0.004044</td>\n",
|
||||
" <td>8,683.47</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>GPT-4</td>\n",
|
||||
" <td>0.010800</td>\n",
|
||||
" <td>23,190.28</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Model Cost to Classify One Recipe \\\n",
|
||||
"0 Llama 2 7B (finetuned) 0.000009 \n",
|
||||
"1 GPT-3.5 0.000481 \n",
|
||||
"2 GPT-3.5 (finetuned) 0.004044 \n",
|
||||
"3 GPT-4 0.010800 \n",
|
||||
"\n",
|
||||
" Cost to Classify Entire Dataset \n",
|
||||
"0 18.86 \n",
|
||||
"1 1,033.26 \n",
|
||||
"2 8,683.47 \n",
|
||||
"3 23,190.28 "
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.\n",
|
||||
"finetuned_hourly_cost = 1.14\n",
|
||||
"\n",
|
||||
"finetuned_total_hours = 16.5\n",
|
||||
"\n",
|
||||
"finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)\n",
|
||||
"\n",
|
||||
"# The average input and output tokens for OpenAI, based on the 5000 recipes I\n",
|
||||
"# sent them when generating training data.\n",
|
||||
"avg_input_tokens = 276\n",
|
||||
"avg_output_tokens = 42\n",
|
||||
"\n",
|
||||
"# Token pricing from https://openai.com/pricing\n",
|
||||
"gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000\n",
|
||||
"\n",
|
||||
"gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000\n",
|
||||
"\n",
|
||||
"gpt_35_finetuned_avg_cost = (\n",
|
||||
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"costs = pd.DataFrame(\n",
|
||||
" {\n",
|
||||
" \"Model\": [\n",
|
||||
" \"Llama 2 7B (finetuned)\",\n",
|
||||
" \"GPT-3.5\",\n",
|
||||
" \"GPT-3.5 (finetuned)\",\n",
|
||||
" \"GPT-4\",\n",
|
||||
" ],\n",
|
||||
" \"Cost to Classify One Recipe\": [\n",
|
||||
" finetuned_avg_cost,\n",
|
||||
" gpt_35_avg_cost,\n",
|
||||
" gpt_35_finetuned_avg_cost,\n",
|
||||
" gpt_4_avg_cost,\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
|
||||
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
|
||||
").map(lambda x: f\"{x:,.2f}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"costs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,663 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
|
||||
"Requirement already satisfied: pandas==2.0.3 in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
|
||||
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
|
||||
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
|
||||
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
|
||||
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
|
||||
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.24.4)\n",
|
||||
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
|
||||
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
|
||||
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
|
||||
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
|
||||
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
|
||||
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2.8.2)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
|
||||
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
|
||||
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2->vllm==0.1.3) (4.7.1)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas==2.0.3) (1.16.0)\n",
|
||||
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
|
||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (3.9.0)\n",
|
||||
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
|
||||
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
|
||||
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (23.1)\n",
|
||||
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
|
||||
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (6.0)\n",
|
||||
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.3.1)\n",
|
||||
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.4.0)\n",
|
||||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (2.28.1)\n",
|
||||
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
|
||||
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
|
||||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
|
||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
|
||||
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
|
||||
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
|
||||
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.16.4)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
|
||||
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (4.66.1)\n",
|
||||
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
|
||||
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
|
||||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers>=4.31.0->vllm==0.1.3) (2023.6.0)\n",
|
||||
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
|
||||
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (23.1.0)\n",
|
||||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
|
||||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
|
||||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
|
||||
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2.1.1)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (3.4)\n",
|
||||
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (1.26.13)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2022.12.7)\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
|
||||
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
|
||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
|
||||
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install vllm==0.1.3 pandas==2.0.3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"test_data = pd.read_json(\"./data/test.jsonl\", lines=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample prompt:\n",
|
||||
"--------------\n",
|
||||
"### Instruction:\n",
|
||||
"[{\"role\":\"system\",\"content\":\"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\"},{\"role\":\"user\",\"content\":\"Pan Gravy\\n\\nIngredients:\\n- 1/3 cup all purpose flour\\n- 1/3 cup turkey drippings\\n- 3 cup water or broth\\n- 1/8 to 1/4 teaspoon salt\\n- 1/8 tsp pepper\\n\\nDirections:\\n- In a skillet or roasting pan, add flour to drippings; blend well.\\n- Cook over medium heat 2 to 3 minutes until smooth and light brown, stirring constantly.\\n- Add water; cook until mixture boils and thickens, stirring constantly.\\n- Stir in salt and pepper.\\n- *Flour and drippings can be decreased to 1/4 cup each for thinner gravy.\\n- *\"}]\n",
|
||||
"\n",
|
||||
"### Response:\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from axolotl.prompters import UnpromptedPrompter\n",
|
||||
"\n",
|
||||
"prompter = UnpromptedPrompter()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def format_prompt(input: str) -> str:\n",
|
||||
" return next(prompter.build_prompt(input))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"prompts = test_data[\"instruction\"].apply(format_prompt)\n",
|
||||
"\n",
|
||||
"print(f\"Sample prompt:\\n--------------\\n{prompts[0]}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 08-25 03:58:49 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
|
||||
"INFO 08-25 03:59:40 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.42it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample output:\n",
|
||||
"--------------\n",
|
||||
"{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"classify\",\"arguments\":\"{\\n\\\"has_non_fish_meat\\\": true,\\n\\\"requires_oven\\\": false,\\n\\\"requires_stove\\\": true,\\n\\\"cook_time_over_30_mins\\\": false,\\n\\\"main_dish\\\": false\\n}\"}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from vllm import LLM, SamplingParams\n",
|
||||
"\n",
|
||||
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
|
||||
"\n",
|
||||
"sampling_params = SamplingParams(\n",
|
||||
" # 120 should be fine for the work we're doing here.\n",
|
||||
" max_tokens=120,\n",
|
||||
" # This is a deterministic task so temperature=0 is best.\n",
|
||||
" temperature=0,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"my_outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
|
||||
"my_outputs = [o.outputs[0].text for o in my_outputs]\n",
|
||||
"\n",
|
||||
"test_data[\"my_outputs\"] = my_outputs\n",
|
||||
"\n",
|
||||
"print(f\"Sample output:\\n--------------\\n{my_outputs[0]}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Overall accuracy: 0.95\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def parse_fn_call(str):\n",
|
||||
" \"\"\"Parse the function call arguments from the response\"\"\"\n",
|
||||
" response_dict = json.loads(str)\n",
|
||||
" args_dict = json.loads(response_dict[\"function_call\"][\"arguments\"])\n",
|
||||
"\n",
|
||||
" return args_dict\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def calculate_accuracy(row):\n",
|
||||
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
|
||||
" true_outputs = parse_fn_call(row[\"output\"])\n",
|
||||
" my_outputs = parse_fn_call(row[\"my_outputs\"])\n",
|
||||
"\n",
|
||||
" num_matching_outputs = 0\n",
|
||||
" for key in true_outputs.keys():\n",
|
||||
" if key in my_outputs and true_outputs[key] == my_outputs[key]:\n",
|
||||
" num_matching_outputs += 1\n",
|
||||
"\n",
|
||||
" return num_matching_outputs / len(true_outputs)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
|
||||
"\n",
|
||||
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Alligator Sauce Piquant\n",
|
||||
"\n",
|
||||
"Ingredients:\n",
|
||||
"- 2 lb. alligator, boneless and cubed *\n",
|
||||
"- 4 onions, diced\n",
|
||||
"- 1 c. parsley, chopped\n",
|
||||
"- 4 stalks celery, chopped\n",
|
||||
"- 1 bell pepper, diced\n",
|
||||
"- 1 c. catsup\n",
|
||||
"- 2 Tbsp. Heinz steak sauce\n",
|
||||
"- 2 Tbsp. soy sauce\n",
|
||||
"- 2 Tbsp. Louisiana hot sauce\n",
|
||||
"- 2 Tbsp. cornstarch\n",
|
||||
"- 1 tsp. salt\n",
|
||||
"- 2 tsp. red pepper (ground)\n",
|
||||
"- 1/4 c. cooking oil\n",
|
||||
"\n",
|
||||
"Directions:\n",
|
||||
"- *Alligator must be free of all fat; also dark meat is the best (leg and body meat), boneless.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>GPT-4</th>\n",
|
||||
" <th>My model</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>cook_time_over_30_mins</th>\n",
|
||||
" <td>True</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>main_dish</th>\n",
|
||||
" <td>True</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" GPT-4 My model\n",
|
||||
"cook_time_over_30_mins True False\n",
|
||||
"main_dish True False"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Veggie Casserole\n",
|
||||
"\n",
|
||||
"Ingredients:\n",
|
||||
"- 1 (8 oz.) bag mixed veggies (corn, peas, carrots, green beans), steamed\n",
|
||||
"- 1 c. celery\n",
|
||||
"- 1 c. onions\n",
|
||||
"- 1 c. Cheddar cheese\n",
|
||||
"- 1 c. mayonnaise\n",
|
||||
"\n",
|
||||
"Directions:\n",
|
||||
"- Mix above ingredients.\n",
|
||||
"- Bake at 350° for 30 minutes, until bubbly.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>GPT-4</th>\n",
|
||||
" <th>My model</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>main_dish</th>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>True</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" GPT-4 My model\n",
|
||||
"main_dish False True"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Rhonda'S Butter Chess Pie\n",
|
||||
"\n",
|
||||
"Ingredients:\n",
|
||||
"- 5 eggs\n",
|
||||
"- 1 stick melted butter\n",
|
||||
"- 2 c. sugar\n",
|
||||
"- 1 tsp. vanilla\n",
|
||||
"- 1 Tbsp. cornstarch\n",
|
||||
"- 1/2 c. buttermilk\n",
|
||||
"- unbaked 9-inch deep dish pie shell\n",
|
||||
"\n",
|
||||
"Directions:\n",
|
||||
"- Mix eggs with sugar and cornstarch until smooth.\n",
|
||||
"- Add melted butter, vanilla and buttermilk.\n",
|
||||
"- Bake at 350° for 30 minutes or until done.\n",
|
||||
"- Let cool and chill.\n",
|
||||
"- Similar to Furr's Butter Chess Pie.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>GPT-4</th>\n",
|
||||
" <th>My model</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>cook_time_over_30_mins</th>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>True</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" GPT-4 My model\n",
|
||||
"cook_time_over_30_mins False True"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Broccoli Gorgonzola Cream Soup\n",
|
||||
"\n",
|
||||
"Ingredients:\n",
|
||||
"- 2 heads Broccoli\n",
|
||||
"- 700 milliliters Water\n",
|
||||
"- 1 Onion, Peeled And Cut Into Chunks\n",
|
||||
"- 1 pinch Salt\n",
|
||||
"- 1 teaspoon Oregano\n",
|
||||
"- 1 Potato, Peeled And Cut Into Chunks\n",
|
||||
"- 200 grams Crumbled Gorgonzola\n",
|
||||
"- 1 Tablespoon Finely Grated Parmesan\n",
|
||||
"\n",
|
||||
"Directions:\n",
|
||||
"- Cut off the hard trunks of the broccoli and cut it into small pieces. Prepare a pot with water, add broccoli, onion, salt and oregano and boil for about 30 minutes.\n",
|
||||
"- Add the peeled potato and boil for another 20 minutes. When vegetables are cooked, strain and save the stock.\n",
|
||||
"- Using a hand blender, puree vegetables, adding as much stock as desired. Bring soup back to heat over low heat, and sir in gorgonzola. Remove from heat and add Parmesan.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>GPT-4</th>\n",
|
||||
" <th>My model</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>main_dish</th>\n",
|
||||
" <td>False</td>\n",
|
||||
" <td>True</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" GPT-4 My model\n",
|
||||
"main_dish False True"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Wild Rice With Cucumber And Feta\n",
|
||||
"\n",
|
||||
"Ingredients:\n",
|
||||
"- 1 (8.5-ounce) package precooked wild rice (such as Archer Farms)\n",
|
||||
"- 1 cup diced English cucumber\n",
|
||||
"- 1 1/2 tablespoons olive oil\n",
|
||||
"- 1 tablespoon fresh lemon juice\n",
|
||||
"- 2 ounces crumbled feta cheese\n",
|
||||
"- 1/2 teaspoon pepper\n",
|
||||
"- 1/4 teaspoon salt\n",
|
||||
"\n",
|
||||
"Directions:\n",
|
||||
"- Prepare rice according to the package directions.\n",
|
||||
"- Combine cooked rice, cucumber, olive oil, lemon juice, and crumbled feta cheese in a medium bowl; toss to coat. Stir in pepper and salt.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>GPT-4</th>\n",
|
||||
" <th>My model</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>main_dish</th>\n",
|
||||
" <td>True</td>\n",
|
||||
" <td>False</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" GPT-4 My model\n",
|
||||
"main_dish True False"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
"for row in test_data[test_data.accuracy < 1].sample(5).itertuples():\n",
|
||||
" print(json.loads(row.instruction)[1][\"content\"])\n",
|
||||
"\n",
|
||||
" gpt4_output = parse_fn_call(row.output)\n",
|
||||
" my_output = parse_fn_call(row.my_outputs)\n",
|
||||
"\n",
|
||||
" table = pd.DataFrame(\n",
|
||||
" {\n",
|
||||
" \"GPT-4\": gpt4_output,\n",
|
||||
" \"My model\": my_output,\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" table = table[table[\"GPT-4\"] != table[\"My model\"]]\n",
|
||||
" display(table)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n",
|
||||
"\n",
|
||||
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,353 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this notebook I'm using the OpenPipe client to capture a set of calls to the OpenAI API.\n",
|
||||
"\n",
|
||||
"For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
|
||||
"Requirement already satisfied: python-dotenv==1.0.0 in /usr/local/lib/python3.10/dist-packages (1.0.0)\n",
|
||||
"Requirement already satisfied: joblib==1.3.2 in /usr/local/lib/python3.10/dist-packages (1.3.2)\n",
|
||||
"Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
|
||||
"Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
|
||||
"Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
|
||||
"Requirement already satisfied: python-dateutil<3.0.0,>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (2.8.2)\n",
|
||||
"Requirement already satisfied: toml<0.11.0,>=0.10.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.10.2)\n",
|
||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (2022.12.7)\n",
|
||||
"Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.17.3)\n",
|
||||
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.4)\n",
|
||||
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.3.0)\n",
|
||||
"Requirement already satisfied: requests>=2.20 in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.28.1)\n",
|
||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.66.1)\n",
|
||||
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (3.8.5)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil<3.0.0,>=2.8.2->openpipe==3.0.3) (1.16.0)\n",
|
||||
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.14.0)\n",
|
||||
"Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.7.1)\n",
|
||||
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.1.1)\n",
|
||||
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.26.13)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (6.0.4)\n",
|
||||
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.0.3)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.9.2)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.4.0)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.3.1)\n",
|
||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.1.2)\n",
|
||||
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 joblib==1.3.2 datasets==2.14.4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Recipe dataset shape:\n",
|
||||
"------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Dataset({\n",
|
||||
" features: ['recipe'],\n",
|
||||
" num_rows: 5000\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"First recipe:\n",
|
||||
"------------------ Shrimp Creole\n",
|
||||
"\n",
|
||||
"Ingredients:\n",
|
||||
"- 20 shrimp (8 oz.)\n",
|
||||
"- 2 c. (16 oz. can) tomato sauce\n",
|
||||
"- 1 small onion, chopped\n",
|
||||
"- 1 celery stalk, chopped\n",
|
||||
"- 1/4 green bell pepper, diced\n",
|
||||
"- 1/4 c. sliced mushrooms\n",
|
||||
"- 3 Tbsp. parsley\n",
|
||||
"- 1/2 tsp. pepper\n",
|
||||
"- 1 to 1-1/2 c. brown rice, prepared according to pkg. directions (not included in exchanges)\n",
|
||||
"\n",
|
||||
"Directions:\n",
|
||||
"- Peel, devein and wash shrimp; set aside.\n",
|
||||
"- (If shrimp are frozen, let thaw first in the refrigerator.)\n",
|
||||
"- Simmer tomato sauce, onion, celery, green pepper, mushrooms, parsley and pepper in skillet for 30 minutes.\n",
|
||||
"- Add shrimp and cook 10 to 15 minutes more, until shrimp are tender.\n",
|
||||
"- Serve over brown rice.\n",
|
||||
"- Serves 2.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"recipes = load_dataset(\"corbt/unlabeled-recipes\")[\"train\"]\n",
|
||||
"print(\"Recipe dataset shape:\\n------------------\")\n",
|
||||
"display(recipes)\n",
|
||||
"print(\"First recipe:\\n------------------\", recipes[\"recipe\"][0])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
|
||||
"\n",
|
||||
"In this case, I'll ask GPT-4 to classify each recipe along 5 dimensions:\n",
|
||||
" - has_non_fish_meat\n",
|
||||
" - requires_oven\n",
|
||||
" - requires_stove\n",
|
||||
" - cook_time_over_30_mins\n",
|
||||
" - main_dish\n",
|
||||
"\n",
|
||||
"That looks like a pretty random list, but there's actually an important unifying thread: I'm looking for meals that my pescatarian brother/co-founder can make in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
|
||||
"\n",
|
||||
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Classifying first recipe:\n",
|
||||
"------------------\n",
|
||||
"{'has_non_fish_meat': False, 'requires_oven': False, 'requires_stove': True, 'cook_time_over_30_mins': True, 'main_dish': True}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from openpipe import openai, configure_openpipe\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import dotenv\n",
|
||||
"\n",
|
||||
"# Use `dotenv` to load the contents of the `.env` file into the environment\n",
|
||||
"dotenv.load_dotenv()\n",
|
||||
"\n",
|
||||
"# Configure OpenPipe using the API key from the environment\n",
|
||||
"configure_openpipe(api_key=os.environ[\"OPENPIPE_API_KEY\"])\n",
|
||||
"\n",
|
||||
"# Configure OpenAI using the API key from the environment\n",
|
||||
"openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def classify_recipe(recipe: str):\n",
|
||||
" completion = openai.ChatCompletion.create(\n",
|
||||
" model=\"gpt-4\",\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": recipe,\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" functions=[\n",
|
||||
" {\n",
|
||||
" \"name\": \"classify\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"has_non_fish_meat\": {\n",
|
||||
" \"type\": \"boolean\",\n",
|
||||
" \"description\": \"True if the recipe contains any meat or meat products (eg. chicken broth) besides fish\",\n",
|
||||
" },\n",
|
||||
" \"requires_oven\": {\n",
|
||||
" \"type\": \"boolean\",\n",
|
||||
" \"description\": \"True if the recipe requires an oven\",\n",
|
||||
" },\n",
|
||||
" \"requires_stove\": {\n",
|
||||
" \"type\": \"boolean\",\n",
|
||||
" \"description\": \"True if the recipe requires a stove\",\n",
|
||||
" },\n",
|
||||
" \"cook_time_over_30_mins\": {\n",
|
||||
" \"type\": \"boolean\",\n",
|
||||
" \"description\": \"True if the recipe takes over 30 minutes to prepare and cook, including waiting time\",\n",
|
||||
" },\n",
|
||||
" \"main_dish\": {\n",
|
||||
" \"type\": \"boolean\",\n",
|
||||
" \"description\": \"True if the recipe can be served as a main dish\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"required\": [\n",
|
||||
" \"has_non_fish_meat\",\n",
|
||||
" \"requires_oven\",\n",
|
||||
" \"requires_stove\",\n",
|
||||
" \"cook_time_over_30_mins\",\n",
|
||||
" \"main_dish\",\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" function_call={\n",
|
||||
" \"name\": \"classify\",\n",
|
||||
" },\n",
|
||||
" openpipe={\"tags\": {\"prompt_id\": \"classify-recipe\"}, \"cache\": True},\n",
|
||||
" )\n",
|
||||
" return json.loads(completion.choices[0].message.function_call.arguments)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"Classifying first recipe:\\n------------------\")\n",
|
||||
"print(classify_recipe(recipes[\"recipe\"][0]))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Classifying recipe 0/5000: Shrimp Creole\n",
|
||||
"Classifying recipe 100/5000: Spoon Bread\n",
|
||||
"Classifying recipe 200/5000: Quadrangle Grille'S Pumpkin-Walnut Cheesecake\n",
|
||||
"Classifying recipe 300/5000: Broccoli Casserole\n",
|
||||
"Classifying recipe 400/5000: Paal Payasam (3-Ingredient Rice Pudding)\n",
|
||||
"Classifying recipe 500/5000: Dirt Dessert\n",
|
||||
"Classifying recipe 600/5000: Dolma, Stuffed Dried Peppers And Eggplants\n",
|
||||
"Classifying recipe 700/5000: Party Pecan Pies\n",
|
||||
"Classifying recipe 800/5000: Pie Crust\n",
|
||||
"Classifying recipe 900/5000: Russian Dressing(Salad Dressing) \n",
|
||||
"Classifying recipe 1000/5000: O'Brien Potatoes\n",
|
||||
"Classifying recipe 1100/5000: Monster Cookies\n",
|
||||
"Classifying recipe 1200/5000: Striped Fruit Pops\n",
|
||||
"Classifying recipe 1300/5000: Cute Heart-Shaped Fried Egg\n",
|
||||
"Classifying recipe 1400/5000: Steak Marinade\n",
|
||||
"Classifying recipe 1500/5000: Bbq Sauce For Fish Recipe\n",
|
||||
"Classifying recipe 1600/5000: Barbecue Ranch Salad\n",
|
||||
"Classifying recipe 1700/5000: White Fudge\n",
|
||||
"Classifying recipe 1800/5000: Seaton Chocolate Chip Cookies\n",
|
||||
"Classifying recipe 1900/5000: Beef Stroganoff\n",
|
||||
"Classifying recipe 2000/5000: Lemon Delight\n",
|
||||
"Classifying recipe 2100/5000: Cream Cheese Chicken Chili\n",
|
||||
"Classifying recipe 2200/5000: Bean Salad\n",
|
||||
"Classifying recipe 2300/5000: Green Beans Almondine\n",
|
||||
"Classifying recipe 2400/5000: Radish-And-Avocado Salad\n",
|
||||
"Classifying recipe 2500/5000: Salsa Rojo\n",
|
||||
"Classifying recipe 2600/5000: Pepperoni Bread\n",
|
||||
"Classifying recipe 2700/5000: Sabzi Polow\n",
|
||||
"Classifying recipe 2800/5000: Italian Vegetable Pizzas\n",
|
||||
"Classifying recipe 2900/5000: Hot Fudge Sauce, Soda Shop Style\n",
|
||||
"Classifying recipe 3000/5000: Meatball Soup With Vegetables And Brown Rice\n",
|
||||
"Classifying recipe 3100/5000: Herbed Potatoes And Onions\n",
|
||||
"Classifying recipe 3200/5000: Apple Crunch Pie (2 Extra Servings)\n",
|
||||
"Classifying recipe 3300/5000: Pineapple-Orange Punch\n",
|
||||
"Classifying recipe 3400/5000: Turkey Veggie Burgers With Avocado Mayo\n",
|
||||
"Classifying recipe 3500/5000: Pear & Goat Cheese Salad\n",
|
||||
"Classifying recipe 3600/5000: Triple Chocolate Cookies\n",
|
||||
"Classifying recipe 3700/5000: Strawberry Banana Yogurt Pops\n",
|
||||
"Classifying recipe 3800/5000: Chicken Croquettes\n",
|
||||
"Classifying recipe 3900/5000: Mushroom Casserole\n",
|
||||
"Classifying recipe 4000/5000: Vegetarian Summer Roll\n",
|
||||
"Classifying recipe 4100/5000: Prune Cake\n",
|
||||
"Classifying recipe 4200/5000: Strawberry Sorbet\n",
|
||||
"Classifying recipe 4300/5000: Lemonade Chicken\n",
|
||||
"Classifying recipe 4400/5000: Crock-Pot Vegetarian Chili\n",
|
||||
"Classifying recipe 4500/5000: Grandma Dickrell'S Molasses Cake - 1936\n",
|
||||
"Classifying recipe 4600/5000: Creamed Corn Casserole\n",
|
||||
"Classifying recipe 4700/5000: Homemade Croutons\n",
|
||||
"Classifying recipe 4800/5000: Potatoes With Leeks And Gruyere\n",
|
||||
"Classifying recipe 4900/5000: Chocolate Oatmeal Cookie\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i, recipe in enumerate(recipes[\"recipe\"]):\n",
|
||||
" if i % 100 == 0:\n",
|
||||
" recipe_title = recipe.split(\"\\n\")[0]\n",
|
||||
" print(f\"Classifying recipe {i}/{len(recipes)}: {recipe_title}\")\n",
|
||||
" try:\n",
|
||||
" classify_recipe(recipe)\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error classifying recipe {i}: {e}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ok, now that my recipes are classified I'll download the training data. \n",
|
||||
"\n",
|
||||
"Next up I'll train the model -- check out [./train.ipynb](./train.ipynb) for details! Just go to https://app.openpipe.ai/request-logs, select all the logs you created, and click \"Export\". The default 10% testing split is fine for this dataset size.\n",
|
||||
"\n",
|
||||
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
# This file is used by the training script in train.ipynb. You can read more about
|
||||
# the format and see more examples at https://github.com/OpenAccess-AI-Collective/axolotl.
|
||||
# One of the parameters you might want to play around with is `num_epochs`: if you have a
|
||||
# smaller dataset size, making that large can have good results.
|
||||
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: ./data/train.jsonl
|
||||
type: alpaca_instruct.load_no_prompt
|
||||
dataset_prepared_path: ./data/last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./models/run1
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
# This will report stats from your training run to https://wandb.ai/. If you don't want to create a wandb account you can comment this section out.
|
||||
wandb_project: classify-recipes
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id: run1
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 5
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps: 60
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,37 +0,0 @@
|
||||
import yaml
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
import os
|
||||
|
||||
|
||||
def merge_lora_model(config_file: str):
|
||||
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
|
||||
|
||||
base_model = config["base_model"]
|
||||
lora_model = config["output_dir"]
|
||||
merged_model = f"{lora_model}/merged"
|
||||
|
||||
if os.path.exists(merged_model):
|
||||
print(f"Model {merged_model} already exists, skipping")
|
||||
return merged_model
|
||||
|
||||
print("Loading base model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
return_dict=True,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
print("Loading PEFT model")
|
||||
model = PeftModel.from_pretrained(model, lora_model)
|
||||
print(f"Running merge_and_unload")
|
||||
model = model.merge_and_unload()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
|
||||
model.save_pretrained(merged_model)
|
||||
tokenizer.save_pretrained(merged_model)
|
||||
print(f"Model saved to {merged_model}")
|
||||
|
||||
return merged_model
|
||||
18
pnpm-lock.yaml
generated
@@ -174,10 +174,7 @@ importers:
|
||||
specifier: 4.0.0-beta.7
|
||||
version: 4.0.0-beta.7(encoding@0.1.13)
|
||||
openpipe:
|
||||
specifier: ^0.3.0
|
||||
version: 0.3.0
|
||||
openpipe-dev:
|
||||
specifier: workspace:^
|
||||
specifier: workspace:*
|
||||
version: link:../client-libs/typescript
|
||||
pg:
|
||||
specifier: ^8.11.2
|
||||
@@ -7250,19 +7247,6 @@ packages:
|
||||
oidc-token-hash: 5.0.3
|
||||
dev: false
|
||||
|
||||
/openpipe@0.3.0:
|
||||
resolution: {integrity: sha512-0hhk3Aq0kUxzvNb36vm9vssxMHYZvgJOg5wKeepRhVthW4ygBWftHZjR4PHyOtvjcRmnJ/v4h8xd/IINu5ypnQ==}
|
||||
dependencies:
|
||||
encoding: 0.1.13
|
||||
form-data: 4.0.0
|
||||
lodash-es: 4.17.21
|
||||
node-fetch: 2.6.12(encoding@0.1.13)
|
||||
openai-beta: /openai@4.0.0-beta.7(encoding@0.1.13)
|
||||
openai-legacy: /openai@3.3.0
|
||||
transitivePeerDependencies:
|
||||
- debug
|
||||
dev: false
|
||||
|
||||
/optionator@0.9.3:
|
||||
resolution: {integrity: sha512-JjCoypp+jKn1ttEFExxhetCKeJt9zhAgAve5FXHixTvFDW/5aEktX9bufBKLRRMdU7bNtpLfcGu94B3cdEJgjg==}
|
||||
engines: {node: '>= 0.8.0'}
|
||||
|
||||