store model and use to calculate completion costs

This commit is contained in:
Kyle Corbitt
2023-07-14 11:06:07 -07:00
parent 0371dacfca
commit a5378b106b
14 changed files with 67 additions and 73 deletions

View File

@@ -7,9 +7,7 @@ test.skip("constructPrompt", async () => {
constructFn: `prompt = { "fooz": "bar" }`,
},
{
variableValues: {
foo: "bar",
},
foo: "bar",
},
);

View File

@@ -6,10 +6,8 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function constructPrompt(
variant: Pick<PromptVariant, "constructFn">,
testScenario: Pick<TestScenario, "variableValues">,
scenario: TestScenario["variableValues"],
): Promise<JSONSerializable> {
const scenario = testScenario.variableValues as JSONSerializable;
const code = `
const scenario = ${JSON.stringify(scenario, null, 2)};
let prompt

View File

@@ -0,0 +1,6 @@
export default function userError(message: string): { status: "error"; message: string } {
return {
status: "error",
message,
};
}

View File

@@ -4,14 +4,11 @@ import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type JSONSerializable, OpenAIChatModel } from "../types";
import { type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { getModelName } from "./getModelName";
import { rateLimitErrorMessage } from "~/sharedStrings";
env;
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number;
@@ -22,35 +19,7 @@ type CompletionResponse = {
};
export async function getCompletion(
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> {
const modelName = getModelName(payload);
if (!modelName)
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid payload provided",
timeToComplete: 0,
};
if (modelName in OpenAIChatModel) {
return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
env.OPENAI_API_KEY,
channel,
);
}
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid model provided",
timeToComplete: 0,
};
}
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
apiKey: string,
channel?: string,
): Promise<CompletionResponse> {
// If functions are enabled, disable streaming so that we get the full response with token counts
@@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion(
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
},
body: JSON.stringify(payload),
});

View File

@@ -1,9 +0,0 @@
import { isObject } from "lodash";
import { type JSONSerializable, type SupportedModel } from "../types";
import { type Prisma } from "@prisma/client";
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
if (!isObject(config)) return null;
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
return null;
}