More work on modelProviders

I think everything that's OpenAI-specific is inside modelProviders at this point, so we can get started adding more providers.
This commit is contained in:
Kyle Corbitt
2023-07-20 18:54:26 -07:00
parent ded6678e97
commit 332a2101c0
21 changed files with 344 additions and 330 deletions

View File

@@ -6,11 +6,11 @@ import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect } from "react"; import { type ReactElement, useState, useEffect } from "react";
import { type ChatCompletion } from "openai/resources/chat";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats"; import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler"; import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions"; import { CellOptions } from "./CellOptions";
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -33,15 +33,17 @@ export default function OutputCell({
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output"; if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
// if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output";
const [refetchInterval, setRefetchInterval] = useState(0); const [refetchInterval, setRefetchInterval] = useState(0);
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery( const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
{ scenarioId: scenario.id, variantId: variant.id }, { scenarioId: scenario.id, variantId: variant.id },
{ refetchInterval }, { refetchInterval },
); );
const provider =
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation(); const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation();
const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => { const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => {
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id }); await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
@@ -66,8 +68,7 @@ export default function OutputCell({
const modelOutput = cell?.modelOutput; const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore // Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(cell?.streamingChannel); const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null; if (!vars) return null;
@@ -86,19 +87,13 @@ export default function OutputCell({
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />; return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
} }
const response = modelOutput?.output as unknown as ChatCompletion; const normalizedOutput = modelOutput
const message = response?.choices?.[0]?.message; ? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
: streamedMessage
if (modelOutput && message?.function_call) { ? provider.normalizeOutput(streamedMessage)
const rawArgs = message.function_call.arguments ?? "null"; : null;
let parsedArgs: string;
try {
parsedArgs = JSON.parse(rawArgs);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
}
if (modelOutput && normalizedOutput?.type === "json") {
return ( return (
<VStack <VStack
w="100%" w="100%"
@@ -119,13 +114,7 @@ export default function OutputCell({
}} }}
wrapLines wrapLines
> >
{stringify( {stringify(normalizedOutput.value, { maxLength: 40 })}
{
function: message.function_call.name,
args: parsedArgs,
},
{ maxLength: 40 },
)}
</SyntaxHighlighter> </SyntaxHighlighter>
</VStack> </VStack>
<OutputStats modelOutput={modelOutput} scenario={scenario} /> <OutputStats modelOutput={modelOutput} scenario={scenario} />
@@ -133,8 +122,7 @@ export default function OutputCell({
); );
} }
const contentToDisplay = const contentToDisplay = (normalizedOutput?.type === "text" && normalizedOutput.value) || "";
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return ( return (
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap"> <VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">

View File

@@ -50,8 +50,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere // Make sure the user defined the prompt with the string "prompt\w*=" somewhere
const promptRegex = /definePrompt\(/; const promptRegex = /definePrompt\(/;
if (!promptRegex.test(currentFn)) { if (!promptRegex.test(currentFn)) {
console.log("no prompt");
console.log(currentFn);
toast({ toast({
title: "Missing prompt", title: "Missing prompt",
description: "Please define the prompt (eg. `definePrompt(...`", description: "Please define the prompt (eg. `definePrompt(...`",

View File

@@ -1,5 +1,5 @@
import { type JSONSchema4Object } from "json-schema"; import { type JSONSchema4Object } from "json-schema";
import modelProviders from "."; import modelProviders from "./modelProviders";
import { compile } from "json-schema-to-typescript"; import { compile } from "json-schema-to-typescript";
import dedent from "dedent"; import dedent from "dedent";

View File

@@ -0,0 +1,10 @@
import modelProviderFrontend from "./openai-ChatCompletion/frontend";
// Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server.
const modelProvidersFrontend = {
"openai/ChatCompletion": modelProviderFrontend,
} as const;
export default modelProvidersFrontend;

View File

@@ -0,0 +1,42 @@
import { type JsonValue } from "type-fest";
import { type OpenaiChatModelProvider } from ".";
import { type ModelProviderFrontend } from "../types";
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
normalizeOutput: (output) => {
const message = output.choices[0]?.message;
if (!message)
return {
type: "json",
value: output as unknown as JsonValue,
};
if (message.content) {
return {
type: "text",
value: message.content,
};
} else if (message.function_call) {
let args = message.function_call.arguments ?? "";
try {
args = JSON.parse(args);
} catch (e) {
// Ignore
}
return {
type: "json",
value: {
...message.function_call,
arguments: args,
},
};
} else {
return {
type: "json",
value: message as unknown as JsonValue,
};
}
},
};
export default modelProviderFrontend;

View File

@@ -0,0 +1,142 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import {
type ChatCompletionChunk,
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types";
import { omit } from "lodash-es";
import { openai } from "~/server/utils/openai";
import { type OpenAIChatModel } from "~/server/types";
import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai";
import { modelStats } from "../modelStats";
const mergeStreamedChunks = (
base: ChatCompletion | null,
chunk: ChatCompletionChunk,
): ChatCompletion => {
if (base === null) {
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
}
const choices = [...base.choices];
for (const choice of chunk.choices) {
const baseChoice = choices.find((c) => c.index === choice.index);
if (baseChoice) {
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
baseChoice.message = baseChoice.message ?? { role: "assistant" };
if (choice.delta?.content)
baseChoice.message.content =
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
if (choice.delta?.function_call) {
const fnCall = baseChoice.message.function_call ?? {};
fnCall.name =
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
fnCall.arguments =
((fnCall.arguments as string) ?? "") +
((choice.delta.function_call.arguments as string) ?? "");
}
} else {
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
}
}
const merged: ChatCompletion = {
...base,
choices,
};
return merged;
};
export async function getCompletion(
input: CompletionCreateParams,
onStream: ((partialOutput: ChatCompletion) => void) | null,
): Promise<CompletionResponse<ChatCompletion>> {
const start = Date.now();
let finalCompletion: ChatCompletion | null = null;
let promptTokens: number | undefined = undefined;
let completionTokens: number | undefined = undefined;
try {
if (onStream) {
const resp = await openai.chat.completions.create(
{ ...input, stream: true },
{
maxRetries: 0,
},
);
for await (const part of resp) {
finalCompletion = mergeStreamedChunks(finalCompletion, part);
onStream(finalCompletion);
}
if (!finalCompletion) {
return {
type: "error",
message: "Streaming failed to return a completion",
autoRetry: false,
};
}
try {
promptTokens = countOpenAIChatTokens(
input.model as keyof typeof OpenAIChatModel,
input.messages,
);
completionTokens = countOpenAIChatTokens(
input.model as keyof typeof OpenAIChatModel,
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
} else {
const resp = await openai.chat.completions.create(
{ ...input, stream: false },
{
maxRetries: 0,
},
);
finalCompletion = resp;
promptTokens = resp.usage?.prompt_tokens ?? 0;
completionTokens = resp.usage?.completion_tokens ?? 0;
}
const timeToComplete = Date.now() - start;
const stats = modelStats[input.model as keyof typeof OpenAIChatModel];
let cost = undefined;
if (stats && promptTokens && completionTokens) {
cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice;
}
return {
type: "success",
statusCode: 200,
value: finalCompletion,
timeToComplete,
promptTokens,
completionTokens,
cost,
};
} catch (error: unknown) {
console.error("ERROR IS", error);
if (error instanceof APIError) {
return {
type: "error",
message: error.message,
autoRetry: error.status === 429 || error.status === 503,
statusCode: error.status,
};
} else {
console.error(error);
return {
type: "error",
message: (error as Error).message,
autoRetry: true,
};
}
}
}

View File

@@ -1,7 +1,8 @@
import { type JSONSchema4 } from "json-schema"; import { type JSONSchema4 } from "json-schema";
import { type ModelProvider } from "../types"; import { type ModelProvider } from "../types";
import inputSchema from "./codegen/input.schema.json"; import inputSchema from "./codegen/input.schema.json";
import { type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion";
const supportedModels = [ const supportedModels = [
"gpt-4-0613", "gpt-4-0613",
@@ -12,7 +13,13 @@ const supportedModels = [
type SupportedModel = (typeof supportedModels)[number]; type SupportedModel = (typeof supportedModels)[number];
const modelProvider: ModelProvider<SupportedModel, CompletionCreateParams> = { export type OpenaiChatModelProvider = ModelProvider<
SupportedModel,
CompletionCreateParams,
ChatCompletion
>;
const modelProvider: OpenaiChatModelProvider = {
name: "OpenAI ChatCompletion", name: "OpenAI ChatCompletion",
models: { models: {
"gpt-4-0613": { "gpt-4-0613": {
@@ -49,6 +56,7 @@ const modelProvider: ModelProvider<SupportedModel, CompletionCreateParams> = {
}, },
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
shouldStream: (input) => input.stream ?? false, shouldStream: (input) => input.stream ?? false,
getCompletion,
}; };
export default modelProvider; export default modelProvider;

View File

@@ -1,14 +1,48 @@
import { type JSONSchema4 } from "json-schema"; import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest";
export type ModelProviderModel = { type ModelProviderModel = {
name: string; name: string;
learnMore: string; learnMore: string;
}; };
export type ModelProvider<SupportedModels extends string, InputSchema> = { export type CompletionResponse<T> =
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
| {
type: "success";
value: T;
timeToComplete: number;
statusCode: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
};
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
name: string; name: string;
models: Record<SupportedModels, ModelProviderModel>; models: Record<SupportedModels, ModelProviderModel>;
getModel: (input: InputSchema) => SupportedModels | null; getModel: (input: InputSchema) => SupportedModels | null;
shouldStream: (input: InputSchema) => boolean; shouldStream: (input: InputSchema) => boolean;
inputSchema: JSONSchema4; inputSchema: JSONSchema4;
getCompletion: (
input: InputSchema,
onStream: ((partialOutput: OutputSchema) => void) | null,
) => Promise<CompletionResponse<OutputSchema>>;
// This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null;
};
export type NormalizedOutput =
| {
type: "text";
value: string;
}
| {
type: "json";
value: JsonValue;
};
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
}; };

View File

@@ -0,0 +1,19 @@
import "dotenv/config";
import { openai } from "../utils/openai";
const resp = await openai.chat.completions.create({
model: "gpt-3.5-turbo-0613",
stream: true,
messages: [
{
role: "user",
content: "count to 20",
},
],
});
for await (const part of resp) {
console.log("part", part);
}
console.log("final resp", resp);

View File

@@ -1,15 +1,18 @@
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import defineTask from "./defineTask"; import defineTask from "./defineTask";
import { type CompletionResponse, getOpenAIChatCompletion } from "../utils/getCompletion";
import { sleep } from "../utils/sleep"; import { sleep } from "../utils/sleep";
import { generateChannel } from "~/utils/generateChannel"; import { generateChannel } from "~/utils/generateChannel";
import { runEvalsForOutput } from "../utils/evaluations"; import { runEvalsForOutput } from "../utils/evaluations";
import { type CompletionCreateParams } from "openai/resources/chat";
import { type Prisma } from "@prisma/client"; import { type Prisma } from "@prisma/client";
import parseConstructFn from "../utils/parseConstructFn"; import parseConstructFn from "../utils/parseConstructFn";
import hashPrompt from "../utils/hashPrompt"; import hashPrompt from "../utils/hashPrompt";
import { type JsonObject } from "type-fest"; import { type JsonObject } from "type-fest";
import modelProviders from "~/modelProviders"; import modelProviders from "~/modelProviders/modelProviders";
import { wsConnection } from "~/utils/wsConnection";
export type queryLLMJob = {
scenarioVariantCellId: string;
};
const MAX_AUTO_RETRIES = 10; const MAX_AUTO_RETRIES = 10;
const MIN_DELAY = 500; // milliseconds const MIN_DELAY = 500; // milliseconds
@@ -21,51 +24,6 @@ function calculateDelay(numPreviousTries: number): number {
return baseDelay + jitter; return baseDelay + jitter;
} }
const getCompletionWithRetries = async (
cellId: string,
payload: JsonObject,
channel?: string,
): Promise<CompletionResponse> => {
let modelResponse: CompletionResponse | null = null;
try {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
modelResponse = await getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
channel,
);
if (
(modelResponse.statusCode !== 429 && modelResponse.statusCode !== 503) ||
i === MAX_AUTO_RETRIES - 1
) {
return modelResponse;
}
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
errorMessage: "Rate limit exceeded",
statusCode: 429,
retryTime: new Date(Date.now() + delay),
},
});
// TODO: Maybe requeue the job so other jobs can run in the future?
await sleep(delay);
}
throw new Error("Max retries limit reached");
} catch (error: unknown) {
return {
statusCode: modelResponse?.statusCode ?? 500,
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
output: null,
timeToComplete: 0,
};
}
};
export type queryLLMJob = {
scenarioVariantCellId: string;
};
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => { export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const { scenarioVariantCellId } = task; const { scenarioVariantCellId } = task;
const cell = await prisma.scenarioVariantCell.findUnique({ const cell = await prisma.scenarioVariantCell.findUnique({
@@ -141,57 +99,67 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const provider = modelProviders[prompt.modelProvider]; const provider = modelProviders[prompt.modelProvider];
const streamingEnabled = provider.shouldStream(prompt.modelInput); const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
let streamingChannel;
if (streamingEnabled) { if (streamingChannel) {
streamingChannel = generateChannel();
// Save streaming channel so that UI can connect to it // Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: scenarioVariantCellId },
data: { streamingChannel }, data: { streamingChannel },
}); });
} }
const onStream = streamingChannel
? (partialOutput: (typeof provider)["_outputSchema"]) => {
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
}
: null;
const modelResponse = await getCompletionWithRetries( for (let i = 0; true; i++) {
scenarioVariantCellId, const response = await provider.getCompletion(prompt.modelInput, onStream);
prompt.modelInput as unknown as JsonObject, if (response.type === "success") {
streamingChannel, const inputHash = hashPrompt(prompt);
);
let modelOutput = null; const modelOutput = await prisma.modelOutput.create({
if (modelResponse.statusCode === 200) { data: {
const inputHash = hashPrompt(prompt); scenarioVariantCellId,
inputHash,
modelOutput = await prisma.modelOutput.create({ output: response.value as unknown as Prisma.InputJsonObject,
data: { timeToComplete: response.timeToComplete,
scenarioVariantCellId, promptTokens: response.promptTokens,
inputHash, completionTokens: response.completionTokens,
output: modelResponse.output as unknown as Prisma.InputJsonObject, cost: response.cost,
timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens,
cost: modelResponse.cost,
},
});
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: modelResponse.statusCode,
errorMessage: modelResponse.errorMessage,
streamingChannel: null,
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
modelOutput: {
connect: {
id: modelOutput?.id,
}, },
}, });
},
});
if (modelOutput) { await prisma.scenarioVariantCell.update({
await runEvalsForOutput(variant.experimentId, scenario, modelOutput); where: { id: scenarioVariantCellId },
data: {
statusCode: response.statusCode,
retrievalStatus: "COMPLETE",
},
});
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
break;
} else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
errorMessage: response.message,
statusCode: response.statusCode,
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
},
});
if (shouldRetry) {
await sleep(delay);
} else {
break;
}
}
} }
}); });

View File

@@ -1,14 +1,3 @@
export type JSONSerializable =
| string
| number
| boolean
| null
| JSONSerializable[]
| { [key: string]: JSONSerializable };
// Placeholder for now
export type OpenAIChatConfig = NonNullable<JSONSerializable>;
export enum OpenAIChatModel { export enum OpenAIChatModel {
"gpt-4" = "gpt-4", "gpt-4" = "gpt-4",
"gpt-4-0613" = "gpt-4-0613", "gpt-4-0613" = "gpt-4-0613",

View File

@@ -1,5 +1,3 @@
import { type JSONSerializable } from "../types";
export type VariableMap = Record<string, string>; export type VariableMap = Record<string, string>;
// Escape quotes to match the way we encode JSON // Escape quotes to match the way we encode JSON
@@ -15,24 +13,3 @@ export function escapeRegExp(str: string) {
export function fillTemplate(template: string, variables: VariableMap): string { export function fillTemplate(template: string, variables: VariableMap): string {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || ""); return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
} }
export function fillTemplateJson<T extends JSONSerializable>(
template: T,
variables: VariableMap,
): T {
if (typeof template === "string") {
return fillTemplate(template, variables) as T;
} else if (Array.isArray(template)) {
return template.map((item) => fillTemplateJson(item, variables)) as T;
} else if (typeof template === "object" && template !== null) {
return Object.keys(template).reduce(
(acc, key) => {
acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables);
return acc;
},
{} as { [key: string]: JSONSerializable } & T,
);
} else {
return template;
}
}

View File

@@ -46,6 +46,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
testScenarioId: scenarioId, testScenarioId: scenarioId,
statusCode: 400, statusCode: 400,
errorMessage: parsedConstructFn.error, errorMessage: parsedConstructFn.error,
retrievalStatus: "ERROR",
}, },
}); });
} }
@@ -57,6 +58,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
promptVariantId: variantId, promptVariantId: variantId,
testScenarioId: scenarioId, testScenarioId: scenarioId,
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue, prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
retrievalStatus: "PENDING",
}, },
include: { include: {
modelOutput: true, modelOutput: true,
@@ -83,6 +85,10 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
updatedAt: matchingModelOutput.updatedAt, updatedAt: matchingModelOutput.updatedAt,
}, },
}); });
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "COMPLETE" },
});
} else { } else {
cell = await queueLLMRetrievalTask(cell.id); cell = await queueLLMRetrievalTask(cell.id);
} }

View File

@@ -1,107 +0,0 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash-es";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type SupportedModel, type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings";
import { modelStats } from "../../modelProviders/modelStats";
export type CompletionResponse = {
output: ChatCompletion | null;
statusCode: number;
errorMessage: string | null;
timeToComplete: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
};
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
channel?: string,
): Promise<CompletionResponse> {
// If functions are enabled, disable streaming so that we get the full response with token counts
if (payload.functions?.length) payload.stream = false;
const start = Date.now();
const response = await fetch("https://api.openai.com/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
},
body: JSON.stringify(payload),
});
const resp: CompletionResponse = {
output: null,
errorMessage: null,
statusCode: response.status,
timeToComplete: 0,
};
try {
if (payload.stream) {
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
let finalOutput: ChatCompletion | null = null;
await (async () => {
for await (const partialCompletion of completion) {
finalOutput = partialCompletion;
wsConnection.emit("message", { channel, payload: partialCompletion });
}
})().catch((err) => console.error(err));
if (finalOutput) {
resp.output = finalOutput;
resp.timeToComplete = Date.now() - start;
}
} else {
resp.timeToComplete = Date.now() - start;
resp.output = await response.json();
}
if (!response.ok) {
if (response.status === 429) {
resp.errorMessage = rateLimitErrorMessage;
} else if (
isObject(resp.output) &&
"error" in resp.output &&
isObject(resp.output.error) &&
"message" in resp.output.error
) {
// If it's an object, try to get the error message
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
}
}
if (isObject(resp.output) && "usage" in resp.output) {
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
resp.promptTokens = usage.prompt_tokens;
resp.completionTokens = usage.completion_tokens;
} else if (isObject(resp.output) && "choices" in resp.output) {
const model = payload.model as unknown as OpenAIChatModel;
resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
const message = choices[0]?.message;
if (message) {
const messages = [message];
resp.completionTokens = countOpenAIChatTokens(model, messages);
}
}
const stats = modelStats[resp.output?.model as SupportedModel];
if (stats && resp.promptTokens && resp.completionTokens) {
resp.cost =
resp.promptTokens * stats.promptTokenPrice +
resp.completionTokens * stats.completionTokenPrice;
}
} catch (e) {
console.error(e);
if (response.ok) {
resp.errorMessage = "Failed to parse response";
}
}
return resp;
}

View File

@@ -1,64 +1,5 @@
import { omit } from "lodash-es";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import OpenAI from "openai"; import OpenAI from "openai";
import {
type ChatCompletion,
type ChatCompletionChunk,
type CompletionCreateParams,
} from "openai/resources/chat";
export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY }); export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY });
export const mergeStreamedChunks = (
base: ChatCompletion | null,
chunk: ChatCompletionChunk,
): ChatCompletion => {
if (base === null) {
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
}
const choices = [...base.choices];
for (const choice of chunk.choices) {
const baseChoice = choices.find((c) => c.index === choice.index);
if (baseChoice) {
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
baseChoice.message = baseChoice.message ?? { role: "assistant" };
if (choice.delta?.content)
baseChoice.message.content =
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
if (choice.delta?.function_call) {
const fnCall = baseChoice.message.function_call ?? {};
fnCall.name =
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
fnCall.arguments =
((fnCall.arguments as string) ?? "") +
((choice.delta.function_call.arguments as string) ?? "");
}
} else {
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
}
}
const merged: ChatCompletion = {
...base,
choices,
};
return merged;
};
export const streamChatCompletion = async function* (body: CompletionCreateParams) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
const resp = await openai.chat.completions.create({
...body,
stream: true,
});
let mergedChunks: ChatCompletion | null = null;
for await (const part of resp) {
mergedChunks = mergeStreamedChunks(mergedChunks, part);
yield mergedChunks;
}
};

View File

@@ -1,4 +1,4 @@
import modelProviders from "~/modelProviders"; import modelProviders from "~/modelProviders/modelProviders";
import ivm from "isolated-vm"; import ivm from "isolated-vm";
import { isObject, isString } from "lodash-es"; import { isObject, isString } from "lodash-es";
import { type JsonObject } from "type-fest"; import { type JsonObject } from "type-fest";

View File

@@ -22,7 +22,6 @@ export function SyncAppStore() {
const setApi = useAppStore((state) => state.setApi); const setApi = useAppStore((state) => state.setApi);
useEffect(() => { useEffect(() => {
console.log("setting api", utils);
setApi(utils); setApi(utils);
}, [utils, setApi]); }, [utils, setApi]);

View File

@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
} }
export const countOpenAIChatTokens = ( export const countOpenAIChatTokens = (
model: OpenAIChatModel, model: keyof typeof OpenAIChatModel,
messages: ChatCompletion.Choice.Message[], messages: ChatCompletion.Choice.Message[],
) => { ) => {
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] }) return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })

View File

@@ -1,13 +1,12 @@
import { type ChatCompletion } from "openai/resources/chat";
import { useRef, useState, useEffect } from "react"; import { useRef, useState, useEffect } from "react";
import { io, type Socket } from "socket.io-client"; import { io, type Socket } from "socket.io-client";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
const url = env.NEXT_PUBLIC_SOCKET_URL; const url = env.NEXT_PUBLIC_SOCKET_URL;
export default function useSocket(channel?: string | null) { export default function useSocket<T>(channel?: string | null) {
const socketRef = useRef<Socket>(); const socketRef = useRef<Socket>();
const [message, setMessage] = useState<ChatCompletion | null>(null); const [message, setMessage] = useState<T | null>(null);
useEffect(() => { useEffect(() => {
if (!channel) return; if (!channel) return;
@@ -21,7 +20,7 @@ export default function useSocket(channel?: string | null) {
socketRef.current?.emit("join", channel); socketRef.current?.emit("join", channel);
// Listen for 'message' events // Listen for 'message' events
socketRef.current?.on("message", (message: ChatCompletion) => { socketRef.current?.on("message", (message: T) => {
setMessage(message); setMessage(message);
}); });
}); });

1
src/utils/utils.ts Normal file
View File

@@ -0,0 +1 @@
export const truthyFilter = <T>(x: T | null | undefined): x is T => Boolean(x);