Compare commits

..

13 Commits

Author SHA1 Message Date
David Corbitt
01343efb6a Properly use isDisabled 2023-07-20 22:54:38 -07:00
David Corbitt
c7aaaea426 Update instructions 2023-07-20 22:42:51 -07:00
David Corbitt
332e7afb0c Pass variant to SelectModelModal 2023-07-20 22:39:32 -07:00
David Corbitt
fe08e29f47 Show prompt comparison in SelectModelModal 2023-07-20 22:39:14 -07:00
David Corbitt
89ce730e52 Accept newModel in getModifiedPromptFn 2023-07-20 22:38:42 -07:00
David Corbitt
ad87c1b2eb Change RefinePromptModal styles 2023-07-20 22:38:09 -07:00
David Corbitt
58ddc72cbb Make CompareFunctions more configurable 2023-07-20 22:36:21 -07:00
arcticfly
9978075867 Fix auth flicker (#75)
* Remove experiments flicker for unauthenticated users

* Decrease size of NewScenarioButton spinner
2023-07-20 20:46:31 -07:00
Kyle Corbitt
372c2512c9 Merge pull request #73 from OpenPipe/model-providers
More work on modelProviders
2023-07-20 18:56:14 -07:00
arcticfly
1822fe198e Initially render AutoResizeTextArea without overflow (#72)
* Rerender resized text area with scroll

* Remove default hidden overflow
2023-07-20 15:00:09 -07:00
Kyle Corbitt
f06e1db3db Merge pull request #71 from OpenPipe/model-providers
Prep for more model providers
2023-07-20 14:55:31 -07:00
arcticfly
9314a86857 Use translation in initial scenarios (#70) 2023-07-20 14:28:48 -07:00
David Corbitt
54dcb4a567 Prevent text input labels from overlaying scenarios header 2023-07-20 14:28:36 -07:00
26 changed files with 204 additions and 308 deletions

View File

@@ -73,7 +73,6 @@
"react-syntax-highlighter": "^15.5.0",
"react-textarea-autosize": "^8.5.0",
"recast": "^0.23.3",
"replicate": "^0.12.3",
"socket.io": "^4.7.1",
"socket.io-client": "^4.7.1",
"superjson": "1.12.2",

8
pnpm-lock.yaml generated
View File

@@ -161,9 +161,6 @@ dependencies:
recast:
specifier: ^0.23.3
version: 0.23.3
replicate:
specifier: ^0.12.3
version: 0.12.3
socket.io:
specifier: ^4.7.1
version: 4.7.1
@@ -6991,11 +6988,6 @@ packages:
functions-have-names: 1.2.3
dev: true
/replicate@0.12.3:
resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==}
engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'}
dev: false
/require-directory@2.1.1:
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
engines: {node: '>=0.10.0'}

View File

@@ -1,19 +1,22 @@
import { Textarea, type TextareaProps } from "@chakra-ui/react";
import ResizeTextarea from "react-textarea-autosize";
import React from "react";
import React, { useLayoutEffect, useState } from "react";
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
HTMLTextAreaElement,
TextareaProps & { minRows?: number }
> = (props, ref) => {
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
const [isRerendered, setIsRerendered] = useState(false);
useLayoutEffect(() => setIsRerendered(true), []);
return (
<Textarea
minH="unset"
overflow="hidden"
minRows={minRows}
overflowY={isRerendered ? overflowY : "hidden"}
w="100%"
resize="none"
ref={ref}
minRows={1}
transition="height none"
as={ResizeTextarea}
{...props}

View File

@@ -18,11 +18,9 @@ export const FloatingLabelInput = ({
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
fontSize={isFocused || !!value ? "12px" : "16px"}
transition="all 0.15s"
zIndex="100"
zIndex="5"
bg="white"
px={1}
mt={0}
mb={2}
lineHeight="1"
pointerEvents="none"
color={isFocused ? "blue.500" : "gray.500"}

View File

@@ -49,7 +49,11 @@ export default function NewScenarioButton() {
Add Scenario
</StyledButton>
<StyledButton onClick={onAutogenerate}>
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
<Icon
as={autogenerating ? Spinner : BsPlus}
boxSize={autogenerating ? 4 : 6}
mr={autogenerating ? 2 : 0}
/>
Autogenerate Scenario
</StyledButton>
</HStack>

View File

@@ -88,11 +88,9 @@ export default function OutputCell({
}
const normalizedOutput = modelOutput
? // @ts-expect-error TODO FIX ASAP
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
: streamedMessage
? // @ts-expect-error TODO FIX ASAP
provider.normalizeOutput(streamedMessage)
? provider.normalizeOutput(streamedMessage)
: null;
if (modelOutput && normalizedOutput?.type === "json") {

View File

@@ -4,5 +4,5 @@ export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "0",
backgroundColor: "#fff",
zIndex: 1,
zIndex: 10,
};

View File

@@ -1,4 +1,4 @@
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react";
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
import React from "react";
import DiffViewer, { DiffMethod } from "react-diff-viewer";
import Prism from "prismjs";
@@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => {
const CompareFunctions = ({
originalFunction,
newFunction = "",
leftTitle = "Original",
rightTitle = "Modified",
...props
}: {
originalFunction: string;
newFunction?: string;
}) => {
leftTitle?: string;
rightTitle?: string;
} & StackProps) => {
const showSplitView = useBreakpointValue(
{
base: false,
@@ -34,22 +39,20 @@ const CompareFunctions = ({
);
return (
<HStack w="full" spacing={5}>
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={showSplitView}
hideLineNumbers={!showSplitView}
leftTitle="Original"
rightTitle={newFunction ? "Modified" : "Unmodified"}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
showDiffOnly={false}
/>
</VStack>
</HStack>
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={showSplitView}
hideLineNumbers={!showSplitView}
leftTitle={leftTitle}
rightTitle={rightTitle}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
showDiffOnly={false}
/>
</VStack>
);
};

View File

@@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({
minW="unset"
size="sm"
onClick={() => onSubmit()}
disabled={!instructions}
variant={instructions ? "solid" : "ghost"}
mr={4}
borderRadius="8"

View File

@@ -36,25 +36,25 @@ export const RefinePromptModal = ({
}) => {
const utils = api.useContext();
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getRefinedPromptFn.useMutation();
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
RefineOptionLabel | undefined
>(undefined);
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
async (label?: RefineOptionLabel) => {
if (!variant.experimentId) return;
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
setActiveRefineOptionLabel(label);
await getRefinedPromptMutateAsync({
await getModifiedPromptMutateAsync({
id: variant.id,
instructions: updatedInstructions,
});
},
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
@@ -75,7 +75,11 @@ export const RefinePromptModal = ({
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
<Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
@@ -93,15 +97,15 @@ export const RefinePromptModal = ({
label="Convert to function call"
activeLabel={activeRefineOptionLabel}
icon={VscJson}
onClick={getRefinedPromptFn}
loading={refiningInProgress}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>
<RefineOption
label="Add chain of thought"
activeLabel={activeRefineOptionLabel}
icon={TfiThought}
onClick={getRefinedPromptFn}
loading={refiningInProgress}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>
</SimpleGrid>
<HStack>
@@ -110,13 +114,14 @@ export const RefinePromptModal = ({
<CustomInstructionsInput
instructions={instructions}
setInstructions={setInstructions}
loading={refiningInProgress}
onSubmit={getRefinedPromptFn}
loading={modificationInProgress}
onSubmit={getModifiedPromptFn}
/>
</VStack>
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
maxH="40vh"
/>
</VStack>
</ModalBody>
@@ -124,12 +129,10 @@ export const RefinePromptModal = ({
<ModalFooter>
<HStack spacing={4}>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
disabled={replacementInProgress || !refinedPromptFn}
_disabled={{
bgColor: "blue.500",
}}
isDisabled={replacementInProgress || !refinedPromptFn}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>

View File

@@ -12,7 +12,7 @@ export const refineOptions: Record<
This is what a prompt looks like before adding chain of thought:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -25,11 +25,11 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
},
],
};
});
This is what one looks like after adding chain of thought:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -42,13 +42,13 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
},
],
};
});
Here's another example:
Before:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -78,11 +78,11 @@ export const refineOptions: Record<
function_call: {
name: "score_post",
},
};
});
After:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -115,7 +115,7 @@ export const refineOptions: Record<
function_call: {
name: "score_post",
},
};
});
Add chain of thought to the original prompt.`,
},
@@ -125,7 +125,7 @@ export const refineOptions: Record<
This is what a prompt looks like before adding a function:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -138,11 +138,11 @@ export const refineOptions: Record<
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
},
],
};
});
This is what one looks like after adding a function:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
@@ -172,13 +172,13 @@ export const refineOptions: Record<
function_call: {
name: "extract_sentiment",
},
};
});
Here's another example of adding a function:
Before:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -196,11 +196,11 @@ export const refineOptions: Record<
},
],
temperature: 0,
};
});
After:
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
messages: [
{
@@ -230,7 +230,7 @@ export const refineOptions: Record<
function_call: {
name: "score_post",
},
};
});
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
},

View File

@@ -20,36 +20,60 @@ import { ModelStatsCard } from "./ModelStatsCard";
import { SelectModelSearch } from "./SelectModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es";
export const SelectModelModal = ({
originalModel,
variantId,
variant,
onClose,
}: {
originalModel: SupportedModel;
variantId: string;
variant: PromptVariant;
onClose: () => void;
}) => {
const originalModel = variant.model as SupportedModel;
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined);
const utils = api.useContext();
const experiment = useExperiment();
const createMutation = api.promptVariants.create.useMutation();
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment?.data?.id) return;
await createMutation.mutateAsync({
experimentId: experiment?.data?.id,
variantId,
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment) return;
await getModifiedPromptMutateAsync({
id: variant.id,
newModel: selectedModel,
});
setConvertedModel(selectedModel);
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (
!variant.experimentId ||
!modifiedPromptFn ||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
)
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: modifiedPromptFn,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [createMutation, experiment?.data?.id, variantId, onClose]);
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
<Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
@@ -66,18 +90,36 @@ export const SelectModelModal = ({
<ModelStatsCard label="New Model" model={selectedModel} />
)}
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && (
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={modifiedPromptFn}
leftTitle={originalModel}
rightTitle={convertedModel}
/>
)}
</VStack>
</ModalBody>
<ModalFooter>
<Button
colorScheme="blue"
onClick={createNewVariant}
minW={24}
disabled={originalModel === selectedModel}
>
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
</Button>
<HStack>
<Button
colorScheme="gray"
onClick={getModifiedPromptFn}
minW={24}
isDisabled={originalModel === selectedModel || modificationInProgress}
>
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
</Button>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>

View File

@@ -18,7 +18,6 @@ import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
export default function VariantHeaderMenuButton({
variant,
@@ -99,11 +98,7 @@ export default function VariantHeaderMenuButton({
</MenuList>
</Menu>
{selectModelModalOpen && (
<SelectModelModal
originalModel={variant.model as SupportedModel}
variantId={variant.id}
onClose={() => setSelectModelModalOpen(false)}
/>
<SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} />
)}
{refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />

View File

@@ -17,7 +17,6 @@ export const env = createEnv({
.transform((val) => val.toLowerCase() === "true"),
GITHUB_CLIENT_ID: z.string().min(1),
GITHUB_CLIENT_SECRET: z.string().min(1),
REPLICATE_API_TOKEN: z.string().min(1),
},
/**
@@ -43,7 +42,6 @@ export const env = createEnv({
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
},
/**
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.

View File

@@ -1,9 +1,7 @@
import openaiChatCompletion from "./openai-ChatCompletion";
import replicateLlama2 from "./replicate-llama2";
const modelProviders = {
"openai/ChatCompletion": openaiChatCompletion,
"replicate/llama2": replicateLlama2,
} as const;
export default modelProviders;

View File

@@ -1,14 +1,10 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend";
// TODO: make sure we get a typescript error if you forget to add a provider here
import modelProviderFrontend from "./openai-ChatCompletion/frontend";
// Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server.
const modelProvidersFrontend = {
"openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend,
"openai/ChatCompletion": modelProviderFrontend,
} as const;
export default modelProvidersFrontend;

View File

@@ -1,13 +0,0 @@
import { type ReplicateLlama2Provider } from ".";
import { type ModelProviderFrontend } from "../types";
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
normalizeOutput: (output) => {
return {
type: "text",
value: output.join(""),
};
},
};
export default modelProviderFrontend;

View File

@@ -1,62 +0,0 @@
import { env } from "~/env.mjs";
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
import { type CompletionResponse } from "../types";
import Replicate from "replicate";
const replicate = new Replicate({
auth: env.REPLICATE_API_TOKEN || "",
});
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
};
export async function getCompletion(
input: ReplicateLlama2Input,
onStream: ((partialOutput: string[]) => void) | null,
): Promise<CompletionResponse<ReplicateLlama2Output>> {
const start = Date.now();
const { model, stream, ...rest } = input;
try {
const prediction = await replicate.predictions.create({
version: modelIds[model],
input: rest,
});
console.log("stream?", onStream);
const interval = onStream
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
setInterval(async () => {
const partialPrediction = await replicate.predictions.get(prediction.id);
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
}, 500)
: null;
const resp = await replicate.wait(prediction, {});
if (interval) clearInterval(interval);
const timeToComplete = Date.now() - start;
if (resp.error) throw new Error(resp.error as string);
return {
type: "success",
statusCode: 200,
value: resp.output as ReplicateLlama2Output,
timeToComplete,
};
} catch (error: unknown) {
console.error("ERROR IS", error);
return {
type: "error",
message: (error as Error).message,
autoRetry: true,
};
}
}

View File

@@ -1,74 +0,0 @@
import { type ModelProvider } from "../types";
import { getCompletion } from "./getCompletion";
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
type SupportedModel = (typeof supportedModels)[number];
export type ReplicateLlama2Input = {
model: SupportedModel;
prompt: string;
stream?: boolean;
max_length?: number;
temperature?: number;
top_p?: number;
repetition_penalty?: number;
debug?: boolean;
};
export type ReplicateLlama2Output = string[];
export type ReplicateLlama2Provider = ModelProvider<
SupportedModel,
ReplicateLlama2Input,
ReplicateLlama2Output
>;
const modelProvider: ReplicateLlama2Provider = {
name: "OpenAI ChatCompletion",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
getModel: (input) => {
if (supportedModels.includes(input.model)) return input.model;
return null;
},
inputSchema: {
type: "object",
properties: {
model: {
type: "string",
enum: supportedModels as unknown as string[],
},
prompt: {
type: "string",
},
stream: {
type: "boolean",
},
max_length: {
type: "number",
},
temperature: {
type: "number",
},
top_p: {
type: "number",
},
repetition_penalty: {
type: "number",
},
debug: {
type: "boolean",
},
},
required: ["model", "prompt"],
},
shouldStream: (input) => input.stream ?? false,
getCompletion,
};
export default modelProvider;

View File

@@ -2,8 +2,8 @@ import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest";
type ModelProviderModel = {
name?: string;
learnMore?: string;
name: string;
learnMore: string;
};
export type CompletionResponse<T> =

View File

@@ -20,22 +20,25 @@ export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery();
const user = useSession().data;
const authLoading = useSession().status === "loading";
if (user === null) {
if (user === null || authLoading) {
return (
<AppShell title="Experiments">
<Center h="100%">
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
{!authLoading && (
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
)}
</Center>
</AppShell>
);

View File

@@ -98,7 +98,7 @@ export const experimentsRouter = createTRPCRouter({
},
});
const [variant, _, scenario] = await prisma.$transaction([
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
prisma.promptVariant.create({
data: {
experimentId: exp.id,
@@ -109,7 +109,8 @@ export const experimentsRouter = createTRPCRouter({
constructFn: dedent`
/**
* Use Javascript to define an OpenAI chat completion
* (https://platform.openai.com/docs/api-reference/chat/create).
* (https://platform.openai.com/docs/api-reference/chat/create) and
* assign it to the \`prompt\` variable.
*
* You have access to the current scenario in the \`scenario\`
* variable.
@@ -121,7 +122,7 @@ export const experimentsRouter = createTRPCRouter({
messages: [
{
role: "system",
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
},
],
});`,
@@ -133,20 +134,38 @@ export const experimentsRouter = createTRPCRouter({
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "text",
label: "language",
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
text: "This is a test scenario.",
language: "English",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "Spanish",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "German",
},
},
}),
]);
await generateNewCell(variant.id, scenario.id);
await generateNewCell(variant.id, scenario1.id);
await generateNewCell(variant.id, scenario2.id);
await generateNewCell(variant.id, scenario3.id);
return exp;
}),

View File

@@ -284,11 +284,12 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
getRefinedPromptFn: protectedProcedure
getModifiedPromptFn: protectedProcedure
.input(
z.object({
id: z.string(),
instructions: z.string(),
instructions: z.string().optional(),
newModel: z.string().optional(),
}),
)
.mutation(async ({ input, ctx }) => {
@@ -307,7 +308,7 @@ export const promptVariantsRouter = createTRPCRouter({
const promptConstructionFn = await deriveNewConstructFn(
existing,
constructedPrompt.model as SupportedModel,
input.newModel as SupportedModel | undefined,
input.instructions,
);

View File

@@ -1,26 +1,26 @@
/* eslint-disable */
// /* eslint-disable */
import "dotenv/config";
import Replicate from "replicate";
// import "dotenv/config";
// import Replicate from "replicate";
const replicate = new Replicate({
auth: process.env.REPLICATE_API_TOKEN || "",
});
// const replicate = new Replicate({
// auth: process.env.REPLICATE_API_TOKEN || "",
// });
console.log("going to run");
const prediction = await replicate.predictions.create({
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
input: {
prompt: "...",
},
});
// console.log("going to run");
// const prediction = await replicate.predictions.create({
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
// input: {
// prompt: "...",
// },
// });
console.log("waiting");
setInterval(() => {
replicate.predictions.get(prediction.id).then((prediction) => {
console.log(prediction);
});
}, 500);
// const output = await replicate.wait(prediction, {});
// console.log("waiting");
// setInterval(() => {
// replicate.predictions.get(prediction.id).then((prediction) => {
// console.log(prediction.output);
// });
// }, 500);
// // const output = await replicate.wait(prediction, {});
// console.log(output);
// // console.log(output);

View File

@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const provider = modelProviders[prompt.modelProvider];
// @ts-expect-error TODO FIX ASAP
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
if (streamingChannel) {
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
: null;
for (let i = 0; true; i++) {
// @ts-expect-error TODO FIX ASAP
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const inputHash = hashPrompt(prompt);

View File

@@ -70,7 +70,6 @@ export default async function parseConstructFn(
// We've validated the JSON schema so this should be safe
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
// @ts-expect-error TODO FIX ASAP
const model = provider.getModel(input);
if (!model) {
return {
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
return {
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
// @ts-expect-error TODO FIX ASAP
model,
modelInput: input,
};