Compare commits

..

6 Commits

Author SHA1 Message Date
Kyle Corbitt
731406d1f4 Pseudo function signatures
Show pseudo function signatures in the variant editor box as a UX hint that you're typing in javascript and have access to the scenario.
2023-07-14 13:56:45 -07:00
Kyle Corbitt
3c59e4b774 Merge pull request #42 from OpenPipe/autoformat
implement format on save
2023-07-14 12:56:41 -07:00
Kyle Corbitt
972b1f2333 Merge pull request #41 from OpenPipe/github-actions
CI checks
2023-07-14 11:40:42 -07:00
Kyle Corbitt
7321f3deda CI checks 2023-07-14 11:36:47 -07:00
Kyle Corbitt
2bd41fdfbf Merge pull request #40 from OpenPipe:completion-costs
store model and use to calculate completion costs
2023-07-14 11:07:15 -07:00
Kyle Corbitt
a5378b106b store model and use to calculate completion costs 2023-07-14 11:06:07 -07:00
15 changed files with 136 additions and 88 deletions

51
.github/workflows/ci.yaml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: CI checks
on:
pull_request:
branches: [main]
jobs:
run-checks:
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v2
- name: Set up Node.js
uses: actions/setup-node@v2
with:
node-version: "20"
- uses: pnpm/action-setup@v2
name: Install pnpm
id: pnpm-install
with:
version: 8.6.1
run_install: false
- name: Get pnpm store directory
id: pnpm-cache
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
- uses: actions/cache@v3
name: Setup pnpm cache
with:
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: Install Dependencies
run: pnpm install
- name: Check types
run: pnpm tsc
- name: Lint
run: SKIP_ENV_VALIDATION=1 pnpm lint
- name: Check prettier
run: pnpm prettier . --check

View File

@@ -27,9 +27,10 @@ model Experiment {
model PromptVariant {
id String @id @default(uuid()) @db.Uuid
label String
label String
constructFn String
model String @default("gpt-3.5-turbo")
uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true)

View File

@@ -37,10 +37,6 @@ export default function OutputCell({
// if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output";
// const model = getModelName(variant.config as JSONSerializable);
// TODO: Temporarily hardcoding this while we get other stuff working
const model = "gpt-3.5-turbo";
const outputMutation = api.outputs.get.useMutation();
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
@@ -140,7 +136,7 @@ export default function OutputCell({
{ maxLength: 40 },
)}
</SyntaxHighlighter>
<OutputStats model={model} modelOutput={output} scenario={scenario} />
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} />
</Box>
);
}
@@ -150,7 +146,7 @@ export default function OutputCell({
return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay}
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />}
</Flex>
);
}

View File

@@ -17,7 +17,7 @@ export const OutputStats = ({
modelOutput,
scenario,
}: {
model: SupportedModel | null;
model: SupportedModel | string | null;
modelOutput: ModelOutput;
scenario: Scenario;
}) => {

View File

@@ -1,9 +1,10 @@
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import { editorBackground } from "~/state/sharedVariantEditor.slice";
// import openAITypes from "~/codegen/openai.types.ts.txt";
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
@@ -64,10 +65,17 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
return;
}
await replaceVariant.mutateAsync({
const resp = await replaceVariant.mutateAsync({
id: props.variant.id,
constructFn: currentFn,
});
if (resp.status === "error") {
return toast({
title: "Error saving variant",
description: resp.message,
status: "error",
});
}
await utils.promptVariants.list.invalidate();
@@ -122,21 +130,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
/* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]);
// useEffect(() => {
// const savedConfigChanged = lastSavedFn !== savedConfig;
// lastSavedFn = savedConfig;
// if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
// editorRef.current?.setValue(savedConfig);
// }
// checkForChanges();
// }, [savedConfig, checkForChanges]);
return (
<Box w="100%" pos="relative">
<VStack
spacing={0}
align="stretch"
fontSize="xs"
fontWeight="bold"
color="gray.600"
py={2}
bgColor={editorBackground}
>
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
<code>{`return prompt; }`}</code>
</VStack>
{isChanged && (
<HStack pos="absolute" bottom={2} right={2}>
<Button

View File

@@ -75,6 +75,7 @@ export const experimentsRouter = createTRPCRouter({
stream: true,
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
}`,
model: "gpt-3.5-turbo-0613",
},
}),
prisma.testScenario.create({

View File

@@ -6,6 +6,7 @@ import type { Prisma } from "@prisma/client";
import { reevaluateVariant } from "~/server/utils/evaluations";
import { getCompletion } from "~/server/utils/getCompletion";
import { constructPrompt } from "~/server/utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
@@ -43,7 +44,7 @@ export const modelOutputsRouter = createTRPCRouter({
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario);
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
@@ -65,7 +66,10 @@ export const modelOutputsRouter = createTRPCRouter({
};
} else {
try {
modelResponse = await getCompletion(prompt, input.channel);
modelResponse = await getCompletion(
prompt as unknown as CompletionCreateParams,
input.channel,
);
} catch (e) {
console.error(e);
throw e;

View File

@@ -1,6 +1,10 @@
import { isObject } from "lodash";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { OpenAIChatModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
@@ -57,14 +61,10 @@ export const promptVariantsRouter = createTRPCRouter({
},
});
// TODO: fix this
const model = "gpt-3.5-turbo-0613";
// const model = getModelName(variant.config);
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(model, promptTokens);
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
const overallCost = overallPromptCost + overallCompletionCost;
@@ -106,6 +106,7 @@ export const promptVariantsRouter = createTRPCRouter({
label: `Prompt Variant ${largestSortIndex + 2}`,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
constructFn: lastVariant?.constructFn ?? "",
model: lastVariant?.model ?? "gpt-3.5-turbo",
},
});
@@ -185,6 +186,27 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
}
let model = existing.model;
try {
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
if (!isObject(contructedPrompt)) {
return userError("Prompt is not an object");
}
if (!("model" in contructedPrompt)) {
return userError("Prompt does not define a model");
}
if (
typeof contructedPrompt.model !== "string" ||
!(contructedPrompt.model in OpenAIChatModel)
) {
return userError("Prompt defines an invalid model");
}
model = contructedPrompt.model;
} catch (e) {
return userError((e as Error).message);
}
// Create a duplicate with only the config changed
const newVariant = await prisma.promptVariant.create({
data: {
@@ -193,11 +215,12 @@ export const promptVariantsRouter = createTRPCRouter({
sortIndex: existing.sortIndex,
uiId: existing.uiId,
constructFn: input.constructFn,
model,
},
});
// Hide anything with the same uiId besides the new one
const hideOldVariantsAction = prisma.promptVariant.updateMany({
const hideOldVariants = prisma.promptVariant.updateMany({
where: {
uiId: existing.uiId,
id: {
@@ -209,12 +232,9 @@ export const promptVariantsRouter = createTRPCRouter({
},
});
await prisma.$transaction([
hideOldVariantsAction,
recordExperimentUpdated(existing.experimentId),
]);
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
return newVariant;
return { status: "ok" } as const;
}),
reorder: publicProcedure

View File

@@ -7,10 +7,8 @@ test.skip("constructPrompt", async () => {
constructFn: `prompt = { "fooz": "bar" }`,
},
{
variableValues: {
foo: "bar",
},
},
);
console.log(constructed);

View File

@@ -6,10 +6,8 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function constructPrompt(
variant: Pick<PromptVariant, "constructFn">,
testScenario: Pick<TestScenario, "variableValues">,
scenario: TestScenario["variableValues"],
): Promise<JSONSerializable> {
const scenario = testScenario.variableValues as JSONSerializable;
const code = `
const scenario = ${JSON.stringify(scenario, null, 2)};
let prompt

View File

@@ -0,0 +1,6 @@
export default function userError(message: string): { status: "error"; message: string } {
return {
status: "error",
message,
};
}

View File

@@ -4,14 +4,11 @@ import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type JSONSerializable, OpenAIChatModel } from "../types";
import { type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { getModelName } from "./getModelName";
import { rateLimitErrorMessage } from "~/sharedStrings";
env;
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number;
@@ -22,35 +19,7 @@ type CompletionResponse = {
};
export async function getCompletion(
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> {
const modelName = getModelName(payload);
if (!modelName)
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid payload provided",
timeToComplete: 0,
};
if (modelName in OpenAIChatModel) {
return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
env.OPENAI_API_KEY,
channel,
);
}
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid model provided",
timeToComplete: 0,
};
}
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
apiKey: string,
channel?: string,
): Promise<CompletionResponse> {
// If functions are enabled, disable streaming so that we get the full response with token counts
@@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion(
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
},
body: JSON.stringify(payload),
});

View File

@@ -1,9 +0,0 @@
import { isObject } from "lodash";
import { type JSONSerializable, type SupportedModel } from "../types";
import { type Prisma } from "@prisma/client";
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
if (!isObject(config)) return null;
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
return null;
}

View File

@@ -9,6 +9,8 @@ import parserTypescript from "prettier/plugins/typescript";
import parserEstree from "prettier/plugins/estree";
import { type languages } from "monaco-editor/esm/vs/editor/editor.api";
export const editorBackground = "#fafafa";
export type SharedVariantEditorSlice = {
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
loadMonaco: () => Promise<void>;
@@ -40,6 +42,9 @@ const customFormatter: languages.DocumentFormattingEditProvider = {
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
monaco: loader.__getMonacoInstance(),
loadMonaco: async () => {
// We only want to run this client-side
if (typeof window === "undefined") return;
const monaco = await loader.init();
monaco.editor.defineTheme("customTheme", {
@@ -47,7 +52,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
inherit: true,
rules: [],
colors: {
"editor.background": "#fafafa",
"editor.background": editorBackground,
},
});

View File

@@ -23,7 +23,7 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
};
export const calculateTokenCost = (
model: SupportedModel | null,
model: SupportedModel | string | null,
numTokens: number,
isCompletion = false,
) => {