fix gemini error

This commit is contained in:
tjb-tech
2025-02-28 15:22:26 +08:00
parent e9d3589ceb
commit 106c8694e0

View File

@@ -55,6 +55,44 @@ def should_retry_error(exception):
__CTX_VARS_NAME__ = "context_variables"
logger = LoggerManager.get_logger()
def adapt_tools_for_gemini(tools):
"""为 Gemini 模型适配工具定义,确保所有 OBJECT 类型参数都有非空的 properties"""
if tools is None:
return None
adapted_tools = []
for tool in tools:
adapted_tool = copy.deepcopy(tool)
# 检查参数
if "parameters" in adapted_tool["function"]:
params = adapted_tool["function"]["parameters"]
# 处理顶层参数
if params.get("type") == "object":
if "properties" not in params or not params["properties"]:
params["properties"] = {
"dummy": {
"type": "string",
"description": "Dummy property for Gemini compatibility"
}
}
# 处理嵌套参数
if "properties" in params:
for prop_name, prop in params["properties"].items():
if isinstance(prop, dict) and prop.get("type") == "object":
if "properties" not in prop or not prop["properties"]:
prop["properties"] = {
"dummy": {
"type": "string",
"description": "Dummy property for Gemini compatibility"
}
}
adapted_tools.append(adapted_tool)
return adapted_tools
class MetaChain:
def __init__(self, log_path: Union[str, None, MetaChainLogger] = None):
"""
@@ -68,12 +106,12 @@ class MetaChain:
self.logger = MetaChainLogger(log_path=log_path)
# if self.logger.log_path is None: self.logger.info("[Warning] Not specific log path, so log will not be saved", "...", title="Log Path", color="light_cyan3")
# else: self.logger.info("Log file is saved to", self.logger.log_path, "...", title="Log Path", color="light_cyan3")
@retry(
stop=stop_after_attempt(4),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=should_retry_error,
before_sleep=lambda retry_state: print(f"Retrying... (attempt {retry_state.attempt_number})")
)
# @retry(
# stop=stop_after_attempt(4),
# wait=wait_exponential(multiplier=1, min=4, max=60),
# retry=should_retry_error,
# before_sleep=lambda retry_state: print(f"Retrying... (attempt {retry_state.attempt_number})")
# )
def get_chat_completion(
self,
agent: Agent,
@@ -103,8 +141,12 @@ class MetaChain:
params["properties"].pop(__CTX_VARS_NAME__, None)
if __CTX_VARS_NAME__ in params["required"]:
params["required"].remove(__CTX_VARS_NAME__)
create_model = model_override or agent.model
if "gemini" in create_model.lower():
tools = adapt_tools_for_gemini(tools)
if FN_CALL:
create_model = model_override or agent.model
# create_model = model_override or agent.model
assert litellm.supports_function_calling(model = create_model) == True, f"Model {create_model} does not support function calling, please set `FN_CALL=False` to use non-function calling mode"
create_params = {
"model": create_model,
@@ -130,7 +172,7 @@ class MetaChain:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
completion_response = completion(**create_params)
else:
create_model = model_override or agent.model
# create_model = model_override or agent.model
assert agent.tool_choice == "required", f"Non-function calling mode MUST use tool_choice = 'required' rather than {agent.tool_choice}"
last_content = messages[-1]["content"]
tools_description = convert_tools_to_description(tools)