From c6ec901374957d9cd10d98f9c58056d74980d638 Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Wed, 16 Aug 2023 22:37:37 -0700 Subject: [PATCH] Ad openpipe/Chat provider with Open-Orca/OpenOrcaxOpenChat-Preview2-13B model (#163) * Display 4 decimal points in ModelStatsCard * Add openpipe-chat provider --- .../ChangeModelModal/ModelStatsCard.tsx | 2 +- .../modelProviders/frontendModelProviders.ts | 2 + app/src/modelProviders/modelProviders.ts | 2 + .../modelProviders/openpipe-chat/frontend.ts | 26 +++++ .../openpipe-chat/getCompletion.ts | 104 ++++++++++++++++++ app/src/modelProviders/openpipe-chat/index.ts | 50 +++++++++ .../openpipe-chat/input.schema.json | 88 +++++++++++++++ .../openpipe-chat/refinementActions.ts | 3 + .../openpipe-chat/templatePrompt.ts | 24 ++++ app/src/modelProviders/types.ts | 3 + 10 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 app/src/modelProviders/openpipe-chat/frontend.ts create mode 100644 app/src/modelProviders/openpipe-chat/getCompletion.ts create mode 100644 app/src/modelProviders/openpipe-chat/index.ts create mode 100644 app/src/modelProviders/openpipe-chat/input.schema.json create mode 100644 app/src/modelProviders/openpipe-chat/refinementActions.ts create mode 100644 app/src/modelProviders/openpipe-chat/templatePrompt.ts diff --git a/app/src/components/ChangeModelModal/ModelStatsCard.tsx b/app/src/components/ChangeModelModal/ModelStatsCard.tsx index 4d831cd..1755e1e 100644 --- a/app/src/components/ChangeModelModal/ModelStatsCard.tsx +++ b/app/src/components/ChangeModelModal/ModelStatsCard.tsx @@ -87,7 +87,7 @@ export const ModelStatsCard = ({ label="Price" info={ - ${model.pricePerSecond.toFixed(3)} + ${model.pricePerSecond.toFixed(4)} / second } diff --git a/app/src/modelProviders/frontendModelProviders.ts b/app/src/modelProviders/frontendModelProviders.ts index 9950e36..ad5962d 100644 --- a/app/src/modelProviders/frontendModelProviders.ts +++ b/app/src/modelProviders/frontendModelProviders.ts @@ -1,6 +1,7 @@ import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; import replicateLlama2Frontend from "./replicate-llama2/frontend"; import anthropicFrontend from "./anthropic-completion/frontend"; +import openpipeFrontend from "./openpipe-chat/frontend"; import { type SupportedProvider, type FrontendModelProvider } from "./types"; // Keep attributes here that need to be accessible from the frontend. We can't @@ -10,6 +11,7 @@ const frontendModelProviders: Record> = { "openai/ChatCompletion": openaiChatCompletion, "replicate/llama2": replicateLlama2, "anthropic/completion": anthropicCompletion, + "openpipe/Chat": openpipeChatCompletion, }; export default modelProviders; diff --git a/app/src/modelProviders/openpipe-chat/frontend.ts b/app/src/modelProviders/openpipe-chat/frontend.ts new file mode 100644 index 0000000..da3c66c --- /dev/null +++ b/app/src/modelProviders/openpipe-chat/frontend.ts @@ -0,0 +1,26 @@ +import { type OpenpipeChatOutput, type SupportedModel } from "."; +import { type FrontendModelProvider } from "../types"; +import { refinementActions } from "./refinementActions"; +import { templateOpenOrcaPrompt } from "./templatePrompt"; + +const frontendModelProvider: FrontendModelProvider = { + name: "OpenAI ChatCompletion", + + models: { + "Open-Orca/OpenOrcaxOpenChat-Preview2-13B": { + name: "OpenOrca-Platypus2-13B", + contextWindow: 4096, + pricePerSecond: 0.0003, + speed: "medium", + provider: "openpipe/Chat", + learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B", + templatePrompt: templateOpenOrcaPrompt, + }, + }, + + refinementActions, + + normalizeOutput: (output) => ({ type: "text", value: output }), +}; + +export default frontendModelProvider; diff --git a/app/src/modelProviders/openpipe-chat/getCompletion.ts b/app/src/modelProviders/openpipe-chat/getCompletion.ts new file mode 100644 index 0000000..a314575 --- /dev/null +++ b/app/src/modelProviders/openpipe-chat/getCompletion.ts @@ -0,0 +1,104 @@ +/* eslint-disable @typescript-eslint/no-unsafe-call */ +import { isArray, isString } from "lodash-es"; +import OpenAI, { APIError } from "openai"; + +import { type CompletionResponse } from "../types"; +import { type OpenpipeChatInput, type OpenpipeChatOutput } from "."; +import frontendModelProvider from "./frontend"; + +const modelEndpoints: Record = { + "Open-Orca/OpenOrcaxOpenChat-Preview2-13B": "https://5ef82gjxk8kdys-8000.proxy.runpod.net/v1", +}; + +export async function getCompletion( + input: OpenpipeChatInput, + onStream: ((partialOutput: OpenpipeChatOutput) => void) | null, +): Promise> { + const { model, messages, ...rest } = input; + + const templatedPrompt = frontendModelProvider.models[model].templatePrompt?.(messages); + + if (!templatedPrompt) { + return { + type: "error", + message: "Failed to generate prompt", + autoRetry: false, + }; + } + + const openai = new OpenAI({ + baseURL: modelEndpoints[model], + }); + const start = Date.now(); + let finalCompletion: OpenpipeChatOutput = ""; + + try { + if (onStream) { + const resp = await openai.completions.create( + { model, prompt: templatedPrompt, ...rest, stream: true }, + { + maxRetries: 0, + }, + ); + + for await (const part of resp) { + finalCompletion += part.choices[0]?.text; + onStream(finalCompletion); + } + if (!finalCompletion) { + return { + type: "error", + message: "Streaming failed to return a completion", + autoRetry: false, + }; + } + } else { + const resp = await openai.completions.create( + { model, prompt: templatedPrompt, ...rest, stream: false }, + { + maxRetries: 0, + }, + ); + finalCompletion = resp.choices[0]?.text || ""; + if (!finalCompletion) { + return { + type: "error", + message: "Failed to return a completion", + autoRetry: false, + }; + } + } + const timeToComplete = Date.now() - start; + + return { + type: "success", + statusCode: 200, + value: finalCompletion, + timeToComplete, + }; + } catch (error: unknown) { + if (error instanceof APIError) { + // The types from the sdk are wrong + const rawMessage = error.message as string | string[]; + // If the message is not a string, stringify it + const message = isString(rawMessage) + ? rawMessage + : isArray(rawMessage) + ? rawMessage.map((m) => m.toString()).join("\n") + : (rawMessage as any).toString(); + return { + type: "error", + message, + autoRetry: error.status === 429 || error.status === 503, + statusCode: error.status, + }; + } else { + console.error(error); + return { + type: "error", + message: (error as Error).message, + autoRetry: true, + }; + } + } +} diff --git a/app/src/modelProviders/openpipe-chat/index.ts b/app/src/modelProviders/openpipe-chat/index.ts new file mode 100644 index 0000000..b4dc0eb --- /dev/null +++ b/app/src/modelProviders/openpipe-chat/index.ts @@ -0,0 +1,50 @@ +import { type JSONSchema4 } from "json-schema"; +import { type ModelProvider } from "../types"; +import inputSchema from "./input.schema.json"; +import { getCompletion } from "./getCompletion"; +import frontendModelProvider from "./frontend"; + +const supportedModels = ["Open-Orca/OpenOrcaxOpenChat-Preview2-13B"] as const; + +export type SupportedModel = (typeof supportedModels)[number]; + +export type OpenpipeChatInput = { + model: SupportedModel; + messages: { + role: "system" | "user" | "assistant"; + content: string; + }[]; + temperature?: number; + top_p?: number; + stop?: string[] | string; + max_tokens?: number; + presence_penalty?: number; + frequency_penalty?: number; +}; + +export type OpenpipeChatOutput = string; + +export type OpenpipeChatModelProvider = ModelProvider< + SupportedModel, + OpenpipeChatInput, + OpenpipeChatOutput +>; + +const modelProvider: OpenpipeChatModelProvider = { + getModel: (input) => { + if (supportedModels.includes(input.model as SupportedModel)) + return input.model as SupportedModel; + + return null; + }, + inputSchema: inputSchema as JSONSchema4, + canStream: true, + getCompletion, + getUsage: (input, output) => { + // TODO: Implement this + return null; + }, + ...frontendModelProvider, +}; + +export default modelProvider; diff --git a/app/src/modelProviders/openpipe-chat/input.schema.json b/app/src/modelProviders/openpipe-chat/input.schema.json new file mode 100644 index 0000000..c3b4046 --- /dev/null +++ b/app/src/modelProviders/openpipe-chat/input.schema.json @@ -0,0 +1,88 @@ +{ + "type": "object", + "properties": { + "model": { + "description": "ID of the model to use.", + "example": "Open-Orca/OpenOrcaxOpenChat-Preview2-13B", + "type": "string", + "enum": ["Open-Orca/OpenOrcaxOpenChat-Preview2-13B"] + }, + "messages": { + "description": "A list of messages comprising the conversation so far.", + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "role": { + "type": "string", + "enum": ["system", "user", "assistant"], + "description": "The role of the messages author. One of `system`, `user`, or `assistant`." + }, + "content": { + "type": "string", + "description": "The contents of the message. `content` is required for all messages." + } + }, + "required": ["role", "content"] + } + }, + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 2, + "default": 1, + "example": 1, + "nullable": true, + "description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n" + }, + "top_p": { + "type": "number", + "minimum": 0, + "maximum": 1, + "default": 1, + "example": 1, + "nullable": true, + "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n" + }, + "stop": { + "description": "Up to 4 sequences where the API will stop generating further tokens.\n", + "default": null, + "oneOf": [ + { + "type": "string", + "nullable": true + }, + { + "type": "array", + "minItems": 1, + "maxItems": 4, + "items": { + "type": "string" + } + } + ] + }, + "max_tokens": { + "description": "The maximum number of [tokens](/tokenizer) to generate in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.\n", + "type": "integer" + }, + "presence_penalty": { + "type": "number", + "default": 0, + "minimum": -2, + "maximum": 2, + "nullable": true, + "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n" + }, + "frequency_penalty": { + "type": "number", + "default": 0, + "minimum": -2, + "maximum": 2, + "nullable": true, + "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/api-reference/parameter-details)\n" + } + }, + "required": ["model", "messages"] +} diff --git a/app/src/modelProviders/openpipe-chat/refinementActions.ts b/app/src/modelProviders/openpipe-chat/refinementActions.ts new file mode 100644 index 0000000..6df8ac8 --- /dev/null +++ b/app/src/modelProviders/openpipe-chat/refinementActions.ts @@ -0,0 +1,3 @@ +import { type RefinementAction } from "../types"; + +export const refinementActions: Record = {}; diff --git a/app/src/modelProviders/openpipe-chat/templatePrompt.ts b/app/src/modelProviders/openpipe-chat/templatePrompt.ts new file mode 100644 index 0000000..56be928 --- /dev/null +++ b/app/src/modelProviders/openpipe-chat/templatePrompt.ts @@ -0,0 +1,24 @@ +import { type OpenpipeChatInput } from "."; + +export const templateOpenOrcaPrompt = (messages: OpenpipeChatInput["messages"]) => { + const splitter = "<|end_of_turn|>"; // end of turn splitter + + const formattedMessages = messages.map((message) => { + if (message.role === "system" || message.role === "user") { + return "User: " + message.content; + } else { + return "Assistant: " + message.content; + } + }); + + let prompt = formattedMessages.join(splitter); + + // Ensure that the prompt ends with an assistant message + const lastUserIndex = prompt.lastIndexOf("User:"); + const lastAssistantIndex = prompt.lastIndexOf("Assistant:"); + if (lastUserIndex > lastAssistantIndex) { + prompt += splitter + "Assistant:"; + } + + return prompt; +}; diff --git a/app/src/modelProviders/types.ts b/app/src/modelProviders/types.ts index 6b5e09e..058e16d 100644 --- a/app/src/modelProviders/types.ts +++ b/app/src/modelProviders/types.ts @@ -2,11 +2,13 @@ import { type JSONSchema4 } from "json-schema"; import { type IconType } from "react-icons"; import { type JsonValue } from "type-fest"; import { z } from "zod"; +import { type OpenpipeChatInput } from "./openpipe-chat"; export const ZodSupportedProvider = z.union([ z.literal("openai/ChatCompletion"), z.literal("replicate/llama2"), z.literal("anthropic/completion"), + z.literal("openpipe/Chat"), ]); export type SupportedProvider = z.infer; @@ -22,6 +24,7 @@ export type Model = { description?: string; learnMoreUrl?: string; apiDocsUrl?: string; + templatePrompt?: (initialPrompt: OpenpipeChatInput["messages"]) => string; }; export type ProviderModel = { provider: z.infer; model: string };