From f24668c35c4c9d4ad55ee8fce933da47ab6ab1eb Mon Sep 17 00:00:00 2001 From: hlo-world Date: Sat, 15 Feb 2025 21:16:15 -0500 Subject: [PATCH] feat: add num_ctx slider when provider is ollama and add predefined model names for ollama --- src/utils/default_config_settings.py | 31 +++++++++++++----------- src/utils/utils.py | 2 +- webui.py | 36 ++++++++++++++++++++++++---- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/src/utils/default_config_settings.py b/src/utils/default_config_settings.py index 92515e5..e6fa88f 100644 --- a/src/utils/default_config_settings.py +++ b/src/utils/default_config_settings.py @@ -14,6 +14,7 @@ def default_config(): "tool_calling_method": "auto", "llm_provider": "openai", "llm_model_name": "gpt-4o", + "llm_num_ctx": 32000, "llm_temperature": 1.0, "llm_base_url": "", "llm_api_key": "", @@ -59,20 +60,21 @@ def save_current_config(*args): "tool_calling_method": args[4], "llm_provider": args[5], "llm_model_name": args[6], - "llm_temperature": args[7], - "llm_base_url": args[8], - "llm_api_key": args[9], - "use_own_browser": args[10], - "keep_browser_open": args[11], - "headless": args[12], - "disable_security": args[13], - "enable_recording": args[14], - "window_w": args[15], - "window_h": args[16], - "save_recording_path": args[17], - "save_trace_path": args[18], - "save_agent_history_path": args[19], - "task": args[20], + "llm_num_ctx": args[7], + "llm_temperature": args[8], + "llm_base_url": args[9], + "llm_api_key": args[10], + "use_own_browser": args[11], + "keep_browser_open": args[12], + "headless": args[13], + "disable_security": args[14], + "enable_recording": args[15], + "window_w": args[16], + "window_h": args[17], + "save_recording_path": args[18], + "save_trace_path": args[19], + "save_agent_history_path": args[20], + "task": args[21], } return save_config_to_file(current_config) @@ -89,6 +91,7 @@ def update_ui_from_config(config_file): gr.update(value=loaded_config.get("tool_calling_method", True)), gr.update(value=loaded_config.get("llm_provider", "openai")), gr.update(value=loaded_config.get("llm_model_name", "gpt-4o")), + gr.update(value=loaded_config.get("llm_num_ctx", 32000)), gr.update(value=loaded_config.get("llm_temperature", 1.0)), gr.update(value=loaded_config.get("llm_base_url", "")), gr.update(value=loaded_config.get("llm_api_key", "")), diff --git a/src/utils/utils.py b/src/utils/utils.py index 223d028..223dba5 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -167,7 +167,7 @@ model_names = { "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"], "deepseek": ["deepseek-chat", "deepseek-reasoner"], "google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"], - "ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"], + "ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"], "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"], "mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"], "alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"] diff --git a/webui.py b/webui.py index 8e9d6b2..cde23d1 100644 --- a/webui.py +++ b/webui.py @@ -99,6 +99,7 @@ async def run_browser_agent( agent_type, llm_provider, llm_model_name, + llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, @@ -143,6 +144,7 @@ async def run_browser_agent( llm = utils.get_llm_model( provider=llm_provider, model_name=llm_model_name, + num_ctx=llm_num_ctx, temperature=llm_temperature, base_url=llm_base_url, api_key=llm_api_key, @@ -431,6 +433,7 @@ async def run_with_stream( agent_type, llm_provider, llm_model_name, + llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, @@ -459,6 +462,7 @@ async def run_with_stream( agent_type=agent_type, llm_provider=llm_provider, llm_model_name=llm_model_name, + llm_num_ctx=llm_num_ctx, llm_temperature=llm_temperature, llm_base_url=llm_base_url, llm_api_key=llm_api_key, @@ -491,6 +495,7 @@ async def run_with_stream( agent_type=agent_type, llm_provider=llm_provider, llm_model_name=llm_model_name, + llm_num_ctx=llm_num_ctx, llm_temperature=llm_temperature, llm_base_url=llm_base_url, llm_api_key=llm_api_key, @@ -623,7 +628,7 @@ async def close_global_browser(): await _global_browser.close() _global_browser = None -async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless): +async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless): from src.utils.deep_research import deep_research global _global_agent_state @@ -633,6 +638,7 @@ async def run_deep_search(research_task, max_search_iteration_input, max_query_p llm = utils.get_llm_model( provider=llm_provider, model_name=llm_model_name, + num_ctx=llm_num_ctx, temperature=llm_temperature, base_url=llm_base_url, api_key=llm_api_key, @@ -736,6 +742,15 @@ def create_ui(config, theme_name="Ocean"): allow_custom_value=True, # Allow users to input custom model names info="Select a model from the dropdown or type a custom model name" ) + llm_num_ctx = gr.Slider( + minimum=2**8, + maximum=2**16, + value=config['llm_num_ctx'], + step=1, + label="Max Context Length", + info="Controls max context length model needs to handle (less = faster)", + visible=config['llm_provider'] == "ollama" + ) llm_temperature = gr.Slider( minimum=0.0, maximum=2.0, @@ -757,6 +772,17 @@ def create_ui(config, theme_name="Ocean"): info="Your API key (leave blank to use .env)" ) + # Change event to update context length slider + def update_llm_num_ctx_visibility(llm_provider): + return gr.update(visible=llm_provider == "ollama") + + # Bind the change event of llm_provider to update the visibility of context length slider + llm_provider.change( + fn=update_llm_num_ctx_visibility, + inputs=llm_provider, + outputs=llm_num_ctx + ) + with gr.TabItem("🌐 Browser Settings", id=3): with gr.Group(): with gr.Row(): @@ -899,7 +925,7 @@ def create_ui(config, theme_name="Ocean"): run_button.click( fn=run_with_stream, inputs=[ - agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, + agent_type, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_agent_history_path, save_trace_path, # Include the new path enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method @@ -921,7 +947,7 @@ def create_ui(config, theme_name="Ocean"): # Run Deep Research research_button.click( fn=run_deep_search, - inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless], + inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless], outputs=[markdown_output_display, markdown_download, stop_research_button, research_button] ) # Bind the stop button click event after errors_output is defined @@ -987,7 +1013,7 @@ def create_ui(config, theme_name="Ocean"): inputs=[config_file_input], outputs=[ agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method, - llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, + llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_own_browser, keep_browser_open, headless, disable_security, enable_recording, window_w, window_h, save_recording_path, save_trace_path, save_agent_history_path, task, config_status @@ -998,7 +1024,7 @@ def create_ui(config, theme_name="Ocean"): fn=save_current_config, inputs=[ agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method, - llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, + llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_own_browser, keep_browser_open, headless, disable_security, enable_recording, window_w, window_h, save_recording_path, save_trace_path, save_agent_history_path, task,