More work on modelProviders

I think everything that's OpenAI-specific is inside modelProviders at this point, so we can get started adding more providers.
This commit is contained in:
Kyle Corbitt
2023-07-20 18:54:26 -07:00
parent ded6678e97
commit 332a2101c0
21 changed files with 344 additions and 330 deletions

View File

@@ -6,11 +6,11 @@ import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect } from "react";
import { type ChatCompletion } from "openai/resources/chat";
import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
export default function OutputCell({
scenario,
@@ -33,15 +33,17 @@ export default function OutputCell({
if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output";
// if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output";
const [refetchInterval, setRefetchInterval] = useState(0);
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
{ scenarioId: scenario.id, variantId: variant.id },
{ refetchInterval },
);
const provider =
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];
const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation();
const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => {
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
@@ -66,8 +68,7 @@ export default function OutputCell({
const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(cell?.streamingChannel);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
if (!vars) return null;
@@ -86,19 +87,13 @@ export default function OutputCell({
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
}
const response = modelOutput?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message;
if (modelOutput && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string;
try {
parsedArgs = JSON.parse(rawArgs);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
}
const normalizedOutput = modelOutput
? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
: streamedMessage
? provider.normalizeOutput(streamedMessage)
: null;
if (modelOutput && normalizedOutput?.type === "json") {
return (
<VStack
w="100%"
@@ -119,13 +114,7 @@ export default function OutputCell({
}}
wrapLines
>
{stringify(
{
function: message.function_call.name,
args: parsedArgs,
},
{ maxLength: 40 },
)}
{stringify(normalizedOutput.value, { maxLength: 40 })}
</SyntaxHighlighter>
</VStack>
<OutputStats modelOutput={modelOutput} scenario={scenario} />
@@ -133,8 +122,7 @@ export default function OutputCell({
);
}
const contentToDisplay =
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
const contentToDisplay = (normalizedOutput?.type === "text" && normalizedOutput.value) || "";
return (
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">

View File

@@ -50,8 +50,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
const promptRegex = /definePrompt\(/;
if (!promptRegex.test(currentFn)) {
console.log("no prompt");
console.log(currentFn);
toast({
title: "Missing prompt",
description: "Please define the prompt (eg. `definePrompt(...`",