Compare commits
6 Commits
autoformat
...
function-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
731406d1f4 | ||
|
|
3c59e4b774 | ||
|
|
972b1f2333 | ||
|
|
7321f3deda | ||
|
|
2bd41fdfbf | ||
|
|
a5378b106b |
51
.github/workflows/ci.yaml
vendored
Normal file
51
.github/workflows/ci.yaml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: CI checks
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run-checks:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v2
|
||||||
|
with:
|
||||||
|
node-version: "20"
|
||||||
|
|
||||||
|
- uses: pnpm/action-setup@v2
|
||||||
|
name: Install pnpm
|
||||||
|
id: pnpm-install
|
||||||
|
with:
|
||||||
|
version: 8.6.1
|
||||||
|
run_install: false
|
||||||
|
|
||||||
|
- name: Get pnpm store directory
|
||||||
|
id: pnpm-cache
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- uses: actions/cache@v3
|
||||||
|
name: Setup pnpm cache
|
||||||
|
with:
|
||||||
|
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||||
|
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-store-
|
||||||
|
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: pnpm install
|
||||||
|
|
||||||
|
- name: Check types
|
||||||
|
run: pnpm tsc
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: SKIP_ENV_VALIDATION=1 pnpm lint
|
||||||
|
|
||||||
|
- name: Check prettier
|
||||||
|
run: pnpm prettier . --check
|
||||||
@@ -26,10 +26,11 @@ model Experiment {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model PromptVariant {
|
model PromptVariant {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
label String
|
|
||||||
|
|
||||||
|
label String
|
||||||
constructFn String
|
constructFn String
|
||||||
|
model String @default("gpt-3.5-turbo")
|
||||||
|
|
||||||
uiId String @default(uuid()) @db.Uuid
|
uiId String @default(uuid()) @db.Uuid
|
||||||
visible Boolean @default(true)
|
visible Boolean @default(true)
|
||||||
|
|||||||
@@ -37,10 +37,6 @@ export default function OutputCell({
|
|||||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||||
// disabledReason = "Save your prompt variant to see output";
|
// disabledReason = "Save your prompt variant to see output";
|
||||||
|
|
||||||
// const model = getModelName(variant.config as JSONSerializable);
|
|
||||||
// TODO: Temporarily hardcoding this while we get other stuff working
|
|
||||||
const model = "gpt-3.5-turbo";
|
|
||||||
|
|
||||||
const outputMutation = api.outputs.get.useMutation();
|
const outputMutation = api.outputs.get.useMutation();
|
||||||
|
|
||||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
||||||
@@ -140,7 +136,7 @@ export default function OutputCell({
|
|||||||
{ maxLength: 40 },
|
{ maxLength: 40 },
|
||||||
)}
|
)}
|
||||||
</SyntaxHighlighter>
|
</SyntaxHighlighter>
|
||||||
<OutputStats model={model} modelOutput={output} scenario={scenario} />
|
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} />
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -150,7 +146,7 @@ export default function OutputCell({
|
|||||||
return (
|
return (
|
||||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
{contentToDisplay}
|
{contentToDisplay}
|
||||||
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
|
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ export const OutputStats = ({
|
|||||||
modelOutput,
|
modelOutput,
|
||||||
scenario,
|
scenario,
|
||||||
}: {
|
}: {
|
||||||
model: SupportedModel | null;
|
model: SupportedModel | string | null;
|
||||||
modelOutput: ModelOutput;
|
modelOutput: ModelOutput;
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
}) => {
|
}) => {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
|
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useRef, useEffect, useState, useCallback } from "react";
|
||||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||||
// import openAITypes from "~/codegen/openai.types.ts.txt";
|
// import openAITypes from "~/codegen/openai.types.ts.txt";
|
||||||
|
|
||||||
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||||
@@ -64,10 +65,17 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
await replaceVariant.mutateAsync({
|
const resp = await replaceVariant.mutateAsync({
|
||||||
id: props.variant.id,
|
id: props.variant.id,
|
||||||
constructFn: currentFn,
|
constructFn: currentFn,
|
||||||
});
|
});
|
||||||
|
if (resp.status === "error") {
|
||||||
|
return toast({
|
||||||
|
title: "Error saving variant",
|
||||||
|
description: resp.message,
|
||||||
|
status: "error",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
|
|
||||||
@@ -122,21 +130,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||||
}, [monaco, editorId]);
|
}, [monaco, editorId]);
|
||||||
|
|
||||||
// useEffect(() => {
|
|
||||||
// const savedConfigChanged = lastSavedFn !== savedConfig;
|
|
||||||
|
|
||||||
// lastSavedFn = savedConfig;
|
|
||||||
|
|
||||||
// if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
|
|
||||||
// editorRef.current?.setValue(savedConfig);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// checkForChanges();
|
|
||||||
// }, [savedConfig, checkForChanges]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="100%" pos="relative">
|
<Box w="100%" pos="relative">
|
||||||
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
<VStack
|
||||||
|
spacing={0}
|
||||||
|
align="stretch"
|
||||||
|
fontSize="xs"
|
||||||
|
fontWeight="bold"
|
||||||
|
color="gray.600"
|
||||||
|
py={2}
|
||||||
|
bgColor={editorBackground}
|
||||||
|
>
|
||||||
|
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
|
||||||
|
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
||||||
|
<code>{`return prompt; }`}</code>
|
||||||
|
</VStack>
|
||||||
{isChanged && (
|
{isChanged && (
|
||||||
<HStack pos="absolute" bottom={2} right={2}>
|
<HStack pos="absolute" bottom={2} right={2}>
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
stream: true,
|
stream: true,
|
||||||
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
||||||
}`,
|
}`,
|
||||||
|
model: "gpt-3.5-turbo-0613",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
prisma.testScenario.create({
|
prisma.testScenario.create({
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import type { Prisma } from "@prisma/client";
|
|||||||
import { reevaluateVariant } from "~/server/utils/evaluations";
|
import { reevaluateVariant } from "~/server/utils/evaluations";
|
||||||
import { getCompletion } from "~/server/utils/getCompletion";
|
import { getCompletion } from "~/server/utils/getCompletion";
|
||||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||||
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
export const modelOutputsRouter = createTRPCRouter({
|
||||||
get: publicProcedure
|
get: publicProcedure
|
||||||
@@ -43,7 +44,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
if (!variant || !scenario) return null;
|
if (!variant || !scenario) return null;
|
||||||
|
|
||||||
const prompt = await constructPrompt(variant, scenario);
|
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||||
|
|
||||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||||
|
|
||||||
@@ -65,7 +66,10 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
modelResponse = await getCompletion(prompt, input.channel);
|
modelResponse = await getCompletion(
|
||||||
|
prompt as unknown as CompletionCreateParams,
|
||||||
|
input.channel,
|
||||||
|
);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e);
|
console.error(e);
|
||||||
throw e;
|
throw e;
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
import { isObject } from "lodash";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { OpenAIChatModel } from "~/server/types";
|
||||||
|
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||||
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
|
|
||||||
@@ -57,14 +61,10 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: fix this
|
|
||||||
const model = "gpt-3.5-turbo-0613";
|
|
||||||
// const model = getModelName(variant.config);
|
|
||||||
|
|
||||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||||
const overallPromptCost = calculateTokenCost(model, promptTokens);
|
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
|
||||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||||
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
|
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
|
||||||
|
|
||||||
const overallCost = overallPromptCost + overallCompletionCost;
|
const overallCost = overallPromptCost + overallCompletionCost;
|
||||||
|
|
||||||
@@ -106,6 +106,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
label: `Prompt Variant ${largestSortIndex + 2}`,
|
||||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
||||||
constructFn: lastVariant?.constructFn ?? "",
|
constructFn: lastVariant?.constructFn ?? "",
|
||||||
|
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -185,6 +186,27 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let model = existing.model;
|
||||||
|
try {
|
||||||
|
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
|
||||||
|
|
||||||
|
if (!isObject(contructedPrompt)) {
|
||||||
|
return userError("Prompt is not an object");
|
||||||
|
}
|
||||||
|
if (!("model" in contructedPrompt)) {
|
||||||
|
return userError("Prompt does not define a model");
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
typeof contructedPrompt.model !== "string" ||
|
||||||
|
!(contructedPrompt.model in OpenAIChatModel)
|
||||||
|
) {
|
||||||
|
return userError("Prompt defines an invalid model");
|
||||||
|
}
|
||||||
|
model = contructedPrompt.model;
|
||||||
|
} catch (e) {
|
||||||
|
return userError((e as Error).message);
|
||||||
|
}
|
||||||
|
|
||||||
// Create a duplicate with only the config changed
|
// Create a duplicate with only the config changed
|
||||||
const newVariant = await prisma.promptVariant.create({
|
const newVariant = await prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -193,11 +215,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
sortIndex: existing.sortIndex,
|
sortIndex: existing.sortIndex,
|
||||||
uiId: existing.uiId,
|
uiId: existing.uiId,
|
||||||
constructFn: input.constructFn,
|
constructFn: input.constructFn,
|
||||||
|
model,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Hide anything with the same uiId besides the new one
|
// Hide anything with the same uiId besides the new one
|
||||||
const hideOldVariantsAction = prisma.promptVariant.updateMany({
|
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||||
where: {
|
where: {
|
||||||
uiId: existing.uiId,
|
uiId: existing.uiId,
|
||||||
id: {
|
id: {
|
||||||
@@ -209,12 +232,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||||
hideOldVariantsAction,
|
|
||||||
recordExperimentUpdated(existing.experimentId),
|
|
||||||
]);
|
|
||||||
|
|
||||||
return newVariant;
|
return { status: "ok" } as const;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
reorder: publicProcedure
|
reorder: publicProcedure
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ test.skip("constructPrompt", async () => {
|
|||||||
constructFn: `prompt = { "fooz": "bar" }`,
|
constructFn: `prompt = { "fooz": "bar" }`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
variableValues: {
|
foo: "bar",
|
||||||
foo: "bar",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,8 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
|||||||
|
|
||||||
export async function constructPrompt(
|
export async function constructPrompt(
|
||||||
variant: Pick<PromptVariant, "constructFn">,
|
variant: Pick<PromptVariant, "constructFn">,
|
||||||
testScenario: Pick<TestScenario, "variableValues">,
|
scenario: TestScenario["variableValues"],
|
||||||
): Promise<JSONSerializable> {
|
): Promise<JSONSerializable> {
|
||||||
const scenario = testScenario.variableValues as JSONSerializable;
|
|
||||||
|
|
||||||
const code = `
|
const code = `
|
||||||
const scenario = ${JSON.stringify(scenario, null, 2)};
|
const scenario = ${JSON.stringify(scenario, null, 2)};
|
||||||
let prompt
|
let prompt
|
||||||
|
|||||||
6
src/server/utils/error.ts
Normal file
6
src/server/utils/error.ts
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export default function userError(message: string): { status: "error"; message: string } {
|
||||||
|
return {
|
||||||
|
status: "error",
|
||||||
|
message,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -4,14 +4,11 @@ import { Prisma } from "@prisma/client";
|
|||||||
import { streamChatCompletion } from "./openai";
|
import { streamChatCompletion } from "./openai";
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { type JSONSerializable, OpenAIChatModel } from "../types";
|
import { type OpenAIChatModel } from "../types";
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
import { getModelName } from "./getModelName";
|
|
||||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||||
|
|
||||||
env;
|
|
||||||
|
|
||||||
type CompletionResponse = {
|
type CompletionResponse = {
|
||||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||||
statusCode: number;
|
statusCode: number;
|
||||||
@@ -22,35 +19,7 @@ type CompletionResponse = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export async function getCompletion(
|
export async function getCompletion(
|
||||||
payload: JSONSerializable,
|
|
||||||
channel?: string,
|
|
||||||
): Promise<CompletionResponse> {
|
|
||||||
const modelName = getModelName(payload);
|
|
||||||
if (!modelName)
|
|
||||||
return {
|
|
||||||
output: Prisma.JsonNull,
|
|
||||||
statusCode: 400,
|
|
||||||
errorMessage: "Invalid payload provided",
|
|
||||||
timeToComplete: 0,
|
|
||||||
};
|
|
||||||
if (modelName in OpenAIChatModel) {
|
|
||||||
return getOpenAIChatCompletion(
|
|
||||||
payload as unknown as CompletionCreateParams,
|
|
||||||
env.OPENAI_API_KEY,
|
|
||||||
channel,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
output: Prisma.JsonNull,
|
|
||||||
statusCode: 400,
|
|
||||||
errorMessage: "Invalid model provided",
|
|
||||||
timeToComplete: 0,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function getOpenAIChatCompletion(
|
|
||||||
payload: CompletionCreateParams,
|
payload: CompletionCreateParams,
|
||||||
apiKey: string,
|
|
||||||
channel?: string,
|
channel?: string,
|
||||||
): Promise<CompletionResponse> {
|
): Promise<CompletionResponse> {
|
||||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||||
@@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion(
|
|||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
Authorization: `Bearer ${apiKey}`,
|
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
||||||
},
|
},
|
||||||
body: JSON.stringify(payload),
|
body: JSON.stringify(payload),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
import { isObject } from "lodash";
|
|
||||||
import { type JSONSerializable, type SupportedModel } from "../types";
|
|
||||||
import { type Prisma } from "@prisma/client";
|
|
||||||
|
|
||||||
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
|
|
||||||
if (!isObject(config)) return null;
|
|
||||||
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
@@ -9,6 +9,8 @@ import parserTypescript from "prettier/plugins/typescript";
|
|||||||
import parserEstree from "prettier/plugins/estree";
|
import parserEstree from "prettier/plugins/estree";
|
||||||
import { type languages } from "monaco-editor/esm/vs/editor/editor.api";
|
import { type languages } from "monaco-editor/esm/vs/editor/editor.api";
|
||||||
|
|
||||||
|
export const editorBackground = "#fafafa";
|
||||||
|
|
||||||
export type SharedVariantEditorSlice = {
|
export type SharedVariantEditorSlice = {
|
||||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||||
loadMonaco: () => Promise<void>;
|
loadMonaco: () => Promise<void>;
|
||||||
@@ -40,6 +42,9 @@ const customFormatter: languages.DocumentFormattingEditProvider = {
|
|||||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||||
monaco: loader.__getMonacoInstance(),
|
monaco: loader.__getMonacoInstance(),
|
||||||
loadMonaco: async () => {
|
loadMonaco: async () => {
|
||||||
|
// We only want to run this client-side
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
|
||||||
const monaco = await loader.init();
|
const monaco = await loader.init();
|
||||||
|
|
||||||
monaco.editor.defineTheme("customTheme", {
|
monaco.editor.defineTheme("customTheme", {
|
||||||
@@ -47,7 +52,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
inherit: true,
|
inherit: true,
|
||||||
rules: [],
|
rules: [],
|
||||||
colors: {
|
colors: {
|
||||||
"editor.background": "#fafafa",
|
"editor.background": editorBackground,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const calculateTokenCost = (
|
export const calculateTokenCost = (
|
||||||
model: SupportedModel | null,
|
model: SupportedModel | string | null,
|
||||||
numTokens: number,
|
numTokens: number,
|
||||||
isCompletion = false,
|
isCompletion = false,
|
||||||
) => {
|
) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user