mirror of
https://github.com/HKUDS/AutoAgent.git
synced 2025-10-09 13:41:35 +03:00
fix gemini error
This commit is contained in:
@@ -55,6 +55,44 @@ def should_retry_error(exception):
|
||||
__CTX_VARS_NAME__ = "context_variables"
|
||||
logger = LoggerManager.get_logger()
|
||||
|
||||
def adapt_tools_for_gemini(tools):
|
||||
"""为 Gemini 模型适配工具定义,确保所有 OBJECT 类型参数都有非空的 properties"""
|
||||
if tools is None:
|
||||
return None
|
||||
|
||||
adapted_tools = []
|
||||
for tool in tools:
|
||||
adapted_tool = copy.deepcopy(tool)
|
||||
|
||||
# 检查参数
|
||||
if "parameters" in adapted_tool["function"]:
|
||||
params = adapted_tool["function"]["parameters"]
|
||||
|
||||
# 处理顶层参数
|
||||
if params.get("type") == "object":
|
||||
if "properties" not in params or not params["properties"]:
|
||||
params["properties"] = {
|
||||
"dummy": {
|
||||
"type": "string",
|
||||
"description": "Dummy property for Gemini compatibility"
|
||||
}
|
||||
}
|
||||
|
||||
# 处理嵌套参数
|
||||
if "properties" in params:
|
||||
for prop_name, prop in params["properties"].items():
|
||||
if isinstance(prop, dict) and prop.get("type") == "object":
|
||||
if "properties" not in prop or not prop["properties"]:
|
||||
prop["properties"] = {
|
||||
"dummy": {
|
||||
"type": "string",
|
||||
"description": "Dummy property for Gemini compatibility"
|
||||
}
|
||||
}
|
||||
|
||||
adapted_tools.append(adapted_tool)
|
||||
return adapted_tools
|
||||
|
||||
class MetaChain:
|
||||
def __init__(self, log_path: Union[str, None, MetaChainLogger] = None):
|
||||
"""
|
||||
@@ -68,12 +106,12 @@ class MetaChain:
|
||||
self.logger = MetaChainLogger(log_path=log_path)
|
||||
# if self.logger.log_path is None: self.logger.info("[Warning] Not specific log path, so log will not be saved", "...", title="Log Path", color="light_cyan3")
|
||||
# else: self.logger.info("Log file is saved to", self.logger.log_path, "...", title="Log Path", color="light_cyan3")
|
||||
@retry(
|
||||
stop=stop_after_attempt(4),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=should_retry_error,
|
||||
before_sleep=lambda retry_state: print(f"Retrying... (attempt {retry_state.attempt_number})")
|
||||
)
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(4),
|
||||
# wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
# retry=should_retry_error,
|
||||
# before_sleep=lambda retry_state: print(f"Retrying... (attempt {retry_state.attempt_number})")
|
||||
# )
|
||||
def get_chat_completion(
|
||||
self,
|
||||
agent: Agent,
|
||||
@@ -103,8 +141,12 @@ class MetaChain:
|
||||
params["properties"].pop(__CTX_VARS_NAME__, None)
|
||||
if __CTX_VARS_NAME__ in params["required"]:
|
||||
params["required"].remove(__CTX_VARS_NAME__)
|
||||
create_model = model_override or agent.model
|
||||
|
||||
if "gemini" in create_model.lower():
|
||||
tools = adapt_tools_for_gemini(tools)
|
||||
if FN_CALL:
|
||||
create_model = model_override or agent.model
|
||||
# create_model = model_override or agent.model
|
||||
assert litellm.supports_function_calling(model = create_model) == True, f"Model {create_model} does not support function calling, please set `FN_CALL=False` to use non-function calling mode"
|
||||
create_params = {
|
||||
"model": create_model,
|
||||
@@ -130,7 +172,7 @@ class MetaChain:
|
||||
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
|
||||
completion_response = completion(**create_params)
|
||||
else:
|
||||
create_model = model_override or agent.model
|
||||
# create_model = model_override or agent.model
|
||||
assert agent.tool_choice == "required", f"Non-function calling mode MUST use tool_choice = 'required' rather than {agent.tool_choice}"
|
||||
last_content = messages[-1]["content"]
|
||||
tools_description = convert_tools_to_description(tools)
|
||||
|
||||
Reference in New Issue
Block a user