add deepseek-r1 ollama

This commit is contained in:
vincent
2025-01-27 16:36:13 +08:00
parent 6ceaf8de6b
commit 664dce757e
8 changed files with 94 additions and 30 deletions

View File

@@ -11,6 +11,8 @@ AZURE_OPENAI_API_KEY=
DEEPSEEK_ENDPOINT=https://api.deepseek.com
DEEPSEEK_API_KEY=
OLLAMA_ENDPOINT=http://localhost:11434
# Set to false to disable anonymized telemetry
ANONYMIZED_TELEMETRY=true

View File

@@ -98,7 +98,7 @@ class CustomAgent(Agent):
register_done_callback=register_done_callback,
tool_calling_method=tool_calling_method
)
if self.model_name == "deepseek-reasoner":
if self.model_name in ["deepseek-reasoner"] or self.model_name.startswith("deepseek-r1"):
# deepseek-reasoner does not support function calling
self.use_deepseek_r1 = True
# deepseek-reasoner only support 64000 context
@@ -191,6 +191,7 @@ class CustomAgent(Agent):
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
logger.debug(ai_message.content)
raise ValueError(f'Could not parse response.')
else:
ai_message = self.llm.invoke(input_messages)
@@ -201,6 +202,7 @@ class CustomAgent(Agent):
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
logger.debug(ai_message.content)
raise ValueError(f'Could not parse response.')
# cut the number of actions to max_actions_per_step
@@ -229,6 +231,9 @@ class CustomAgent(Agent):
self.update_step_info(model_output, step_info)
logger.info(f"🧠 All Memory: \n{step_info.memory}")
self._save_conversation(input_messages, model_output)
# should we remove last state message? at least, deepseek-reasoner cannot remove
if self.model_name != "deepseek-reasoner":
self.message_manager._remove_last_state_message()
except Exception as e:
# model call failed, remove last state message from history
self.message_manager._remove_last_state_message()
@@ -253,7 +258,7 @@ class CustomAgent(Agent):
self.consecutive_failures = 0
except Exception as e:
result = self._handle_step_error(e)
result = await self._handle_step_error(e)
self._last_result = result
finally:

View File

@@ -26,12 +26,7 @@ class CustomSystemPrompt(SystemPrompt):
"summary": "Please generate a brief natural language description for the operation in next actions based on your Thought."
},
"action": [
{
"action_name": {
// action-specific parameters
}
},
// ... more actions in sequence
* actions in sequences, please refer to **Common action sequences**. Each output action MUST be formated as: \{action_name\: action_params\}*
]
}
@@ -44,7 +39,6 @@ class CustomSystemPrompt(SystemPrompt):
{"click_element": {"index": 3}}
]
- Navigation and extraction: [
{"open_new_tab": {}},
{"go_to_url": {"url": "https://example.com"}},
{"extract_page_content": {}}
]
@@ -127,7 +121,7 @@ class CustomSystemPrompt(SystemPrompt):
AGENT_PROMPT = f"""You are a precise browser automation agent that interacts with websites through structured commands. Your role is to:
1. Analyze the provided webpage elements and structure
2. Plan a sequence of actions to accomplish the given task
3. Respond with valid JSON containing your action sequence and state assessment
3. Your final result MUST be a valid JSON as the **RESPONSE FORMAT** described, containing your action sequence and state assessment, No need extra content to expalin.
Current date and time: {time_str}
@@ -200,15 +194,16 @@ class CustomAgentMessagePrompt(AgentMessagePrompt):
"""
if self.result:
for i, result in enumerate(self.result):
if result.include_in_memory:
if result.extracted_content:
state_description += f"\nResult of action {i + 1}/{len(self.result)}: {result.extracted_content}"
state_description += f"\nResult of previous action {i + 1}/{len(self.result)}: {result.extracted_content}"
if result.error:
# only use last 300 characters of error
error = result.error[-self.max_error_length:]
state_description += (
f"\nError of action {i + 1}/{len(self.result)}: ...{error}"
f"\nError of previous action {i + 1}/{len(self.result)}: ...{error}"
)
if self.state.screenshot:

View File

@@ -25,6 +25,7 @@ from langchain_core.outputs import (
LLMResult,
RunInfo,
)
from langchain_ollama import ChatOllama
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
@@ -98,4 +99,38 @@ class DeepSeekR1ChatOpenAI(ChatOpenAI):
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
return AIMessage(content=content, reasoning_content=reasoning_content)
class DeepSeekR1ChatOllama(ChatOllama):
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AIMessage:
org_ai_message = await super().ainvoke(input=input)
org_content = org_ai_message.content
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
content = org_content.split("</think>")[1]
if "**JSON Response:**" in content:
content = content.split("**JSON Response:**")[-1]
return AIMessage(content=content, reasoning_content=reasoning_content)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AIMessage:
org_ai_message = super().invoke(input=input)
org_content = org_ai_message.content
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
content = org_content.split("</think>")[1]
if "**JSON Response:**" in content:
content = content.split("**JSON Response:**")[-1]
return AIMessage(content=content, reasoning_content=reasoning_content)

View File

@@ -10,7 +10,7 @@ from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
import gradio as gr
from .llm import DeepSeekR1ChatOpenAI
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
def get_llm_model(provider: str, **kwargs):
"""
@@ -89,12 +89,25 @@ def get_llm_model(provider: str, **kwargs):
google_api_key=api_key,
)
elif provider == "ollama":
return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=kwargs.get("base_url", "http://localhost:11434"),
)
if not kwargs.get("base_url", ""):
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
else:
base_url = kwargs.get("base_url")
if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"):
return DeepSeekR1ChatOllama(
model=kwargs.get("model_name", "deepseek-r1:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=kwargs.get("base_url", base_url),
)
else:
return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=kwargs.get("base_url", base_url),
)
elif provider == "azure_openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
@@ -120,7 +133,7 @@ model_names = {
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
"ollama": ["qwen2.5:7b", "llama2:7b"],
"ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
}

View File

@@ -257,22 +257,26 @@ async def test_browser_use_custom_v2():
# temperature=0.8
# )
llm = utils.get_llm_model(
provider="deepseek",
model_name="deepseek-chat",
temperature=0.8
)
# llm = utils.get_llm_model(
# provider="deepseek",
# model_name="deepseek-chat",
# temperature=0.8
# )
# llm = utils.get_llm_model(
# provider="ollama", model_name="qwen2.5:7b", temperature=0.5
# )
# llm = utils.get_llm_model(
# provider="ollama", model_name="deepseek-r1:14b", temperature=0.5
# )
controller = CustomController()
use_own_browser = False
disable_security = True
use_vision = False # Set to False when using DeepSeek
max_actions_per_step = 1
max_actions_per_step = 10
playwright = None
browser = None
browser_context = None
@@ -303,7 +307,7 @@ async def test_browser_use_custom_v2():
)
)
agent = CustomAgent(
task="give me stock price of Nvidia and tesla",
task="go to google.com and type 'Nvidia' click search and give me the first url",
add_infos="", # some hints for llm to complete the task
llm=llm,
browser=browser,

View File

@@ -142,12 +142,21 @@ def test_ollama_model():
llm = ChatOllama(model="qwen2.5:7b")
ai_msg = llm.invoke("Sing a ballad of LangChain.")
print(ai_msg.content)
def test_deepseek_r1_ollama_model():
from src.utils.llm import DeepSeekR1ChatOllama
llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
ai_msg = llm.invoke("how many r in strawberry?")
print(ai_msg.content)
pdb.set_trace()
if __name__ == '__main__':
# test_openai_model()
# test_gemini_model()
# test_azure_openai_model()
test_deepseek_model()
# test_deepseek_model()
# test_ollama_model()
# test_deepseek_r1_model()
test_deepseek_r1_model()
# test_deepseek_r1_ollama_model()

View File

@@ -658,7 +658,8 @@ def create_ui(config, theme_name="Ocean"):
interactive=True,
allow_custom_value=True, # Allow users to input custom model names
choices=["auto", "json_schema", "function_calling"],
info="Tool Calls Funtion Name"
info="Tool Calls Funtion Name",
visible=False
)
with gr.TabItem("🔧 LLM Configuration", id=2):