Files
AutoAgent/autoagent/registry.py
2025-02-12 21:56:40 +08:00

226 lines
7.8 KiB
Python

from typing import Callable, Dict, Any, Union, Literal, List, Optional
from dataclasses import dataclass, asdict
import inspect
import os
import functools
import tiktoken
MAX_OUTPUT_LENGTH = 12000
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
def truncate_output(output: str, max_length: int = MAX_OUTPUT_LENGTH) -> str:
"""Truncate output if it exceeds max_length"""
tokens = encode_string_by_tiktoken(output)
if len(tokens) > max_length:
return decode_tokens_by_tiktoken(tokens[:max_length]) + f"\n\n[TOOL WARNING] Output truncated, exceeded {max_length} tokens)\n[TOOL SUGGESTION] Maybe this tool with direct output is not an optimal choice, consider save the output to a file in the `workplace/` directory to implement the same functionality."
return output
@dataclass
class FunctionInfo:
name: str
func_name: str
func: Callable
args: List[str]
docstring: Optional[str]
body: str
return_type: Optional[str]
file_path: Optional[str]
def to_dict(self) -> dict:
# using asdict, but exclude func field because it cannot be serialized
d = asdict(self)
d.pop('func') # remove func field
return d
@classmethod
def from_dict(cls, data: dict) -> 'FunctionInfo':
# if you need to create an object from a dictionary
if 'func' not in data:
data['func'] = None # or other default value
return cls(**data)
class Registry:
_instance = None
_registry: Dict[str, Dict[str, Callable]] = {
"tools": {},
"agents": {},
"plugin_tools": {},
"plugin_agents": {},
"workflows": {}
}
_registry_info: Dict[str, Dict[str, FunctionInfo]] = {
"tools": {},
"agents": {},
"plugin_tools": {},
"plugin_agents": {},
"workflows": {}
}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def register(self,
type: Literal["tool", "agent", "plugin_tool", "plugin_agent", "workflow"],
name: str = None,
func_name: str = None):
"""
统一的注册装饰器
Args:
type: 注册类型,"tool""agent"
name: 可选的注册名称
"""
def decorator(func: Callable):
nonlocal name
if name is None:
name = func.__name__
# if type == "agent" and name.startswith('get_'):
# name = name[4:] # 对 agent 移除 'get_' 前缀
# 获取函数的文件路径
if type == "plugin_tool":
original_func = func # 保存原始函数
@functools.wraps(original_func)
def wrapped_func(*args, **kwargs):
result = original_func(*args, **kwargs) # 调用原始函数
if isinstance(result, str):
return truncate_output(result)
return result
else:
wrapped_func = func
try:
file_path = os.path.abspath(inspect.getfile(func))
except:
file_path = "Unknown"
# 获取函数信息
signature = inspect.signature(func)
args = list(signature.parameters.keys())
docstring = inspect.getdoc(func)
# 获取函数体
source_lines = inspect.getsource(func)
# 移除装饰器和函数定义行
body_lines = source_lines.split('\n')[1:] # 跳过装饰器行
while body_lines and (body_lines[0].strip().startswith('@') or 'def ' in body_lines[0]):
body_lines = body_lines[1:]
body = '\n'.join(body_lines)
# 获取返回类型提示
return_type = None
if signature.return_annotation != inspect.Signature.empty:
return_type = str(signature.return_annotation)
# 创建函数信息对象
func_info = FunctionInfo(
name=name,
func_name=func_name,
func=wrapped_func if type == "plugin_tool" else func,
args=args,
docstring=docstring,
body=body,
return_type=return_type,
file_path=file_path # 添加文件路径
)
registry_type = f"{type}s"
self._registry[registry_type][func_name] = wrapped_func if type == "plugin_tool" else func
self._registry_info[registry_type][name] = func_info
return wrapped_func if type == "plugin_tool" else func
return decorator
@property
def tools(self) -> Dict[str, Callable]:
return self._registry["tools"]
@property
def agents(self) -> Dict[str, Callable]:
return self._registry["agents"]
@property
def plugin_tools(self) -> Dict[str, Callable]:
return self._registry["plugin_tools"]
@property
def plugin_agents(self) -> Dict[str, Callable]:
return self._registry["plugin_agents"]
@property
def workflows(self) -> Dict[str, Callable]:
return self._registry["workflows"]
@property
def tools_info(self) -> Dict[str, FunctionInfo]:
return self._registry_info["tools"]
@property
def agents_info(self) -> Dict[str, FunctionInfo]:
return self._registry_info["agents"]
@property
def plugin_tools_info(self) -> Dict[str, FunctionInfo]:
return self._registry_info["plugin_tools"]
@property
def plugin_agents_info(self) -> Dict[str, FunctionInfo]:
return self._registry_info["plugin_agents"]
@property
def workflows_info(self) -> Dict[str, FunctionInfo]:
return self._registry_info["workflows"]
@property
def display_plugin_tools_info(self):
display_info = {}
for name, info in self.plugin_tools_info.items():
tmp_info = info.to_dict().copy()
tmp_info.pop('func', None)
display_info[name] = tmp_info
return display_info
@property
def display_plugin_agents_info(self):
display_info = {}
for name, info in self.plugin_agents_info.items():
tmp_info = info.to_dict().copy()
tmp_info.pop('func', None)
display_info[name] = tmp_info
return display_info
@property
def display_workflows_info(self):
display_info = {}
for name, info in self.workflows_info.items():
tmp_info = info.to_dict().copy()
tmp_info.pop('func', None)
display_info[name] = tmp_info
return display_info
# 创建全局实例
registry = Registry()
# 便捷的注册函数
def register_tool(name: str = None):
func_name = name
return registry.register(type="tool", name=name, func_name=func_name)
def register_agent(name: str = None, func_name: str = None):
return registry.register(type="agent", name=name, func_name=func_name)
def register_plugin_tool(name: str = None):
func_name = name
return registry.register(type="plugin_tool", name=name, func_name=func_name)
def register_plugin_agent(name: str = None, func_name: str = None):
return registry.register(type="plugin_agent", name=name, func_name=func_name)
def register_workflow(name: str = None):
func_name = name
return registry.register(type="workflow", name=name, func_name=func_name)