Compare commits
13 Commits
model-prov
...
change-mod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01343efb6a | ||
|
|
c7aaaea426 | ||
|
|
332e7afb0c | ||
|
|
fe08e29f47 | ||
|
|
89ce730e52 | ||
|
|
ad87c1b2eb | ||
|
|
58ddc72cbb | ||
|
|
9978075867 | ||
|
|
372c2512c9 | ||
|
|
1822fe198e | ||
|
|
f06e1db3db | ||
|
|
9314a86857 | ||
|
|
54dcb4a567 |
@@ -73,7 +73,6 @@
|
|||||||
"react-syntax-highlighter": "^15.5.0",
|
"react-syntax-highlighter": "^15.5.0",
|
||||||
"react-textarea-autosize": "^8.5.0",
|
"react-textarea-autosize": "^8.5.0",
|
||||||
"recast": "^0.23.3",
|
"recast": "^0.23.3",
|
||||||
"replicate": "^0.12.3",
|
|
||||||
"socket.io": "^4.7.1",
|
"socket.io": "^4.7.1",
|
||||||
"socket.io-client": "^4.7.1",
|
"socket.io-client": "^4.7.1",
|
||||||
"superjson": "1.12.2",
|
"superjson": "1.12.2",
|
||||||
|
|||||||
8
pnpm-lock.yaml
generated
8
pnpm-lock.yaml
generated
@@ -161,9 +161,6 @@ dependencies:
|
|||||||
recast:
|
recast:
|
||||||
specifier: ^0.23.3
|
specifier: ^0.23.3
|
||||||
version: 0.23.3
|
version: 0.23.3
|
||||||
replicate:
|
|
||||||
specifier: ^0.12.3
|
|
||||||
version: 0.12.3
|
|
||||||
socket.io:
|
socket.io:
|
||||||
specifier: ^4.7.1
|
specifier: ^4.7.1
|
||||||
version: 4.7.1
|
version: 4.7.1
|
||||||
@@ -6991,11 +6988,6 @@ packages:
|
|||||||
functions-have-names: 1.2.3
|
functions-have-names: 1.2.3
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
/replicate@0.12.3:
|
|
||||||
resolution: {integrity: sha512-HVWKPoVhWVTONlWk+lUXmq9Vy2J8MxBJMtDBQq3dA5uq71ZzKTh0xvJfvzW4+VLBjhBeL7tkdua6hZJmKfzAPQ==}
|
|
||||||
engines: {git: '>=2.11.0', node: '>=16.6.0', npm: '>=7.19.0', yarn: '>=1.7.0'}
|
|
||||||
dev: false
|
|
||||||
|
|
||||||
/require-directory@2.1.1:
|
/require-directory@2.1.1:
|
||||||
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
|
||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
import { Textarea, type TextareaProps } from "@chakra-ui/react";
|
||||||
import ResizeTextarea from "react-textarea-autosize";
|
import ResizeTextarea from "react-textarea-autosize";
|
||||||
import React from "react";
|
import React, { useLayoutEffect, useState } from "react";
|
||||||
|
|
||||||
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
export const AutoResizeTextarea: React.ForwardRefRenderFunction<
|
||||||
HTMLTextAreaElement,
|
HTMLTextAreaElement,
|
||||||
TextareaProps & { minRows?: number }
|
TextareaProps & { minRows?: number }
|
||||||
> = (props, ref) => {
|
> = ({ minRows = 1, overflowY = "hidden", ...props }, ref) => {
|
||||||
|
const [isRerendered, setIsRerendered] = useState(false);
|
||||||
|
useLayoutEffect(() => setIsRerendered(true), []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Textarea
|
<Textarea
|
||||||
minH="unset"
|
minH="unset"
|
||||||
overflow="hidden"
|
minRows={minRows}
|
||||||
|
overflowY={isRerendered ? overflowY : "hidden"}
|
||||||
w="100%"
|
w="100%"
|
||||||
resize="none"
|
resize="none"
|
||||||
ref={ref}
|
ref={ref}
|
||||||
minRows={1}
|
|
||||||
transition="height none"
|
transition="height none"
|
||||||
as={ResizeTextarea}
|
as={ResizeTextarea}
|
||||||
{...props}
|
{...props}
|
||||||
|
|||||||
@@ -18,11 +18,9 @@ export const FloatingLabelInput = ({
|
|||||||
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
transform={isFocused || !!value ? "translateY(-50%)" : "translateY(0)"}
|
||||||
fontSize={isFocused || !!value ? "12px" : "16px"}
|
fontSize={isFocused || !!value ? "12px" : "16px"}
|
||||||
transition="all 0.15s"
|
transition="all 0.15s"
|
||||||
zIndex="100"
|
zIndex="5"
|
||||||
bg="white"
|
bg="white"
|
||||||
px={1}
|
px={1}
|
||||||
mt={0}
|
|
||||||
mb={2}
|
|
||||||
lineHeight="1"
|
lineHeight="1"
|
||||||
pointerEvents="none"
|
pointerEvents="none"
|
||||||
color={isFocused ? "blue.500" : "gray.500"}
|
color={isFocused ? "blue.500" : "gray.500"}
|
||||||
|
|||||||
@@ -49,7 +49,11 @@ export default function NewScenarioButton() {
|
|||||||
Add Scenario
|
Add Scenario
|
||||||
</StyledButton>
|
</StyledButton>
|
||||||
<StyledButton onClick={onAutogenerate}>
|
<StyledButton onClick={onAutogenerate}>
|
||||||
<Icon as={autogenerating ? Spinner : BsPlus} boxSize={6} mr={autogenerating ? 1 : 0} />
|
<Icon
|
||||||
|
as={autogenerating ? Spinner : BsPlus}
|
||||||
|
boxSize={autogenerating ? 4 : 6}
|
||||||
|
mr={autogenerating ? 2 : 0}
|
||||||
|
/>
|
||||||
Autogenerate Scenario
|
Autogenerate Scenario
|
||||||
</StyledButton>
|
</StyledButton>
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|||||||
@@ -88,11 +88,9 @@ export default function OutputCell({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const normalizedOutput = modelOutput
|
const normalizedOutput = modelOutput
|
||||||
? // @ts-expect-error TODO FIX ASAP
|
? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
||||||
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
|
||||||
: streamedMessage
|
: streamedMessage
|
||||||
? // @ts-expect-error TODO FIX ASAP
|
? provider.normalizeOutput(streamedMessage)
|
||||||
provider.normalizeOutput(streamedMessage)
|
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
if (modelOutput && normalizedOutput?.type === "json") {
|
if (modelOutput && normalizedOutput?.type === "json") {
|
||||||
|
|||||||
@@ -4,5 +4,5 @@ export const stickyHeaderStyle: SystemStyleObject = {
|
|||||||
position: "sticky",
|
position: "sticky",
|
||||||
top: "0",
|
top: "0",
|
||||||
backgroundColor: "#fff",
|
backgroundColor: "#fff",
|
||||||
zIndex: 1,
|
zIndex: 10,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react";
|
import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react";
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
import DiffViewer, { DiffMethod } from "react-diff-viewer";
|
||||||
import Prism from "prismjs";
|
import Prism from "prismjs";
|
||||||
@@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => {
|
|||||||
const CompareFunctions = ({
|
const CompareFunctions = ({
|
||||||
originalFunction,
|
originalFunction,
|
||||||
newFunction = "",
|
newFunction = "",
|
||||||
|
leftTitle = "Original",
|
||||||
|
rightTitle = "Modified",
|
||||||
|
...props
|
||||||
}: {
|
}: {
|
||||||
originalFunction: string;
|
originalFunction: string;
|
||||||
newFunction?: string;
|
newFunction?: string;
|
||||||
}) => {
|
leftTitle?: string;
|
||||||
|
rightTitle?: string;
|
||||||
|
} & StackProps) => {
|
||||||
const showSplitView = useBreakpointValue(
|
const showSplitView = useBreakpointValue(
|
||||||
{
|
{
|
||||||
base: false,
|
base: false,
|
||||||
@@ -34,22 +39,20 @@ const CompareFunctions = ({
|
|||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack w="full" spacing={5}>
|
<VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}>
|
||||||
<VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto">
|
|
||||||
<DiffViewer
|
<DiffViewer
|
||||||
oldValue={originalFunction}
|
oldValue={originalFunction}
|
||||||
newValue={newFunction || originalFunction}
|
newValue={newFunction || originalFunction}
|
||||||
splitView={showSplitView}
|
splitView={showSplitView}
|
||||||
hideLineNumbers={!showSplitView}
|
hideLineNumbers={!showSplitView}
|
||||||
leftTitle="Original"
|
leftTitle={leftTitle}
|
||||||
rightTitle={newFunction ? "Modified" : "Unmodified"}
|
rightTitle={rightTitle}
|
||||||
disableWordDiff={true}
|
disableWordDiff={true}
|
||||||
compareMethod={DiffMethod.CHARS}
|
compareMethod={DiffMethod.CHARS}
|
||||||
renderContent={highlightSyntax}
|
renderContent={highlightSyntax}
|
||||||
showDiffOnly={false}
|
showDiffOnly={false}
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
</HStack>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({
|
|||||||
minW="unset"
|
minW="unset"
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={() => onSubmit()}
|
onClick={() => onSubmit()}
|
||||||
disabled={!instructions}
|
|
||||||
variant={instructions ? "solid" : "ghost"}
|
variant={instructions ? "solid" : "ghost"}
|
||||||
mr={4}
|
mr={4}
|
||||||
borderRadius="8"
|
borderRadius="8"
|
||||||
|
|||||||
@@ -36,25 +36,25 @@ export const RefinePromptModal = ({
|
|||||||
}) => {
|
}) => {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
|
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
|
||||||
api.promptVariants.getRefinedPromptFn.useMutation();
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
const [instructions, setInstructions] = useState<string>("");
|
const [instructions, setInstructions] = useState<string>("");
|
||||||
|
|
||||||
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<
|
||||||
RefineOptionLabel | undefined
|
RefineOptionLabel | undefined
|
||||||
>(undefined);
|
>(undefined);
|
||||||
|
|
||||||
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(
|
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
|
||||||
async (label?: RefineOptionLabel) => {
|
async (label?: RefineOptionLabel) => {
|
||||||
if (!variant.experimentId) return;
|
if (!variant.experimentId) return;
|
||||||
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
const updatedInstructions = label ? refineOptions[label].instructions : instructions;
|
||||||
setActiveRefineOptionLabel(label);
|
setActiveRefineOptionLabel(label);
|
||||||
await getRefinedPromptMutateAsync({
|
await getModifiedPromptMutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
instructions: updatedInstructions,
|
instructions: updatedInstructions,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
|
||||||
);
|
);
|
||||||
|
|
||||||
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
@@ -75,7 +75,11 @@ export const RefinePromptModal = ({
|
|||||||
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
|
<Modal
|
||||||
|
isOpen
|
||||||
|
onClose={onClose}
|
||||||
|
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||||
|
>
|
||||||
<ModalOverlay />
|
<ModalOverlay />
|
||||||
<ModalContent w={1200}>
|
<ModalContent w={1200}>
|
||||||
<ModalHeader>
|
<ModalHeader>
|
||||||
@@ -93,15 +97,15 @@ export const RefinePromptModal = ({
|
|||||||
label="Convert to function call"
|
label="Convert to function call"
|
||||||
activeLabel={activeRefineOptionLabel}
|
activeLabel={activeRefineOptionLabel}
|
||||||
icon={VscJson}
|
icon={VscJson}
|
||||||
onClick={getRefinedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
loading={refiningInProgress}
|
loading={modificationInProgress}
|
||||||
/>
|
/>
|
||||||
<RefineOption
|
<RefineOption
|
||||||
label="Add chain of thought"
|
label="Add chain of thought"
|
||||||
activeLabel={activeRefineOptionLabel}
|
activeLabel={activeRefineOptionLabel}
|
||||||
icon={TfiThought}
|
icon={TfiThought}
|
||||||
onClick={getRefinedPromptFn}
|
onClick={getModifiedPromptFn}
|
||||||
loading={refiningInProgress}
|
loading={modificationInProgress}
|
||||||
/>
|
/>
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
<HStack>
|
<HStack>
|
||||||
@@ -110,13 +114,14 @@ export const RefinePromptModal = ({
|
|||||||
<CustomInstructionsInput
|
<CustomInstructionsInput
|
||||||
instructions={instructions}
|
instructions={instructions}
|
||||||
setInstructions={setInstructions}
|
setInstructions={setInstructions}
|
||||||
loading={refiningInProgress}
|
loading={modificationInProgress}
|
||||||
onSubmit={getRefinedPromptFn}
|
onSubmit={getModifiedPromptFn}
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
<CompareFunctions
|
<CompareFunctions
|
||||||
originalFunction={variant.constructFn}
|
originalFunction={variant.constructFn}
|
||||||
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
|
||||||
|
maxH="40vh"
|
||||||
/>
|
/>
|
||||||
</VStack>
|
</VStack>
|
||||||
</ModalBody>
|
</ModalBody>
|
||||||
@@ -124,12 +129,10 @@ export const RefinePromptModal = ({
|
|||||||
<ModalFooter>
|
<ModalFooter>
|
||||||
<HStack spacing={4}>
|
<HStack spacing={4}>
|
||||||
<Button
|
<Button
|
||||||
|
colorScheme="blue"
|
||||||
onClick={replaceVariant}
|
onClick={replaceVariant}
|
||||||
minW={24}
|
minW={24}
|
||||||
disabled={replacementInProgress || !refinedPromptFn}
|
isDisabled={replacementInProgress || !refinedPromptFn}
|
||||||
_disabled={{
|
|
||||||
bgColor: "blue.500",
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ export const refineOptions: Record<
|
|||||||
|
|
||||||
This is what a prompt looks like before adding chain of thought:
|
This is what a prompt looks like before adding chain of thought:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -25,11 +25,11 @@ export const refineOptions: Record<
|
|||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
});
|
||||||
|
|
||||||
This is what one looks like after adding chain of thought:
|
This is what one looks like after adding chain of thought:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -42,13 +42,13 @@ export const refineOptions: Record<
|
|||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
});
|
||||||
|
|
||||||
Here's another example:
|
Here's another example:
|
||||||
|
|
||||||
Before:
|
Before:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -78,11 +78,11 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "score_post",
|
name: "score_post",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
After:
|
After:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -115,7 +115,7 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "score_post",
|
name: "score_post",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
Add chain of thought to the original prompt.`,
|
Add chain of thought to the original prompt.`,
|
||||||
},
|
},
|
||||||
@@ -125,7 +125,7 @@ export const refineOptions: Record<
|
|||||||
|
|
||||||
This is what a prompt looks like before adding a function:
|
This is what a prompt looks like before adding a function:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -138,11 +138,11 @@ export const refineOptions: Record<
|
|||||||
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
});
|
||||||
|
|
||||||
This is what one looks like after adding a function:
|
This is what one looks like after adding a function:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-4",
|
model: "gpt-4",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
@@ -172,13 +172,13 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "extract_sentiment",
|
name: "extract_sentiment",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
Here's another example of adding a function:
|
Here's another example of adding a function:
|
||||||
|
|
||||||
Before:
|
Before:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -196,11 +196,11 @@ export const refineOptions: Record<
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
};
|
});
|
||||||
|
|
||||||
After:
|
After:
|
||||||
|
|
||||||
prompt = {
|
definePrompt("openai/ChatCompletion", {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
@@ -230,7 +230,7 @@ export const refineOptions: Record<
|
|||||||
function_call: {
|
function_call: {
|
||||||
name: "score_post",
|
name: "score_post",
|
||||||
},
|
},
|
||||||
};
|
});
|
||||||
|
|
||||||
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -20,36 +20,60 @@ import { ModelStatsCard } from "./ModelStatsCard";
|
|||||||
import { SelectModelSearch } from "./SelectModelSearch";
|
import { SelectModelSearch } from "./SelectModelSearch";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
import { isObject, isString } from "lodash-es";
|
||||||
|
|
||||||
export const SelectModelModal = ({
|
export const SelectModelModal = ({
|
||||||
originalModel,
|
variant,
|
||||||
variantId,
|
|
||||||
onClose,
|
onClose,
|
||||||
}: {
|
}: {
|
||||||
originalModel: SupportedModel;
|
variant: PromptVariant;
|
||||||
variantId: string;
|
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
|
const originalModel = variant.model as SupportedModel;
|
||||||
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
|
||||||
|
const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined);
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
|
|
||||||
const createMutation = api.promptVariants.create.useMutation();
|
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } =
|
||||||
|
api.promptVariants.getModifiedPromptFn.useMutation();
|
||||||
|
|
||||||
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
|
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
if (!experiment?.data?.id) return;
|
if (!experiment) return;
|
||||||
await createMutation.mutateAsync({
|
|
||||||
experimentId: experiment?.data?.id,
|
await getModifiedPromptMutateAsync({
|
||||||
variantId,
|
id: variant.id,
|
||||||
newModel: selectedModel,
|
newModel: selectedModel,
|
||||||
});
|
});
|
||||||
|
setConvertedModel(selectedModel);
|
||||||
|
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
|
||||||
|
|
||||||
|
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
|
||||||
|
|
||||||
|
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
if (
|
||||||
|
!variant.experimentId ||
|
||||||
|
!modifiedPromptFn ||
|
||||||
|
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
|
||||||
|
)
|
||||||
|
return;
|
||||||
|
await replaceVariantMutation.mutateAsync({
|
||||||
|
id: variant.id,
|
||||||
|
constructFn: modifiedPromptFn,
|
||||||
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
onClose();
|
onClose();
|
||||||
}, [createMutation, experiment?.data?.id, variantId, onClose]);
|
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
<Modal
|
||||||
|
isOpen
|
||||||
|
onClose={onClose}
|
||||||
|
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
|
||||||
|
>
|
||||||
<ModalOverlay />
|
<ModalOverlay />
|
||||||
<ModalContent w={1200}>
|
<ModalContent w={1200}>
|
||||||
<ModalHeader>
|
<ModalHeader>
|
||||||
@@ -66,18 +90,36 @@ export const SelectModelModal = ({
|
|||||||
<ModelStatsCard label="New Model" model={selectedModel} />
|
<ModelStatsCard label="New Model" model={selectedModel} />
|
||||||
)}
|
)}
|
||||||
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
|
||||||
|
{isString(modifiedPromptFn) && (
|
||||||
|
<CompareFunctions
|
||||||
|
originalFunction={variant.constructFn}
|
||||||
|
newFunction={modifiedPromptFn}
|
||||||
|
leftTitle={originalModel}
|
||||||
|
rightTitle={convertedModel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
</ModalBody>
|
</ModalBody>
|
||||||
|
|
||||||
<ModalFooter>
|
<ModalFooter>
|
||||||
|
<HStack>
|
||||||
|
<Button
|
||||||
|
colorScheme="gray"
|
||||||
|
onClick={getModifiedPromptFn}
|
||||||
|
minW={24}
|
||||||
|
isDisabled={originalModel === selectedModel || modificationInProgress}
|
||||||
|
>
|
||||||
|
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
|
||||||
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
colorScheme="blue"
|
colorScheme="blue"
|
||||||
onClick={createNewVariant}
|
onClick={replaceVariant}
|
||||||
minW={24}
|
minW={24}
|
||||||
disabled={originalModel === selectedModel}
|
isDisabled={!convertedModel || modificationInProgress || replacementInProgress}
|
||||||
>
|
>
|
||||||
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
|
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
|
||||||
</Button>
|
</Button>
|
||||||
|
</HStack>
|
||||||
</ModalFooter>
|
</ModalFooter>
|
||||||
</ModalContent>
|
</ModalContent>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import { useState } from "react";
|
|||||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||||
import { type SupportedModel } from "~/server/types";
|
|
||||||
|
|
||||||
export default function VariantHeaderMenuButton({
|
export default function VariantHeaderMenuButton({
|
||||||
variant,
|
variant,
|
||||||
@@ -99,11 +98,7 @@ export default function VariantHeaderMenuButton({
|
|||||||
</MenuList>
|
</MenuList>
|
||||||
</Menu>
|
</Menu>
|
||||||
{selectModelModalOpen && (
|
{selectModelModalOpen && (
|
||||||
<SelectModelModal
|
<SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} />
|
||||||
originalModel={variant.model as SupportedModel}
|
|
||||||
variantId={variant.id}
|
|
||||||
onClose={() => setSelectModelModalOpen(false)}
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
{refinePromptModalOpen && (
|
{refinePromptModalOpen && (
|
||||||
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ export const env = createEnv({
|
|||||||
.transform((val) => val.toLowerCase() === "true"),
|
.transform((val) => val.toLowerCase() === "true"),
|
||||||
GITHUB_CLIENT_ID: z.string().min(1),
|
GITHUB_CLIENT_ID: z.string().min(1),
|
||||||
GITHUB_CLIENT_SECRET: z.string().min(1),
|
GITHUB_CLIENT_SECRET: z.string().min(1),
|
||||||
REPLICATE_API_TOKEN: z.string().min(1),
|
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -43,7 +42,6 @@ export const env = createEnv({
|
|||||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||||
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
|
||||||
},
|
},
|
||||||
/**
|
/**
|
||||||
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||||
import replicateLlama2 from "./replicate-llama2";
|
|
||||||
|
|
||||||
const modelProviders = {
|
const modelProviders = {
|
||||||
"openai/ChatCompletion": openaiChatCompletion,
|
"openai/ChatCompletion": openaiChatCompletion,
|
||||||
"replicate/llama2": replicateLlama2,
|
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
export default modelProviders;
|
export default modelProviders;
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
import modelProviderFrontend from "./openai-ChatCompletion/frontend";
|
||||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
|
||||||
|
|
||||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
|
||||||
|
|
||||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||||
// just include them in the default `modelProviders` object because it has some
|
// just include them in the default `modelProviders` object because it has some
|
||||||
// transient dependencies that can only be imported on the server.
|
// transient dependencies that can only be imported on the server.
|
||||||
const modelProvidersFrontend = {
|
const modelProvidersFrontend = {
|
||||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
"openai/ChatCompletion": modelProviderFrontend,
|
||||||
"replicate/llama2": replicateLlama2Frontend,
|
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
export default modelProvidersFrontend;
|
export default modelProvidersFrontend;
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
import { type ReplicateLlama2Provider } from ".";
|
|
||||||
import { type ModelProviderFrontend } from "../types";
|
|
||||||
|
|
||||||
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
|
|
||||||
normalizeOutput: (output) => {
|
|
||||||
return {
|
|
||||||
type: "text",
|
|
||||||
value: output.join(""),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export default modelProviderFrontend;
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import { env } from "~/env.mjs";
|
|
||||||
import { type ReplicateLlama2Input, type ReplicateLlama2Output } from ".";
|
|
||||||
import { type CompletionResponse } from "../types";
|
|
||||||
import Replicate from "replicate";
|
|
||||||
|
|
||||||
const replicate = new Replicate({
|
|
||||||
auth: env.REPLICATE_API_TOKEN || "",
|
|
||||||
});
|
|
||||||
|
|
||||||
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
|
|
||||||
"7b-chat": "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
|
||||||
"13b-chat": "5c785d117c5bcdd1928d5a9acb1ffa6272d6cf13fcb722e90886a0196633f9d3",
|
|
||||||
"70b-chat": "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
|
||||||
};
|
|
||||||
|
|
||||||
export async function getCompletion(
|
|
||||||
input: ReplicateLlama2Input,
|
|
||||||
onStream: ((partialOutput: string[]) => void) | null,
|
|
||||||
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
|
||||||
const start = Date.now();
|
|
||||||
|
|
||||||
const { model, stream, ...rest } = input;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const prediction = await replicate.predictions.create({
|
|
||||||
version: modelIds[model],
|
|
||||||
input: rest,
|
|
||||||
});
|
|
||||||
|
|
||||||
console.log("stream?", onStream);
|
|
||||||
|
|
||||||
const interval = onStream
|
|
||||||
? // eslint-disable-next-line @typescript-eslint/no-misused-promises
|
|
||||||
setInterval(async () => {
|
|
||||||
const partialPrediction = await replicate.predictions.get(prediction.id);
|
|
||||||
|
|
||||||
if (partialPrediction.output) onStream(partialPrediction.output as ReplicateLlama2Output);
|
|
||||||
}, 500)
|
|
||||||
: null;
|
|
||||||
|
|
||||||
const resp = await replicate.wait(prediction, {});
|
|
||||||
if (interval) clearInterval(interval);
|
|
||||||
|
|
||||||
const timeToComplete = Date.now() - start;
|
|
||||||
|
|
||||||
if (resp.error) throw new Error(resp.error as string);
|
|
||||||
|
|
||||||
return {
|
|
||||||
type: "success",
|
|
||||||
statusCode: 200,
|
|
||||||
value: resp.output as ReplicateLlama2Output,
|
|
||||||
timeToComplete,
|
|
||||||
};
|
|
||||||
} catch (error: unknown) {
|
|
||||||
console.error("ERROR IS", error);
|
|
||||||
return {
|
|
||||||
type: "error",
|
|
||||||
message: (error as Error).message,
|
|
||||||
autoRetry: true,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
import { type ModelProvider } from "../types";
|
|
||||||
import { getCompletion } from "./getCompletion";
|
|
||||||
|
|
||||||
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
|
|
||||||
|
|
||||||
type SupportedModel = (typeof supportedModels)[number];
|
|
||||||
|
|
||||||
export type ReplicateLlama2Input = {
|
|
||||||
model: SupportedModel;
|
|
||||||
prompt: string;
|
|
||||||
stream?: boolean;
|
|
||||||
max_length?: number;
|
|
||||||
temperature?: number;
|
|
||||||
top_p?: number;
|
|
||||||
repetition_penalty?: number;
|
|
||||||
debug?: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ReplicateLlama2Output = string[];
|
|
||||||
|
|
||||||
export type ReplicateLlama2Provider = ModelProvider<
|
|
||||||
SupportedModel,
|
|
||||||
ReplicateLlama2Input,
|
|
||||||
ReplicateLlama2Output
|
|
||||||
>;
|
|
||||||
|
|
||||||
const modelProvider: ReplicateLlama2Provider = {
|
|
||||||
name: "OpenAI ChatCompletion",
|
|
||||||
models: {
|
|
||||||
"7b-chat": {},
|
|
||||||
"13b-chat": {},
|
|
||||||
"70b-chat": {},
|
|
||||||
},
|
|
||||||
getModel: (input) => {
|
|
||||||
if (supportedModels.includes(input.model)) return input.model;
|
|
||||||
|
|
||||||
return null;
|
|
||||||
},
|
|
||||||
inputSchema: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
model: {
|
|
||||||
type: "string",
|
|
||||||
enum: supportedModels as unknown as string[],
|
|
||||||
},
|
|
||||||
prompt: {
|
|
||||||
type: "string",
|
|
||||||
},
|
|
||||||
stream: {
|
|
||||||
type: "boolean",
|
|
||||||
},
|
|
||||||
max_length: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
temperature: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
top_p: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
repetition_penalty: {
|
|
||||||
type: "number",
|
|
||||||
},
|
|
||||||
debug: {
|
|
||||||
type: "boolean",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ["model", "prompt"],
|
|
||||||
},
|
|
||||||
shouldStream: (input) => input.stream ?? false,
|
|
||||||
getCompletion,
|
|
||||||
};
|
|
||||||
|
|
||||||
export default modelProvider;
|
|
||||||
@@ -2,8 +2,8 @@ import { type JSONSchema4 } from "json-schema";
|
|||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
|
|
||||||
type ModelProviderModel = {
|
type ModelProviderModel = {
|
||||||
name?: string;
|
name: string;
|
||||||
learnMore?: string;
|
learnMore: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CompletionResponse<T> =
|
export type CompletionResponse<T> =
|
||||||
|
|||||||
@@ -20,11 +20,13 @@ export default function ExperimentsPage() {
|
|||||||
const experiments = api.experiments.list.useQuery();
|
const experiments = api.experiments.list.useQuery();
|
||||||
|
|
||||||
const user = useSession().data;
|
const user = useSession().data;
|
||||||
|
const authLoading = useSession().status === "loading";
|
||||||
|
|
||||||
if (user === null) {
|
if (user === null || authLoading) {
|
||||||
return (
|
return (
|
||||||
<AppShell title="Experiments">
|
<AppShell title="Experiments">
|
||||||
<Center h="100%">
|
<Center h="100%">
|
||||||
|
{!authLoading && (
|
||||||
<Text>
|
<Text>
|
||||||
<Link
|
<Link
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
@@ -36,6 +38,7 @@ export default function ExperimentsPage() {
|
|||||||
</Link>{" "}
|
</Link>{" "}
|
||||||
to view or create new experiments!
|
to view or create new experiments!
|
||||||
</Text>
|
</Text>
|
||||||
|
)}
|
||||||
</Center>
|
</Center>
|
||||||
</AppShell>
|
</AppShell>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const [variant, _, scenario] = await prisma.$transaction([
|
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||||
prisma.promptVariant.create({
|
prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
@@ -109,7 +109,8 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
constructFn: dedent`
|
constructFn: dedent`
|
||||||
/**
|
/**
|
||||||
* Use Javascript to define an OpenAI chat completion
|
* Use Javascript to define an OpenAI chat completion
|
||||||
* (https://platform.openai.com/docs/api-reference/chat/create).
|
* (https://platform.openai.com/docs/api-reference/chat/create) and
|
||||||
|
* assign it to the \`prompt\` variable.
|
||||||
*
|
*
|
||||||
* You have access to the current scenario in the \`scenario\`
|
* You have access to the current scenario in the \`scenario\`
|
||||||
* variable.
|
* variable.
|
||||||
@@ -121,7 +122,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "system",
|
role: "system",
|
||||||
content: \`"Return 'this is output for the scenario "${"$"}{scenario.text}"'\`,
|
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
});`,
|
});`,
|
||||||
@@ -133,20 +134,38 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
prisma.templateVariable.create({
|
prisma.templateVariable.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
label: "text",
|
label: "language",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
prisma.testScenario.create({
|
prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
variableValues: {
|
variableValues: {
|
||||||
text: "This is a test scenario.",
|
language: "English",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
variableValues: {
|
||||||
|
language: "Spanish",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
prisma.testScenario.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
variableValues: {
|
||||||
|
language: "German",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
await generateNewCell(variant.id, scenario.id);
|
await generateNewCell(variant.id, scenario1.id);
|
||||||
|
await generateNewCell(variant.id, scenario2.id);
|
||||||
|
await generateNewCell(variant.id, scenario3.id);
|
||||||
|
|
||||||
return exp;
|
return exp;
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -284,11 +284,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
return updatedPromptVariant;
|
return updatedPromptVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
getRefinedPromptFn: protectedProcedure
|
getModifiedPromptFn: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
instructions: z.string(),
|
instructions: z.string().optional(),
|
||||||
|
newModel: z.string().optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -307,7 +308,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
const promptConstructionFn = await deriveNewConstructFn(
|
const promptConstructionFn = await deriveNewConstructFn(
|
||||||
existing,
|
existing,
|
||||||
constructedPrompt.model as SupportedModel,
|
input.newModel as SupportedModel | undefined,
|
||||||
input.instructions,
|
input.instructions,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
/* eslint-disable */
|
// /* eslint-disable */
|
||||||
|
|
||||||
import "dotenv/config";
|
// import "dotenv/config";
|
||||||
import Replicate from "replicate";
|
// import Replicate from "replicate";
|
||||||
|
|
||||||
const replicate = new Replicate({
|
// const replicate = new Replicate({
|
||||||
auth: process.env.REPLICATE_API_TOKEN || "",
|
// auth: process.env.REPLICATE_API_TOKEN || "",
|
||||||
});
|
// });
|
||||||
|
|
||||||
console.log("going to run");
|
// console.log("going to run");
|
||||||
const prediction = await replicate.predictions.create({
|
// const prediction = await replicate.predictions.create({
|
||||||
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||||
input: {
|
// input: {
|
||||||
prompt: "...",
|
// prompt: "...",
|
||||||
},
|
// },
|
||||||
});
|
// });
|
||||||
|
|
||||||
console.log("waiting");
|
// console.log("waiting");
|
||||||
setInterval(() => {
|
// setInterval(() => {
|
||||||
replicate.predictions.get(prediction.id).then((prediction) => {
|
// replicate.predictions.get(prediction.id).then((prediction) => {
|
||||||
console.log(prediction);
|
// console.log(prediction.output);
|
||||||
});
|
// });
|
||||||
}, 500);
|
// }, 500);
|
||||||
// const output = await replicate.wait(prediction, {});
|
// // const output = await replicate.wait(prediction, {});
|
||||||
|
|
||||||
// console.log(output);
|
// // console.log(output);
|
||||||
|
|||||||
@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const provider = modelProviders[prompt.modelProvider];
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||||
|
|
||||||
if (streamingChannel) {
|
if (streamingChannel) {
|
||||||
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
: null;
|
: null;
|
||||||
|
|
||||||
for (let i = 0; true; i++) {
|
for (let i = 0; true; i++) {
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
|
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
if (response.type === "success") {
|
if (response.type === "success") {
|
||||||
const inputHash = hashPrompt(prompt);
|
const inputHash = hashPrompt(prompt);
|
||||||
|
|||||||
@@ -70,7 +70,6 @@ export default async function parseConstructFn(
|
|||||||
// We've validated the JSON schema so this should be safe
|
// We've validated the JSON schema so this should be safe
|
||||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||||
|
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
const model = provider.getModel(input);
|
const model = provider.getModel(input);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return {
|
return {
|
||||||
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
|
|
||||||
model,
|
model,
|
||||||
modelInput: input,
|
modelInput: input,
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user