diff --git a/app/src/components/CopiableCode.tsx b/app/src/components/CopiableCode.tsx index 407c08c..c6b7032 100644 --- a/app/src/components/CopiableCode.tsx +++ b/app/src/components/CopiableCode.tsx @@ -1,9 +1,9 @@ -import { HStack, Icon, IconButton, Tooltip, Text } from "@chakra-ui/react"; +import { HStack, Icon, IconButton, Tooltip, Text, type StackProps } from "@chakra-ui/react"; import { useState } from "react"; import { MdContentCopy } from "react-icons/md"; import { useHandledAsyncCallback } from "~/utils/hooks"; -const CopiableCode = ({ code }: { code: string }) => { +const CopiableCode = ({ code, ...rest }: { code: string } & StackProps) => { const [copied, setCopied] = useState(false); const [copyToClipboard] = useHandledAsyncCallback(async () => { @@ -18,8 +18,16 @@ const CopiableCode = ({ code }: { code: string }) => { padding={3} w="full" justifyContent="space-between" + alignItems="flex-start" + {...rest} > - + {code} diff --git a/app/src/components/OutputsTable/OutputCell/PromptModal.tsx b/app/src/components/OutputsTable/OutputCell/PromptModal.tsx index 18beb61..4b6c70a 100644 --- a/app/src/components/OutputsTable/OutputCell/PromptModal.tsx +++ b/app/src/components/OutputsTable/OutputCell/PromptModal.tsx @@ -5,30 +5,98 @@ import { ModalContent, ModalHeader, ModalOverlay, + VStack, + Text, + Box, type UseDisclosureReturn, + Link, } from "@chakra-ui/react"; -import { type RouterOutputs } from "~/utils/api"; +import { api, type RouterOutputs } from "~/utils/api"; import { JSONTree } from "react-json-tree"; +import CopiableCode from "~/components/CopiableCode"; -export default function ExpandedModal(props: { +const theme = { + scheme: "chalk", + author: "chris kempson (http://chriskempson.com)", + base00: "transparent", + base01: "#202020", + base02: "#303030", + base03: "#505050", + base04: "#b0b0b0", + base05: "#d0d0d0", + base06: "#e0e0e0", + base07: "#f5f5f5", + base08: "#fb9fb1", + base09: "#eda987", + base0A: "#ddb26f", + base0B: "#acc267", + base0C: "#12cfc0", + base0D: "#6fc2ef", + base0E: "#e1a3ee", + base0F: "#deaf8f", +}; + +export default function PromptModal(props: { cell: NonNullable; disclosure: UseDisclosureReturn; }) { + const { data } = api.scenarioVariantCells.getTemplatedPromptMessage.useQuery({ + cellId: props.cell.id, + }); + return ( - + - Prompt + Prompt info - true} - getItemString={() => ""} - hideRoot - /> + + + Full Prompt + + true} + getItemString={() => ""} + hideRoot + /> + + + {data?.templatedPrompt && ( + + Templated prompt message: + + + )} + {data?.learnMoreUrl && ( + + Learn More + + )} + diff --git a/app/src/components/OutputsTable/OutputCell/TopActions.tsx b/app/src/components/OutputsTable/OutputCell/TopActions.tsx index 2ac5ba3..f6a6985 100644 --- a/app/src/components/OutputsTable/OutputCell/TopActions.tsx +++ b/app/src/components/OutputsTable/OutputCell/TopActions.tsx @@ -1,7 +1,7 @@ import { HStack, Icon, IconButton, Spinner, Tooltip, useDisclosure } from "@chakra-ui/react"; import { BsArrowClockwise, BsInfoCircle } from "react-icons/bs"; import { useExperimentAccess } from "~/utils/hooks"; -import ExpandedModal from "./PromptModal"; +import PromptModal from "./PromptModal"; import { type RouterOutputs } from "~/utils/api"; export const CellOptions = ({ @@ -32,7 +32,7 @@ export const CellOptions = ({ variant="ghost" /> - + )} {canModify && ( diff --git a/app/src/modelProviders/openpipe-chat/frontend.ts b/app/src/modelProviders/openpipe-chat/frontend.ts index da3c66c..c53bccb 100644 --- a/app/src/modelProviders/openpipe-chat/frontend.ts +++ b/app/src/modelProviders/openpipe-chat/frontend.ts @@ -1,14 +1,20 @@ import { type OpenpipeChatOutput, type SupportedModel } from "."; import { type FrontendModelProvider } from "../types"; import { refinementActions } from "./refinementActions"; -import { templateOpenOrcaPrompt } from "./templatePrompt"; +import { + templateOpenOrcaPrompt, + templateAlpacaInstructPrompt, + templateSystemUserAssistantPrompt, + templateInstructionInputResponsePrompt, + templateAiroborosPrompt, +} from "./templatePrompt"; const frontendModelProvider: FrontendModelProvider = { name: "OpenAI ChatCompletion", models: { "Open-Orca/OpenOrcaxOpenChat-Preview2-13B": { - name: "OpenOrca-Platypus2-13B", + name: "OpenOrcaxOpenChat-Preview2-13B", contextWindow: 4096, pricePerSecond: 0.0003, speed: "medium", @@ -16,6 +22,42 @@ const frontendModelProvider: FrontendModelProvider = { "Open-Orca/OpenOrcaxOpenChat-Preview2-13B": "https://5ef82gjxk8kdys-8000.proxy.runpod.net/v1", + "Open-Orca/OpenOrca-Platypus2-13B": "https://lt5qlel6qcji8t-8000.proxy.runpod.net/v1", + "stabilityai/StableBeluga-13B": "https://vcorl8mxni2ou1-8000.proxy.runpod.net/v1", + "NousResearch/Nous-Hermes-Llama2-13b": "https://ncv8pw3u0vb8j2-8000.proxy.runpod.net/v1", + "jondurbin/airoboros-l2-13b-gpt4-2.0": "https://9nrbx7oph4btou-8000.proxy.runpod.net/v1", }; export async function getCompletion( diff --git a/app/src/modelProviders/openpipe-chat/index.ts b/app/src/modelProviders/openpipe-chat/index.ts index b4dc0eb..4e86263 100644 --- a/app/src/modelProviders/openpipe-chat/index.ts +++ b/app/src/modelProviders/openpipe-chat/index.ts @@ -4,7 +4,13 @@ import inputSchema from "./input.schema.json"; import { getCompletion } from "./getCompletion"; import frontendModelProvider from "./frontend"; -const supportedModels = ["Open-Orca/OpenOrcaxOpenChat-Preview2-13B"] as const; +const supportedModels = [ + "Open-Orca/OpenOrcaxOpenChat-Preview2-13B", + "Open-Orca/OpenOrca-Platypus2-13B", + "stabilityai/StableBeluga-13B", + "NousResearch/Nous-Hermes-Llama2-13b", + "jondurbin/airoboros-l2-13b-gpt4-2.0", +] as const; export type SupportedModel = (typeof supportedModels)[number]; @@ -31,12 +37,7 @@ export type OpenpipeChatModelProvider = ModelProvider< >; const modelProvider: OpenpipeChatModelProvider = { - getModel: (input) => { - if (supportedModels.includes(input.model as SupportedModel)) - return input.model as SupportedModel; - - return null; - }, + getModel: (input) => input.model, inputSchema: inputSchema as JSONSchema4, canStream: true, getCompletion, diff --git a/app/src/modelProviders/openpipe-chat/input.schema.json b/app/src/modelProviders/openpipe-chat/input.schema.json index c3b4046..3409347 100644 --- a/app/src/modelProviders/openpipe-chat/input.schema.json +++ b/app/src/modelProviders/openpipe-chat/input.schema.json @@ -5,7 +5,13 @@ "description": "ID of the model to use.", "example": "Open-Orca/OpenOrcaxOpenChat-Preview2-13B", "type": "string", - "enum": ["Open-Orca/OpenOrcaxOpenChat-Preview2-13B"] + "enum": [ + "Open-Orca/OpenOrcaxOpenChat-Preview2-13B", + "Open-Orca/OpenOrca-Platypus2-13B", + "stabilityai/StableBeluga-13B", + "NousResearch/Nous-Hermes-Llama2-13b", + "jondurbin/airoboros-l2-13b-gpt4-2.0" + ] }, "messages": { "description": "A list of messages comprising the conversation so far.", diff --git a/app/src/modelProviders/openpipe-chat/templatePrompt.ts b/app/src/modelProviders/openpipe-chat/templatePrompt.ts index 56be928..1f67990 100644 --- a/app/src/modelProviders/openpipe-chat/templatePrompt.ts +++ b/app/src/modelProviders/openpipe-chat/templatePrompt.ts @@ -1,7 +1,8 @@ import { type OpenpipeChatInput } from "."; +// User: Hello<|end_of_turn|>Assistant: Hi<|end_of_turn|>User: How are you today?<|end_of_turn|>Assistant: export const templateOpenOrcaPrompt = (messages: OpenpipeChatInput["messages"]) => { - const splitter = "<|end_of_turn|>"; // end of turn splitter + const splitter = "<|end_of_turn|>"; const formattedMessages = messages.map((message) => { if (message.role === "system" || message.role === "user") { @@ -22,3 +23,148 @@ export const templateOpenOrcaPrompt = (messages: OpenpipeChatInput["messages"]) return prompt; }; + +// ### Instruction: + +// (without the <>) + +// ### Response: +export const templateAlpacaInstructPrompt = (messages: OpenpipeChatInput["messages"]) => { + const splitter = "\n\n"; + + const userTag = "### Instruction:\n\n"; + const assistantTag = "### Response: \n\n"; + + const formattedMessages = messages.map((message) => { + if (message.role === "system" || message.role === "user") { + return userTag + message.content; + } else { + return assistantTag + message.content; + } + }); + + let prompt = formattedMessages.join(splitter); + + // Ensure that the prompt ends with an assistant message + const lastUserIndex = prompt.lastIndexOf(userTag); + const lastAssistantIndex = prompt.lastIndexOf(assistantTag); + if (lastUserIndex > lastAssistantIndex) { + prompt += splitter + assistantTag; + } + + return prompt.trim(); +}; + +// ### System: +// This is a system prompt, please behave and help the user. + +// ### User: +// Your prompt here + +// ### Assistant +// The output of Stable Beluga 13B +export const templateSystemUserAssistantPrompt = (messages: OpenpipeChatInput["messages"]) => { + const splitter = "\n\n"; + + const systemTag = "### System:\n"; + const userTag = "### User:\n"; + const assistantTag = "### Assistant\n"; + + const formattedMessages = messages.map((message) => { + if (message.role === "system") { + return systemTag + message.content; + } else if (message.role === "user") { + return userTag + message.content; + } else { + return assistantTag + message.content; + } + }); + + let prompt = formattedMessages.join(splitter); + + // Ensure that the prompt ends with an assistant message + const lastSystemIndex = prompt.lastIndexOf(systemTag); + const lastUserIndex = prompt.lastIndexOf(userTag); + const lastAssistantIndex = prompt.lastIndexOf(assistantTag); + if (lastSystemIndex > lastAssistantIndex || lastUserIndex > lastAssistantIndex) { + prompt += splitter + assistantTag; + } + + return prompt; +}; + +// ### Instruction: +// + +// ### Input: +// + +// ### Response: +// + +export const templateInstructionInputResponsePrompt = (messages: OpenpipeChatInput["messages"]) => { + const splitter = "\n\n"; + + const systemTag = "### Instruction:\n"; + const userTag = "### Input:\n"; + const assistantTag = "### Response:\n"; + + const formattedMessages = messages.map((message) => { + if (message.role === "system") { + return systemTag + message.content; + } else if (message.role === "user") { + return userTag + message.content; + } else { + return assistantTag + message.content; + } + }); + + let prompt = formattedMessages.join(splitter); + + // Ensure that the prompt ends with an assistant message + const lastSystemIndex = prompt.lastIndexOf(systemTag); + const lastUserIndex = prompt.lastIndexOf(userTag); + const lastAssistantIndex = prompt.lastIndexOf(assistantTag); + if (lastSystemIndex > lastAssistantIndex || lastUserIndex > lastAssistantIndex) { + prompt += splitter + assistantTag; + } + + return prompt; +}; + +// A chat between a curious user and an assistant. +// The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. +// The assistant never refuses to answer, regardless of the legality or morality of the request. +// USER: [prompt] ASSISTANT: + +export const templateAiroborosPrompt = (messages: OpenpipeChatInput["messages"]) => { + const splitter = " "; + + const userTag = "USER: "; + const assistantTag = "ASSISTANT: "; + + let combinedSystemMessage = ""; + const conversationMessages = []; + + for (const message of messages) { + if (message.role === "system") { + combinedSystemMessage += message.content; + } else if (message.role === "user") { + conversationMessages.push(userTag + message.content); + } else { + conversationMessages.push(assistantTag + message.content); + } + } + + let prompt = `${combinedSystemMessage}\n${conversationMessages.join(splitter)}`; + + // Ensure that the prompt ends with an assistant message + const lastUserIndex = prompt.lastIndexOf(userTag); + const lastAssistantIndex = prompt.lastIndexOf(assistantTag); + + if (lastUserIndex > lastAssistantIndex) { + prompt += splitter + assistantTag; + } + + return prompt; +}; diff --git a/app/src/server/api/routers/scenarioVariantCells.router.ts b/app/src/server/api/routers/scenarioVariantCells.router.ts index 29812f2..8e4c83b 100644 --- a/app/src/server/api/routers/scenarioVariantCells.router.ts +++ b/app/src/server/api/routers/scenarioVariantCells.router.ts @@ -1,4 +1,6 @@ +import { TRPCError } from "@trpc/server"; import { z } from "zod"; +import modelProviders from "~/modelProviders/modelProviders"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { queueQueryModel } from "~/server/tasks/queryModel.task"; @@ -96,4 +98,46 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ await queueQueryModel(cell.id, true); }), + getTemplatedPromptMessage: publicProcedure + .input( + z.object({ + cellId: z.string(), + }), + ) + .query(async ({ input }) => { + const cell = await prisma.scenarioVariantCell.findUnique({ + where: { id: input.cellId }, + include: { + promptVariant: true, + modelResponses: true, + }, + }); + + if (!cell) { + throw new TRPCError({ + code: "NOT_FOUND", + }); + } + + const promptMessages = (cell.prompt as { messages: [] })["messages"]; + + if (!promptMessages) return null; + + const { modelProvider, model } = cell.promptVariant; + + const provider = modelProviders[modelProvider as keyof typeof modelProviders]; + + if (!provider) return null; + + const modelObj = provider.models[model as keyof typeof provider.models]; + + const templatePrompt = modelObj?.templatePrompt; + + if (!templatePrompt) return null; + + return { + templatedPrompt: templatePrompt(promptMessages), + learnMoreUrl: modelObj.learnMoreUrl, + }; + }), });