Merge pull request #302 from hlo-world/num_ctx-for-ollama
feat: add num_ctx slider when provider is ollama and add predefined model names for ollama
This commit is contained in:
@@ -14,6 +14,7 @@ def default_config():
|
|||||||
"tool_calling_method": "auto",
|
"tool_calling_method": "auto",
|
||||||
"llm_provider": "openai",
|
"llm_provider": "openai",
|
||||||
"llm_model_name": "gpt-4o",
|
"llm_model_name": "gpt-4o",
|
||||||
|
"llm_num_ctx": 32000,
|
||||||
"llm_temperature": 1.0,
|
"llm_temperature": 1.0,
|
||||||
"llm_base_url": "",
|
"llm_base_url": "",
|
||||||
"llm_api_key": "",
|
"llm_api_key": "",
|
||||||
@@ -59,20 +60,21 @@ def save_current_config(*args):
|
|||||||
"tool_calling_method": args[4],
|
"tool_calling_method": args[4],
|
||||||
"llm_provider": args[5],
|
"llm_provider": args[5],
|
||||||
"llm_model_name": args[6],
|
"llm_model_name": args[6],
|
||||||
"llm_temperature": args[7],
|
"llm_num_ctx": args[7],
|
||||||
"llm_base_url": args[8],
|
"llm_temperature": args[8],
|
||||||
"llm_api_key": args[9],
|
"llm_base_url": args[9],
|
||||||
"use_own_browser": args[10],
|
"llm_api_key": args[10],
|
||||||
"keep_browser_open": args[11],
|
"use_own_browser": args[11],
|
||||||
"headless": args[12],
|
"keep_browser_open": args[12],
|
||||||
"disable_security": args[13],
|
"headless": args[13],
|
||||||
"enable_recording": args[14],
|
"disable_security": args[14],
|
||||||
"window_w": args[15],
|
"enable_recording": args[15],
|
||||||
"window_h": args[16],
|
"window_w": args[16],
|
||||||
"save_recording_path": args[17],
|
"window_h": args[17],
|
||||||
"save_trace_path": args[18],
|
"save_recording_path": args[18],
|
||||||
"save_agent_history_path": args[19],
|
"save_trace_path": args[19],
|
||||||
"task": args[20],
|
"save_agent_history_path": args[20],
|
||||||
|
"task": args[21],
|
||||||
}
|
}
|
||||||
return save_config_to_file(current_config)
|
return save_config_to_file(current_config)
|
||||||
|
|
||||||
@@ -89,6 +91,7 @@ def update_ui_from_config(config_file):
|
|||||||
gr.update(value=loaded_config.get("tool_calling_method", True)),
|
gr.update(value=loaded_config.get("tool_calling_method", True)),
|
||||||
gr.update(value=loaded_config.get("llm_provider", "openai")),
|
gr.update(value=loaded_config.get("llm_provider", "openai")),
|
||||||
gr.update(value=loaded_config.get("llm_model_name", "gpt-4o")),
|
gr.update(value=loaded_config.get("llm_model_name", "gpt-4o")),
|
||||||
|
gr.update(value=loaded_config.get("llm_num_ctx", 32000)),
|
||||||
gr.update(value=loaded_config.get("llm_temperature", 1.0)),
|
gr.update(value=loaded_config.get("llm_temperature", 1.0)),
|
||||||
gr.update(value=loaded_config.get("llm_base_url", "")),
|
gr.update(value=loaded_config.get("llm_base_url", "")),
|
||||||
gr.update(value=loaded_config.get("llm_api_key", "")),
|
gr.update(value=loaded_config.get("llm_api_key", "")),
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ model_names = {
|
|||||||
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
|
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
|
||||||
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
|
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
|
||||||
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
|
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
|
||||||
"ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
|
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
|
||||||
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
|
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
|
||||||
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
|
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
|
||||||
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
|
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
|
||||||
|
|||||||
36
webui.py
36
webui.py
@@ -100,6 +100,7 @@ async def run_browser_agent(
|
|||||||
agent_type,
|
agent_type,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
llm_model_name,
|
llm_model_name,
|
||||||
|
llm_num_ctx,
|
||||||
llm_temperature,
|
llm_temperature,
|
||||||
llm_base_url,
|
llm_base_url,
|
||||||
llm_api_key,
|
llm_api_key,
|
||||||
@@ -144,6 +145,7 @@ async def run_browser_agent(
|
|||||||
llm = utils.get_llm_model(
|
llm = utils.get_llm_model(
|
||||||
provider=llm_provider,
|
provider=llm_provider,
|
||||||
model_name=llm_model_name,
|
model_name=llm_model_name,
|
||||||
|
num_ctx=llm_num_ctx,
|
||||||
temperature=llm_temperature,
|
temperature=llm_temperature,
|
||||||
base_url=llm_base_url,
|
base_url=llm_base_url,
|
||||||
api_key=llm_api_key,
|
api_key=llm_api_key,
|
||||||
@@ -435,6 +437,7 @@ async def run_with_stream(
|
|||||||
agent_type,
|
agent_type,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
llm_model_name,
|
llm_model_name,
|
||||||
|
llm_num_ctx,
|
||||||
llm_temperature,
|
llm_temperature,
|
||||||
llm_base_url,
|
llm_base_url,
|
||||||
llm_api_key,
|
llm_api_key,
|
||||||
@@ -463,6 +466,7 @@ async def run_with_stream(
|
|||||||
agent_type=agent_type,
|
agent_type=agent_type,
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
llm_model_name=llm_model_name,
|
llm_model_name=llm_model_name,
|
||||||
|
llm_num_ctx=llm_num_ctx,
|
||||||
llm_temperature=llm_temperature,
|
llm_temperature=llm_temperature,
|
||||||
llm_base_url=llm_base_url,
|
llm_base_url=llm_base_url,
|
||||||
llm_api_key=llm_api_key,
|
llm_api_key=llm_api_key,
|
||||||
@@ -495,6 +499,7 @@ async def run_with_stream(
|
|||||||
agent_type=agent_type,
|
agent_type=agent_type,
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
llm_model_name=llm_model_name,
|
llm_model_name=llm_model_name,
|
||||||
|
llm_num_ctx=llm_num_ctx,
|
||||||
llm_temperature=llm_temperature,
|
llm_temperature=llm_temperature,
|
||||||
llm_base_url=llm_base_url,
|
llm_base_url=llm_base_url,
|
||||||
llm_api_key=llm_api_key,
|
llm_api_key=llm_api_key,
|
||||||
@@ -627,7 +632,7 @@ async def close_global_browser():
|
|||||||
await _global_browser.close()
|
await _global_browser.close()
|
||||||
_global_browser = None
|
_global_browser = None
|
||||||
|
|
||||||
async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless):
|
async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless):
|
||||||
from src.utils.deep_research import deep_research
|
from src.utils.deep_research import deep_research
|
||||||
global _global_agent_state
|
global _global_agent_state
|
||||||
|
|
||||||
@@ -637,6 +642,7 @@ async def run_deep_search(research_task, max_search_iteration_input, max_query_p
|
|||||||
llm = utils.get_llm_model(
|
llm = utils.get_llm_model(
|
||||||
provider=llm_provider,
|
provider=llm_provider,
|
||||||
model_name=llm_model_name,
|
model_name=llm_model_name,
|
||||||
|
num_ctx=llm_num_ctx,
|
||||||
temperature=llm_temperature,
|
temperature=llm_temperature,
|
||||||
base_url=llm_base_url,
|
base_url=llm_base_url,
|
||||||
api_key=llm_api_key,
|
api_key=llm_api_key,
|
||||||
@@ -740,6 +746,15 @@ def create_ui(config, theme_name="Ocean"):
|
|||||||
allow_custom_value=True, # Allow users to input custom model names
|
allow_custom_value=True, # Allow users to input custom model names
|
||||||
info="Select a model from the dropdown or type a custom model name"
|
info="Select a model from the dropdown or type a custom model name"
|
||||||
)
|
)
|
||||||
|
llm_num_ctx = gr.Slider(
|
||||||
|
minimum=2**8,
|
||||||
|
maximum=2**16,
|
||||||
|
value=config['llm_num_ctx'],
|
||||||
|
step=1,
|
||||||
|
label="Max Context Length",
|
||||||
|
info="Controls max context length model needs to handle (less = faster)",
|
||||||
|
visible=config['llm_provider'] == "ollama"
|
||||||
|
)
|
||||||
llm_temperature = gr.Slider(
|
llm_temperature = gr.Slider(
|
||||||
minimum=0.0,
|
minimum=0.0,
|
||||||
maximum=2.0,
|
maximum=2.0,
|
||||||
@@ -761,6 +776,17 @@ def create_ui(config, theme_name="Ocean"):
|
|||||||
info="Your API key (leave blank to use .env)"
|
info="Your API key (leave blank to use .env)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Change event to update context length slider
|
||||||
|
def update_llm_num_ctx_visibility(llm_provider):
|
||||||
|
return gr.update(visible=llm_provider == "ollama")
|
||||||
|
|
||||||
|
# Bind the change event of llm_provider to update the visibility of context length slider
|
||||||
|
llm_provider.change(
|
||||||
|
fn=update_llm_num_ctx_visibility,
|
||||||
|
inputs=llm_provider,
|
||||||
|
outputs=llm_num_ctx
|
||||||
|
)
|
||||||
|
|
||||||
with gr.TabItem("🌐 Browser Settings", id=3):
|
with gr.TabItem("🌐 Browser Settings", id=3):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -903,7 +929,7 @@ def create_ui(config, theme_name="Ocean"):
|
|||||||
run_button.click(
|
run_button.click(
|
||||||
fn=run_with_stream,
|
fn=run_with_stream,
|
||||||
inputs=[
|
inputs=[
|
||||||
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
agent_type, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key,
|
||||||
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h,
|
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h,
|
||||||
save_recording_path, save_agent_history_path, save_trace_path, # Include the new path
|
save_recording_path, save_agent_history_path, save_trace_path, # Include the new path
|
||||||
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method
|
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method
|
||||||
@@ -925,7 +951,7 @@ def create_ui(config, theme_name="Ocean"):
|
|||||||
# Run Deep Research
|
# Run Deep Research
|
||||||
research_button.click(
|
research_button.click(
|
||||||
fn=run_deep_search,
|
fn=run_deep_search,
|
||||||
inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless],
|
inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless],
|
||||||
outputs=[markdown_output_display, markdown_download, stop_research_button, research_button]
|
outputs=[markdown_output_display, markdown_download, stop_research_button, research_button]
|
||||||
)
|
)
|
||||||
# Bind the stop button click event after errors_output is defined
|
# Bind the stop button click event after errors_output is defined
|
||||||
@@ -991,7 +1017,7 @@ def create_ui(config, theme_name="Ocean"):
|
|||||||
inputs=[config_file_input],
|
inputs=[config_file_input],
|
||||||
outputs=[
|
outputs=[
|
||||||
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
|
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
|
||||||
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key,
|
||||||
use_own_browser, keep_browser_open, headless, disable_security, enable_recording,
|
use_own_browser, keep_browser_open, headless, disable_security, enable_recording,
|
||||||
window_w, window_h, save_recording_path, save_trace_path, save_agent_history_path,
|
window_w, window_h, save_recording_path, save_trace_path, save_agent_history_path,
|
||||||
task, config_status
|
task, config_status
|
||||||
@@ -1002,7 +1028,7 @@ def create_ui(config, theme_name="Ocean"):
|
|||||||
fn=save_current_config,
|
fn=save_current_config,
|
||||||
inputs=[
|
inputs=[
|
||||||
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
|
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
|
||||||
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key,
|
||||||
use_own_browser, keep_browser_open, headless, disable_security,
|
use_own_browser, keep_browser_open, headless, disable_security,
|
||||||
enable_recording, window_w, window_h, save_recording_path, save_trace_path,
|
enable_recording, window_w, window_h, save_recording_path, save_trace_path,
|
||||||
save_agent_history_path, task,
|
save_agent_history_path, task,
|
||||||
|
|||||||
Reference in New Issue
Block a user