add deepseek-r1 ollama
This commit is contained in:
@@ -11,6 +11,8 @@ AZURE_OPENAI_API_KEY=
|
||||
DEEPSEEK_ENDPOINT=https://api.deepseek.com
|
||||
DEEPSEEK_API_KEY=
|
||||
|
||||
OLLAMA_ENDPOINT=http://localhost:11434
|
||||
|
||||
# Set to false to disable anonymized telemetry
|
||||
ANONYMIZED_TELEMETRY=true
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class CustomAgent(Agent):
|
||||
register_done_callback=register_done_callback,
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
if self.model_name == "deepseek-reasoner":
|
||||
if self.model_name in ["deepseek-reasoner"] or self.model_name.startswith("deepseek-r1"):
|
||||
# deepseek-reasoner does not support function calling
|
||||
self.use_deepseek_r1 = True
|
||||
# deepseek-reasoner only support 64000 context
|
||||
@@ -191,6 +191,7 @@ class CustomAgent(Agent):
|
||||
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
logger.debug(ai_message.content)
|
||||
raise ValueError(f'Could not parse response.')
|
||||
else:
|
||||
ai_message = self.llm.invoke(input_messages)
|
||||
@@ -201,6 +202,7 @@ class CustomAgent(Agent):
|
||||
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
logger.debug(ai_message.content)
|
||||
raise ValueError(f'Could not parse response.')
|
||||
|
||||
# cut the number of actions to max_actions_per_step
|
||||
@@ -229,6 +231,9 @@ class CustomAgent(Agent):
|
||||
self.update_step_info(model_output, step_info)
|
||||
logger.info(f"🧠 All Memory: \n{step_info.memory}")
|
||||
self._save_conversation(input_messages, model_output)
|
||||
# should we remove last state message? at least, deepseek-reasoner cannot remove
|
||||
if self.model_name != "deepseek-reasoner":
|
||||
self.message_manager._remove_last_state_message()
|
||||
except Exception as e:
|
||||
# model call failed, remove last state message from history
|
||||
self.message_manager._remove_last_state_message()
|
||||
@@ -253,7 +258,7 @@ class CustomAgent(Agent):
|
||||
self.consecutive_failures = 0
|
||||
|
||||
except Exception as e:
|
||||
result = self._handle_step_error(e)
|
||||
result = await self._handle_step_error(e)
|
||||
self._last_result = result
|
||||
|
||||
finally:
|
||||
|
||||
@@ -26,12 +26,7 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
"summary": "Please generate a brief natural language description for the operation in next actions based on your Thought."
|
||||
},
|
||||
"action": [
|
||||
{
|
||||
"action_name": {
|
||||
// action-specific parameters
|
||||
}
|
||||
},
|
||||
// ... more actions in sequence
|
||||
* actions in sequences, please refer to **Common action sequences**. Each output action MUST be formated as: \{action_name\: action_params\}*
|
||||
]
|
||||
}
|
||||
|
||||
@@ -44,7 +39,6 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
{"click_element": {"index": 3}}
|
||||
]
|
||||
- Navigation and extraction: [
|
||||
{"open_new_tab": {}},
|
||||
{"go_to_url": {"url": "https://example.com"}},
|
||||
{"extract_page_content": {}}
|
||||
]
|
||||
@@ -127,7 +121,7 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
AGENT_PROMPT = f"""You are a precise browser automation agent that interacts with websites through structured commands. Your role is to:
|
||||
1. Analyze the provided webpage elements and structure
|
||||
2. Plan a sequence of actions to accomplish the given task
|
||||
3. Respond with valid JSON containing your action sequence and state assessment
|
||||
3. Your final result MUST be a valid JSON as the **RESPONSE FORMAT** described, containing your action sequence and state assessment, No need extra content to expalin.
|
||||
|
||||
Current date and time: {time_str}
|
||||
|
||||
@@ -200,15 +194,16 @@ class CustomAgentMessagePrompt(AgentMessagePrompt):
|
||||
"""
|
||||
|
||||
if self.result:
|
||||
|
||||
for i, result in enumerate(self.result):
|
||||
if result.include_in_memory:
|
||||
if result.extracted_content:
|
||||
state_description += f"\nResult of action {i + 1}/{len(self.result)}: {result.extracted_content}"
|
||||
state_description += f"\nResult of previous action {i + 1}/{len(self.result)}: {result.extracted_content}"
|
||||
if result.error:
|
||||
# only use last 300 characters of error
|
||||
error = result.error[-self.max_error_length:]
|
||||
state_description += (
|
||||
f"\nError of action {i + 1}/{len(self.result)}: ...{error}"
|
||||
f"\nError of previous action {i + 1}/{len(self.result)}: ...{error}"
|
||||
)
|
||||
|
||||
if self.state.screenshot:
|
||||
|
||||
@@ -25,6 +25,7 @@ from langchain_core.outputs import (
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -98,4 +99,38 @@ class DeepSeekR1ChatOpenAI(ChatOpenAI):
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
|
||||
class DeepSeekR1ChatOllama(ChatOllama):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
org_ai_message = await super().ainvoke(input=input)
|
||||
org_content = org_ai_message.content
|
||||
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
||||
content = org_content.split("</think>")[1]
|
||||
if "**JSON Response:**" in content:
|
||||
content = content.split("**JSON Response:**")[-1]
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
org_ai_message = super().invoke(input=input)
|
||||
org_content = org_ai_message.content
|
||||
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
||||
content = org_content.split("</think>")[1]
|
||||
if "**JSON Response:**" in content:
|
||||
content = content.split("**JSON Response:**")[-1]
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
@@ -10,7 +10,7 @@ from langchain_ollama import ChatOllama
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
import gradio as gr
|
||||
|
||||
from .llm import DeepSeekR1ChatOpenAI
|
||||
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
|
||||
|
||||
def get_llm_model(provider: str, **kwargs):
|
||||
"""
|
||||
@@ -89,12 +89,25 @@ def get_llm_model(provider: str, **kwargs):
|
||||
google_api_key=api_key,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
model=kwargs.get("model_name", "qwen2.5:7b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=kwargs.get("num_ctx", 32000),
|
||||
base_url=kwargs.get("base_url", "http://localhost:11434"),
|
||||
)
|
||||
if not kwargs.get("base_url", ""):
|
||||
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
|
||||
else:
|
||||
base_url = kwargs.get("base_url")
|
||||
|
||||
if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"):
|
||||
return DeepSeekR1ChatOllama(
|
||||
model=kwargs.get("model_name", "deepseek-r1:7b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=kwargs.get("num_ctx", 32000),
|
||||
base_url=kwargs.get("base_url", base_url),
|
||||
)
|
||||
else:
|
||||
return ChatOllama(
|
||||
model=kwargs.get("model_name", "qwen2.5:7b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=kwargs.get("num_ctx", 32000),
|
||||
base_url=kwargs.get("base_url", base_url),
|
||||
)
|
||||
elif provider == "azure_openai":
|
||||
if not kwargs.get("base_url", ""):
|
||||
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
|
||||
@@ -120,7 +133,7 @@ model_names = {
|
||||
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
|
||||
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
|
||||
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
|
||||
"ollama": ["qwen2.5:7b", "llama2:7b"],
|
||||
"ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
|
||||
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
|
||||
}
|
||||
|
||||
|
||||
@@ -257,22 +257,26 @@ async def test_browser_use_custom_v2():
|
||||
# temperature=0.8
|
||||
# )
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="deepseek",
|
||||
model_name="deepseek-chat",
|
||||
temperature=0.8
|
||||
)
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="deepseek",
|
||||
# model_name="deepseek-chat",
|
||||
# temperature=0.8
|
||||
# )
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="ollama", model_name="qwen2.5:7b", temperature=0.5
|
||||
# )
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="ollama", model_name="deepseek-r1:14b", temperature=0.5
|
||||
# )
|
||||
|
||||
controller = CustomController()
|
||||
use_own_browser = False
|
||||
disable_security = True
|
||||
use_vision = False # Set to False when using DeepSeek
|
||||
|
||||
max_actions_per_step = 1
|
||||
max_actions_per_step = 10
|
||||
playwright = None
|
||||
browser = None
|
||||
browser_context = None
|
||||
@@ -303,7 +307,7 @@ async def test_browser_use_custom_v2():
|
||||
)
|
||||
)
|
||||
agent = CustomAgent(
|
||||
task="give me stock price of Nvidia and tesla",
|
||||
task="go to google.com and type 'Nvidia' click search and give me the first url",
|
||||
add_infos="", # some hints for llm to complete the task
|
||||
llm=llm,
|
||||
browser=browser,
|
||||
|
||||
@@ -142,12 +142,21 @@ def test_ollama_model():
|
||||
llm = ChatOllama(model="qwen2.5:7b")
|
||||
ai_msg = llm.invoke("Sing a ballad of LangChain.")
|
||||
print(ai_msg.content)
|
||||
|
||||
def test_deepseek_r1_ollama_model():
|
||||
from src.utils.llm import DeepSeekR1ChatOllama
|
||||
|
||||
llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
|
||||
ai_msg = llm.invoke("how many r in strawberry?")
|
||||
print(ai_msg.content)
|
||||
pdb.set_trace()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_openai_model()
|
||||
# test_gemini_model()
|
||||
# test_azure_openai_model()
|
||||
test_deepseek_model()
|
||||
# test_deepseek_model()
|
||||
# test_ollama_model()
|
||||
# test_deepseek_r1_model()
|
||||
test_deepseek_r1_model()
|
||||
# test_deepseek_r1_ollama_model()
|
||||
3
webui.py
3
webui.py
@@ -658,7 +658,8 @@ def create_ui(config, theme_name="Ocean"):
|
||||
interactive=True,
|
||||
allow_custom_value=True, # Allow users to input custom model names
|
||||
choices=["auto", "json_schema", "function_calling"],
|
||||
info="Tool Calls Funtion Name"
|
||||
info="Tool Calls Funtion Name",
|
||||
visible=False
|
||||
)
|
||||
|
||||
with gr.TabItem("🔧 LLM Configuration", id=2):
|
||||
|
||||
Reference in New Issue
Block a user