feat(frontend): Improve models input UI/UX in settings (#3530)

* Create helper functions

* Add map according to litellm docs

* Create ModelSelector

* Extend model selector

* use autocomplete from nextui

* Improve keys without providers

* Handle models without a provider

* Add verified section and some empty handling

* Add support for default or previously set models

* Update tests

* Lint

* Remove modifier

* Fix typescript error

* Functionality for switching to custom model

* Add verified models

* Respond to resetting to default

* Comment
This commit is contained in:
sp.wack
2024-08-23 20:06:15 +03:00
committed by GitHub
parent b63dec4b2e
commit 07e750f038
19 changed files with 901 additions and 55 deletions

View File

@@ -0,0 +1,193 @@
import React from "react";
import { describe, it, expect, vi } from "vitest";
import { render, screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { ModelSelector } from "./ModelSelector";
describe("ModelSelector", () => {
const models = {
openai: {
separator: "/",
models: ["gpt-4o", "gpt-3.5-turbo"],
},
azure: {
separator: "/",
models: ["ada", "gpt-35-turbo"],
},
vertex_ai: {
separator: "/",
models: ["chat-bison", "chat-bison-32k"],
},
cohere: {
separator: ".",
models: ["command-r-v1:0"],
},
};
it("should display the provider selector", async () => {
const user = userEvent.setup();
const onModelChange = vi.fn();
render(<ModelSelector models={models} onModelChange={onModelChange} />);
const selector = screen.getByLabelText("Provider");
expect(selector).toBeInTheDocument();
await user.click(selector);
expect(screen.getByText("OpenAI")).toBeInTheDocument();
expect(screen.getByText("Azure")).toBeInTheDocument();
expect(screen.getByText("VertexAI")).toBeInTheDocument();
expect(screen.getByText("cohere")).toBeInTheDocument();
});
it("should disable the model selector if the provider is not selected", async () => {
const user = userEvent.setup();
const onModelChange = vi.fn();
render(<ModelSelector models={models} onModelChange={onModelChange} />);
const modelSelector = screen.getByLabelText("Model");
expect(modelSelector).toBeDisabled();
const providerSelector = screen.getByLabelText("Provider");
await user.click(providerSelector);
const vertexAI = screen.getByText("VertexAI");
await user.click(vertexAI);
expect(modelSelector).not.toBeDisabled();
});
it("should display the model selector", async () => {
const user = userEvent.setup();
const onModelChange = vi.fn();
render(<ModelSelector models={models} onModelChange={onModelChange} />);
const providerSelector = screen.getByLabelText("Provider");
await user.click(providerSelector);
const azureProvider = screen.getByText("Azure");
await user.click(azureProvider);
const modelSelector = screen.getByLabelText("Model");
await user.click(modelSelector);
expect(screen.getByText("ada")).toBeInTheDocument();
expect(screen.getByText("gpt-35-turbo")).toBeInTheDocument();
await user.click(providerSelector);
const vertexProvider = screen.getByText("VertexAI");
await user.click(vertexProvider);
await user.click(modelSelector);
expect(screen.getByText("chat-bison")).toBeInTheDocument();
expect(screen.getByText("chat-bison-32k")).toBeInTheDocument();
});
it("should display the actual litellm model ID as the user is making the selections", async () => {
const user = userEvent.setup();
const onModelChange = vi.fn();
render(<ModelSelector models={models} onModelChange={onModelChange} />);
const id = screen.getByTestId("model-id");
const providerSelector = screen.getByLabelText("Provider");
const modelSelector = screen.getByLabelText("Model");
expect(id).toHaveTextContent("No model selected");
await user.click(providerSelector);
await user.click(screen.getByText("Azure"));
expect(id).toHaveTextContent("azure/");
await user.click(modelSelector);
await user.click(screen.getByText("ada"));
expect(id).toHaveTextContent("azure/ada");
await user.click(providerSelector);
await user.click(screen.getByText("cohere"));
expect(id).toHaveTextContent("cohere.");
await user.click(modelSelector);
await user.click(screen.getByText("command-r-v1:0"));
expect(id).toHaveTextContent("cohere.command-r-v1:0");
});
it("should call onModelChange when the model is changed", async () => {
const user = userEvent.setup();
const onModelChange = vi.fn();
render(<ModelSelector models={models} onModelChange={onModelChange} />);
const providerSelector = screen.getByLabelText("Provider");
const modelSelector = screen.getByLabelText("Model");
await user.click(providerSelector);
await user.click(screen.getByText("Azure"));
await user.click(modelSelector);
await user.click(screen.getByText("ada"));
expect(onModelChange).toHaveBeenCalledTimes(1);
expect(onModelChange).toHaveBeenCalledWith("azure/ada");
await user.click(modelSelector);
await user.click(screen.getByText("gpt-35-turbo"));
expect(onModelChange).toHaveBeenCalledTimes(2);
expect(onModelChange).toHaveBeenCalledWith("azure/gpt-35-turbo");
await user.click(providerSelector);
await user.click(screen.getByText("cohere"));
await user.click(modelSelector);
await user.click(screen.getByText("command-r-v1:0"));
expect(onModelChange).toHaveBeenCalledTimes(3);
expect(onModelChange).toHaveBeenCalledWith("cohere.command-r-v1:0");
});
it("should clear the model ID when the provider is cleared", async () => {
const user = userEvent.setup();
const onModelChange = vi.fn();
render(<ModelSelector models={models} onModelChange={onModelChange} />);
const providerSelector = screen.getByLabelText("Provider");
const modelSelector = screen.getByLabelText("Model");
await user.click(providerSelector);
await user.click(screen.getByText("Azure"));
await user.click(modelSelector);
await user.click(screen.getByText("ada"));
expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
await user.clear(providerSelector);
expect(screen.getByTestId("model-id")).toHaveTextContent(
"No model selected",
);
});
it("should have a default value if passed", async () => {
const onModelChange = vi.fn();
render(
<ModelSelector
models={models}
onModelChange={onModelChange}
defaultModel="azure/ada"
/>,
);
expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
expect(screen.getByLabelText("Provider")).toHaveValue("Azure");
expect(screen.getByLabelText("Model")).toHaveValue("ada");
});
it.todo("should disable provider if isDisabled is true");
it.todo(
"should display the verified models in the correct order",
async () => {},
);
});

View File

@@ -0,0 +1,133 @@
import {
Autocomplete,
AutocompleteItem,
AutocompleteSection,
} from "@nextui-org/react";
import React from "react";
import { mapProvider } from "#/utils/mapProvider";
import { VERIFIED_MODELS, VERIFIED_PROVIDERS } from "#/utils/verified-models";
import { extractModelAndProvider } from "#/utils/extractModelAndProvider";
interface ModelSelectorProps {
isDisabled?: boolean;
models: Record<string, { separator: string; models: string[] }>;
onModelChange: (model: string) => void;
defaultModel?: string;
}
export function ModelSelector({
isDisabled,
models,
onModelChange,
defaultModel,
}: ModelSelectorProps) {
const [litellmId, setLitellmId] = React.useState<string | null>(null);
const [selectedProvider, setSelectedProvider] = React.useState<string | null>(
null,
);
const [selectedModel, setSelectedModel] = React.useState<string | null>(null);
React.useEffect(() => {
if (defaultModel) {
// runs when resetting to defaults
const { provider, model } = extractModelAndProvider(defaultModel);
setLitellmId(defaultModel);
setSelectedProvider(provider);
setSelectedModel(model);
}
}, [defaultModel]);
const handleChangeProvider = (provider: string) => {
setSelectedProvider(provider);
setSelectedModel(null);
const separator = models[provider]?.separator || "";
setLitellmId(provider + separator);
};
const handleChangeModel = (model: string) => {
const separator = models[selectedProvider || ""]?.separator || "";
const fullModel = selectedProvider + separator + model;
setLitellmId(fullModel);
onModelChange(fullModel);
setSelectedModel(model);
};
const clear = () => {
setSelectedProvider(null);
setLitellmId(null);
};
return (
<div data-testid="model-selector" className="flex flex-col gap-2">
<span className="text-center italic text-gray-500" data-testid="model-id">
{litellmId?.replace("other", "") || "No model selected"}
</span>
<div className="flex flex-col gap-3">
<Autocomplete
isDisabled={isDisabled}
label="Provider"
placeholder="Select a provider"
isClearable={false}
onSelectionChange={(e) => {
if (e?.toString()) handleChangeProvider(e.toString());
}}
onInputChange={(value) => !value && clear()}
defaultSelectedKey={selectedProvider ?? undefined}
selectedKey={selectedProvider}
>
<AutocompleteSection title="Verified">
{Object.keys(models)
.filter((provider) => VERIFIED_PROVIDERS.includes(provider))
.map((provider) => (
<AutocompleteItem key={provider} value={provider}>
{mapProvider(provider)}
</AutocompleteItem>
))}
</AutocompleteSection>
<AutocompleteSection title="Others">
{Object.keys(models)
.filter((provider) => !VERIFIED_PROVIDERS.includes(provider))
.map((provider) => (
<AutocompleteItem key={provider} value={provider}>
{mapProvider(provider)}
</AutocompleteItem>
))}
</AutocompleteSection>
</Autocomplete>
<Autocomplete
label="Model"
placeholder="Select a model"
onSelectionChange={(e) => {
if (e?.toString()) handleChangeModel(e.toString());
}}
isDisabled={isDisabled || !selectedProvider}
selectedKey={selectedModel}
defaultSelectedKey={selectedModel ?? undefined}
>
<AutocompleteSection title="Verified">
{models[selectedProvider || ""]?.models
.filter((model) => VERIFIED_MODELS.includes(model))
.map((model) => (
<AutocompleteItem key={model} value={model}>
{model}
</AutocompleteItem>
))}
</AutocompleteSection>
<AutocompleteSection title="Others">
{models[selectedProvider || ""]?.models
.filter((model) => !VERIFIED_MODELS.includes(model))
.map((model) => (
<AutocompleteItem key={model} value={model}>
{model}
</AutocompleteItem>
))}
</AutocompleteSection>
</Autocomplete>
</div>
</div>
);
}

View File

@@ -6,6 +6,8 @@ import { Settings } from "#/services/settings";
import SettingsForm from "./SettingsForm";
const onModelChangeMock = vi.fn();
const onCustomModelChangeMock = vi.fn();
const onModelTypeChangeMock = vi.fn();
const onAgentChangeMock = vi.fn();
const onLanguageChangeMock = vi.fn();
const onAPIKeyChangeMock = vi.fn();
@@ -18,7 +20,9 @@ const renderSettingsForm = (settings?: Settings) => {
disabled={false}
settings={
settings || {
LLM_MODEL: "model1",
LLM_MODEL: "gpt-4o",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "agent1",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
@@ -26,10 +30,12 @@ const renderSettingsForm = (settings?: Settings) => {
SECURITY_ANALYZER: "analyzer1",
}
}
models={["model1", "model2", "model3"]}
models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
agents={["agent1", "agent2", "agent3"]}
securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
onModelChange={onModelChangeMock}
onCustomModelChange={onCustomModelChangeMock}
onModelTypeChange={onModelTypeChangeMock}
onAgentChange={onAgentChangeMock}
onLanguageChange={onLanguageChangeMock}
onAPIKeyChange={onAPIKeyChangeMock}
@@ -43,7 +49,8 @@ describe("SettingsForm", () => {
it("should display the first values in the array by default", () => {
renderSettingsForm();
const modelInput = screen.getByRole("combobox", { name: "model" });
const providerInput = screen.getByRole("combobox", { name: "Provider" });
const modelInput = screen.getByRole("combobox", { name: "Model" });
const agentInput = screen.getByRole("combobox", { name: "agent" });
const languageInput = screen.getByRole("combobox", { name: "language" });
const apiKeyInput = screen.getByTestId("apikey");
@@ -52,7 +59,8 @@ describe("SettingsForm", () => {
name: "securityanalyzer",
});
expect(modelInput).toHaveValue("model1");
expect(providerInput).toHaveValue("OpenAI");
expect(modelInput).toHaveValue("gpt-4o");
expect(agentInput).toHaveValue("agent1");
expect(languageInput).toHaveValue("English");
expect(apiKeyInput).toHaveValue("sk-...");
@@ -62,7 +70,9 @@ describe("SettingsForm", () => {
it("should display the existing values if they are present", () => {
renderSettingsForm({
LLM_MODEL: "model2",
LLM_MODEL: "gpt-3.5-turbo",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "agent2",
LANGUAGE: "es",
LLM_API_KEY: "sk-...",
@@ -70,14 +80,16 @@ describe("SettingsForm", () => {
SECURITY_ANALYZER: "analyzer2",
});
const modelInput = screen.getByRole("combobox", { name: "model" });
const providerInput = screen.getByRole("combobox", { name: "Provider" });
const modelInput = screen.getByRole("combobox", { name: "Model" });
const agentInput = screen.getByRole("combobox", { name: "agent" });
const languageInput = screen.getByRole("combobox", { name: "language" });
const securityAnalyzerInput = screen.getByRole("combobox", {
name: "securityanalyzer",
});
expect(modelInput).toHaveValue("model2");
expect(providerInput).toHaveValue("OpenAI");
expect(modelInput).toHaveValue("gpt-3.5-turbo");
expect(agentInput).toHaveValue("agent2");
expect(languageInput).toHaveValue("Español");
expect(securityAnalyzerInput).toHaveValue("analyzer2");
@@ -87,18 +99,22 @@ describe("SettingsForm", () => {
renderWithProviders(
<SettingsForm
settings={{
LLM_MODEL: "model1",
LLM_MODEL: "gpt-4o",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "agent1",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
CONFIRMATION_MODE: true,
SECURITY_ANALYZER: "analyzer1",
}}
models={["model1", "model2", "model3"]}
models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
agents={["agent1", "agent2", "agent3"]}
securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
disabled
onModelChange={onModelChangeMock}
onCustomModelChange={onCustomModelChangeMock}
onModelTypeChange={onModelTypeChangeMock}
onAgentChange={onAgentChangeMock}
onLanguageChange={onLanguageChangeMock}
onAPIKeyChange={onAPIKeyChangeMock}
@@ -106,7 +122,9 @@ describe("SettingsForm", () => {
onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock}
/>,
);
const modelInput = screen.getByRole("combobox", { name: "model" });
const providerInput = screen.getByRole("combobox", { name: "Provider" });
const modelInput = screen.getByRole("combobox", { name: "Model" });
const agentInput = screen.getByRole("combobox", { name: "agent" });
const languageInput = screen.getByRole("combobox", { name: "language" });
const confirmationModeInput = screen.getByTestId("confirmationmode");
@@ -114,6 +132,7 @@ describe("SettingsForm", () => {
name: "securityanalyzer",
});
expect(providerInput).toBeDisabled();
expect(modelInput).toBeDisabled();
expect(agentInput).toBeDisabled();
expect(languageInput).toBeDisabled();
@@ -122,22 +141,6 @@ describe("SettingsForm", () => {
});
describe("onChange handlers", () => {
it("should call the onModelChange handler when the model changes", async () => {
renderSettingsForm();
const modelInput = screen.getByRole("combobox", { name: "model" });
await act(async () => {
await userEvent.click(modelInput);
});
const model3 = screen.getByText("model3");
await act(async () => {
await userEvent.click(model3);
});
expect(onModelChangeMock).toHaveBeenCalledWith("model3");
});
it("should call the onAgentChange handler when the agent changes", async () => {
const user = userEvent.setup();
renderSettingsForm();
@@ -182,4 +185,76 @@ describe("SettingsForm", () => {
expect(onAPIKeyChangeMock).toHaveBeenCalledWith("sk-...x");
});
});
describe("Setting a custom LLM model", () => {
it("should display the fetched models by default", () => {
renderSettingsForm();
const modelSelector = screen.getByTestId("model-selector");
expect(modelSelector).toBeInTheDocument();
const customModelInput = screen.queryByTestId("custom-model-input");
expect(customModelInput).not.toBeInTheDocument();
});
it("should switch to the custom model input when the custom model toggle is clicked", async () => {
const user = userEvent.setup();
renderSettingsForm();
const customModelToggle = screen.getByTestId("custom-model-toggle");
await user.click(customModelToggle);
const modelSelector = screen.queryByTestId("model-selector");
expect(modelSelector).not.toBeInTheDocument();
const customModelInput = screen.getByTestId("custom-model-input");
expect(customModelInput).toBeInTheDocument();
});
it("should call the onCustomModelChange handler when the custom model input changes", async () => {
const user = userEvent.setup();
renderSettingsForm();
const customModelToggle = screen.getByTestId("custom-model-toggle");
await user.click(customModelToggle);
const customModelInput = screen.getByTestId("custom-model-input");
await userEvent.type(customModelInput, "my/custom-model");
expect(onCustomModelChangeMock).toHaveBeenCalledWith("my/custom-model");
expect(onModelTypeChangeMock).toHaveBeenCalledWith("custom");
});
it("should have custom model switched if using custom model", () => {
renderWithProviders(
<SettingsForm
settings={{
LLM_MODEL: "gpt-4o",
CUSTOM_LLM_MODEL: "CUSTOM_MODEL",
USING_CUSTOM_MODEL: true,
AGENT: "agent1",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
CONFIRMATION_MODE: true,
SECURITY_ANALYZER: "analyzer1",
}}
models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
agents={["agent1", "agent2", "agent3"]}
securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
disabled
onModelChange={onModelChangeMock}
onCustomModelChange={onCustomModelChangeMock}
onModelTypeChange={onModelTypeChangeMock}
onAgentChange={onAgentChangeMock}
onLanguageChange={onLanguageChangeMock}
onAPIKeyChange={onAPIKeyChangeMock}
onConfirmationModeChange={onConfirmationModeChangeMock}
onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock}
/>,
);
const customModelToggle = screen.getByTestId("custom-model-toggle");
expect(customModelToggle).toHaveAttribute("aria-checked", "true");
});
});
});

View File

@@ -6,6 +6,8 @@ import { AvailableLanguages } from "../../../i18n";
import { I18nKey } from "../../../i18n/declaration";
import { AutocompleteCombobox } from "./AutocompleteCombobox";
import { Settings } from "#/services/settings";
import { organizeModelsAndProviders } from "#/utils/organizeModelsAndProviders";
import { ModelSelector } from "./ModelSelector";
interface SettingsFormProps {
settings: Settings;
@@ -15,6 +17,8 @@ interface SettingsFormProps {
disabled: boolean;
onModelChange: (model: string) => void;
onCustomModelChange: (model: string) => void;
onModelTypeChange: (type: "custom" | "default") => void;
onAPIKeyChange: (apiKey: string) => void;
onAgentChange: (agent: string) => void;
onLanguageChange: (language: string) => void;
@@ -29,6 +33,8 @@ function SettingsForm({
securityAnalyzers,
disabled,
onModelChange,
onCustomModelChange,
onModelTypeChange,
onAPIKeyChange,
onAgentChange,
onLanguageChange,
@@ -38,20 +44,46 @@ function SettingsForm({
const { t } = useTranslation();
const { isOpen: isVisible, onOpenChange: onVisibleChange } = useDisclosure();
const [isAgentSelectEnabled, setIsAgentSelectEnabled] = React.useState(false);
const [usingCustomModel, setUsingCustomModel] = React.useState(
settings.USING_CUSTOM_MODEL,
);
const changeModelType = (type: "custom" | "default") => {
if (type === "custom") {
setUsingCustomModel(true);
onModelTypeChange("custom");
} else {
setUsingCustomModel(false);
onModelTypeChange("default");
}
};
return (
<>
<AutocompleteCombobox
ariaLabel="model"
items={models.map((model) => ({ value: model, label: model }))}
defaultKey={settings.LLM_MODEL}
onChange={(e) => {
onModelChange(e);
}}
tooltip={t(I18nKey.SETTINGS$MODEL_TOOLTIP)}
allowCustomValue // user can type in a custom LLM model that is not in the list
disabled={disabled}
/>
<Switch
data-testid="custom-model-toggle"
aria-checked={usingCustomModel}
isSelected={usingCustomModel}
onValueChange={(value) => changeModelType(value ? "custom" : "default")}
>
Use custom model
</Switch>
{usingCustomModel && (
<Input
data-testid="custom-model-input"
label="Custom Model"
onValueChange={onCustomModelChange}
defaultValue={settings.CUSTOM_LLM_MODEL}
/>
)}
{!usingCustomModel && (
<ModelSelector
isDisabled={disabled}
models={organizeModelsAndProviders(models)}
onModelChange={onModelChange}
defaultModel={settings.LLM_MODEL}
/>
)}
<Input
label="API Key"
isDisabled={disabled}

View File

@@ -24,6 +24,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
...(await importOriginal<typeof import("#/services/settings")>()),
getSettings: vi.fn().mockReturnValue({
LLM_MODEL: "gpt-4o",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
@@ -32,6 +34,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
}),
getDefaultSettings: vi.fn().mockReturnValue({
LLM_MODEL: "gpt-4o",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "",
@@ -46,7 +50,14 @@ vi.mock("#/services/options", async (importOriginal) => ({
...(await importOriginal<typeof import("#/services/options")>()),
fetchModels: vi
.fn()
.mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])),
.mockResolvedValue(
Promise.resolve([
"gpt-4o",
"gpt-3.5-turbo",
"azure/ada",
"cohere.command-r-v1:0",
]),
),
fetchAgents: vi
.fn()
.mockResolvedValue(Promise.resolve(["agent1", "agent2", "agent3"])),
@@ -104,6 +115,8 @@ describe("SettingsModal", () => {
describe("onHandleSave", () => {
const initialSettings: Settings = {
LLM_MODEL: "gpt-4o",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
@@ -122,17 +135,22 @@ describe("SettingsModal", () => {
await assertModelsAndAgentsFetched();
const saveButton = screen.getByRole("button", { name: /save/i });
const modelInput = screen.getByRole("combobox", { name: "model" });
const providerInput = screen.getByRole("combobox", { name: "Provider" });
const modelInput = screen.getByRole("combobox", { name: "Model" });
await user.click(providerInput);
const azure = screen.getByText("Azure");
await user.click(azure);
await user.click(modelInput);
const model3 = screen.getByText("model3");
const model3 = screen.getByText("ada");
await user.click(model3);
await user.click(saveButton);
expect(saveSettings).toHaveBeenCalledWith({
...initialSettings,
LLM_MODEL: "model3",
LLM_MODEL: "azure/ada",
});
});
@@ -146,12 +164,17 @@ describe("SettingsModal", () => {
);
const saveButton = screen.getByRole("button", { name: /save/i });
const modelInput = screen.getByRole("combobox", { name: "model" });
const providerInput = screen.getByRole("combobox", { name: "Provider" });
const modelInput = screen.getByRole("combobox", { name: "Model" });
await user.click(providerInput);
const openai = screen.getByText("OpenAI");
await user.click(openai);
await user.click(modelInput);
const model3 = screen.getByText("model3");
const model3 = screen.getByText("gpt-3.5-turbo");
await user.click(model3);
await user.click(saveButton);
expect(startNewSessionSpy).toHaveBeenCalled();
@@ -167,12 +190,17 @@ describe("SettingsModal", () => {
);
const saveButton = screen.getByRole("button", { name: /save/i });
const modelInput = screen.getByRole("combobox", { name: "model" });
const providerInput = screen.getByRole("combobox", { name: "Provider" });
const modelInput = screen.getByRole("combobox", { name: "Model" });
await user.click(providerInput);
const cohere = screen.getByText("cohere");
await user.click(cohere);
await user.click(modelInput);
const model3 = screen.getByText("model3");
const model3 = screen.getByText("command-r-v1:0");
await user.click(model3);
await user.click(saveButton);
expect(toastSpy).toHaveBeenCalledTimes(4);
@@ -213,12 +241,17 @@ describe("SettingsModal", () => {
});
const saveButton = screen.getByRole("button", { name: /save/i });
const modelInput = screen.getByRole("combobox", { name: "model" });
const providerInput = screen.getByRole("combobox", { name: "Provider" });
const modelInput = screen.getByRole("combobox", { name: "Model" });
await user.click(providerInput);
const cohere = screen.getByText("cohere");
await user.click(cohere);
await user.click(modelInput);
const model3 = screen.getByText("model3");
const model3 = screen.getByText("command-r-v1:0");
await user.click(model3);
await user.click(saveButton);
expect(onOpenChangeMock).toHaveBeenCalledWith(false);

View File

@@ -63,8 +63,10 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
React.useEffect(() => {
(async () => {
try {
setModels(await fetchModels());
setAgents(await fetchAgents());
const fetchedModels = await fetchModels();
const fetchedAgents = await fetchAgents();
setModels(fetchedModels);
setAgents(fetchedAgents);
setSecurityAnalyzers(await fetchSecurityAnalyzers());
} catch (error) {
toast.error("settings", t(I18nKey.CONFIGURATION$ERROR_FETCH_MODELS));
@@ -81,6 +83,20 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
}));
};
const handleCustomModelChange = (model: string) => {
setSettings((prev) => ({
...prev,
CUSTOM_LLM_MODEL: model,
}));
};
const handleModelTypeChange = (type: "custom" | "default") => {
setSettings((prev) => ({
...prev,
USING_CUSTOM_MODEL: type === "custom",
}));
};
const handleAgentChange = (agent: string) => {
setSettings((prev) => ({ ...prev, AGENT: agent }));
};
@@ -189,6 +205,8 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
agents={agents}
securityAnalyzers={securityAnalyzers}
onModelChange={handleModelChange}
onCustomModelChange={handleCustomModelChange}
onModelTypeChange={handleModelTypeChange}
onAgentChange={handleAgentChange}
onLanguageChange={handleLanguageChange}
onAPIKeyChange={handleAPIKeyChange}

View File

@@ -11,9 +11,16 @@ const setupSpy = vi.spyOn(Session, "_setupSocket").mockImplementation(() => {
});
describe("startNewSession", () => {
afterEach(() => {
sendSpy.mockClear();
setupSpy.mockClear();
});
it("Should start a new session with the current settings", () => {
const settings: Settings = {
LLM_MODEL: "llm_value",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "sk-...",
@@ -32,4 +39,33 @@ describe("startNewSession", () => {
expect(setupSpy).toHaveBeenCalledTimes(1);
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
it("should start with the custom llm if set", () => {
const settings: Settings = {
LLM_MODEL: "llm_value",
CUSTOM_LLM_MODEL: "custom_llm_value",
USING_CUSTOM_MODEL: true,
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "sk-...",
CONFIRMATION_MODE: true,
SECURITY_ANALYZER: "analyzer",
};
const event = {
action: ActionType.INIT,
args: settings,
};
saveSettings(settings);
Session.startNewSession();
expect(setupSpy).toHaveBeenCalledTimes(1);
expect(sendSpy).toHaveBeenCalledWith(
JSON.stringify({
...event,
args: { ...settings, LLM_MODEL: "custom_llm_value" },
}),
);
});
});

View File

@@ -46,7 +46,15 @@ class Session {
private static _initializeAgent = () => {
const settings = getSettings();
const event = { action: ActionType.INIT, args: settings };
const event = {
action: ActionType.INIT,
args: {
...settings,
LLM_MODEL: settings.USING_CUSTOM_MODEL
? settings.CUSTOM_LLM_MODEL
: settings.LLM_MODEL,
},
};
const eventString = JSON.stringify(event);
Session.send(eventString);
};

View File

@@ -18,6 +18,8 @@ describe("getSettings", () => {
it("should get the stored settings", () => {
(localStorage.getItem as Mock)
.mockReturnValueOnce("llm_value")
.mockReturnValueOnce("custom_llm_value")
.mockReturnValueOnce("true")
.mockReturnValueOnce("agent_value")
.mockReturnValueOnce("language_value")
.mockReturnValueOnce("api_key")
@@ -28,6 +30,8 @@ describe("getSettings", () => {
expect(settings).toEqual({
LLM_MODEL: "llm_value",
CUSTOM_LLM_MODEL: "custom_llm_value",
USING_CUSTOM_MODEL: true,
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "api_key",
@@ -43,12 +47,16 @@ describe("getSettings", () => {
.mockReturnValueOnce(null)
.mockReturnValueOnce(null)
.mockReturnValueOnce(null)
.mockReturnValueOnce(null)
.mockReturnValueOnce(null)
.mockReturnValueOnce(null);
const settings = getSettings();
expect(settings).toEqual({
LLM_MODEL: DEFAULT_SETTINGS.LLM_MODEL,
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
AGENT: DEFAULT_SETTINGS.AGENT,
LANGUAGE: DEFAULT_SETTINGS.LANGUAGE,
LLM_API_KEY: "",
@@ -62,6 +70,8 @@ describe("saveSettings", () => {
it("should save the settings", () => {
const settings: Settings = {
LLM_MODEL: "llm_value",
CUSTOM_LLM_MODEL: "custom_llm_value",
USING_CUSTOM_MODEL: true,
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "some_key",
@@ -72,6 +82,14 @@ describe("saveSettings", () => {
saveSettings(settings);
expect(localStorage.setItem).toHaveBeenCalledWith("LLM_MODEL", "llm_value");
expect(localStorage.setItem).toHaveBeenCalledWith(
"CUSTOM_LLM_MODEL",
"custom_llm_value",
);
expect(localStorage.setItem).toHaveBeenCalledWith(
"USING_CUSTOM_MODEL",
"true",
);
expect(localStorage.setItem).toHaveBeenCalledWith("AGENT", "agent_value");
expect(localStorage.setItem).toHaveBeenCalledWith(
"LANGUAGE",
@@ -122,6 +140,8 @@ describe("getSettingsDifference", () => {
beforeEach(() => {
(localStorage.getItem as Mock)
.mockReturnValueOnce("llm_value")
.mockReturnValueOnce("custom_llm_value")
.mockReturnValueOnce("false")
.mockReturnValueOnce("agent_value")
.mockReturnValueOnce("language_value");
});
@@ -129,6 +149,8 @@ describe("getSettingsDifference", () => {
it("should return updated settings", () => {
const settings = {
LLM_MODEL: "new_llm_value",
CUSTOM_LLM_MODEL: "custom_llm_value",
USING_CUSTOM_MODEL: true,
AGENT: "new_agent_value",
LANGUAGE: "language_value",
};
@@ -136,6 +158,7 @@ describe("getSettingsDifference", () => {
const updatedSettings = getSettingsDifference(settings);
expect(updatedSettings).toEqual({
USING_CUSTOM_MODEL: true,
LLM_MODEL: "new_llm_value",
AGENT: "new_agent_value",
});

View File

@@ -2,6 +2,8 @@ const LATEST_SETTINGS_VERSION = 1;
export type Settings = {
LLM_MODEL: string;
CUSTOM_LLM_MODEL: string;
USING_CUSTOM_MODEL: boolean;
AGENT: string;
LANGUAGE: string;
LLM_API_KEY: string;
@@ -12,7 +14,9 @@ export type Settings = {
type SettingsInput = Settings[keyof Settings];
export const DEFAULT_SETTINGS: Settings = {
LLM_MODEL: "gpt-4o",
LLM_MODEL: "openai/gpt-4o",
CUSTOM_LLM_MODEL: "",
USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "",
@@ -54,6 +58,9 @@ export const getDefaultSettings = (): Settings => DEFAULT_SETTINGS;
*/
export const getSettings = (): Settings => {
const model = localStorage.getItem("LLM_MODEL");
const customModel = localStorage.getItem("CUSTOM_LLM_MODEL");
const usingCustomModel =
localStorage.getItem("USING_CUSTOM_MODEL") === "true";
const agent = localStorage.getItem("AGENT");
const language = localStorage.getItem("LANGUAGE");
const apiKey = localStorage.getItem("LLM_API_KEY");
@@ -62,6 +69,8 @@ export const getSettings = (): Settings => {
return {
LLM_MODEL: model || DEFAULT_SETTINGS.LLM_MODEL,
CUSTOM_LLM_MODEL: customModel || DEFAULT_SETTINGS.CUSTOM_LLM_MODEL,
USING_CUSTOM_MODEL: usingCustomModel || DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
AGENT: agent || DEFAULT_SETTINGS.AGENT,
LANGUAGE: language || DEFAULT_SETTINGS.LANGUAGE,
LLM_API_KEY: apiKey || DEFAULT_SETTINGS.LLM_API_KEY,

View File

@@ -0,0 +1,62 @@
import { describe, it, expect } from "vitest";
import { extractModelAndProvider } from "./extractModelAndProvider";
describe("extractModelAndProvider", () => {
it("should work", () => {
expect(extractModelAndProvider("azure/ada")).toEqual({
provider: "azure",
model: "ada",
separator: "/",
});
expect(
extractModelAndProvider("azure/standard/1024-x-1024/dall-e-2"),
).toEqual({
provider: "azure",
model: "standard/1024-x-1024/dall-e-2",
separator: "/",
});
expect(extractModelAndProvider("vertex_ai_beta/chat-bison")).toEqual({
provider: "vertex_ai_beta",
model: "chat-bison",
separator: "/",
});
expect(extractModelAndProvider("cohere.command-r-v1:0")).toEqual({
provider: "cohere",
model: "command-r-v1:0",
separator: ".",
});
expect(
extractModelAndProvider(
"cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
),
).toEqual({
provider: "cloudflare",
model: "@cf/mistral/mistral-7b-instruct-v0.1",
separator: "/",
});
expect(extractModelAndProvider("together-ai-21.1b-41b")).toEqual({
provider: "",
model: "together-ai-21.1b-41b",
separator: "",
});
});
it("should add provider for popular models", () => {
expect(extractModelAndProvider("gpt-3.5-turbo")).toEqual({
provider: "openai",
model: "gpt-3.5-turbo",
separator: "/",
});
expect(extractModelAndProvider("gpt-4o")).toEqual({
provider: "openai",
model: "gpt-4o",
separator: "/",
});
});
});

View File

@@ -0,0 +1,49 @@
import { isNumber } from "./isNumber";
import { VERIFIED_OPENAI_MODELS } from "./verified-models";
/**
* Checks if the split array is actually a version number.
* @param split The split array of the model string
* @returns Boolean indicating if the split is actually a version number
*
* @example
* const split = ["gpt-3", "5-turbo"] // incorrectly split from "gpt-3.5-turbo"
* splitIsActuallyVersion(split) // returns true
*/
const splitIsActuallyVersion = (split: string[]) =>
split[1] && split[1][0] && isNumber(split[1][0]);
/**
* Given a model string, extract the provider and model name. Currently the supported separators are "/" and "."
* @param model The model string
* @returns An object containing the provider, model name, and separator
*
* @example
* extractModelAndProvider("azure/ada")
* // returns { provider: "azure", model: "ada", separator: "/" }
*
* extractModelAndProvider("cohere.command-r-v1:0")
* // returns { provider: "cohere", model: "command-r-v1:0", separator: "." }
*/
export const extractModelAndProvider = (model: string) => {
let separator = "/";
let split = model.split(separator);
if (split.length === 1) {
// no "/" separator found, try with "."
separator = ".";
split = model.split(separator);
if (splitIsActuallyVersion(split)) {
split = [split.join(separator)]; // undo the split
}
}
if (split.length === 1) {
// no "/" or "." separator found
if (VERIFIED_OPENAI_MODELS.includes(split[0])) {
return { provider: "openai", model: split[0], separator: "/" };
}
// return as model only
return { provider: "", model, separator: "" };
}
const [provider, ...modelId] = split;
return { provider, model: modelId.join(separator), separator };
};

View File

@@ -0,0 +1,9 @@
import { test, expect } from "vitest";
import { isNumber } from "./isNumber";
test("isNumber", () => {
expect(isNumber(1)).toBe(true);
expect(isNumber(0)).toBe(true);
expect(isNumber("3")).toBe(true);
expect(isNumber("0")).toBe(true);
});

View File

@@ -0,0 +1,2 @@
export const isNumber = (value: string | number): boolean =>
!Number.isNaN(Number(value));

View File

@@ -0,0 +1,27 @@
import { test, expect } from "vitest";
import { mapProvider } from "./mapProvider";
test("mapProvider", () => {
expect(mapProvider("azure")).toBe("Azure");
expect(mapProvider("azure_ai")).toBe("Azure AI Studio");
expect(mapProvider("vertex_ai")).toBe("VertexAI");
expect(mapProvider("palm")).toBe("PaLM");
expect(mapProvider("gemini")).toBe("Gemini");
expect(mapProvider("anthropic")).toBe("Anthropic");
expect(mapProvider("sagemaker")).toBe("AWS SageMaker");
expect(mapProvider("bedrock")).toBe("AWS Bedrock");
expect(mapProvider("mistral")).toBe("Mistral AI");
expect(mapProvider("anyscale")).toBe("Anyscale");
expect(mapProvider("databricks")).toBe("Databricks");
expect(mapProvider("ollama")).toBe("Ollama");
expect(mapProvider("perlexity")).toBe("Perplexity AI");
expect(mapProvider("friendliai")).toBe("FriendliAI");
expect(mapProvider("groq")).toBe("Groq");
expect(mapProvider("fireworks_ai")).toBe("Fireworks AI");
expect(mapProvider("cloudflare")).toBe("Cloudflare Workers AI");
expect(mapProvider("deepinfra")).toBe("DeepInfra");
expect(mapProvider("ai21")).toBe("AI21");
expect(mapProvider("replicate")).toBe("Replicate");
expect(mapProvider("voyage")).toBe("Voyage AI");
expect(mapProvider("openrouter")).toBe("OpenRouter");
});

View File

@@ -0,0 +1,30 @@
export const MAP_PROVIDER = {
openai: "OpenAI",
azure: "Azure",
azure_ai: "Azure AI Studio",
vertex_ai: "VertexAI",
palm: "PaLM",
gemini: "Gemini",
anthropic: "Anthropic",
sagemaker: "AWS SageMaker",
bedrock: "AWS Bedrock",
mistral: "Mistral AI",
anyscale: "Anyscale",
databricks: "Databricks",
ollama: "Ollama",
perlexity: "Perplexity AI",
friendliai: "FriendliAI",
groq: "Groq",
fireworks_ai: "Fireworks AI",
cloudflare: "Cloudflare Workers AI",
deepinfra: "DeepInfra",
ai21: "AI21",
replicate: "Replicate",
voyage: "Voyage AI",
openrouter: "OpenRouter",
};
export const mapProvider = (provider: string) =>
Object.keys(MAP_PROVIDER).includes(provider)
? MAP_PROVIDER[provider as keyof typeof MAP_PROVIDER]
: provider;

View File

@@ -0,0 +1,51 @@
import { test } from "vitest";
import { organizeModelsAndProviders } from "./organizeModelsAndProviders";
test("organizeModelsAndProviders", () => {
const models = [
"azure/ada",
"azure/gpt-35-turbo",
"azure/gpt-3-turbo",
"azure/standard/1024-x-1024/dall-e-2",
"vertex_ai_beta/chat-bison",
"vertex_ai_beta/chat-bison-32k",
"sagemaker/meta-textgeneration-llama-2-13b",
"cohere.command-r-v1:0",
"cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
"gpt-4o",
"together-ai-21.1b-41b",
"gpt-3.5-turbo",
];
const object = organizeModelsAndProviders(models);
expect(object).toEqual({
azure: {
separator: "/",
models: [
"ada",
"gpt-35-turbo",
"gpt-3-turbo",
"standard/1024-x-1024/dall-e-2",
],
},
vertex_ai_beta: {
separator: "/",
models: ["chat-bison", "chat-bison-32k"],
},
sagemaker: { separator: "/", models: ["meta-textgeneration-llama-2-13b"] },
cohere: { separator: ".", models: ["command-r-v1:0"] },
cloudflare: {
separator: "/",
models: ["@cf/mistral/mistral-7b-instruct-v0.1"],
},
openai: {
separator: "/",
models: ["gpt-4o", "gpt-3.5-turbo"],
},
other: {
separator: "",
models: ["together-ai-21.1b-41b"],
},
});
});

View File

@@ -0,0 +1,42 @@
import { extractModelAndProvider } from "./extractModelAndProvider";
/**
* Given a list of models, organize them by provider
* @param models The list of models
* @returns An object containing the provider and models
*
* @example
* const models = [
* "azure/ada",
* "azure/gpt-35-turbo",
* "cohere.command-r-v1:0",
* ];
*
* organizeModelsAndProviders(models);
* // returns {
* // azure: {
* // separator: "/",
* // models: ["ada", "gpt-35-turbo"],
* // },
* // cohere: {
* // separator: ".",
* // models: ["command-r-v1:0"],
* // },
* // }
*/
export const organizeModelsAndProviders = (models: string[]) => {
const object: Record<string, { separator: string; models: string[] }> = {};
models.forEach((model) => {
const {
separator,
provider,
model: modelId,
} = extractModelAndProvider(model);
const key = provider || "other";
if (!object[key]) {
object[key] = { separator, models: [] };
}
object[key].models.push(modelId);
});
return object;
};

View File

@@ -0,0 +1,14 @@
// Here are the list of verified models and providers that we know work well with OpenHands.
export const VERIFIED_PROVIDERS = ["openai", "azure", "anthropic"];
export const VERIFIED_MODELS = ["gpt-4o", "claude-3-5-sonnet-20240620-v1:0"];
// LiteLLM does not return OpenAI models with the provider, so we list them here to set them ourselves for consistency
// (e.g., they return `gpt-4o` instead of `openai/gpt-4o`)
export const VERIFIED_OPENAI_MODELS = [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
];