Define refinement actions in the model providers (#87)

* Add descriptions of fields in llama 2 input schema

* Let GPT-4 know when the provider stays the same

* Allow refetching in the event of any errors

* Define refinement actions in model providers

* Fix prettier
This commit is contained in:
arcticfly
2023-07-23 17:37:08 -07:00
committed by GitHub
parent 3dbb06ec00
commit 2b2e0ab8ee
11 changed files with 346 additions and 307 deletions

View File

@@ -21,10 +21,10 @@ import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import CompareFunctions from "./CompareFunctions";
import { CustomInstructionsInput } from "./CustomInstructionsInput";
import { type RefineOptionInfo, refineOptions } from "./refineOptions";
import { RefineOption } from "./RefineOption";
import { RefineAction } from "./RefineAction";
import { isObject, isString } from "lodash-es";
import { type SupportedProvider } from "~/modelProviders/types";
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
export const RefinePromptModal = ({
variant,
@@ -35,13 +35,14 @@ export const RefinePromptModal = ({
}) => {
const utils = api.useContext();
const providerRefineOptions = refineOptions[variant.modelProvider as SupportedProvider];
const refinementActions =
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getModifiedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState<string | undefined>(
const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>(
undefined,
);
@@ -49,15 +50,15 @@ export const RefinePromptModal = ({
async (label?: string) => {
if (!variant.experimentId) return;
const updatedInstructions = label
? (providerRefineOptions[label] as RefineOptionInfo).instructions
? (refinementActions[label] as RefinementAction).instructions
: instructions;
setActiveRefineOptionLabel(label);
setActiveRefineActionLabel(label);
await getModifiedPromptMutateAsync({
id: variant.id,
instructions: updatedInstructions,
});
},
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel],
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel],
);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
@@ -95,18 +96,18 @@ export const RefinePromptModal = ({
<ModalBody maxW="unset">
<VStack spacing={8}>
<VStack spacing={4}>
{Object.keys(providerRefineOptions).length && (
{Object.keys(refinementActions).length && (
<>
<SimpleGrid columns={{ base: 1, md: 2 }} spacing={8}>
{Object.keys(providerRefineOptions).map((label) => (
<RefineOption
{Object.keys(refinementActions).map((label) => (
<RefineAction
key={label}
label={label}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
icon={providerRefineOptions[label]!.icon}
icon={refinementActions[label]!.icon}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
desciption={providerRefineOptions[label]!.description}
activeLabel={activeRefineOptionLabel}
desciption={refinementActions[label]!.description}
activeLabel={activeRefineActionLabel}
onClick={getModifiedPromptFn}
loading={modificationInProgress}
/>