Merge pull request #151 from vvincent1234/feat/deepseek-r1
Feat/deepseek r1
This commit is contained in:
@@ -95,6 +95,12 @@ class CustomAgent(Agent):
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
if self.llm.model_name in ["deepseek-reasoner"]:
|
||||
self.use_function_calling = False
|
||||
# TODO: deepseek-reasoner only support 64000 context
|
||||
self.max_input_tokens = 64000
|
||||
else:
|
||||
self.use_function_calling = True
|
||||
self.add_infos = add_infos
|
||||
self.agent_state = agent_state
|
||||
self.message_manager = CustomMassageManager(
|
||||
@@ -107,6 +113,7 @@ class CustomAgent(Agent):
|
||||
max_error_length=self.max_error_length,
|
||||
max_actions_per_step=self.max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
use_function_calling=self.use_function_calling
|
||||
)
|
||||
|
||||
def _setup_action_models(self) -> None:
|
||||
@@ -127,7 +134,8 @@ class CustomAgent(Agent):
|
||||
|
||||
logger.info(f"{emoji} Eval: {response.current_state.prev_action_evaluation}")
|
||||
logger.info(f"🧠 New Memory: {response.current_state.important_contents}")
|
||||
logger.info(f"⏳ Task Progress: {response.current_state.completed_contents}")
|
||||
logger.info(f"⏳ Task Progress: \n{response.current_state.task_progress}")
|
||||
logger.info(f"📋 Future Plans: \n{response.current_state.future_plans}")
|
||||
logger.info(f"🤔 Thought: {response.current_state.thought}")
|
||||
logger.info(f"🎯 Summary: {response.current_state.summary}")
|
||||
for i, action in enumerate(response.action):
|
||||
@@ -153,28 +161,54 @@ class CustomAgent(Agent):
|
||||
):
|
||||
step_info.memory += important_contents + "\n"
|
||||
|
||||
completed_contents = model_output.current_state.completed_contents
|
||||
if completed_contents and "None" not in completed_contents:
|
||||
step_info.task_progress = completed_contents
|
||||
task_progress = model_output.current_state.task_progress
|
||||
if task_progress and "None" not in task_progress:
|
||||
step_info.task_progress = task_progress
|
||||
|
||||
future_plans = model_output.current_state.future_plans
|
||||
if future_plans and "None" not in future_plans:
|
||||
step_info.future_plans = future_plans
|
||||
|
||||
@time_execution_async("--get_next_action")
|
||||
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
|
||||
"""Get next action from LLM based on current state"""
|
||||
try:
|
||||
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
|
||||
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
|
||||
if self.use_function_calling:
|
||||
try:
|
||||
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
|
||||
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
|
||||
|
||||
parsed: AgentOutput = response['parsed']
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
parsed: AgentOutput = response['parsed']
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
except Exception as e:
|
||||
# If something goes wrong, try to invoke the LLM again without structured output,
|
||||
# and Manually parse the response. Temporarily solution for DeepSeek
|
||||
return parsed
|
||||
except Exception as e:
|
||||
# If something goes wrong, try to invoke the LLM again without structured output,
|
||||
# and Manually parse the response. Temporarily solution for DeepSeek
|
||||
ret = self.llm.invoke(input_messages)
|
||||
if isinstance(ret.content, list):
|
||||
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
|
||||
else:
|
||||
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
raise ValueError(f'Could not parse response.')
|
||||
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
else:
|
||||
ret = self.llm.invoke(input_messages)
|
||||
if not self.use_function_calling:
|
||||
self.message_manager._add_message_with_tokens(ret)
|
||||
logger.info(f"🤯 Start Deep Thinking: ")
|
||||
logger.info(ret.reasoning_content)
|
||||
logger.info(f"🤯 End Deep Thinking")
|
||||
if isinstance(ret.content, list):
|
||||
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
|
||||
else:
|
||||
@@ -204,14 +238,22 @@ class CustomAgent(Agent):
|
||||
input_messages = self.message_manager.get_messages()
|
||||
model_output = await self.get_next_action(input_messages)
|
||||
self.update_step_info(model_output, step_info)
|
||||
logger.info(f"🧠 All Memory: {step_info.memory}")
|
||||
logger.info(f"🧠 All Memory: \n{step_info.memory}")
|
||||
self._save_conversation(input_messages, model_output)
|
||||
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
|
||||
self.message_manager.add_model_output(model_output)
|
||||
if self.use_function_calling:
|
||||
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
|
||||
self.message_manager.add_model_output(model_output)
|
||||
|
||||
result: list[ActionResult] = await self.controller.multi_act(
|
||||
model_output.action, self.browser_context
|
||||
)
|
||||
if len(result) != len(model_output.action):
|
||||
for ri in range(len(result), len(model_output.action)):
|
||||
result.append(ActionResult(extracted_content=None,
|
||||
include_in_memory=True,
|
||||
error=f"{model_output.action[ri].model_dump_json(exclude_unset=True)} is Failed to execute. \
|
||||
Something new appeared after action {model_output.action[len(result) - 1].model_dump_json(exclude_unset=True)}",
|
||||
is_done=False))
|
||||
self._last_result = result
|
||||
|
||||
if len(result) > 0 and result[-1].is_done:
|
||||
@@ -369,6 +411,7 @@ class CustomAgent(Agent):
|
||||
max_steps=max_steps,
|
||||
memory="",
|
||||
task_progress="",
|
||||
future_plans=""
|
||||
)
|
||||
|
||||
for step in range(max_steps):
|
||||
|
||||
@@ -39,6 +39,7 @@ class CustomMassageManager(MessageManager):
|
||||
max_error_length: int = 400,
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = False,
|
||||
use_function_calling: bool = True
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
@@ -53,41 +54,52 @@ class CustomMassageManager(MessageManager):
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
|
||||
self.use_function_calling = use_function_calling
|
||||
# Custom: Move Task info to state_message
|
||||
self.history = MessageHistory()
|
||||
self._add_message_with_tokens(self.system_prompt)
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'CustomAgentOutput',
|
||||
'args': {
|
||||
'current_state': {
|
||||
'prev_action_evaluation': 'Unknown - No previous actions to evaluate.',
|
||||
'important_contents': '',
|
||||
'completed_contents': '',
|
||||
'thought': 'Now Google is open. Need to type OpenAI to search.',
|
||||
'summary': 'Type OpenAI to search.',
|
||||
|
||||
if self.use_function_calling:
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'CustomAgentOutput',
|
||||
'args': {
|
||||
'current_state': {
|
||||
'prev_action_evaluation': 'Unknown - No previous actions to evaluate.',
|
||||
'important_contents': '',
|
||||
'completed_contents': '',
|
||||
'thought': 'Now Google is open. Need to type OpenAI to search.',
|
||||
'summary': 'Type OpenAI to search.',
|
||||
},
|
||||
'action': [],
|
||||
},
|
||||
'action': [],
|
||||
},
|
||||
'id': '',
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
if self.tool_call_in_content:
|
||||
# openai throws error if tool_calls are not responded -> move to content
|
||||
example_tool_call = AIMessage(
|
||||
content=f'{tool_calls}',
|
||||
tool_calls=[],
|
||||
)
|
||||
else:
|
||||
example_tool_call = AIMessage(
|
||||
content=f'',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
'id': '',
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
if self.tool_call_in_content:
|
||||
# openai throws error if tool_calls are not responded -> move to content
|
||||
example_tool_call = AIMessage(
|
||||
content=f'{tool_calls}',
|
||||
tool_calls=[],
|
||||
)
|
||||
else:
|
||||
example_tool_call = AIMessage(
|
||||
content=f'',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
self._add_message_with_tokens(example_tool_call)
|
||||
self._add_message_with_tokens(example_tool_call)
|
||||
|
||||
def cut_messages(self):
|
||||
"""Get current message list, potentially trimmed to max tokens"""
|
||||
diff = self.history.total_tokens - self.max_input_tokens
|
||||
i = 1 # start from 1 to keep system message in history
|
||||
while diff > 0 and i < len(self.history.messages):
|
||||
self.history.remove_message(i)
|
||||
diff = self.history.total_tokens - self.max_input_tokens
|
||||
i += 1
|
||||
|
||||
def add_state_message(
|
||||
self,
|
||||
state: BrowserState,
|
||||
@@ -95,21 +107,6 @@ class CustomMassageManager(MessageManager):
|
||||
step_info: Optional[AgentStepInfo] = None,
|
||||
) -> None:
|
||||
"""Add browser state as human message"""
|
||||
|
||||
# if keep in memory, add to directly to history and add state without result
|
||||
if result:
|
||||
for r in result:
|
||||
if r.include_in_memory:
|
||||
if r.extracted_content:
|
||||
msg = HumanMessage(content=str(r.extracted_content))
|
||||
self._add_message_with_tokens(msg)
|
||||
if r.error:
|
||||
msg = HumanMessage(
|
||||
content=str(r.error)[-self.max_error_length:]
|
||||
)
|
||||
self._add_message_with_tokens(msg)
|
||||
result = None # if result in history, we dont want to add it again
|
||||
|
||||
# otherwise add state message and result to next message (which will not stay in memory)
|
||||
state_message = CustomAgentMessagePrompt(
|
||||
state,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# @Author : wenshao
|
||||
# @ProjectName: browser-use-webui
|
||||
# @FileName: custom_prompts.py
|
||||
|
||||
import pdb
|
||||
from typing import List, Optional
|
||||
|
||||
from browser_use.agent.prompts import SystemPrompt
|
||||
@@ -25,8 +25,9 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
"current_state": {
|
||||
"prev_action_evaluation": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Ignore the action result. The website is the ground truth. Also mention if something unexpected happened like new suggestions in an input field. Shortly state why/why not. Note that the result you output must be consistent with the reasoning you output afterwards. If you consider it to be 'Failed,' you should reflect on this during your thought.",
|
||||
"important_contents": "Output important contents closely related to user\'s instruction or task on the current page. If there is, please output the contents. If not, please output empty string ''.",
|
||||
"completed_contents": "Update the input Task Progress. Completed contents is a general summary of the current contents that have been completed. Just summarize the contents that have been actually completed based on the current page and the history operations. Please list each completed item individually, such as: 1. Input username. 2. Input Password. 3. Click confirm button",
|
||||
"thought": "Think about the requirements that have been completed in previous operations and the requirements that need to be completed in the next one operation. If the output of prev_action_evaluation is 'Failed', please reflect and output your reflection here. If you think you have entered the wrong page, consider to go back to the previous page in next action.",
|
||||
"task_progress": "Task Progress is a general summary of the current contents that have been completed. Just summarize the contents that have been actually completed based on the content at current step and the history operations. Please list each completed item individually, such as: 1. Input username. 2. Input Password. 3. Click confirm button. Please return string type not a list.",
|
||||
"future_plans": "Based on the user's request and the current state, outline the remaining steps needed to complete the task. This should be a concise list of actions yet to be performed, such as: 1. Select a date. 2. Choose a specific time slot. 3. Confirm booking. Please return string type not a list.",
|
||||
"thought": "Think about the requirements that have been completed in previous operations and the requirements that need to be completed in the next one operation. If your output of prev_action_evaluation is 'Failed', please reflect and output your reflection here.",
|
||||
"summary": "Please generate a brief natural language description for the operation in next actions based on your Thought."
|
||||
},
|
||||
"action": [
|
||||
@@ -70,6 +71,7 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
- Don't hallucinate actions.
|
||||
- If the task requires specific information - make sure to include everything in the done function. This is what the user will see.
|
||||
- If you are running out of steps (current step), think about speeding it up, and ALWAYS use the done action as the last action.
|
||||
- Note that you must verify if you've truly fulfilled the user's request by examining the actual page content, not just by looking at the actions you output but also whether the action is executed successfully. Pay particular attention when errors occur during action execution.
|
||||
|
||||
6. VISUAL CONTEXT:
|
||||
- When an image is provided, use it to understand the page layout
|
||||
@@ -100,10 +102,9 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
1. Task: The user\'s instructions you need to complete.
|
||||
2. Hints(Optional): Some hints to help you complete the user\'s instructions.
|
||||
3. Memory: Important contents are recorded during historical operations for use in subsequent operations.
|
||||
4. Task Progress: Up to the current page, the content you have completed can be understood as the progress of the task.
|
||||
5. Current URL: The webpage you're currently on
|
||||
6. Available Tabs: List of open browser tabs
|
||||
7. Interactive Elements: List in the format:
|
||||
4. Current URL: The webpage you're currently on
|
||||
5. Available Tabs: List of open browser tabs
|
||||
6. Interactive Elements: List in the format:
|
||||
index[:]<element_type>element_text</element_type>
|
||||
- index: Numeric identifier for interaction
|
||||
- element_type: HTML element type (button, input, etc.)
|
||||
@@ -162,20 +163,27 @@ class CustomAgentMessagePrompt:
|
||||
self.step_info = step_info
|
||||
|
||||
def get_user_message(self) -> HumanMessage:
|
||||
if self.step_info:
|
||||
step_info_description = f'Current step: {self.step_info.step_number + 1}/{self.step_info.max_steps}'
|
||||
else:
|
||||
step_info_description = ''
|
||||
|
||||
elements_text = self.state.element_tree.clickable_elements_to_string(include_attributes=self.include_attributes)
|
||||
if not elements_text:
|
||||
elements_text = 'empty page'
|
||||
state_description = f"""
|
||||
1. Task: {self.step_info.task}
|
||||
2. Hints(Optional):
|
||||
{self.step_info.add_infos}
|
||||
3. Memory:
|
||||
{self.step_info.memory}
|
||||
4. Task Progress:
|
||||
{self.step_info.task_progress}
|
||||
5. Current url: {self.state.url}
|
||||
6. Available tabs:
|
||||
{self.state.tabs}
|
||||
7. Interactive elements:
|
||||
{self.state.element_tree.clickable_elements_to_string(include_attributes=self.include_attributes)}
|
||||
"""
|
||||
{step_info_description}
|
||||
1. Task: {self.step_info.task}
|
||||
2. Hints(Optional):
|
||||
{self.step_info.add_infos}
|
||||
3. Memory:
|
||||
{self.step_info.memory}
|
||||
4. Current url: {self.state.url}
|
||||
5. Available tabs:
|
||||
{self.state.tabs}
|
||||
6. Interactive elements:
|
||||
{elements_text}
|
||||
"""
|
||||
|
||||
if self.result:
|
||||
for i, result in enumerate(self.result):
|
||||
@@ -202,4 +210,4 @@ class CustomAgentMessagePrompt:
|
||||
]
|
||||
)
|
||||
|
||||
return HumanMessage(content=state_description)
|
||||
return HumanMessage(content=state_description)
|
||||
@@ -20,6 +20,7 @@ class CustomAgentStepInfo:
|
||||
add_infos: str
|
||||
memory: str
|
||||
task_progress: str
|
||||
future_plans: str
|
||||
|
||||
|
||||
class CustomAgentBrain(BaseModel):
|
||||
@@ -27,7 +28,8 @@ class CustomAgentBrain(BaseModel):
|
||||
|
||||
prev_action_evaluation: str
|
||||
important_contents: str
|
||||
completed_contents: str
|
||||
task_progress: str
|
||||
future_plans: str
|
||||
thought: str
|
||||
summary: str
|
||||
|
||||
|
||||
101
src/utils/llm.py
Normal file
101
src/utils/llm.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from openai import OpenAI
|
||||
import pdb
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
class DeepSeekR1ChatOpenAI(ChatOpenAI):
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_key=kwargs.get("api_key")
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
message_history = []
|
||||
for input_ in input:
|
||||
if isinstance(input_, SystemMessage):
|
||||
message_history.append({"role": "system", "content": input_.content})
|
||||
elif isinstance(input_, AIMessage):
|
||||
message_history.append({"role": "assistant", "content": input_.content})
|
||||
else:
|
||||
message_history.append({"role": "user", "content": input_.content})
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
message_history = []
|
||||
for input_ in input:
|
||||
if isinstance(input_, SystemMessage):
|
||||
message_history.append({"role": "system", "content": input_.content})
|
||||
elif isinstance(input_, AIMessage):
|
||||
message_history.append({"role": "assistant", "content": input_.content})
|
||||
else:
|
||||
message_history.append({"role": "user", "content": input_.content})
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=message_history
|
||||
)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
@@ -16,6 +16,8 @@ from langchain_ollama import ChatOllama
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
import gradio as gr
|
||||
|
||||
from .llm import DeepSeekR1ChatOpenAI
|
||||
|
||||
def get_llm_model(provider: str, **kwargs):
|
||||
"""
|
||||
获取LLM 模型
|
||||
@@ -68,12 +70,20 @@ def get_llm_model(provider: str, **kwargs):
|
||||
else:
|
||||
api_key = kwargs.get("api_key")
|
||||
|
||||
return ChatOpenAI(
|
||||
model=kwargs.get("model_name", "deepseek-chat"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
|
||||
return DeepSeekR1ChatOpenAI(
|
||||
model=kwargs.get("model_name", "deepseek-reasoner"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
else:
|
||||
return ChatOpenAI(
|
||||
model=kwargs.get("model_name", "deepseek-chat"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
if not kwargs.get("api_key", ""):
|
||||
api_key = os.getenv("GOOGLE_API_KEY", "")
|
||||
@@ -114,7 +124,7 @@ def get_llm_model(provider: str, **kwargs):
|
||||
model_names = {
|
||||
"anthropic": ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
|
||||
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
|
||||
"deepseek": ["deepseek-chat"],
|
||||
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
|
||||
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
|
||||
"ollama": ["qwen2.5:7b", "llama2:7b"],
|
||||
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
|
||||
|
||||
@@ -247,16 +247,16 @@ async def test_browser_use_custom_v2():
|
||||
# api_key=os.getenv("GOOGLE_API_KEY", "")
|
||||
# )
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="deepseek",
|
||||
# model_name="deepseek-chat",
|
||||
# temperature=0.8
|
||||
# )
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="ollama", model_name="qwen2.5:7b", temperature=0.5
|
||||
provider="deepseek",
|
||||
model_name="deepseek-reasoner",
|
||||
temperature=0.8
|
||||
)
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="ollama", model_name="qwen2.5:7b", temperature=0.5
|
||||
# )
|
||||
|
||||
controller = CustomController()
|
||||
use_own_browser = False
|
||||
disable_security = True
|
||||
|
||||
@@ -114,6 +114,33 @@ def test_deepseek_model():
|
||||
ai_msg = llm.invoke([message])
|
||||
print(ai_msg.content)
|
||||
|
||||
def test_deepseek_r1_model():
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from src.utils import utils
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="deepseek",
|
||||
model_name="deepseek-reasoner",
|
||||
temperature=0.8,
|
||||
base_url=os.getenv("DEEPSEEK_ENDPOINT", ""),
|
||||
api_key=os.getenv("DEEPSEEK_API_KEY", "")
|
||||
)
|
||||
messages = []
|
||||
sys_message = SystemMessage(
|
||||
content=[{"type": "text", "text": "you are a helpful AI assistant"}]
|
||||
)
|
||||
messages.append(sys_message)
|
||||
user_message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "9.11 and 9.8, which is greater?"}
|
||||
]
|
||||
)
|
||||
messages.append(user_message)
|
||||
ai_msg = llm.invoke(messages)
|
||||
print(ai_msg.reasoning_content)
|
||||
print(ai_msg.content)
|
||||
print(llm.model_name)
|
||||
pdb.set_trace()
|
||||
|
||||
def test_ollama_model():
|
||||
from langchain_ollama import ChatOllama
|
||||
@@ -128,4 +155,5 @@ if __name__ == '__main__':
|
||||
# test_gemini_model()
|
||||
# test_azure_openai_model()
|
||||
# test_deepseek_model()
|
||||
test_ollama_model()
|
||||
# test_ollama_model()
|
||||
test_deepseek_r1_model()
|
||||
|
||||
Reference in New Issue
Block a user