store model and use to calculate completion costs
This commit is contained in:
@@ -7,9 +7,7 @@ test.skip("constructPrompt", async () => {
|
||||
constructFn: `prompt = { "fooz": "bar" }`,
|
||||
},
|
||||
{
|
||||
variableValues: {
|
||||
foo: "bar",
|
||||
},
|
||||
foo: "bar",
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -6,10 +6,8 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function constructPrompt(
|
||||
variant: Pick<PromptVariant, "constructFn">,
|
||||
testScenario: Pick<TestScenario, "variableValues">,
|
||||
scenario: TestScenario["variableValues"],
|
||||
): Promise<JSONSerializable> {
|
||||
const scenario = testScenario.variableValues as JSONSerializable;
|
||||
|
||||
const code = `
|
||||
const scenario = ${JSON.stringify(scenario, null, 2)};
|
||||
let prompt
|
||||
|
||||
6
src/server/utils/error.ts
Normal file
6
src/server/utils/error.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export default function userError(message: string): { status: "error"; message: string } {
|
||||
return {
|
||||
status: "error",
|
||||
message,
|
||||
};
|
||||
}
|
||||
@@ -4,14 +4,11 @@ import { Prisma } from "@prisma/client";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type JSONSerializable, OpenAIChatModel } from "../types";
|
||||
import { type OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { getModelName } from "./getModelName";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
|
||||
env;
|
||||
|
||||
type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
statusCode: number;
|
||||
@@ -22,35 +19,7 @@ type CompletionResponse = {
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
payload: JSONSerializable,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
const modelName = getModelName(payload);
|
||||
if (!modelName)
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid payload provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
if (modelName in OpenAIChatModel) {
|
||||
return getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
env.OPENAI_API_KEY,
|
||||
channel,
|
||||
);
|
||||
}
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid model provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
apiKey: string,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||
@@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion(
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
import { isObject } from "lodash";
|
||||
import { type JSONSerializable, type SupportedModel } from "../types";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
|
||||
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
|
||||
if (!isObject(config)) return null;
|
||||
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
||||
return null;
|
||||
}
|
||||
Reference in New Issue
Block a user