Files
AutoAgent/autoagent/util.py
tjb-tech efbf1e3bc8 update
2025-02-18 12:30:06 +08:00

541 lines
18 KiB
Python

import inspect
from datetime import datetime
import socket
import json
import uuid
from typing import Callable, List, Dict, Any, Optional, Callable, Union, get_args, get_origin
from dataclasses import is_dataclass, fields, MISSING
from pydantic import BaseModel
from rich.panel import Panel
from rich.prompt import Prompt
from rich.console import Console
import inquirer
from rich.markdown import Markdown
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.styles import Style
def debug_print_swarm(debug: bool, *args: str) -> None:
if not debug:
return
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
message = " ".join(map(str, args))
print(f"\033[97m[\033[90m{timestamp}\033[97m]\033[90m {message}\033[0m")
def print_in_box(text: str, console: Optional[Console] = None, title: str = "", color: str = "white") -> None:
"""
Print the text in a box.
:param text: the text to print.
:param console: the console to print the text.
:param title: the title of the box.
:param color: the border color.
:return:
"""
console = console or Console()
# panel = Panel(text, title=title, border_style=color, expand=True, highlight=True)
# console.print(panel)
console.print('_'*20 + title + '_'*20, style=f"bold {color}")
console.print(text, highlight=True, emoji=True)
def debug_print(debug: bool, *args: str, **kwargs: dict) -> None:
if not debug:
return
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
message = "\n".join(map(str, args))
color = kwargs.get("color", "white")
title = kwargs.get("title", "")
log_str = f"[{timestamp}]\n{message}"
print_in_box(log_str, color=color, title=title)
log_path = kwargs.get("log_path", None)
if log_path:
with open(log_path, 'a') as f:
f.write(log_str + '\n')
def ask_text(question: str, title: str = "User", console: Optional[Console] = None, default_answer: str = "") -> str:
"""
Display a question in a panel and prompt the user for an answer.
:param question: the question to display.
:param title: the title of the panel.
:param console: the console to use.
:return: the user's answer.
"""
console = console or Console()
console.print(Panel(question, title=title, border_style="green"))
answer = Prompt.ask(f"Type your answer here, press Enter to use default answer", default=default_answer)
console.print(Panel(answer, title=title))
return answer
def print_markdown(md_path: str, console: Optional[Console] = None):
console = console or Console()
with open(md_path, 'r') as f:
md_content = f.read()
console.print(Markdown(md_content))
def single_select_menu(options, message: str = ""):
questions = [
inquirer.List(
'choice',
message=message,
choices=options,
),
]
answers = inquirer.prompt(questions)
return answers['choice']
def get_user_confirmation(prompt: str) -> bool:
user_input = prompt.strip().lower()
if user_input in ['y', 'yes', 'true', 't']:
return True
elif user_input in ['n', 'no', 'false', 'f'] or user_input == '':
return False
else:
print("Invalid input. Please enter 'y' for yes or 'n' for no.")
def merge_fields(target, source):
for key, value in source.items():
if isinstance(value, str):
target[key] += value
elif value is not None and isinstance(value, dict):
merge_fields(target[key], value)
def merge_chunk(final_response: dict, delta: dict) -> None:
delta.pop("role", None)
merge_fields(final_response, delta)
tool_calls = delta.get("tool_calls")
if tool_calls and len(tool_calls) > 0:
index = tool_calls[0].pop("index")
merge_fields(final_response["tool_calls"][index], tool_calls[0])
# def function_to_json(func) -> dict:
# """
# Converts a Python function into a JSON-serializable dictionary
# that describes the function's signature, including its name,
# description, and parameters.
# Args:
# func: The function to be converted.
# Returns:
# A dictionary representing the function's signature in JSON format.
# """
# type_map = {
# str: "string",
# int: "integer",
# float: "number",
# bool: "boolean",
# list: "array",
# dict: "object",
# type(None): "null",
# }
# try:
# signature = inspect.signature(func)
# except ValueError as e:
# raise ValueError(
# f"Failed to get signature for function {func.__name__}: {str(e)}"
# )
# parameters = {}
# for param in signature.parameters.values():
# try:
# param_type = type_map.get(param.annotation, "string")
# except KeyError as e:
# raise KeyError(
# f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
# )
# parameters[param.name] = {"type": param_type}
# required = [
# param.name
# for param in signature.parameters.values()
# if param.default == inspect._empty
# ]
# return {
# "type": "function",
# "function": {
# "name": func.__name__,
# "description": func.__doc__ or "",
# "parameters": {
# "type": "object",
# "properties": parameters,
# "required": required,
# },
# },
# }
def get_type_info(annotation, base_type_map):
# 处理基本类型
if annotation in base_type_map:
return {"type": base_type_map[annotation]}
# 处理typing类型
origin = get_origin(annotation)
if origin is not None:
args = get_args(annotation)
# 处理List类型
if origin is list or origin is List:
item_type = args[0]
return {
"type": "array",
"items": get_type_info(item_type, base_type_map)
}
# 处理Dict类型
elif origin is dict or origin is Dict:
key_type, value_type = args
if key_type != str:
raise ValueError("Dictionary keys must be strings")
# 如果value_type是TypedDict或Pydantic模型
if (hasattr(value_type, "__annotations__") or
(isinstance(value_type, type) and issubclass(value_type, BaseModel))):
return get_type_info(value_type, base_type_map)
# 普通Dict类型
return {
"type": "object",
"additionalProperties": get_type_info(value_type, base_type_map)
}
# 处理Union类型
elif origin is Union:
types = [get_type_info(arg, base_type_map) for arg in args if arg != type(None)]
if len(types) == 1:
return types[0]
return {"oneOf": types}
# 处理Pydantic模型
if isinstance(annotation, type):
try:
if issubclass(annotation, BaseModel):
schema = annotation.model_json_schema()
# 提取主要schema部分
properties = schema.get("properties", {})
required = schema.get("required", [])
# 处理definitions
definitions = schema.get("$defs", {})
if definitions:
# 如果有引用的定义,直接展开它们
for prop_name, prop_schema in properties.items():
if "$ref" in prop_schema:
ref_name = prop_schema["$ref"].split("/")[-1]
if ref_name in definitions:
properties[prop_name] = definitions[ref_name]
return {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False
}
except TypeError:
pass
# 处理dataclass
if is_dataclass(annotation):
properties = {}
required = []
for field in fields(annotation):
properties[field.name] = get_type_info(field.type, base_type_map)
if field.default == field.default_factory == MISSING:
required.append(field.name)
return {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False
}
# 处理TypedDict
if hasattr(annotation, "__annotations__"):
properties = {}
required = getattr(annotation, "__required_keys__", annotation.__annotations__.keys())
for key, field_type in annotation.__annotations__.items():
properties[key] = get_type_info(field_type, base_type_map)
return {
"type": "object",
"properties": properties,
"required": list(required),
"additionalProperties": False
}
# 默认返回string类型
return {"type": "string"}
def function_to_json(func) -> dict:
"""
Converts a Python function into a JSON-serializable dictionary
that describes the function's signature, including its name,
description, and parameters.
Args:
func: The function to be converted.
Returns:
A dictionary representing the function's signature in JSON format.
"""
type_map = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
# list: "array",
# dict: "object",
type(None): "null",
}
# def get_type_info(annotation):
# if hasattr(annotation, "__origin__"): # 处理typing类型
# origin = annotation.__origin__
# if origin is list: # 处理List类型
# item_type = annotation.__args__[0]
# return {
# "type": "array",
# "items": {
# "type": type_map.get(item_type, "string")
# }
# }
# elif origin is dict: # 处理Dict类型
# return {"type": "object"}
# return {"type": type_map.get(annotation, "string")}
try:
signature = inspect.signature(func)
except ValueError as e:
raise ValueError(
f"Failed to get signature for function {func.__name__}: {str(e)}"
)
parameters = {}
# for param in signature.parameters.values():
# try:
# param_type = type_map.get(param.annotation, "string")
# except KeyError as e:
# raise KeyError(
# f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
# )
# parameters[param.name] = {"type": param_type}
for param in signature.parameters.values():
if param.name == "context_variables":
continue
try:
param_info = get_type_info(param.annotation, type_map)
if isinstance(param_info, dict) and "additionalProperties" in param_info:
del param_info["additionalProperties"]
parameters[param.name] = get_type_info(param.annotation, type_map)
except KeyError as e:
raise KeyError(f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}")
required = [
param.name
for param in signature.parameters.values()
if param.default == inspect._empty
]
# if not parameters:
# parameters["dummy"] = {
# "type": "string",
# "description": "Dummy parameter (not used). Added to satisfy non-empty schema requirements."
# }
# required = []
return {
"type": "function",
"function": {
"name": func.__name__,
"description": func.__doc__ or "",
"parameters": {
"type": "object",
"properties": parameters,
"required": required,
},
},
}
def run_command_in_container_v1(command, stream_callback: Callable = None):
# TCP parameters
hostname = 'localhost'
port = 12345 # TCP port mapped to the container
buffer_size = 4096
# Create TCP client
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((hostname, port))
s.sendall(command.encode())
full_response = b""
while True:
chunk = s.recv(buffer_size)
if not chunk:
break
full_response += chunk
if stream_callback:
stream_callback(chunk)
if len(chunk) < buffer_size:
# If the received data is less than the buffer size, it may have been received
break
# Decode the complete response
try:
decoded_response = full_response.decode('utf-8')
return json.loads(decoded_response)
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw response received: {decoded_response}")
return {"status": -1, "result": "Response parsing error"}
def run_command_in_container(command, stream_callback=None):
"""
communicate with docker container and execute command, support stream output
Args:
command: the command to execute
stream_callback: optional callback function, for handling stream output
the function signature should be callback(text: str)
Returns:
dict: the complete JSON result returned by the docker container
"""
hostname = 'localhost'
port = 12345
buffer_size = 4096
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((hostname, port))
s.sendall(command.encode())
partial_line = ""
while True:
chunk = s.recv(buffer_size)
if not chunk:
break
# add new received data to the unfinished data
data = partial_line + chunk.decode('utf-8')
lines = data.split('\n')
# except the last line, process all complete lines
for line in lines[:-1]:
if line:
try:
response = json.loads(line)
if response['type'] == 'chunk':
# process stream output
if stream_callback:
stream_callback(response['data'])
elif response['type'] == 'final':
# return the final result
return {
'status': response['status'],
'result': response['result']
}
except json.JSONDecodeError:
print(f"Invalid JSON: {line}")
# save the possibly unfinished last line
partial_line = lines[-1]
# if the loop ends normally without receiving a final response
return {
'status': -1,
'result': 'Connection closed without final response'
}
def make_tool_message(tools: Callable, args: dict, tool_content: str) -> List[Dict]:
tool_calls = [
{
"type": "function",
"function": {
"name": tools.__name__,
"arguments": json.dumps(args)
},
"id": str(uuid.uuid4()).replace('-', '')[:9]
}
]
return [
{'role': 'assistant', 'tool_calls': tool_calls},
{'role': 'tool', 'content': tool_content, 'name': tools.__name__, 'tool_call_id': tool_calls[0]['id']}
]
def make_message(role: str, content: str):
return [
{'role': role, 'content': content}
]
class UserCompleter(Completer):
def __init__(self, users: List[str]):
super().__init__()
self.users = users
def get_completions(self, document, complete_event):
word = document.get_word_before_cursor()
if word.startswith('@'):
prefix = word[1:] # 去掉@
for user in self.users:
if user.startswith(prefix):
yield Completion(
user,
start_position=-len(prefix),
style='fg:blue bold' # 蓝色加粗
)
def pretty_print_messages(message, **kwargs) -> None:
# for message in messages:
if message["role"] != "assistant" and message["role"] != "tool":
return
console = Console()
if message["role"] == "tool":
console.print("[bold blue]tool execution:[/bold blue]", end=" ")
console.print(f"[bold purple]{message['name']}[/bold purple], result: {message['content']}")
log_path = kwargs.get("log_path", None)
if log_path:
with open(log_path, 'a') as file:
file.write(f"tool execution: {message['name']}, result: {message['content']}\n")
return
# print agent name in blue
console.print(f"[bold blue]{message['sender']}[/bold blue]:", end=" ")
# print response, if any
if message["content"]:
console.print(message["content"], highlight=True, emoji=True)
# print tool calls in purple, if any
tool_calls = message.get("tool_calls") or []
if len(tool_calls) > 1:
console.print()
for tool_call in tool_calls:
f = tool_call["function"]
name, args = f["name"], f["arguments"]
arg_str = json.dumps(json.loads(args)).replace(":", "=")
console.print(f"[bold purple]{name}[/bold purple]({arg_str[1:-1]})")
log_path = kwargs.get("log_path", None)
if log_path:
with open(log_path, 'a') as file:
file.write(f"{message['sender']}: {message['content']}\n")
for tool_call in tool_calls:
f = tool_call["function"]
name, args = f["name"], f["arguments"]
arg_str = json.dumps(json.loads(args)).replace(":", "=")
file.write(f"{name}({arg_str[1:-1]})\n")