136 lines
4.5 KiB
Python
136 lines
4.5 KiB
Python
from openai import OpenAI
|
|
import pdb
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.globals import get_llm_cache
|
|
from langchain_core.language_models.base import (
|
|
BaseLanguageModel,
|
|
LangSmithParams,
|
|
LanguageModelInput,
|
|
)
|
|
from langchain_core.load import dumpd, dumps
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
SystemMessage,
|
|
AnyMessage,
|
|
BaseMessage,
|
|
BaseMessageChunk,
|
|
HumanMessage,
|
|
convert_to_messages,
|
|
message_chunk_to_message,
|
|
)
|
|
from langchain_core.outputs import (
|
|
ChatGeneration,
|
|
ChatGenerationChunk,
|
|
ChatResult,
|
|
LLMResult,
|
|
RunInfo,
|
|
)
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_core.output_parsers.base import OutputParserLike
|
|
from langchain_core.runnables import Runnable, RunnableConfig
|
|
from langchain_core.tools import BaseTool
|
|
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Literal,
|
|
Optional,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
class DeepSeekR1ChatOpenAI(ChatOpenAI):
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.client = OpenAI(
|
|
base_url=kwargs.get("base_url"),
|
|
api_key=kwargs.get("api_key")
|
|
)
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
message_history = []
|
|
for input_ in input:
|
|
if isinstance(input_, SystemMessage):
|
|
message_history.append({"role": "system", "content": input_.content})
|
|
elif isinstance(input_, AIMessage):
|
|
message_history.append({"role": "assistant", "content": input_.content})
|
|
else:
|
|
message_history.append({"role": "user", "content": input_.content})
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=message_history
|
|
)
|
|
|
|
reasoning_content = response.choices[0].message.reasoning_content
|
|
content = response.choices[0].message.content
|
|
return AIMessage(content=content, reasoning_content=reasoning_content)
|
|
|
|
def invoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
message_history = []
|
|
for input_ in input:
|
|
if isinstance(input_, SystemMessage):
|
|
message_history.append({"role": "system", "content": input_.content})
|
|
elif isinstance(input_, AIMessage):
|
|
message_history.append({"role": "assistant", "content": input_.content})
|
|
else:
|
|
message_history.append({"role": "user", "content": input_.content})
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=message_history
|
|
)
|
|
|
|
reasoning_content = response.choices[0].message.reasoning_content
|
|
content = response.choices[0].message.content
|
|
return AIMessage(content=content, reasoning_content=reasoning_content)
|
|
|
|
class DeepSeekR1ChatOllama(ChatOllama):
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
org_ai_message = await super().ainvoke(input=input)
|
|
org_content = org_ai_message.content
|
|
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
|
content = org_content.split("</think>")[1]
|
|
if "**JSON Response:**" in content:
|
|
content = content.split("**JSON Response:**")[-1]
|
|
return AIMessage(content=content, reasoning_content=reasoning_content)
|
|
|
|
def invoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
org_ai_message = super().invoke(input=input)
|
|
org_content = org_ai_message.content
|
|
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
|
content = org_content.split("</think>")[1]
|
|
if "**JSON Response:**" in content:
|
|
content = content.split("**JSON Response:**")[-1]
|
|
return AIMessage(content=content, reasoning_content=reasoning_content) |