mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2024-08-29 01:18:33 +03:00
feat(frontend): Improve models input UI/UX in settings (#3530)
* Create helper functions * Add map according to litellm docs * Create ModelSelector * Extend model selector * use autocomplete from nextui * Improve keys without providers * Handle models without a provider * Add verified section and some empty handling * Add support for default or previously set models * Update tests * Lint * Remove modifier * Fix typescript error * Functionality for switching to custom model * Add verified models * Respond to resetting to default * Comment
This commit is contained in:
193
frontend/src/components/modals/settings/ModelSelector.test.tsx
Normal file
193
frontend/src/components/modals/settings/ModelSelector.test.tsx
Normal file
@@ -0,0 +1,193 @@
|
||||
import React from "react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { ModelSelector } from "./ModelSelector";
|
||||
|
||||
describe("ModelSelector", () => {
|
||||
const models = {
|
||||
openai: {
|
||||
separator: "/",
|
||||
models: ["gpt-4o", "gpt-3.5-turbo"],
|
||||
},
|
||||
azure: {
|
||||
separator: "/",
|
||||
models: ["ada", "gpt-35-turbo"],
|
||||
},
|
||||
vertex_ai: {
|
||||
separator: "/",
|
||||
models: ["chat-bison", "chat-bison-32k"],
|
||||
},
|
||||
cohere: {
|
||||
separator: ".",
|
||||
models: ["command-r-v1:0"],
|
||||
},
|
||||
};
|
||||
|
||||
it("should display the provider selector", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onModelChange = vi.fn();
|
||||
render(<ModelSelector models={models} onModelChange={onModelChange} />);
|
||||
|
||||
const selector = screen.getByLabelText("Provider");
|
||||
expect(selector).toBeInTheDocument();
|
||||
|
||||
await user.click(selector);
|
||||
|
||||
expect(screen.getByText("OpenAI")).toBeInTheDocument();
|
||||
expect(screen.getByText("Azure")).toBeInTheDocument();
|
||||
expect(screen.getByText("VertexAI")).toBeInTheDocument();
|
||||
expect(screen.getByText("cohere")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should disable the model selector if the provider is not selected", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onModelChange = vi.fn();
|
||||
render(<ModelSelector models={models} onModelChange={onModelChange} />);
|
||||
|
||||
const modelSelector = screen.getByLabelText("Model");
|
||||
expect(modelSelector).toBeDisabled();
|
||||
|
||||
const providerSelector = screen.getByLabelText("Provider");
|
||||
await user.click(providerSelector);
|
||||
|
||||
const vertexAI = screen.getByText("VertexAI");
|
||||
await user.click(vertexAI);
|
||||
|
||||
expect(modelSelector).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should display the model selector", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onModelChange = vi.fn();
|
||||
render(<ModelSelector models={models} onModelChange={onModelChange} />);
|
||||
|
||||
const providerSelector = screen.getByLabelText("Provider");
|
||||
await user.click(providerSelector);
|
||||
|
||||
const azureProvider = screen.getByText("Azure");
|
||||
await user.click(azureProvider);
|
||||
|
||||
const modelSelector = screen.getByLabelText("Model");
|
||||
await user.click(modelSelector);
|
||||
|
||||
expect(screen.getByText("ada")).toBeInTheDocument();
|
||||
expect(screen.getByText("gpt-35-turbo")).toBeInTheDocument();
|
||||
|
||||
await user.click(providerSelector);
|
||||
const vertexProvider = screen.getByText("VertexAI");
|
||||
await user.click(vertexProvider);
|
||||
|
||||
await user.click(modelSelector);
|
||||
|
||||
expect(screen.getByText("chat-bison")).toBeInTheDocument();
|
||||
expect(screen.getByText("chat-bison-32k")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display the actual litellm model ID as the user is making the selections", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onModelChange = vi.fn();
|
||||
render(<ModelSelector models={models} onModelChange={onModelChange} />);
|
||||
|
||||
const id = screen.getByTestId("model-id");
|
||||
const providerSelector = screen.getByLabelText("Provider");
|
||||
const modelSelector = screen.getByLabelText("Model");
|
||||
|
||||
expect(id).toHaveTextContent("No model selected");
|
||||
|
||||
await user.click(providerSelector);
|
||||
await user.click(screen.getByText("Azure"));
|
||||
|
||||
expect(id).toHaveTextContent("azure/");
|
||||
|
||||
await user.click(modelSelector);
|
||||
await user.click(screen.getByText("ada"));
|
||||
expect(id).toHaveTextContent("azure/ada");
|
||||
|
||||
await user.click(providerSelector);
|
||||
await user.click(screen.getByText("cohere"));
|
||||
expect(id).toHaveTextContent("cohere.");
|
||||
|
||||
await user.click(modelSelector);
|
||||
await user.click(screen.getByText("command-r-v1:0"));
|
||||
expect(id).toHaveTextContent("cohere.command-r-v1:0");
|
||||
});
|
||||
|
||||
it("should call onModelChange when the model is changed", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onModelChange = vi.fn();
|
||||
render(<ModelSelector models={models} onModelChange={onModelChange} />);
|
||||
|
||||
const providerSelector = screen.getByLabelText("Provider");
|
||||
const modelSelector = screen.getByLabelText("Model");
|
||||
|
||||
await user.click(providerSelector);
|
||||
await user.click(screen.getByText("Azure"));
|
||||
|
||||
await user.click(modelSelector);
|
||||
await user.click(screen.getByText("ada"));
|
||||
|
||||
expect(onModelChange).toHaveBeenCalledTimes(1);
|
||||
expect(onModelChange).toHaveBeenCalledWith("azure/ada");
|
||||
|
||||
await user.click(modelSelector);
|
||||
await user.click(screen.getByText("gpt-35-turbo"));
|
||||
|
||||
expect(onModelChange).toHaveBeenCalledTimes(2);
|
||||
expect(onModelChange).toHaveBeenCalledWith("azure/gpt-35-turbo");
|
||||
|
||||
await user.click(providerSelector);
|
||||
await user.click(screen.getByText("cohere"));
|
||||
|
||||
await user.click(modelSelector);
|
||||
await user.click(screen.getByText("command-r-v1:0"));
|
||||
|
||||
expect(onModelChange).toHaveBeenCalledTimes(3);
|
||||
expect(onModelChange).toHaveBeenCalledWith("cohere.command-r-v1:0");
|
||||
});
|
||||
|
||||
it("should clear the model ID when the provider is cleared", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onModelChange = vi.fn();
|
||||
render(<ModelSelector models={models} onModelChange={onModelChange} />);
|
||||
|
||||
const providerSelector = screen.getByLabelText("Provider");
|
||||
const modelSelector = screen.getByLabelText("Model");
|
||||
|
||||
await user.click(providerSelector);
|
||||
await user.click(screen.getByText("Azure"));
|
||||
|
||||
await user.click(modelSelector);
|
||||
await user.click(screen.getByText("ada"));
|
||||
|
||||
expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
|
||||
|
||||
await user.clear(providerSelector);
|
||||
|
||||
expect(screen.getByTestId("model-id")).toHaveTextContent(
|
||||
"No model selected",
|
||||
);
|
||||
});
|
||||
|
||||
it("should have a default value if passed", async () => {
|
||||
const onModelChange = vi.fn();
|
||||
render(
|
||||
<ModelSelector
|
||||
models={models}
|
||||
onModelChange={onModelChange}
|
||||
defaultModel="azure/ada"
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
|
||||
expect(screen.getByLabelText("Provider")).toHaveValue("Azure");
|
||||
expect(screen.getByLabelText("Model")).toHaveValue("ada");
|
||||
});
|
||||
|
||||
it.todo("should disable provider if isDisabled is true");
|
||||
|
||||
it.todo(
|
||||
"should display the verified models in the correct order",
|
||||
async () => {},
|
||||
);
|
||||
});
|
||||
133
frontend/src/components/modals/settings/ModelSelector.tsx
Normal file
133
frontend/src/components/modals/settings/ModelSelector.tsx
Normal file
@@ -0,0 +1,133 @@
|
||||
import {
|
||||
Autocomplete,
|
||||
AutocompleteItem,
|
||||
AutocompleteSection,
|
||||
} from "@nextui-org/react";
|
||||
import React from "react";
|
||||
import { mapProvider } from "#/utils/mapProvider";
|
||||
import { VERIFIED_MODELS, VERIFIED_PROVIDERS } from "#/utils/verified-models";
|
||||
import { extractModelAndProvider } from "#/utils/extractModelAndProvider";
|
||||
|
||||
interface ModelSelectorProps {
|
||||
isDisabled?: boolean;
|
||||
models: Record<string, { separator: string; models: string[] }>;
|
||||
onModelChange: (model: string) => void;
|
||||
defaultModel?: string;
|
||||
}
|
||||
|
||||
export function ModelSelector({
|
||||
isDisabled,
|
||||
models,
|
||||
onModelChange,
|
||||
defaultModel,
|
||||
}: ModelSelectorProps) {
|
||||
const [litellmId, setLitellmId] = React.useState<string | null>(null);
|
||||
const [selectedProvider, setSelectedProvider] = React.useState<string | null>(
|
||||
null,
|
||||
);
|
||||
const [selectedModel, setSelectedModel] = React.useState<string | null>(null);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (defaultModel) {
|
||||
// runs when resetting to defaults
|
||||
const { provider, model } = extractModelAndProvider(defaultModel);
|
||||
|
||||
setLitellmId(defaultModel);
|
||||
setSelectedProvider(provider);
|
||||
setSelectedModel(model);
|
||||
}
|
||||
}, [defaultModel]);
|
||||
|
||||
const handleChangeProvider = (provider: string) => {
|
||||
setSelectedProvider(provider);
|
||||
setSelectedModel(null);
|
||||
|
||||
const separator = models[provider]?.separator || "";
|
||||
setLitellmId(provider + separator);
|
||||
};
|
||||
|
||||
const handleChangeModel = (model: string) => {
|
||||
const separator = models[selectedProvider || ""]?.separator || "";
|
||||
const fullModel = selectedProvider + separator + model;
|
||||
setLitellmId(fullModel);
|
||||
onModelChange(fullModel);
|
||||
setSelectedModel(model);
|
||||
};
|
||||
|
||||
const clear = () => {
|
||||
setSelectedProvider(null);
|
||||
setLitellmId(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<div data-testid="model-selector" className="flex flex-col gap-2">
|
||||
<span className="text-center italic text-gray-500" data-testid="model-id">
|
||||
{litellmId?.replace("other", "") || "No model selected"}
|
||||
</span>
|
||||
|
||||
<div className="flex flex-col gap-3">
|
||||
<Autocomplete
|
||||
isDisabled={isDisabled}
|
||||
label="Provider"
|
||||
placeholder="Select a provider"
|
||||
isClearable={false}
|
||||
onSelectionChange={(e) => {
|
||||
if (e?.toString()) handleChangeProvider(e.toString());
|
||||
}}
|
||||
onInputChange={(value) => !value && clear()}
|
||||
defaultSelectedKey={selectedProvider ?? undefined}
|
||||
selectedKey={selectedProvider}
|
||||
>
|
||||
<AutocompleteSection title="Verified">
|
||||
{Object.keys(models)
|
||||
.filter((provider) => VERIFIED_PROVIDERS.includes(provider))
|
||||
.map((provider) => (
|
||||
<AutocompleteItem key={provider} value={provider}>
|
||||
{mapProvider(provider)}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
<AutocompleteSection title="Others">
|
||||
{Object.keys(models)
|
||||
.filter((provider) => !VERIFIED_PROVIDERS.includes(provider))
|
||||
.map((provider) => (
|
||||
<AutocompleteItem key={provider} value={provider}>
|
||||
{mapProvider(provider)}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
</Autocomplete>
|
||||
|
||||
<Autocomplete
|
||||
label="Model"
|
||||
placeholder="Select a model"
|
||||
onSelectionChange={(e) => {
|
||||
if (e?.toString()) handleChangeModel(e.toString());
|
||||
}}
|
||||
isDisabled={isDisabled || !selectedProvider}
|
||||
selectedKey={selectedModel}
|
||||
defaultSelectedKey={selectedModel ?? undefined}
|
||||
>
|
||||
<AutocompleteSection title="Verified">
|
||||
{models[selectedProvider || ""]?.models
|
||||
.filter((model) => VERIFIED_MODELS.includes(model))
|
||||
.map((model) => (
|
||||
<AutocompleteItem key={model} value={model}>
|
||||
{model}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
<AutocompleteSection title="Others">
|
||||
{models[selectedProvider || ""]?.models
|
||||
.filter((model) => !VERIFIED_MODELS.includes(model))
|
||||
.map((model) => (
|
||||
<AutocompleteItem key={model} value={model}>
|
||||
{model}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
</Autocomplete>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import { Settings } from "#/services/settings";
|
||||
import SettingsForm from "./SettingsForm";
|
||||
|
||||
const onModelChangeMock = vi.fn();
|
||||
const onCustomModelChangeMock = vi.fn();
|
||||
const onModelTypeChangeMock = vi.fn();
|
||||
const onAgentChangeMock = vi.fn();
|
||||
const onLanguageChangeMock = vi.fn();
|
||||
const onAPIKeyChangeMock = vi.fn();
|
||||
@@ -18,7 +20,9 @@ const renderSettingsForm = (settings?: Settings) => {
|
||||
disabled={false}
|
||||
settings={
|
||||
settings || {
|
||||
LLM_MODEL: "model1",
|
||||
LLM_MODEL: "gpt-4o",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "agent1",
|
||||
LANGUAGE: "en",
|
||||
LLM_API_KEY: "sk-...",
|
||||
@@ -26,10 +30,12 @@ const renderSettingsForm = (settings?: Settings) => {
|
||||
SECURITY_ANALYZER: "analyzer1",
|
||||
}
|
||||
}
|
||||
models={["model1", "model2", "model3"]}
|
||||
models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
|
||||
agents={["agent1", "agent2", "agent3"]}
|
||||
securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
|
||||
onModelChange={onModelChangeMock}
|
||||
onCustomModelChange={onCustomModelChangeMock}
|
||||
onModelTypeChange={onModelTypeChangeMock}
|
||||
onAgentChange={onAgentChangeMock}
|
||||
onLanguageChange={onLanguageChangeMock}
|
||||
onAPIKeyChange={onAPIKeyChangeMock}
|
||||
@@ -43,7 +49,8 @@ describe("SettingsForm", () => {
|
||||
it("should display the first values in the array by default", () => {
|
||||
renderSettingsForm();
|
||||
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
const providerInput = screen.getByRole("combobox", { name: "Provider" });
|
||||
const modelInput = screen.getByRole("combobox", { name: "Model" });
|
||||
const agentInput = screen.getByRole("combobox", { name: "agent" });
|
||||
const languageInput = screen.getByRole("combobox", { name: "language" });
|
||||
const apiKeyInput = screen.getByTestId("apikey");
|
||||
@@ -52,7 +59,8 @@ describe("SettingsForm", () => {
|
||||
name: "securityanalyzer",
|
||||
});
|
||||
|
||||
expect(modelInput).toHaveValue("model1");
|
||||
expect(providerInput).toHaveValue("OpenAI");
|
||||
expect(modelInput).toHaveValue("gpt-4o");
|
||||
expect(agentInput).toHaveValue("agent1");
|
||||
expect(languageInput).toHaveValue("English");
|
||||
expect(apiKeyInput).toHaveValue("sk-...");
|
||||
@@ -62,7 +70,9 @@ describe("SettingsForm", () => {
|
||||
|
||||
it("should display the existing values if they are present", () => {
|
||||
renderSettingsForm({
|
||||
LLM_MODEL: "model2",
|
||||
LLM_MODEL: "gpt-3.5-turbo",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "agent2",
|
||||
LANGUAGE: "es",
|
||||
LLM_API_KEY: "sk-...",
|
||||
@@ -70,14 +80,16 @@ describe("SettingsForm", () => {
|
||||
SECURITY_ANALYZER: "analyzer2",
|
||||
});
|
||||
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
const providerInput = screen.getByRole("combobox", { name: "Provider" });
|
||||
const modelInput = screen.getByRole("combobox", { name: "Model" });
|
||||
const agentInput = screen.getByRole("combobox", { name: "agent" });
|
||||
const languageInput = screen.getByRole("combobox", { name: "language" });
|
||||
const securityAnalyzerInput = screen.getByRole("combobox", {
|
||||
name: "securityanalyzer",
|
||||
});
|
||||
|
||||
expect(modelInput).toHaveValue("model2");
|
||||
expect(providerInput).toHaveValue("OpenAI");
|
||||
expect(modelInput).toHaveValue("gpt-3.5-turbo");
|
||||
expect(agentInput).toHaveValue("agent2");
|
||||
expect(languageInput).toHaveValue("Español");
|
||||
expect(securityAnalyzerInput).toHaveValue("analyzer2");
|
||||
@@ -87,18 +99,22 @@ describe("SettingsForm", () => {
|
||||
renderWithProviders(
|
||||
<SettingsForm
|
||||
settings={{
|
||||
LLM_MODEL: "model1",
|
||||
LLM_MODEL: "gpt-4o",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "agent1",
|
||||
LANGUAGE: "en",
|
||||
LLM_API_KEY: "sk-...",
|
||||
CONFIRMATION_MODE: true,
|
||||
SECURITY_ANALYZER: "analyzer1",
|
||||
}}
|
||||
models={["model1", "model2", "model3"]}
|
||||
models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
|
||||
agents={["agent1", "agent2", "agent3"]}
|
||||
securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
|
||||
disabled
|
||||
onModelChange={onModelChangeMock}
|
||||
onCustomModelChange={onCustomModelChangeMock}
|
||||
onModelTypeChange={onModelTypeChangeMock}
|
||||
onAgentChange={onAgentChangeMock}
|
||||
onLanguageChange={onLanguageChangeMock}
|
||||
onAPIKeyChange={onAPIKeyChangeMock}
|
||||
@@ -106,7 +122,9 @@ describe("SettingsForm", () => {
|
||||
onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock}
|
||||
/>,
|
||||
);
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
|
||||
const providerInput = screen.getByRole("combobox", { name: "Provider" });
|
||||
const modelInput = screen.getByRole("combobox", { name: "Model" });
|
||||
const agentInput = screen.getByRole("combobox", { name: "agent" });
|
||||
const languageInput = screen.getByRole("combobox", { name: "language" });
|
||||
const confirmationModeInput = screen.getByTestId("confirmationmode");
|
||||
@@ -114,6 +132,7 @@ describe("SettingsForm", () => {
|
||||
name: "securityanalyzer",
|
||||
});
|
||||
|
||||
expect(providerInput).toBeDisabled();
|
||||
expect(modelInput).toBeDisabled();
|
||||
expect(agentInput).toBeDisabled();
|
||||
expect(languageInput).toBeDisabled();
|
||||
@@ -122,22 +141,6 @@ describe("SettingsForm", () => {
|
||||
});
|
||||
|
||||
describe("onChange handlers", () => {
|
||||
it("should call the onModelChange handler when the model changes", async () => {
|
||||
renderSettingsForm();
|
||||
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
await act(async () => {
|
||||
await userEvent.click(modelInput);
|
||||
});
|
||||
|
||||
const model3 = screen.getByText("model3");
|
||||
await act(async () => {
|
||||
await userEvent.click(model3);
|
||||
});
|
||||
|
||||
expect(onModelChangeMock).toHaveBeenCalledWith("model3");
|
||||
});
|
||||
|
||||
it("should call the onAgentChange handler when the agent changes", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderSettingsForm();
|
||||
@@ -182,4 +185,76 @@ describe("SettingsForm", () => {
|
||||
expect(onAPIKeyChangeMock).toHaveBeenCalledWith("sk-...x");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Setting a custom LLM model", () => {
|
||||
it("should display the fetched models by default", () => {
|
||||
renderSettingsForm();
|
||||
|
||||
const modelSelector = screen.getByTestId("model-selector");
|
||||
expect(modelSelector).toBeInTheDocument();
|
||||
|
||||
const customModelInput = screen.queryByTestId("custom-model-input");
|
||||
expect(customModelInput).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should switch to the custom model input when the custom model toggle is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderSettingsForm();
|
||||
|
||||
const customModelToggle = screen.getByTestId("custom-model-toggle");
|
||||
await user.click(customModelToggle);
|
||||
|
||||
const modelSelector = screen.queryByTestId("model-selector");
|
||||
expect(modelSelector).not.toBeInTheDocument();
|
||||
|
||||
const customModelInput = screen.getByTestId("custom-model-input");
|
||||
expect(customModelInput).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call the onCustomModelChange handler when the custom model input changes", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderSettingsForm();
|
||||
|
||||
const customModelToggle = screen.getByTestId("custom-model-toggle");
|
||||
await user.click(customModelToggle);
|
||||
|
||||
const customModelInput = screen.getByTestId("custom-model-input");
|
||||
await userEvent.type(customModelInput, "my/custom-model");
|
||||
|
||||
expect(onCustomModelChangeMock).toHaveBeenCalledWith("my/custom-model");
|
||||
expect(onModelTypeChangeMock).toHaveBeenCalledWith("custom");
|
||||
});
|
||||
|
||||
it("should have custom model switched if using custom model", () => {
|
||||
renderWithProviders(
|
||||
<SettingsForm
|
||||
settings={{
|
||||
LLM_MODEL: "gpt-4o",
|
||||
CUSTOM_LLM_MODEL: "CUSTOM_MODEL",
|
||||
USING_CUSTOM_MODEL: true,
|
||||
AGENT: "agent1",
|
||||
LANGUAGE: "en",
|
||||
LLM_API_KEY: "sk-...",
|
||||
CONFIRMATION_MODE: true,
|
||||
SECURITY_ANALYZER: "analyzer1",
|
||||
}}
|
||||
models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
|
||||
agents={["agent1", "agent2", "agent3"]}
|
||||
securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
|
||||
disabled
|
||||
onModelChange={onModelChangeMock}
|
||||
onCustomModelChange={onCustomModelChangeMock}
|
||||
onModelTypeChange={onModelTypeChangeMock}
|
||||
onAgentChange={onAgentChangeMock}
|
||||
onLanguageChange={onLanguageChangeMock}
|
||||
onAPIKeyChange={onAPIKeyChangeMock}
|
||||
onConfirmationModeChange={onConfirmationModeChangeMock}
|
||||
onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
const customModelToggle = screen.getByTestId("custom-model-toggle");
|
||||
expect(customModelToggle).toHaveAttribute("aria-checked", "true");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,6 +6,8 @@ import { AvailableLanguages } from "../../../i18n";
|
||||
import { I18nKey } from "../../../i18n/declaration";
|
||||
import { AutocompleteCombobox } from "./AutocompleteCombobox";
|
||||
import { Settings } from "#/services/settings";
|
||||
import { organizeModelsAndProviders } from "#/utils/organizeModelsAndProviders";
|
||||
import { ModelSelector } from "./ModelSelector";
|
||||
|
||||
interface SettingsFormProps {
|
||||
settings: Settings;
|
||||
@@ -15,6 +17,8 @@ interface SettingsFormProps {
|
||||
disabled: boolean;
|
||||
|
||||
onModelChange: (model: string) => void;
|
||||
onCustomModelChange: (model: string) => void;
|
||||
onModelTypeChange: (type: "custom" | "default") => void;
|
||||
onAPIKeyChange: (apiKey: string) => void;
|
||||
onAgentChange: (agent: string) => void;
|
||||
onLanguageChange: (language: string) => void;
|
||||
@@ -29,6 +33,8 @@ function SettingsForm({
|
||||
securityAnalyzers,
|
||||
disabled,
|
||||
onModelChange,
|
||||
onCustomModelChange,
|
||||
onModelTypeChange,
|
||||
onAPIKeyChange,
|
||||
onAgentChange,
|
||||
onLanguageChange,
|
||||
@@ -38,20 +44,46 @@ function SettingsForm({
|
||||
const { t } = useTranslation();
|
||||
const { isOpen: isVisible, onOpenChange: onVisibleChange } = useDisclosure();
|
||||
const [isAgentSelectEnabled, setIsAgentSelectEnabled] = React.useState(false);
|
||||
const [usingCustomModel, setUsingCustomModel] = React.useState(
|
||||
settings.USING_CUSTOM_MODEL,
|
||||
);
|
||||
|
||||
const changeModelType = (type: "custom" | "default") => {
|
||||
if (type === "custom") {
|
||||
setUsingCustomModel(true);
|
||||
onModelTypeChange("custom");
|
||||
} else {
|
||||
setUsingCustomModel(false);
|
||||
onModelTypeChange("default");
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<AutocompleteCombobox
|
||||
ariaLabel="model"
|
||||
items={models.map((model) => ({ value: model, label: model }))}
|
||||
defaultKey={settings.LLM_MODEL}
|
||||
onChange={(e) => {
|
||||
onModelChange(e);
|
||||
}}
|
||||
tooltip={t(I18nKey.SETTINGS$MODEL_TOOLTIP)}
|
||||
allowCustomValue // user can type in a custom LLM model that is not in the list
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Switch
|
||||
data-testid="custom-model-toggle"
|
||||
aria-checked={usingCustomModel}
|
||||
isSelected={usingCustomModel}
|
||||
onValueChange={(value) => changeModelType(value ? "custom" : "default")}
|
||||
>
|
||||
Use custom model
|
||||
</Switch>
|
||||
{usingCustomModel && (
|
||||
<Input
|
||||
data-testid="custom-model-input"
|
||||
label="Custom Model"
|
||||
onValueChange={onCustomModelChange}
|
||||
defaultValue={settings.CUSTOM_LLM_MODEL}
|
||||
/>
|
||||
)}
|
||||
{!usingCustomModel && (
|
||||
<ModelSelector
|
||||
isDisabled={disabled}
|
||||
models={organizeModelsAndProviders(models)}
|
||||
onModelChange={onModelChange}
|
||||
defaultModel={settings.LLM_MODEL}
|
||||
/>
|
||||
)}
|
||||
<Input
|
||||
label="API Key"
|
||||
isDisabled={disabled}
|
||||
|
||||
@@ -24,6 +24,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("#/services/settings")>()),
|
||||
getSettings: vi.fn().mockReturnValue({
|
||||
LLM_MODEL: "gpt-4o",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "CodeActAgent",
|
||||
LANGUAGE: "en",
|
||||
LLM_API_KEY: "sk-...",
|
||||
@@ -32,6 +34,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
|
||||
}),
|
||||
getDefaultSettings: vi.fn().mockReturnValue({
|
||||
LLM_MODEL: "gpt-4o",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "CodeActAgent",
|
||||
LANGUAGE: "en",
|
||||
LLM_API_KEY: "",
|
||||
@@ -46,7 +50,14 @@ vi.mock("#/services/options", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("#/services/options")>()),
|
||||
fetchModels: vi
|
||||
.fn()
|
||||
.mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])),
|
||||
.mockResolvedValue(
|
||||
Promise.resolve([
|
||||
"gpt-4o",
|
||||
"gpt-3.5-turbo",
|
||||
"azure/ada",
|
||||
"cohere.command-r-v1:0",
|
||||
]),
|
||||
),
|
||||
fetchAgents: vi
|
||||
.fn()
|
||||
.mockResolvedValue(Promise.resolve(["agent1", "agent2", "agent3"])),
|
||||
@@ -104,6 +115,8 @@ describe("SettingsModal", () => {
|
||||
describe("onHandleSave", () => {
|
||||
const initialSettings: Settings = {
|
||||
LLM_MODEL: "gpt-4o",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "CodeActAgent",
|
||||
LANGUAGE: "en",
|
||||
LLM_API_KEY: "sk-...",
|
||||
@@ -122,17 +135,22 @@ describe("SettingsModal", () => {
|
||||
await assertModelsAndAgentsFetched();
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: /save/i });
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
const providerInput = screen.getByRole("combobox", { name: "Provider" });
|
||||
const modelInput = screen.getByRole("combobox", { name: "Model" });
|
||||
|
||||
await user.click(providerInput);
|
||||
const azure = screen.getByText("Azure");
|
||||
await user.click(azure);
|
||||
|
||||
await user.click(modelInput);
|
||||
const model3 = screen.getByText("model3");
|
||||
|
||||
const model3 = screen.getByText("ada");
|
||||
await user.click(model3);
|
||||
|
||||
await user.click(saveButton);
|
||||
|
||||
expect(saveSettings).toHaveBeenCalledWith({
|
||||
...initialSettings,
|
||||
LLM_MODEL: "model3",
|
||||
LLM_MODEL: "azure/ada",
|
||||
});
|
||||
});
|
||||
|
||||
@@ -146,12 +164,17 @@ describe("SettingsModal", () => {
|
||||
);
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: /save/i });
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
const providerInput = screen.getByRole("combobox", { name: "Provider" });
|
||||
const modelInput = screen.getByRole("combobox", { name: "Model" });
|
||||
|
||||
await user.click(providerInput);
|
||||
const openai = screen.getByText("OpenAI");
|
||||
await user.click(openai);
|
||||
|
||||
await user.click(modelInput);
|
||||
const model3 = screen.getByText("model3");
|
||||
|
||||
const model3 = screen.getByText("gpt-3.5-turbo");
|
||||
await user.click(model3);
|
||||
|
||||
await user.click(saveButton);
|
||||
|
||||
expect(startNewSessionSpy).toHaveBeenCalled();
|
||||
@@ -167,12 +190,17 @@ describe("SettingsModal", () => {
|
||||
);
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: /save/i });
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
const providerInput = screen.getByRole("combobox", { name: "Provider" });
|
||||
const modelInput = screen.getByRole("combobox", { name: "Model" });
|
||||
|
||||
await user.click(providerInput);
|
||||
const cohere = screen.getByText("cohere");
|
||||
await user.click(cohere);
|
||||
|
||||
await user.click(modelInput);
|
||||
const model3 = screen.getByText("model3");
|
||||
|
||||
const model3 = screen.getByText("command-r-v1:0");
|
||||
await user.click(model3);
|
||||
|
||||
await user.click(saveButton);
|
||||
|
||||
expect(toastSpy).toHaveBeenCalledTimes(4);
|
||||
@@ -213,12 +241,17 @@ describe("SettingsModal", () => {
|
||||
});
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: /save/i });
|
||||
const modelInput = screen.getByRole("combobox", { name: "model" });
|
||||
const providerInput = screen.getByRole("combobox", { name: "Provider" });
|
||||
const modelInput = screen.getByRole("combobox", { name: "Model" });
|
||||
|
||||
await user.click(providerInput);
|
||||
const cohere = screen.getByText("cohere");
|
||||
await user.click(cohere);
|
||||
|
||||
await user.click(modelInput);
|
||||
const model3 = screen.getByText("model3");
|
||||
|
||||
const model3 = screen.getByText("command-r-v1:0");
|
||||
await user.click(model3);
|
||||
|
||||
await user.click(saveButton);
|
||||
|
||||
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
|
||||
|
||||
@@ -63,8 +63,10 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
|
||||
React.useEffect(() => {
|
||||
(async () => {
|
||||
try {
|
||||
setModels(await fetchModels());
|
||||
setAgents(await fetchAgents());
|
||||
const fetchedModels = await fetchModels();
|
||||
const fetchedAgents = await fetchAgents();
|
||||
setModels(fetchedModels);
|
||||
setAgents(fetchedAgents);
|
||||
setSecurityAnalyzers(await fetchSecurityAnalyzers());
|
||||
} catch (error) {
|
||||
toast.error("settings", t(I18nKey.CONFIGURATION$ERROR_FETCH_MODELS));
|
||||
@@ -81,6 +83,20 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
|
||||
}));
|
||||
};
|
||||
|
||||
const handleCustomModelChange = (model: string) => {
|
||||
setSettings((prev) => ({
|
||||
...prev,
|
||||
CUSTOM_LLM_MODEL: model,
|
||||
}));
|
||||
};
|
||||
|
||||
const handleModelTypeChange = (type: "custom" | "default") => {
|
||||
setSettings((prev) => ({
|
||||
...prev,
|
||||
USING_CUSTOM_MODEL: type === "custom",
|
||||
}));
|
||||
};
|
||||
|
||||
const handleAgentChange = (agent: string) => {
|
||||
setSettings((prev) => ({ ...prev, AGENT: agent }));
|
||||
};
|
||||
@@ -189,6 +205,8 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
|
||||
agents={agents}
|
||||
securityAnalyzers={securityAnalyzers}
|
||||
onModelChange={handleModelChange}
|
||||
onCustomModelChange={handleCustomModelChange}
|
||||
onModelTypeChange={handleModelTypeChange}
|
||||
onAgentChange={handleAgentChange}
|
||||
onLanguageChange={handleLanguageChange}
|
||||
onAPIKeyChange={handleAPIKeyChange}
|
||||
|
||||
@@ -11,9 +11,16 @@ const setupSpy = vi.spyOn(Session, "_setupSocket").mockImplementation(() => {
|
||||
});
|
||||
|
||||
describe("startNewSession", () => {
|
||||
afterEach(() => {
|
||||
sendSpy.mockClear();
|
||||
setupSpy.mockClear();
|
||||
});
|
||||
|
||||
it("Should start a new session with the current settings", () => {
|
||||
const settings: Settings = {
|
||||
LLM_MODEL: "llm_value",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "agent_value",
|
||||
LANGUAGE: "language_value",
|
||||
LLM_API_KEY: "sk-...",
|
||||
@@ -32,4 +39,33 @@ describe("startNewSession", () => {
|
||||
expect(setupSpy).toHaveBeenCalledTimes(1);
|
||||
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
});
|
||||
|
||||
it("should start with the custom llm if set", () => {
|
||||
const settings: Settings = {
|
||||
LLM_MODEL: "llm_value",
|
||||
CUSTOM_LLM_MODEL: "custom_llm_value",
|
||||
USING_CUSTOM_MODEL: true,
|
||||
AGENT: "agent_value",
|
||||
LANGUAGE: "language_value",
|
||||
LLM_API_KEY: "sk-...",
|
||||
CONFIRMATION_MODE: true,
|
||||
SECURITY_ANALYZER: "analyzer",
|
||||
};
|
||||
|
||||
const event = {
|
||||
action: ActionType.INIT,
|
||||
args: settings,
|
||||
};
|
||||
|
||||
saveSettings(settings);
|
||||
Session.startNewSession();
|
||||
|
||||
expect(setupSpy).toHaveBeenCalledTimes(1);
|
||||
expect(sendSpy).toHaveBeenCalledWith(
|
||||
JSON.stringify({
|
||||
...event,
|
||||
args: { ...settings, LLM_MODEL: "custom_llm_value" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -46,7 +46,15 @@ class Session {
|
||||
|
||||
private static _initializeAgent = () => {
|
||||
const settings = getSettings();
|
||||
const event = { action: ActionType.INIT, args: settings };
|
||||
const event = {
|
||||
action: ActionType.INIT,
|
||||
args: {
|
||||
...settings,
|
||||
LLM_MODEL: settings.USING_CUSTOM_MODEL
|
||||
? settings.CUSTOM_LLM_MODEL
|
||||
: settings.LLM_MODEL,
|
||||
},
|
||||
};
|
||||
const eventString = JSON.stringify(event);
|
||||
Session.send(eventString);
|
||||
};
|
||||
|
||||
@@ -18,6 +18,8 @@ describe("getSettings", () => {
|
||||
it("should get the stored settings", () => {
|
||||
(localStorage.getItem as Mock)
|
||||
.mockReturnValueOnce("llm_value")
|
||||
.mockReturnValueOnce("custom_llm_value")
|
||||
.mockReturnValueOnce("true")
|
||||
.mockReturnValueOnce("agent_value")
|
||||
.mockReturnValueOnce("language_value")
|
||||
.mockReturnValueOnce("api_key")
|
||||
@@ -28,6 +30,8 @@ describe("getSettings", () => {
|
||||
|
||||
expect(settings).toEqual({
|
||||
LLM_MODEL: "llm_value",
|
||||
CUSTOM_LLM_MODEL: "custom_llm_value",
|
||||
USING_CUSTOM_MODEL: true,
|
||||
AGENT: "agent_value",
|
||||
LANGUAGE: "language_value",
|
||||
LLM_API_KEY: "api_key",
|
||||
@@ -43,12 +47,16 @@ describe("getSettings", () => {
|
||||
.mockReturnValueOnce(null)
|
||||
.mockReturnValueOnce(null)
|
||||
.mockReturnValueOnce(null)
|
||||
.mockReturnValueOnce(null)
|
||||
.mockReturnValueOnce(null)
|
||||
.mockReturnValueOnce(null);
|
||||
|
||||
const settings = getSettings();
|
||||
|
||||
expect(settings).toEqual({
|
||||
LLM_MODEL: DEFAULT_SETTINGS.LLM_MODEL,
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
|
||||
AGENT: DEFAULT_SETTINGS.AGENT,
|
||||
LANGUAGE: DEFAULT_SETTINGS.LANGUAGE,
|
||||
LLM_API_KEY: "",
|
||||
@@ -62,6 +70,8 @@ describe("saveSettings", () => {
|
||||
it("should save the settings", () => {
|
||||
const settings: Settings = {
|
||||
LLM_MODEL: "llm_value",
|
||||
CUSTOM_LLM_MODEL: "custom_llm_value",
|
||||
USING_CUSTOM_MODEL: true,
|
||||
AGENT: "agent_value",
|
||||
LANGUAGE: "language_value",
|
||||
LLM_API_KEY: "some_key",
|
||||
@@ -72,6 +82,14 @@ describe("saveSettings", () => {
|
||||
saveSettings(settings);
|
||||
|
||||
expect(localStorage.setItem).toHaveBeenCalledWith("LLM_MODEL", "llm_value");
|
||||
expect(localStorage.setItem).toHaveBeenCalledWith(
|
||||
"CUSTOM_LLM_MODEL",
|
||||
"custom_llm_value",
|
||||
);
|
||||
expect(localStorage.setItem).toHaveBeenCalledWith(
|
||||
"USING_CUSTOM_MODEL",
|
||||
"true",
|
||||
);
|
||||
expect(localStorage.setItem).toHaveBeenCalledWith("AGENT", "agent_value");
|
||||
expect(localStorage.setItem).toHaveBeenCalledWith(
|
||||
"LANGUAGE",
|
||||
@@ -122,6 +140,8 @@ describe("getSettingsDifference", () => {
|
||||
beforeEach(() => {
|
||||
(localStorage.getItem as Mock)
|
||||
.mockReturnValueOnce("llm_value")
|
||||
.mockReturnValueOnce("custom_llm_value")
|
||||
.mockReturnValueOnce("false")
|
||||
.mockReturnValueOnce("agent_value")
|
||||
.mockReturnValueOnce("language_value");
|
||||
});
|
||||
@@ -129,6 +149,8 @@ describe("getSettingsDifference", () => {
|
||||
it("should return updated settings", () => {
|
||||
const settings = {
|
||||
LLM_MODEL: "new_llm_value",
|
||||
CUSTOM_LLM_MODEL: "custom_llm_value",
|
||||
USING_CUSTOM_MODEL: true,
|
||||
AGENT: "new_agent_value",
|
||||
LANGUAGE: "language_value",
|
||||
};
|
||||
@@ -136,6 +158,7 @@ describe("getSettingsDifference", () => {
|
||||
const updatedSettings = getSettingsDifference(settings);
|
||||
|
||||
expect(updatedSettings).toEqual({
|
||||
USING_CUSTOM_MODEL: true,
|
||||
LLM_MODEL: "new_llm_value",
|
||||
AGENT: "new_agent_value",
|
||||
});
|
||||
|
||||
@@ -2,6 +2,8 @@ const LATEST_SETTINGS_VERSION = 1;
|
||||
|
||||
export type Settings = {
|
||||
LLM_MODEL: string;
|
||||
CUSTOM_LLM_MODEL: string;
|
||||
USING_CUSTOM_MODEL: boolean;
|
||||
AGENT: string;
|
||||
LANGUAGE: string;
|
||||
LLM_API_KEY: string;
|
||||
@@ -12,7 +14,9 @@ export type Settings = {
|
||||
type SettingsInput = Settings[keyof Settings];
|
||||
|
||||
export const DEFAULT_SETTINGS: Settings = {
|
||||
LLM_MODEL: "gpt-4o",
|
||||
LLM_MODEL: "openai/gpt-4o",
|
||||
CUSTOM_LLM_MODEL: "",
|
||||
USING_CUSTOM_MODEL: false,
|
||||
AGENT: "CodeActAgent",
|
||||
LANGUAGE: "en",
|
||||
LLM_API_KEY: "",
|
||||
@@ -54,6 +58,9 @@ export const getDefaultSettings = (): Settings => DEFAULT_SETTINGS;
|
||||
*/
|
||||
export const getSettings = (): Settings => {
|
||||
const model = localStorage.getItem("LLM_MODEL");
|
||||
const customModel = localStorage.getItem("CUSTOM_LLM_MODEL");
|
||||
const usingCustomModel =
|
||||
localStorage.getItem("USING_CUSTOM_MODEL") === "true";
|
||||
const agent = localStorage.getItem("AGENT");
|
||||
const language = localStorage.getItem("LANGUAGE");
|
||||
const apiKey = localStorage.getItem("LLM_API_KEY");
|
||||
@@ -62,6 +69,8 @@ export const getSettings = (): Settings => {
|
||||
|
||||
return {
|
||||
LLM_MODEL: model || DEFAULT_SETTINGS.LLM_MODEL,
|
||||
CUSTOM_LLM_MODEL: customModel || DEFAULT_SETTINGS.CUSTOM_LLM_MODEL,
|
||||
USING_CUSTOM_MODEL: usingCustomModel || DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
|
||||
AGENT: agent || DEFAULT_SETTINGS.AGENT,
|
||||
LANGUAGE: language || DEFAULT_SETTINGS.LANGUAGE,
|
||||
LLM_API_KEY: apiKey || DEFAULT_SETTINGS.LLM_API_KEY,
|
||||
|
||||
62
frontend/src/utils/extractModelAndProvider.test.ts
Normal file
62
frontend/src/utils/extractModelAndProvider.test.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { extractModelAndProvider } from "./extractModelAndProvider";
|
||||
|
||||
describe("extractModelAndProvider", () => {
|
||||
it("should work", () => {
|
||||
expect(extractModelAndProvider("azure/ada")).toEqual({
|
||||
provider: "azure",
|
||||
model: "ada",
|
||||
separator: "/",
|
||||
});
|
||||
|
||||
expect(
|
||||
extractModelAndProvider("azure/standard/1024-x-1024/dall-e-2"),
|
||||
).toEqual({
|
||||
provider: "azure",
|
||||
model: "standard/1024-x-1024/dall-e-2",
|
||||
separator: "/",
|
||||
});
|
||||
|
||||
expect(extractModelAndProvider("vertex_ai_beta/chat-bison")).toEqual({
|
||||
provider: "vertex_ai_beta",
|
||||
model: "chat-bison",
|
||||
separator: "/",
|
||||
});
|
||||
|
||||
expect(extractModelAndProvider("cohere.command-r-v1:0")).toEqual({
|
||||
provider: "cohere",
|
||||
model: "command-r-v1:0",
|
||||
separator: ".",
|
||||
});
|
||||
|
||||
expect(
|
||||
extractModelAndProvider(
|
||||
"cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
|
||||
),
|
||||
).toEqual({
|
||||
provider: "cloudflare",
|
||||
model: "@cf/mistral/mistral-7b-instruct-v0.1",
|
||||
separator: "/",
|
||||
});
|
||||
|
||||
expect(extractModelAndProvider("together-ai-21.1b-41b")).toEqual({
|
||||
provider: "",
|
||||
model: "together-ai-21.1b-41b",
|
||||
separator: "",
|
||||
});
|
||||
});
|
||||
|
||||
it("should add provider for popular models", () => {
|
||||
expect(extractModelAndProvider("gpt-3.5-turbo")).toEqual({
|
||||
provider: "openai",
|
||||
model: "gpt-3.5-turbo",
|
||||
separator: "/",
|
||||
});
|
||||
|
||||
expect(extractModelAndProvider("gpt-4o")).toEqual({
|
||||
provider: "openai",
|
||||
model: "gpt-4o",
|
||||
separator: "/",
|
||||
});
|
||||
});
|
||||
});
|
||||
49
frontend/src/utils/extractModelAndProvider.ts
Normal file
49
frontend/src/utils/extractModelAndProvider.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { isNumber } from "./isNumber";
|
||||
import { VERIFIED_OPENAI_MODELS } from "./verified-models";
|
||||
|
||||
/**
|
||||
* Checks if the split array is actually a version number.
|
||||
* @param split The split array of the model string
|
||||
* @returns Boolean indicating if the split is actually a version number
|
||||
*
|
||||
* @example
|
||||
* const split = ["gpt-3", "5-turbo"] // incorrectly split from "gpt-3.5-turbo"
|
||||
* splitIsActuallyVersion(split) // returns true
|
||||
*/
|
||||
const splitIsActuallyVersion = (split: string[]) =>
|
||||
split[1] && split[1][0] && isNumber(split[1][0]);
|
||||
|
||||
/**
|
||||
* Given a model string, extract the provider and model name. Currently the supported separators are "/" and "."
|
||||
* @param model The model string
|
||||
* @returns An object containing the provider, model name, and separator
|
||||
*
|
||||
* @example
|
||||
* extractModelAndProvider("azure/ada")
|
||||
* // returns { provider: "azure", model: "ada", separator: "/" }
|
||||
*
|
||||
* extractModelAndProvider("cohere.command-r-v1:0")
|
||||
* // returns { provider: "cohere", model: "command-r-v1:0", separator: "." }
|
||||
*/
|
||||
export const extractModelAndProvider = (model: string) => {
|
||||
let separator = "/";
|
||||
let split = model.split(separator);
|
||||
if (split.length === 1) {
|
||||
// no "/" separator found, try with "."
|
||||
separator = ".";
|
||||
split = model.split(separator);
|
||||
if (splitIsActuallyVersion(split)) {
|
||||
split = [split.join(separator)]; // undo the split
|
||||
}
|
||||
}
|
||||
if (split.length === 1) {
|
||||
// no "/" or "." separator found
|
||||
if (VERIFIED_OPENAI_MODELS.includes(split[0])) {
|
||||
return { provider: "openai", model: split[0], separator: "/" };
|
||||
}
|
||||
// return as model only
|
||||
return { provider: "", model, separator: "" };
|
||||
}
|
||||
const [provider, ...modelId] = split;
|
||||
return { provider, model: modelId.join(separator), separator };
|
||||
};
|
||||
9
frontend/src/utils/isNumber.test.ts
Normal file
9
frontend/src/utils/isNumber.test.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
import { test, expect } from "vitest";
|
||||
import { isNumber } from "./isNumber";
|
||||
|
||||
test("isNumber", () => {
|
||||
expect(isNumber(1)).toBe(true);
|
||||
expect(isNumber(0)).toBe(true);
|
||||
expect(isNumber("3")).toBe(true);
|
||||
expect(isNumber("0")).toBe(true);
|
||||
});
|
||||
2
frontend/src/utils/isNumber.ts
Normal file
2
frontend/src/utils/isNumber.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export const isNumber = (value: string | number): boolean =>
|
||||
!Number.isNaN(Number(value));
|
||||
27
frontend/src/utils/mapProvider.test.ts
Normal file
27
frontend/src/utils/mapProvider.test.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { test, expect } from "vitest";
|
||||
import { mapProvider } from "./mapProvider";
|
||||
|
||||
test("mapProvider", () => {
|
||||
expect(mapProvider("azure")).toBe("Azure");
|
||||
expect(mapProvider("azure_ai")).toBe("Azure AI Studio");
|
||||
expect(mapProvider("vertex_ai")).toBe("VertexAI");
|
||||
expect(mapProvider("palm")).toBe("PaLM");
|
||||
expect(mapProvider("gemini")).toBe("Gemini");
|
||||
expect(mapProvider("anthropic")).toBe("Anthropic");
|
||||
expect(mapProvider("sagemaker")).toBe("AWS SageMaker");
|
||||
expect(mapProvider("bedrock")).toBe("AWS Bedrock");
|
||||
expect(mapProvider("mistral")).toBe("Mistral AI");
|
||||
expect(mapProvider("anyscale")).toBe("Anyscale");
|
||||
expect(mapProvider("databricks")).toBe("Databricks");
|
||||
expect(mapProvider("ollama")).toBe("Ollama");
|
||||
expect(mapProvider("perlexity")).toBe("Perplexity AI");
|
||||
expect(mapProvider("friendliai")).toBe("FriendliAI");
|
||||
expect(mapProvider("groq")).toBe("Groq");
|
||||
expect(mapProvider("fireworks_ai")).toBe("Fireworks AI");
|
||||
expect(mapProvider("cloudflare")).toBe("Cloudflare Workers AI");
|
||||
expect(mapProvider("deepinfra")).toBe("DeepInfra");
|
||||
expect(mapProvider("ai21")).toBe("AI21");
|
||||
expect(mapProvider("replicate")).toBe("Replicate");
|
||||
expect(mapProvider("voyage")).toBe("Voyage AI");
|
||||
expect(mapProvider("openrouter")).toBe("OpenRouter");
|
||||
});
|
||||
30
frontend/src/utils/mapProvider.ts
Normal file
30
frontend/src/utils/mapProvider.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
export const MAP_PROVIDER = {
|
||||
openai: "OpenAI",
|
||||
azure: "Azure",
|
||||
azure_ai: "Azure AI Studio",
|
||||
vertex_ai: "VertexAI",
|
||||
palm: "PaLM",
|
||||
gemini: "Gemini",
|
||||
anthropic: "Anthropic",
|
||||
sagemaker: "AWS SageMaker",
|
||||
bedrock: "AWS Bedrock",
|
||||
mistral: "Mistral AI",
|
||||
anyscale: "Anyscale",
|
||||
databricks: "Databricks",
|
||||
ollama: "Ollama",
|
||||
perlexity: "Perplexity AI",
|
||||
friendliai: "FriendliAI",
|
||||
groq: "Groq",
|
||||
fireworks_ai: "Fireworks AI",
|
||||
cloudflare: "Cloudflare Workers AI",
|
||||
deepinfra: "DeepInfra",
|
||||
ai21: "AI21",
|
||||
replicate: "Replicate",
|
||||
voyage: "Voyage AI",
|
||||
openrouter: "OpenRouter",
|
||||
};
|
||||
|
||||
export const mapProvider = (provider: string) =>
|
||||
Object.keys(MAP_PROVIDER).includes(provider)
|
||||
? MAP_PROVIDER[provider as keyof typeof MAP_PROVIDER]
|
||||
: provider;
|
||||
51
frontend/src/utils/organizeModelsAndProviders.test.ts
Normal file
51
frontend/src/utils/organizeModelsAndProviders.test.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import { test } from "vitest";
|
||||
import { organizeModelsAndProviders } from "./organizeModelsAndProviders";
|
||||
|
||||
test("organizeModelsAndProviders", () => {
|
||||
const models = [
|
||||
"azure/ada",
|
||||
"azure/gpt-35-turbo",
|
||||
"azure/gpt-3-turbo",
|
||||
"azure/standard/1024-x-1024/dall-e-2",
|
||||
"vertex_ai_beta/chat-bison",
|
||||
"vertex_ai_beta/chat-bison-32k",
|
||||
"sagemaker/meta-textgeneration-llama-2-13b",
|
||||
"cohere.command-r-v1:0",
|
||||
"cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
|
||||
"gpt-4o",
|
||||
"together-ai-21.1b-41b",
|
||||
"gpt-3.5-turbo",
|
||||
];
|
||||
|
||||
const object = organizeModelsAndProviders(models);
|
||||
|
||||
expect(object).toEqual({
|
||||
azure: {
|
||||
separator: "/",
|
||||
models: [
|
||||
"ada",
|
||||
"gpt-35-turbo",
|
||||
"gpt-3-turbo",
|
||||
"standard/1024-x-1024/dall-e-2",
|
||||
],
|
||||
},
|
||||
vertex_ai_beta: {
|
||||
separator: "/",
|
||||
models: ["chat-bison", "chat-bison-32k"],
|
||||
},
|
||||
sagemaker: { separator: "/", models: ["meta-textgeneration-llama-2-13b"] },
|
||||
cohere: { separator: ".", models: ["command-r-v1:0"] },
|
||||
cloudflare: {
|
||||
separator: "/",
|
||||
models: ["@cf/mistral/mistral-7b-instruct-v0.1"],
|
||||
},
|
||||
openai: {
|
||||
separator: "/",
|
||||
models: ["gpt-4o", "gpt-3.5-turbo"],
|
||||
},
|
||||
other: {
|
||||
separator: "",
|
||||
models: ["together-ai-21.1b-41b"],
|
||||
},
|
||||
});
|
||||
});
|
||||
42
frontend/src/utils/organizeModelsAndProviders.ts
Normal file
42
frontend/src/utils/organizeModelsAndProviders.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import { extractModelAndProvider } from "./extractModelAndProvider";
|
||||
|
||||
/**
|
||||
* Given a list of models, organize them by provider
|
||||
* @param models The list of models
|
||||
* @returns An object containing the provider and models
|
||||
*
|
||||
* @example
|
||||
* const models = [
|
||||
* "azure/ada",
|
||||
* "azure/gpt-35-turbo",
|
||||
* "cohere.command-r-v1:0",
|
||||
* ];
|
||||
*
|
||||
* organizeModelsAndProviders(models);
|
||||
* // returns {
|
||||
* // azure: {
|
||||
* // separator: "/",
|
||||
* // models: ["ada", "gpt-35-turbo"],
|
||||
* // },
|
||||
* // cohere: {
|
||||
* // separator: ".",
|
||||
* // models: ["command-r-v1:0"],
|
||||
* // },
|
||||
* // }
|
||||
*/
|
||||
export const organizeModelsAndProviders = (models: string[]) => {
|
||||
const object: Record<string, { separator: string; models: string[] }> = {};
|
||||
models.forEach((model) => {
|
||||
const {
|
||||
separator,
|
||||
provider,
|
||||
model: modelId,
|
||||
} = extractModelAndProvider(model);
|
||||
const key = provider || "other";
|
||||
if (!object[key]) {
|
||||
object[key] = { separator, models: [] };
|
||||
}
|
||||
object[key].models.push(modelId);
|
||||
});
|
||||
return object;
|
||||
};
|
||||
14
frontend/src/utils/verified-models.ts
Normal file
14
frontend/src/utils/verified-models.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
// Here are the list of verified models and providers that we know work well with OpenHands.
|
||||
export const VERIFIED_PROVIDERS = ["openai", "azure", "anthropic"];
|
||||
export const VERIFIED_MODELS = ["gpt-4o", "claude-3-5-sonnet-20240620-v1:0"];
|
||||
|
||||
// LiteLLM does not return OpenAI models with the provider, so we list them here to set them ourselves for consistency
|
||||
// (e.g., they return `gpt-4o` instead of `openai/gpt-4o`)
|
||||
export const VERIFIED_OPENAI_MODELS = [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
];
|
||||
Reference in New Issue
Block a user