Compare commits
1 Commits
github-act
...
autoformat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a20f81939d |
51
.github/workflows/ci.yaml
vendored
51
.github/workflows/ci.yaml
vendored
@@ -1,51 +0,0 @@
|
||||
name: CI checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
run-checks:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: "20"
|
||||
|
||||
- uses: pnpm/action-setup@v2
|
||||
name: Install pnpm
|
||||
id: pnpm-install
|
||||
with:
|
||||
version: 8.6.1
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||
|
||||
- uses: actions/cache@v3
|
||||
name: Setup pnpm cache
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Check types
|
||||
run: pnpm tsc
|
||||
|
||||
- name: Lint
|
||||
run: SKIP_ENV_VALIDATION=1 pnpm lint
|
||||
|
||||
- name: Check prettier
|
||||
run: pnpm prettier . --check
|
||||
@@ -48,6 +48,7 @@
|
||||
"openai": "4.0.0-beta.2",
|
||||
"pluralize": "^8.0.0",
|
||||
"posthog-js": "^1.68.4",
|
||||
"prettier": "^3.0.0",
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-icons": "^4.10.1",
|
||||
@@ -77,8 +78,8 @@
|
||||
"eslint": "^8.40.0",
|
||||
"eslint-config-next": "^13.4.2",
|
||||
"eslint-plugin-unused-imports": "^2.0.0",
|
||||
"monaco-editor": "^0.40.0",
|
||||
"openapi-typescript": "^6.3.4",
|
||||
"prettier": "^3.0.0",
|
||||
"prisma": "^4.14.0",
|
||||
"raw-loader": "^4.0.2",
|
||||
"typescript": "^5.0.4",
|
||||
|
||||
22
pnpm-lock.yaml
generated
22
pnpm-lock.yaml
generated
@@ -22,7 +22,7 @@ dependencies:
|
||||
version: 11.11.0(@emotion/react@11.11.1)(@types/react@18.2.6)(react@18.2.0)
|
||||
'@monaco-editor/loader':
|
||||
specifier: ^1.3.3
|
||||
version: 1.3.3(monaco-editor@0.39.0)
|
||||
version: 1.3.3(monaco-editor@0.40.0)
|
||||
'@next-auth/prisma-adapter':
|
||||
specifier: ^1.0.5
|
||||
version: 1.0.5(@prisma/client@4.14.0)(next-auth@4.22.1)
|
||||
@@ -107,6 +107,9 @@ dependencies:
|
||||
posthog-js:
|
||||
specifier: ^1.68.4
|
||||
version: 1.68.4
|
||||
prettier:
|
||||
specifier: ^3.0.0
|
||||
version: 3.0.0
|
||||
react:
|
||||
specifier: 18.2.0
|
||||
version: 18.2.0
|
||||
@@ -190,12 +193,12 @@ devDependencies:
|
||||
eslint-plugin-unused-imports:
|
||||
specifier: ^2.0.0
|
||||
version: 2.0.0(@typescript-eslint/eslint-plugin@5.59.6)(eslint@8.40.0)
|
||||
monaco-editor:
|
||||
specifier: ^0.40.0
|
||||
version: 0.40.0
|
||||
openapi-typescript:
|
||||
specifier: ^6.3.4
|
||||
version: 6.3.4
|
||||
prettier:
|
||||
specifier: ^3.0.0
|
||||
version: 3.0.0
|
||||
prisma:
|
||||
specifier: ^4.14.0
|
||||
version: 4.14.0
|
||||
@@ -2029,12 +2032,12 @@ packages:
|
||||
'@jridgewell/sourcemap-codec': 1.4.14
|
||||
dev: true
|
||||
|
||||
/@monaco-editor/loader@1.3.3(monaco-editor@0.39.0):
|
||||
/@monaco-editor/loader@1.3.3(monaco-editor@0.40.0):
|
||||
resolution: {integrity: sha512-6KKF4CTzcJiS8BJwtxtfyYt9shBiEv32ateQ9T4UVogwn4HM/uPo9iJd2Dmbkpz8CM6Y0PDUpjnZzCwC+eYo2Q==}
|
||||
peerDependencies:
|
||||
monaco-editor: '>= 0.21.0 < 1'
|
||||
dependencies:
|
||||
monaco-editor: 0.39.0
|
||||
monaco-editor: 0.40.0
|
||||
state-local: 1.0.7
|
||||
dev: false
|
||||
|
||||
@@ -5135,9 +5138,8 @@ packages:
|
||||
ufo: 1.1.2
|
||||
dev: true
|
||||
|
||||
/monaco-editor@0.39.0:
|
||||
resolution: {integrity: sha512-zhbZ2Nx93tLR8aJmL2zI1mhJpsl87HMebNBM6R8z4pLfs8pj604pIVIVwyF1TivcfNtIPpMXL+nb3DsBmE/x6Q==}
|
||||
dev: false
|
||||
/monaco-editor@0.40.0:
|
||||
resolution: {integrity: sha512-1wymccLEuFSMBvCk/jT1YDW/GuxMLYwnFwF9CDyYCxoTw2Pt379J3FUhwy9c43j51JdcxVPjwk0jm0EVDsBS2g==}
|
||||
|
||||
/ms@2.0.0:
|
||||
resolution: {integrity: sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==}
|
||||
@@ -5637,7 +5639,7 @@ packages:
|
||||
resolution: {integrity: sha512-zBf5eHpwHOGPC47h0zrPyNn+eAEIdEzfywMoYn2XPi0P44Zp0tSq64rq0xAREh4auw2cJZHo9QUob+NqCQky4g==}
|
||||
engines: {node: '>=14'}
|
||||
hasBin: true
|
||||
dev: true
|
||||
dev: false
|
||||
|
||||
/pretty-format@29.6.1:
|
||||
resolution: {integrity: sha512-7jRj+yXO0W7e4/tSJKoR7HRIHLPPjtNaUGG2xxKQnGvPNRkgWcQ0AZX6P4KBRJN4FcTBWb3sa7DVUJmocYuoog==}
|
||||
|
||||
@@ -26,11 +26,10 @@ model Experiment {
|
||||
}
|
||||
|
||||
model PromptVariant {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
label String
|
||||
|
||||
label String
|
||||
constructFn String
|
||||
model String @default("gpt-3.5-turbo")
|
||||
|
||||
uiId String @default(uuid()) @db.Uuid
|
||||
visible Boolean @default(true)
|
||||
|
||||
@@ -37,6 +37,10 @@ export default function OutputCell({
|
||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||
// disabledReason = "Save your prompt variant to see output";
|
||||
|
||||
// const model = getModelName(variant.config as JSONSerializable);
|
||||
// TODO: Temporarily hardcoding this while we get other stuff working
|
||||
const model = "gpt-3.5-turbo";
|
||||
|
||||
const outputMutation = api.outputs.get.useMutation();
|
||||
|
||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
||||
@@ -136,7 +140,7 @@ export default function OutputCell({
|
||||
{ maxLength: 40 },
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} />
|
||||
<OutputStats model={model} modelOutput={output} scenario={scenario} />
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -146,7 +150,7 @@ export default function OutputCell({
|
||||
return (
|
||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
{contentToDisplay}
|
||||
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />}
|
||||
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ export const OutputStats = ({
|
||||
modelOutput,
|
||||
scenario,
|
||||
}: {
|
||||
model: SupportedModel | string | null;
|
||||
model: SupportedModel | null;
|
||||
modelOutput: ModelOutput;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
|
||||
@@ -27,11 +27,16 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
const toast = useToast();
|
||||
|
||||
const [onSave] = useHandledAsyncCallback(async () => {
|
||||
const currentFn = editorRef.current?.getValue();
|
||||
if (!editorRef.current) return;
|
||||
|
||||
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
||||
|
||||
const currentFn = editorRef.current.getValue();
|
||||
|
||||
if (!currentFn) return;
|
||||
|
||||
// Check if the editor has any typescript errors
|
||||
const model = editorRef.current?.getModel();
|
||||
const model = editorRef.current.getModel();
|
||||
if (!model) return;
|
||||
|
||||
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
|
||||
@@ -59,17 +64,10 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resp = await replaceVariant.mutateAsync({
|
||||
await replaceVariant.mutateAsync({
|
||||
id: props.variant.id,
|
||||
constructFn: currentFn,
|
||||
});
|
||||
if (resp.status === "error") {
|
||||
return toast({
|
||||
title: "Error saving variant",
|
||||
description: resp.message,
|
||||
status: "error",
|
||||
});
|
||||
}
|
||||
|
||||
await utils.promptVariants.list.invalidate();
|
||||
|
||||
|
||||
@@ -75,7 +75,6 @@ export const experimentsRouter = createTRPCRouter({
|
||||
stream: true,
|
||||
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
||||
}`,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
|
||||
@@ -6,7 +6,6 @@ import type { Prisma } from "@prisma/client";
|
||||
import { reevaluateVariant } from "~/server/utils/evaluations";
|
||||
import { getCompletion } from "~/server/utils/getCompletion";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
|
||||
export const modelOutputsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
@@ -44,7 +43,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
const prompt = await constructPrompt(variant, scenario);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
@@ -66,10 +65,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
||||
};
|
||||
} else {
|
||||
try {
|
||||
modelResponse = await getCompletion(
|
||||
prompt as unknown as CompletionCreateParams,
|
||||
input.channel,
|
||||
);
|
||||
modelResponse = await getCompletion(prompt, input.channel);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
throw e;
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import { isObject } from "lodash";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { OpenAIChatModel } from "~/server/types";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
|
||||
@@ -61,10 +57,14 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: fix this
|
||||
const model = "gpt-3.5-turbo-0613";
|
||||
// const model = getModelName(variant.config);
|
||||
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
|
||||
const overallPromptCost = calculateTokenCost(model, promptTokens);
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
|
||||
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
|
||||
|
||||
const overallCost = overallPromptCost + overallCompletionCost;
|
||||
|
||||
@@ -106,7 +106,6 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
||||
constructFn: lastVariant?.constructFn ?? "",
|
||||
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -186,27 +185,6 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
let model = existing.model;
|
||||
try {
|
||||
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
|
||||
|
||||
if (!isObject(contructedPrompt)) {
|
||||
return userError("Prompt is not an object");
|
||||
}
|
||||
if (!("model" in contructedPrompt)) {
|
||||
return userError("Prompt does not define a model");
|
||||
}
|
||||
if (
|
||||
typeof contructedPrompt.model !== "string" ||
|
||||
!(contructedPrompt.model in OpenAIChatModel)
|
||||
) {
|
||||
return userError("Prompt defines an invalid model");
|
||||
}
|
||||
model = contructedPrompt.model;
|
||||
} catch (e) {
|
||||
return userError((e as Error).message);
|
||||
}
|
||||
|
||||
// Create a duplicate with only the config changed
|
||||
const newVariant = await prisma.promptVariant.create({
|
||||
data: {
|
||||
@@ -215,12 +193,11 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
sortIndex: existing.sortIndex,
|
||||
uiId: existing.uiId,
|
||||
constructFn: input.constructFn,
|
||||
model,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide anything with the same uiId besides the new one
|
||||
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||
const hideOldVariantsAction = prisma.promptVariant.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
@@ -232,9 +209,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||
await prisma.$transaction([
|
||||
hideOldVariantsAction,
|
||||
recordExperimentUpdated(existing.experimentId),
|
||||
]);
|
||||
|
||||
return { status: "ok" } as const;
|
||||
return newVariant;
|
||||
}),
|
||||
|
||||
reorder: publicProcedure
|
||||
|
||||
@@ -7,7 +7,9 @@ test.skip("constructPrompt", async () => {
|
||||
constructFn: `prompt = { "fooz": "bar" }`,
|
||||
},
|
||||
{
|
||||
foo: "bar",
|
||||
variableValues: {
|
||||
foo: "bar",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function constructPrompt(
|
||||
variant: Pick<PromptVariant, "constructFn">,
|
||||
scenario: TestScenario["variableValues"],
|
||||
testScenario: Pick<TestScenario, "variableValues">,
|
||||
): Promise<JSONSerializable> {
|
||||
const scenario = testScenario.variableValues as JSONSerializable;
|
||||
|
||||
const code = `
|
||||
const scenario = ${JSON.stringify(scenario, null, 2)};
|
||||
let prompt
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
export default function userError(message: string): { status: "error"; message: string } {
|
||||
return {
|
||||
status: "error",
|
||||
message,
|
||||
};
|
||||
}
|
||||
@@ -4,11 +4,14 @@ import { Prisma } from "@prisma/client";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type OpenAIChatModel } from "../types";
|
||||
import { type JSONSerializable, OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { getModelName } from "./getModelName";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
|
||||
env;
|
||||
|
||||
type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
statusCode: number;
|
||||
@@ -19,7 +22,35 @@ type CompletionResponse = {
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
payload: JSONSerializable,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
const modelName = getModelName(payload);
|
||||
if (!modelName)
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid payload provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
if (modelName in OpenAIChatModel) {
|
||||
return getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
env.OPENAI_API_KEY,
|
||||
channel,
|
||||
);
|
||||
}
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid model provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
apiKey: string,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||
@@ -29,7 +60,7 @@ export async function getCompletion(
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
9
src/server/utils/getModelName.ts
Normal file
9
src/server/utils/getModelName.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
import { isObject } from "lodash";
|
||||
import { type JSONSerializable, type SupportedModel } from "../types";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
|
||||
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
|
||||
if (!isObject(config)) return null;
|
||||
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
||||
return null;
|
||||
}
|
||||
@@ -2,6 +2,12 @@ import { type RouterOutputs } from "~/utils/api";
|
||||
import { type SliceCreator } from "./store";
|
||||
import loader from "@monaco-editor/loader";
|
||||
import openAITypes from "~/codegen/openai.types.ts.txt";
|
||||
import prettier from "prettier/standalone";
|
||||
import parserTypescript from "prettier/plugins/typescript";
|
||||
|
||||
// @ts-expect-error for some reason missing from types
|
||||
import parserEstree from "prettier/plugins/estree";
|
||||
import { type languages } from "monaco-editor/esm/vs/editor/editor.api";
|
||||
|
||||
export type SharedVariantEditorSlice = {
|
||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||
@@ -11,12 +17,29 @@ export type SharedVariantEditorSlice = {
|
||||
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
|
||||
};
|
||||
|
||||
const customFormatter: languages.DocumentFormattingEditProvider = {
|
||||
provideDocumentFormattingEdits: async (model) => {
|
||||
const val = model.getValue();
|
||||
console.log("going to format!", val);
|
||||
const text = await prettier.format(val, {
|
||||
parser: "typescript",
|
||||
plugins: [parserTypescript, parserEstree],
|
||||
// We're showing these in pretty narrow panes so let's keep the print width low
|
||||
printWidth: 60,
|
||||
});
|
||||
|
||||
return [
|
||||
{
|
||||
range: model.getFullModelRange(),
|
||||
text,
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||
monaco: loader.__getMonacoInstance(),
|
||||
loadMonaco: async () => {
|
||||
// We only want to run this client-side
|
||||
if (typeof window === "undefined") return;
|
||||
|
||||
const monaco = await loader.init();
|
||||
|
||||
monaco.editor.defineTheme("customTheme", {
|
||||
@@ -43,6 +66,8 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
monaco.Uri.parse("file:///openai.types.ts"),
|
||||
);
|
||||
|
||||
monaco.languages.registerDocumentFormattingEditProvider("typescript", customFormatter);
|
||||
|
||||
set((state) => {
|
||||
state.sharedVariantEditor.monaco = monaco;
|
||||
});
|
||||
|
||||
@@ -23,7 +23,7 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||
};
|
||||
|
||||
export const calculateTokenCost = (
|
||||
model: SupportedModel | string | null,
|
||||
model: SupportedModel | null,
|
||||
numTokens: number,
|
||||
isCompletion = false,
|
||||
) => {
|
||||
|
||||
Reference in New Issue
Block a user