Compare commits
	
		
			24 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 1eb10a3467 | ||
|   | b5533b036d | ||
|   | 39e00f25df | ||
|   | 183ef77a84 | ||
|   | a2a0d91b29 | ||
|   | 85853da263 | ||
|   | 487919476b | ||
|   | 1292ea7dc7 | ||
|   | 016ee62cc6 | ||
|   | 1b5fd996d3 | ||
|   | 799c6a32c4 | ||
|   | bdd6a9b4df | ||
|   | a53bad2077 | ||
|   | 54e5b615f4 | ||
|   | 73de598d42 | ||
|   | adefd97a27 | ||
|   | fdaacb87a4 | ||
|   | c42c884900 | ||
|   | 30f77daa07 | ||
|   | e46e59cc6a | ||
|   | 740c3f960b | ||
|   | 93162b27b1 | ||
|   | b49397aa93 | ||
|   | 06bedec5a9 | 
| @@ -2,7 +2,11 @@ version: 2 | ||||
|  | ||||
| formats: all | ||||
|  | ||||
| build: | ||||
|   os: ubuntu-22.04 | ||||
|   tools: | ||||
|     python: "3.10" | ||||
|  | ||||
| python: | ||||
|     version: 3.7 | ||||
|     install: | ||||
|       - requirements: requirements/docs.txt | ||||
|   install: | ||||
|     - requirements: requirements/docs.txt | ||||
|   | ||||
| @@ -83,7 +83,7 @@ Below is an example of running ReWOO with GPT-3.5 | ||||
| ```python | ||||
| # Import necessary modules and classes from the "lagent" library. | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, LLMQA | ||||
| from lagent.actions import ActionExecutor, GoogleSearch | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| # Initialize the Language Model (llm) and provide your API key. | ||||
| @@ -92,14 +92,11 @@ llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
| # Initialize the Google Search tool and provide your API key. | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # Initialize the LLMQA tool using the Language Model (llm). | ||||
| llmqa_tool = LLMQA(llm) | ||||
|  | ||||
| # Create a chatbot by configuring the ReWOO agent. | ||||
| chatbot = ReWOO( | ||||
|     llm=llm,  # Provide the Language Model instance. | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, llmqa_tool]  # Specify the actions the chatbot can perform. | ||||
|         actions=[search_tool]  # Specify the actions the chatbot can perform. | ||||
|     ), | ||||
| ) | ||||
|  | ||||
| @@ -154,6 +151,7 @@ response = chatbot.chat( | ||||
| print(response.response)  # Output the response generated by the chatbot. | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
| ### All Thanks To Our Contributors: | ||||
|   <a href="https://github.com/InternLM/lagent/graphs/contributors"> | ||||
|   <img src="https://contrib.rocks/image?repo=InternLM/lagent" /> | ||||
|   | ||||
							
								
								
									
										14
									
								
								docs/en/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								docs/en/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| API Reference | ||||
| ============= | ||||
|  | ||||
| This page contains auto-generated API reference documentation. | ||||
|  | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|     | ||||
|    {% for page in pages %} | ||||
|    {% if page.top_level_object and page.display %} | ||||
|    {{ page.include_path }} | ||||
|    {% endif %} | ||||
|    {% endfor %} | ||||
							
								
								
									
										112
									
								
								docs/en/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								docs/en/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,112 @@ | ||||
| {% if not obj.display %} | ||||
| :orphan: | ||||
|  | ||||
| {% endif %} | ||||
| :py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}` | ||||
| =========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }} | ||||
|  | ||||
| .. py:module:: {{ obj.name }} | ||||
|  | ||||
| {% if obj.docstring %} | ||||
| .. autoapi-nested-parse:: | ||||
|  | ||||
|    {{ obj.docstring|indent(3) }} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| {% block subpackages %} | ||||
| {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} | ||||
| {% if visible_subpackages %} | ||||
| Subpackages | ||||
| ----------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|  | ||||
| {% for subpackage in visible_subpackages %} | ||||
|    {{ subpackage.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block submodules %} | ||||
| {% set visible_submodules = obj.submodules|selectattr("display")|list %} | ||||
| {% if visible_submodules %} | ||||
| Submodules | ||||
| ---------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 1 | ||||
|  | ||||
| {% for submodule in visible_submodules %} | ||||
|    {{ submodule.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block content %} | ||||
| {% if obj.type is equalto("package") %} | ||||
| {% set visible_children = obj.children|selectattr("display")|list %} | ||||
| {% else %} | ||||
| {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} | ||||
| {% endif %} | ||||
| {% if visible_children %} | ||||
| {{ obj.type|title }} Contents | ||||
| {{ "-" * obj.type|length }}--------- | ||||
|  | ||||
| {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} | ||||
| {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} | ||||
| {% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %} | ||||
| {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} | ||||
| {% block classes scoped %} | ||||
| {% if visible_classes %} | ||||
| Classes | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for klass in visible_classes %} | ||||
|    {{ klass.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block functions scoped %} | ||||
| {% if visible_functions %} | ||||
| Functions | ||||
| ~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for function in visible_functions %} | ||||
|    {{ function.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block attributes scoped %} | ||||
| {% if visible_attributes %} | ||||
| Attributes | ||||
| ~~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for attribute in visible_attributes %} | ||||
|    {{ attribute.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% endif %} | ||||
| {% for obj_item in visible_children %} | ||||
| {{ obj_item.render()|indent(0) }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
							
								
								
									
										110
									
								
								docs/en/conf.py
									
									
									
									
									
								
							
							
						
						
									
										110
									
								
								docs/en/conf.py
									
									
									
									
									
								
							| @@ -11,17 +11,16 @@ | ||||
| # documentation root, use os.path.abspath to make it absolute, like shown here. | ||||
|  | ||||
| import os | ||||
| import re | ||||
| import sys | ||||
|  | ||||
| import pytorch_sphinx_theme | ||||
|  | ||||
| sys.path.insert(0, os.path.abspath('../../')) | ||||
| sys.path.insert(0, os.path.abspath('../..')) | ||||
|  | ||||
| # -- Project information ----------------------------------------------------- | ||||
|  | ||||
| project = 'Lagent' | ||||
| copyright = '2020-2030, InternLM' | ||||
| author = 'InternLM' | ||||
| language = 'en' | ||||
|  | ||||
| # The full version, including alpha/beta/rc tags | ||||
| version_file = '../../lagent/version.py' | ||||
| @@ -36,97 +35,74 @@ release = __version__ | ||||
| # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||||
| # ones. | ||||
| extensions = [ | ||||
|     'sphinx_rtd_theme', | ||||
|     'myst_nb', | ||||
|     'autoapi.extension', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx.ext.autodoc', | ||||
|     'sphinx.ext.napoleon', | ||||
|     'sphinx.ext.viewcode', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx_copybutton', | ||||
|     'myst_parser', | ||||
|     'sphinx.ext.intersphinx', | ||||
|     'sphinx.ext.autodoc.typehints', | ||||
|     'sphinx.ext.autosummary', | ||||
|     'sphinx.ext.autosectionlabel', | ||||
|     'sphinx_tabs.tabs', | ||||
| ] | ||||
|  | ||||
| nb_output_stderr = 'remove-warn' | ||||
| autodoc_typehints = 'description' | ||||
| autosummary_generate = True  # Turn on sphinx.ext.autosummary | ||||
|  | ||||
| # Ignore >>> when copying code | ||||
| copybutton_prompt_text = r'>>> |\.\.\. ' | ||||
| copybutton_prompt_is_regexp = True | ||||
|  | ||||
| myst_enable_extensions = ['colon_fence'] | ||||
| # sphinx-autoapi configuration | ||||
| autoapi_dirs = ['../../lagent'] | ||||
| autoapi_options = [ | ||||
|     'members', | ||||
|     'undoc-members', | ||||
|     'show-inheritance', | ||||
|     'show-module-summary', | ||||
| ] | ||||
| autoapi_ignore = ['*migrations*', '*command.py', '*cli.py'] | ||||
| autoapi_template_dir = '_templates/autoapi' | ||||
| autoapi_add_toctree_entry = False | ||||
|  | ||||
| # Add any paths that contain templates here, relative to this directory. | ||||
| templates_path = ['_templates'] | ||||
|  | ||||
| # The suffix(es) of source filenames. | ||||
| # You can specify multiple suffix as a list of string: | ||||
| # | ||||
| source_suffix = { | ||||
|     '.rst': 'restructuredtext', | ||||
|     '.md': 'markdown', | ||||
| } | ||||
|  | ||||
| # The master toctree document. | ||||
| master_doc = 'index' | ||||
|  | ||||
| # List of patterns, relative to source directory, that match files and | ||||
| # directories to ignore when looking for source files. | ||||
| # This pattern also affects html_static_path and html_extra_path. | ||||
| exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] | ||||
| exclude_patterns = [] | ||||
|  | ||||
| # -- Options for HTML output ------------------------------------------------- | ||||
|  | ||||
| # The theme to use for HTML and HTML Help pages.  See the documentation for | ||||
| # a list of builtin themes. | ||||
| # | ||||
| # html_theme = 'sphinx_rtd_theme' | ||||
| html_theme = 'pytorch_sphinx_theme' | ||||
| html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] | ||||
| html_theme = 'sphinx_rtd_theme' | ||||
| html_theme_options = { | ||||
|     'menu': [ | ||||
|         { | ||||
|             'name': 'GitHub', | ||||
|             'url': 'https://github.com/InternLM/lagent' | ||||
|         }, | ||||
|     ], | ||||
|     # Specify the language of shared menu | ||||
|     'menu_lang': 'en' | ||||
|     'navigation_depth': 3, | ||||
|     'titles_only': False, | ||||
|     'style_nav_header_background': '#4fabab', | ||||
| } | ||||
|  | ||||
| language = 'en' | ||||
| html_context = { | ||||
|     'display_github': True, | ||||
|     'github_host': 'github.com', | ||||
|     'github_user': 'InternLM', | ||||
|     'github_repo': 'lagent', | ||||
|     'github_version': 'main', | ||||
|     'conf_py_path': '/docs/en/', | ||||
| } | ||||
| html_title = 'Lagent' | ||||
| html_logo = '../imgs/lagent_logo.png' | ||||
| html_favicon = '../imgs/lagent_icon.png' | ||||
|  | ||||
| master_doc = 'index' | ||||
|  | ||||
| # Add any paths that contain custom static files (such as style sheets) here, | ||||
| # relative to this directory. They are copied after the builtin static files, | ||||
| # so a file named "default.css" will overwrite the builtin "default.css". | ||||
| # so a file named 'default.css' will overwrite the builtin 'default.css'. | ||||
| html_static_path = ['_static'] | ||||
|  | ||||
| html_css_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css', | ||||
|     'css/readthedocs.css' | ||||
| ] | ||||
| html_js_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js', | ||||
|     'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js', | ||||
|     'js/collapsed.js', | ||||
|     'js/table.js', | ||||
| ] | ||||
|  | ||||
| myst_heading_anchors = 4 | ||||
|  | ||||
| intersphinx_mapping = { | ||||
|     'python': ('https://docs.python.org/3', None), | ||||
|     'numpy': ('https://numpy.org/doc/stable', None), | ||||
|     'torch': ('https://pytorch.org/docs/stable/', None), | ||||
| } | ||||
| def custom_skip(app, what, name, obj, skip, options): | ||||
|     if what in ['data', 'function', 'class'] and re.search('logger', name): | ||||
|         skip = True | ||||
|     return skip | ||||
|  | ||||
|  | ||||
| def builder_inited_handler(app): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def setup(app): | ||||
|     app.connect('builder-inited', builder_inited_handler) | ||||
| def setup(sphinx): | ||||
|     sphinx.connect('autoapi-skip-member', custom_skip) | ||||
|   | ||||
							
								
								
									
										19
									
								
								docs/en/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								docs/en/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| # Installation | ||||
|  | ||||
| ## With pip | ||||
|  | ||||
| Install with pip (Recommended). | ||||
|  | ||||
| ```bash | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| ## From source | ||||
|  | ||||
| Optionally, you could also build Lagent from source in case you want to modify the code: | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
| @@ -1,4 +1,4 @@ | ||||
| # OVERVIEW | ||||
| # Overview | ||||
|  | ||||
| This chapter introduces you to the framework of Lagent, and provides links to detailed tutorials about Lagent. | ||||
|  | ||||
| @@ -18,6 +18,6 @@ Lagent consists of 3 main parts, agents, llms, and actions. | ||||
|  | ||||
| Here is a detailed step-by-step guide to learn more about Lagent: | ||||
|  | ||||
| 1. For installation instructions, please see [README](../README.md). | ||||
| 1. For installation instructions, please see [README](https://github.com/InternLM/lagent/blob/main/README.md). | ||||
|  | ||||
| 2. We provide several examples to build agents with Lagent in [examples](examples/) by simply run `python examples/react_example.py`. | ||||
| 2. We provide several examples to build agents with Lagent in [examples](https://github.com/InternLM/lagent/tree/main/examples) by simply run `python examples/react_example.py`. | ||||
|   | ||||
							
								
								
									
										89
									
								
								docs/en/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								docs/en/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| # Quickstart | ||||
|  | ||||
| Using Lagent, you can easily build agents with just a few lines of code. | ||||
|  | ||||
| ## Run a ReWOO agent with GPT-3.5 | ||||
|  | ||||
| Below is an example of running ReWOO with GPT-3.5 | ||||
|  | ||||
| ```python | ||||
| # Import necessary modules and classes from the "lagent" library. | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| # Initialize the Language Model (llm) and provide your API key. | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
|  | ||||
| # Initialize the Google Search tool and provide your API key. | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # Create a chatbot by configuring the ReWOO agent. | ||||
| chatbot = ReWOO( | ||||
|     llm=llm,  # Provide the Language Model instance. | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool]  # Specify the actions the chatbot can perform. | ||||
|     ), | ||||
| ) | ||||
|  | ||||
| # Ask the chatbot a question and store the response. | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
|  | ||||
| # Print the chatbot's response. | ||||
| print(response.response)  # Output the response generated by the chatbot. | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> Film director. | ||||
| ``` | ||||
|  | ||||
| ## Run a ReAct agent with InternLM | ||||
|  | ||||
| NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first. | ||||
|  | ||||
| ```python | ||||
| # Import necessary modules and classes from the "lagent" library. | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
|  | ||||
| # Initialize the HFTransformer-based Language Model (llm) and | ||||
| # provide the model name. | ||||
| llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META) | ||||
|  | ||||
| # Initialize the Google Search tool and provide your API key. | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # Initialize the Python Interpreter tool. | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| # Create a chatbot by configuring the ReAct agent. | ||||
| # Specify the actions the chatbot can perform. | ||||
| chatbot = ReAct( | ||||
|     llm=llm,  # Provide the Language Model instance. | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]), | ||||
| ) | ||||
| # Ask the chatbot a mathematical question in LaTeX format. | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
|  | ||||
| # Print the chatbot's response. | ||||
| print(response.response)  # Output the response generated by the chatbot. | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
| ## Run ReAct Web Demo | ||||
|  | ||||
| ```python | ||||
| # You need to install streamlit first | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| ``` | ||||
|  | ||||
| Then you can chat through the UI shown as below | ||||
|  | ||||
| @@ -8,13 +8,31 @@ You can switch between English and Chinese in the lower-left corner of the layou | ||||
|    :caption: Get Started | ||||
|  | ||||
|    get_started/overview.md | ||||
|    get_started/install.md | ||||
|    get_started/quickstart.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|    :caption: Tutorials | ||||
|  | ||||
|    tutorials/action.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :caption: Switch Language | ||||
|  | ||||
|    switch_language.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|    :caption: API Reference | ||||
|  | ||||
|    autoapi/lagent/actions/index | ||||
|    autoapi/lagent/agents/index | ||||
|    autoapi/lagent/llms/index | ||||
|    autoapi/lagent/utils/index | ||||
|    autoapi/lagent/schema/index | ||||
|    autoapi/lagent/version/index | ||||
|  | ||||
|  | ||||
| Indices and tables | ||||
| ================== | ||||
|   | ||||
							
								
								
									
										396
									
								
								docs/en/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										396
									
								
								docs/en/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,396 @@ | ||||
| # Action | ||||
|  | ||||
| Actions, also called **tools**, provide a suite of functions LLM-driven agents can use to interact with the real world and perform complex tasks. | ||||
|  | ||||
| ## Basic Concepts | ||||
|  | ||||
| ### Tool & Toolkit | ||||
|  | ||||
| There are two categories of tools: | ||||
|  | ||||
| * tool: provide only one API to call. | ||||
| * toolkit: implement multiple APIs that undertake different sub-tasks. | ||||
|  | ||||
| ### Tool Description | ||||
|  | ||||
| In Lagent, the tool description is a dictionary containing the action's core information of usage, observed by LLMs for decision-making. | ||||
|  | ||||
| For simple tools, the description can be created as follows | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'bold',  # name of the tool | ||||
|     'description': 'a function used to make text bold',  # introduce the tool's function | ||||
|     'parameters': [  # a list of parameters the tool take. | ||||
|         { | ||||
|             'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|         } | ||||
|     ], | ||||
|     'required': ['text'],  # specify names of parameters required | ||||
| } | ||||
| ``` | ||||
|  | ||||
| In some situations there may be optional `return_data`, `parameter_description` keys describing the returns and argument passing format respectively. | ||||
|  | ||||
| ```{attention} | ||||
| `parameter_description` is usually inserted into the tool description automatically by the action's parser. It will be introduced in [Interface Design](#interface-design) . | ||||
| ``` | ||||
|  | ||||
|  | ||||
| For toolkits, the description is very similar but nest submethods | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'PhraseEmphasis',  # name of the toolkit | ||||
|     'description': 'a toolkit which provides different styles of text emphasis',  # introduce the tool's function | ||||
|     'api_list': [ | ||||
|         { | ||||
|             'name': 'bold', | ||||
|             'description': 'make text bold', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         }, | ||||
|         { | ||||
|             'name': 'italic', | ||||
|             'description': 'make text italic', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         } | ||||
|     ] | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## Make Functions Tools | ||||
|  | ||||
| It's not necessary to prepare an extra description for a defined function. In Lagent we provide a decorator `tool_api` which can conveniently turn a function into a tool by automatically parsing the function's typehints and dosctrings to generate the description dictionary and binding it to an attribute `api_description`. | ||||
|  | ||||
| ```python | ||||
| from lagent import tool_api | ||||
|  | ||||
| @tool_api | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         str: bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text']} | ||||
| ``` | ||||
|  | ||||
| Once `returns_named_value` is enabled you should declare the name of the return data, which will be processed to form a new field `return_data`: | ||||
|  | ||||
| ```python | ||||
| @tool_api(returns_named_value=True) | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         bold_text (str): bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text'], | ||||
|  'return_data': [{'name': 'bold_text', | ||||
|    'description': 'bold text', | ||||
|    'type': 'STRING'}]} | ||||
| ``` | ||||
|  | ||||
| Sometimes the tool may return a `dict` or `tuple`, and you want to elaborate each member in `return_data` rather than take them as a whole. Set `explode_return=True` and list them in the return part of docstrings. | ||||
|  | ||||
| ```python | ||||
| @tool_api(explode_return=True) | ||||
| def list_args(a: str, b: int, c: float = 0.0) -> dict: | ||||
|     """Return arguments in dict format | ||||
|  | ||||
|     Args: | ||||
|         a (str): a | ||||
|         b (int): b | ||||
|         c (float): c | ||||
|  | ||||
|     Returns: | ||||
|         dict: input arguments | ||||
|             - a (str): a | ||||
|             - b (int): b | ||||
|             - c: c | ||||
|     """ | ||||
|     return {'a': a, 'b': b, 'c': c} | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'list_args', | ||||
|  'description': 'Return arguments in dict format', | ||||
|  'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'}, | ||||
|   {'name': 'b', 'type': 'NUMBER', 'description': 'b'}, | ||||
|   {'name': 'c', 'type': 'FLOAT', 'description': 'c'}], | ||||
|  'required': ['a', 'b'], | ||||
|  'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'}, | ||||
|   {'name': 'b', 'description': 'b', 'type': 'NUMBER'}, | ||||
|   {'name': 'c', 'description': 'c'}]} | ||||
| ``` | ||||
|  | ||||
| ```{warning} | ||||
| Only Google style Python docstrings is currently supported. | ||||
| ``` | ||||
|  | ||||
| ## Interface Design | ||||
|  | ||||
| `BaseAction(description=None, parser=JsonParser, enable=True)` is the base class all actions should inherit from. It takes three initialization arguments | ||||
|  | ||||
| * **description**: a tool description dictionary, used set instance attribute `description`. Mostly you don't need explicitly pass this argument since the meta class of `BaseAction` will search methods decorated by `tool_api` and assemble their `api_description` as a class attribute `__tool_description__`, and if the initial `description` is left null, then `__tool_description__` will be copied as `description`. | ||||
| * **parser**: `BaseParser` class. It will instantialize a parser used to validate the arguments of APIs in `description`. | ||||
|  | ||||
|     For example, `JsonParser` requires arguments passed in the format of JSON or `dict`. To make LLMs aware of this, It inserts a field `parameter_description` into the `description`. | ||||
|  | ||||
|     ```python | ||||
|     from lagent import BaseAction | ||||
|  | ||||
|     action = BaseAction( | ||||
|         { | ||||
|             'name': 'bold', | ||||
|             'description': 'a function used to make text bold', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         } | ||||
|     ) | ||||
|     action.description | ||||
|     ``` | ||||
|  | ||||
|     ```python | ||||
|     {'name': 'bold', | ||||
|      'description': 'a function used to make text bold', | ||||
|      'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input content'}], | ||||
|      'required': ['text'], | ||||
|      'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'} | ||||
|     ``` | ||||
| * **enable**: specify whether the tool is available. | ||||
|  | ||||
| ### Custom Action | ||||
|  | ||||
| A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word. | ||||
|  | ||||
| ```python | ||||
| class Bold(BaseAction): | ||||
|      | ||||
|     @tool_api | ||||
|     def run(self, text: str): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
| class PhraseEmphasis(BaseAction): | ||||
|     """a toolkit which provides different styles of text emphasis""" | ||||
|  | ||||
|     @tool_api | ||||
|     def bold(self, text): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
|     @tool_api | ||||
|     def italic(self, text): | ||||
|         """make text italic | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: italic text | ||||
|         """ | ||||
|         return '*' + text + '*' | ||||
|  | ||||
| # Inspect the default description  | ||||
| # Bold.__tool_description__, PhraseEmphasis.__tool_description__ | ||||
| ``` | ||||
|  | ||||
| ### Auto-registration | ||||
|  | ||||
| Any subclass of `BaseAction` will be registered automatically. You can use `list_tools()` and `get_tool()` to view all tools and initialize by name. | ||||
|  | ||||
| ```python | ||||
| from lagent import list_tools, get_tool | ||||
|  | ||||
| list_tools() | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| ['BaseAction', | ||||
|  'InvalidAction', | ||||
|  'NoAction', | ||||
|  'FinishAction', | ||||
|  'ArxivSearch', | ||||
|  'BINGMap', | ||||
|  'GoogleScholar', | ||||
|  'GoogleSearch', | ||||
|  'IPythonInterpreter', | ||||
|  'PPT', | ||||
|  'PythonInterpreter', | ||||
|  'Bold', | ||||
|  'PhraseEmphasis'] | ||||
| ``` | ||||
| Create a `PhraseEmphasis` object | ||||
|  | ||||
| ```python | ||||
| action = get_tool('PhraseEmphasis') | ||||
| action.description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'PhraseEmphasis', | ||||
|  'description': 'a toolkit which provides different styles of text emphasis', | ||||
|  'api_list': [{'name': 'bold', | ||||
|    'description': 'make text bold', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|   {'name': 'italic', | ||||
|    'description': 'make text italic', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]} | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## Tool Calling | ||||
|  | ||||
| ### Run a Tool | ||||
|  | ||||
| `__call__` method of `Action` takes two arguments | ||||
|  | ||||
| * `inputs`: It depends on the action's parser. Often a string in specific formats generated by LLMs. | ||||
|   + `JsonParser`: Allow passing arguements in the format of JSON string or Python `dict`. | ||||
|   + `TupleParser`: Allow passing arguments in the format of tuple string format or Python `tuple`. | ||||
| * `name`: Which API to call. Default is `run`. | ||||
|  | ||||
| It returns an `ActionReturn` object which encapsulates calling details | ||||
|  | ||||
| * `args`: Dictionary of action inputs. | ||||
| * `type`: Action name. | ||||
| * `result`: List of dicts. Each contains two keys: 'type' and 'content'. when errors occur, it is `None`. | ||||
| * `errmsg`: Error message. Default is `None`. | ||||
|  | ||||
| Below is an example | ||||
|  | ||||
| ```python | ||||
| from lagent import IPythonInterpreter, TupleParser | ||||
|  | ||||
| action1 = IPythonInterpreter() | ||||
| ret = action1('{"command": "import math;math.sqrt(100)"}') | ||||
| print(ret.result) | ||||
| ret = action1({'command': 'import math;math.sqrt(100)'}) | ||||
| print(ret.result) | ||||
|  | ||||
| action2 = IPythonInterpreter(parser=TupleParser) | ||||
| ret = action2('("import math;math.sqrt(100)", )') | ||||
| print(ret.result) | ||||
| ret = action2(('import math;math.sqrt(100)',)) | ||||
| print(ret.result) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
|  | ||||
| ### Dynamic Invocation | ||||
|  | ||||
| Lagent provides an `ActionExecutor` to manage multiple tools. It will flatten `api_list` of toolkits and rename each `{tool_name}.{api_name}`. | ||||
|  | ||||
| ```python | ||||
| from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter | ||||
|  | ||||
| executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()]) | ||||
| executor.get_actions_info()  # This information is fed to LLMs as the tool meta prompt | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'name': 'ArxivSearch.get_arxiv_article_information', | ||||
|   'description': 'Run Arxiv search and get the article meta information.', | ||||
|   'parameters': [{'name': 'query', | ||||
|     'type': 'STRING', | ||||
|     'description': 'the content of search query'}], | ||||
|   'required': ['query'], | ||||
|   'return_data': [{'name': 'content', | ||||
|     'description': 'a list of 3 arxiv search papers', | ||||
|     'type': 'STRING'}], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|  {'name': 'IPythonInterpreter', | ||||
|   'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.", | ||||
|   'parameters': [{'name': 'command', | ||||
|     'type': 'STRING', | ||||
|     'description': 'Python code'}, | ||||
|    {'name': 'timeout', | ||||
|     'type': 'NUMBER', | ||||
|     'description': 'Upper bound of waiting time for Python script execution.'}], | ||||
|   'required': ['command'], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}] | ||||
| ``` | ||||
|  | ||||
| Trigger an action through the executor | ||||
|  | ||||
| ```python | ||||
| ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}') | ||||
| ret.result | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
							
								
								
									
										14
									
								
								docs/zh_cn/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								docs/zh_cn/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| API Reference | ||||
| ============= | ||||
|  | ||||
| This page contains auto-generated API reference documentation. | ||||
|  | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|     | ||||
|    {% for page in pages %} | ||||
|    {% if page.top_level_object and page.display %} | ||||
|    {{ page.include_path }} | ||||
|    {% endif %} | ||||
|    {% endfor %} | ||||
							
								
								
									
										112
									
								
								docs/zh_cn/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								docs/zh_cn/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,112 @@ | ||||
| {% if not obj.display %} | ||||
| :orphan: | ||||
|  | ||||
| {% endif %} | ||||
| :py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}` | ||||
| =========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }} | ||||
|  | ||||
| .. py:module:: {{ obj.name }} | ||||
|  | ||||
| {% if obj.docstring %} | ||||
| .. autoapi-nested-parse:: | ||||
|  | ||||
|    {{ obj.docstring|indent(3) }} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| {% block subpackages %} | ||||
| {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} | ||||
| {% if visible_subpackages %} | ||||
| Subpackages | ||||
| ----------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|  | ||||
| {% for subpackage in visible_subpackages %} | ||||
|    {{ subpackage.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block submodules %} | ||||
| {% set visible_submodules = obj.submodules|selectattr("display")|list %} | ||||
| {% if visible_submodules %} | ||||
| Submodules | ||||
| ---------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 1 | ||||
|  | ||||
| {% for submodule in visible_submodules %} | ||||
|    {{ submodule.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block content %} | ||||
| {% if obj.type is equalto("package") %} | ||||
| {% set visible_children = obj.children|selectattr("display")|list %} | ||||
| {% else %} | ||||
| {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} | ||||
| {% endif %} | ||||
| {% if visible_children %} | ||||
| {{ obj.type|title }} Contents | ||||
| {{ "-" * obj.type|length }}--------- | ||||
|  | ||||
| {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} | ||||
| {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} | ||||
| {% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %} | ||||
| {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} | ||||
| {% block classes scoped %} | ||||
| {% if visible_classes %} | ||||
| Classes | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for klass in visible_classes %} | ||||
|    {{ klass.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block functions scoped %} | ||||
| {% if visible_functions %} | ||||
| Functions | ||||
| ~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for function in visible_functions %} | ||||
|    {{ function.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block attributes scoped %} | ||||
| {% if visible_attributes %} | ||||
| Attributes | ||||
| ~~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for attribute in visible_attributes %} | ||||
|    {{ attribute.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% endif %} | ||||
| {% for obj_item in visible_children %} | ||||
| {{ obj_item.render()|indent(0) }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| @@ -11,18 +11,16 @@ | ||||
| # documentation root, use os.path.abspath to make it absolute, like shown here. | ||||
|  | ||||
| import os | ||||
| import subprocess | ||||
| import re | ||||
| import sys | ||||
|  | ||||
| import pytorch_sphinx_theme | ||||
|  | ||||
| sys.path.insert(0, os.path.abspath('../../')) | ||||
| sys.path.insert(0, os.path.abspath('../..')) | ||||
|  | ||||
| # -- Project information ----------------------------------------------------- | ||||
|  | ||||
| project = 'Lagent' | ||||
| copyright = '2020-2030, InternLM' | ||||
| author = 'InternLM' | ||||
| language = 'zh_CN' | ||||
|  | ||||
| # The full version, including alpha/beta/rc tags | ||||
| version_file = '../../lagent/version.py' | ||||
| @@ -37,97 +35,74 @@ release = __version__ | ||||
| # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||||
| # ones. | ||||
| extensions = [ | ||||
|     'sphinx_rtd_theme', | ||||
|     'myst_nb', | ||||
|     'autoapi.extension', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx.ext.autodoc', | ||||
|     'sphinx.ext.napoleon', | ||||
|     'sphinx.ext.viewcode', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx_copybutton', | ||||
|     'myst_parser', | ||||
|     'sphinx.ext.intersphinx', | ||||
|     'sphinx.ext.autodoc.typehints', | ||||
|     'sphinx.ext.autosummary', | ||||
|     'sphinx.ext.autosectionlabel', | ||||
|     'sphinx_tabs.tabs', | ||||
| ] | ||||
|  | ||||
| nb_output_stderr = 'remove-warn' | ||||
| autodoc_typehints = 'description' | ||||
|  | ||||
| autosummary_generate = True  # Turn on sphinx.ext.autosummary | ||||
| # Ignore >>> when copying code | ||||
| copybutton_prompt_text = r'>>> |\.\.\. ' | ||||
| copybutton_prompt_is_regexp = True | ||||
|  | ||||
| myst_enable_extensions = ['colon_fence'] | ||||
| # sphinx-autoapi configuration | ||||
| autoapi_dirs = ['../../lagent'] | ||||
| autoapi_options = [ | ||||
|     'members', | ||||
|     'undoc-members', | ||||
|     'show-inheritance', | ||||
|     'show-module-summary', | ||||
| ] | ||||
| autoapi_ignore = ['*migrations*', '*command.py', '*cli.py'] | ||||
| autoapi_template_dir = '_templates/autoapi' | ||||
| autoapi_add_toctree_entry = False | ||||
|  | ||||
| # Add any paths that contain templates here, relative to this directory. | ||||
| templates_path = ['_templates'] | ||||
|  | ||||
| # The suffix(es) of source filenames. | ||||
| # You can specify multiple suffix as a list of string: | ||||
| # | ||||
| source_suffix = { | ||||
|     '.rst': 'restructuredtext', | ||||
|     '.md': 'markdown', | ||||
| } | ||||
|  | ||||
| # The master toctree document. | ||||
| master_doc = 'index' | ||||
|  | ||||
| # List of patterns, relative to source directory, that match files and | ||||
| # directories to ignore when looking for source files. | ||||
| # This pattern also affects html_static_path and html_extra_path. | ||||
| exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] | ||||
| exclude_patterns = [] | ||||
|  | ||||
| # -- Options for HTML output ------------------------------------------------- | ||||
|  | ||||
| # The theme to use for HTML and HTML Help pages.  See the documentation for | ||||
| # a list of builtin themes. | ||||
| # | ||||
| # html_theme = 'sphinx_rtd_theme' | ||||
| html_theme = 'pytorch_sphinx_theme' | ||||
| html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] | ||||
| html_theme = 'sphinx_rtd_theme' | ||||
| html_theme_options = { | ||||
|     'menu': [ | ||||
|         { | ||||
|             'name': 'GitHub', | ||||
|             'url': 'https://github.com/InternLM/lagent' | ||||
|         }, | ||||
|     ], | ||||
|     # Specify the language of shared menu | ||||
|     'menu_lang': 'cn', | ||||
|     'navigation_depth': 3, | ||||
|     'titles_only': False, | ||||
|     'style_nav_header_background': '#4fabab', | ||||
| } | ||||
|  | ||||
| language = 'zh_CN' | ||||
| html_context = { | ||||
|     'display_github': True, | ||||
|     'github_host': 'github.com', | ||||
|     'github_user': 'InternLM', | ||||
|     'github_repo': 'lagent', | ||||
|     'github_version': 'main', | ||||
|     'conf_py_path': '/docs/en/', | ||||
| } | ||||
| html_title = 'Lagent' | ||||
| html_logo = '../imgs/lagent_logo.png' | ||||
| html_favicon = '../imgs/lagent_icon.png' | ||||
|  | ||||
| master_doc = 'index' | ||||
|  | ||||
| # Add any paths that contain custom static files (such as style sheets) here, | ||||
| # relative to this directory. They are copied after the builtin static files, | ||||
| # so a file named "default.css" will overwrite the builtin "default.css". | ||||
| # so a file named 'default.css' will overwrite the builtin 'default.css'. | ||||
| html_static_path = ['_static'] | ||||
| html_css_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css', | ||||
|     'css/readthedocs.css' | ||||
| ] | ||||
| html_js_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js', | ||||
|     'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js', | ||||
|     'js/collapsed.js', | ||||
|     'js/table.js', | ||||
| ] | ||||
|  | ||||
| myst_heading_anchors = 4 | ||||
|  | ||||
| # Configuration for intersphinx | ||||
| intersphinx_mapping = { | ||||
|     'python': ('https://docs.python.org/3', None), | ||||
|     'numpy': ('https://numpy.org/doc/stable', None), | ||||
|     'torch': ('https://pytorch.org/docs/stable/', None), | ||||
| } | ||||
|  | ||||
|  | ||||
| def builder_inited_handler(app): | ||||
|     subprocess.run(['./cp_origin_docs.sh']) | ||||
| def custom_skip(app, what, name, obj, skip, options): | ||||
|     if what in ['data', 'function', 'class'] and re.search('logger', name): | ||||
|         skip = True | ||||
|     return skip | ||||
|  | ||||
|  | ||||
| def setup(app): | ||||
|     app.connect('builder-inited', builder_inited_handler) | ||||
| def setup(sphinx): | ||||
|     sphinx.connect('autoapi-skip-member', custom_skip) | ||||
|   | ||||
							
								
								
									
										19
									
								
								docs/zh_cn/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								docs/zh_cn/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| # 安装方式 | ||||
|  | ||||
| ## pip安装 | ||||
|  | ||||
| 推荐使用 pip 安装 | ||||
|  | ||||
| ```bash | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| ## 源码安装 | ||||
|  | ||||
| 如需修改部分功能,可以从源码构建 Lagent | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
							
								
								
									
										23
									
								
								docs/zh_cn/get_started/overview.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								docs/zh_cn/get_started/overview.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| # 总览 | ||||
|  | ||||
| 本章节将介绍 Lagent 的架构,并提供 Lagent 详细教程的链接。 | ||||
|  | ||||
| ## Lagent 是什么 | ||||
|  | ||||
| Lagent 是一个开源的 LLM 智能体框架,允许使用者快速将一个大语言模型转换成智能体,并提供一些典型工具来激发大语言模型的潜能。Lagent 框架图如下: | ||||
|  | ||||
|  | ||||
|  | ||||
| Lagent 包含三个主要模块:agents,llms 和 actions。 | ||||
|  | ||||
| - **agents** 实现了多种智能体,如 ReAct,AutoGPT。 | ||||
| - **llms** 支持多种大语言模型,包括在 HuggingFace 上托管的开源模型(Llama-2, InterLM)及 GPT3.5/4 等闭源模型。 | ||||
| - **actions** 包含一系列工具,并提供工具执行器来统一管理。 | ||||
|  | ||||
| ## 如何使用 | ||||
|  | ||||
| 以下是帮助您了解关于 Lagent 更多信息的详细教程: | ||||
|  | ||||
| 1. 安装请参考 [README](https://github.com/InternLM/lagent/blob/main/README.md). | ||||
|  | ||||
| 2. 一些构建智能体的实例 [examples](https://github.com/InternLM/lagent/tree/main/examples),直接运行脚本即可,如 `python examples/react_example.py`. | ||||
							
								
								
									
										87
									
								
								docs/zh_cn/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								docs/zh_cn/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| # 快速上手 | ||||
|  | ||||
| 借助 Lagent 仅需几行代码就能构建大语言模型智能体。 | ||||
|  | ||||
| ## GPT-3.5 驱动的 ReWOO 智能体 | ||||
|  | ||||
| 下面是使用 GPT-3.5 运行 ReWOO 的示例 | ||||
|  | ||||
| ```python | ||||
| # 从 Lagent 导入必要的模块和类 | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| # 初始化 LLM,你可能需要提供 API 密钥 | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
|  | ||||
| # 初始化 Goolge 搜索工具,你可能需要提供 API 密钥 | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # 配置 ReWOO 智能体,创建聊天机器人 | ||||
| chatbot = ReWOO( | ||||
|     llm=llm,  # 大语言模型实例 | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool]  # 指定智能体可以调用的工具 | ||||
|     ), | ||||
| ) | ||||
|  | ||||
| # 询问问题并获取回复 | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
|  | ||||
| # 打印回复 | ||||
| print(response.response) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> Film director. | ||||
| ``` | ||||
|  | ||||
| ## InterLM 驱动的 ReAct 智能体 | ||||
|  | ||||
| 注意,如果你想使用 HuggingFace 模型,请先运行 `pip install -e .[all]` | ||||
|  | ||||
| ```python | ||||
| # 从 Lagent 导入必要的模块和类 | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
|  | ||||
| # 初始化 HFTransformer 模型 | ||||
| llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META) | ||||
|  | ||||
| # 初始化 Goolge 搜索工具,你可能需要提供 API 密钥 | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # 初始化 Python 代码解释其 | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| # 配置 ReAct 智能体,创建聊天机器人 | ||||
| chatbot = ReAct( | ||||
|     llm=llm,  # 大语言模型实例 | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]),  # 指定智能体可以调用的工具 | ||||
| ) | ||||
| # 询问LaTeX格式的数学问题 | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
|  | ||||
| # 打印回复 | ||||
| print(response.response) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
| ## 启动 ReAct 网页 App | ||||
|  | ||||
| ```python | ||||
| # 你需要先安装 streamlit | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| ``` | ||||
|  | ||||
| 然后你可以通过下图所示UI界面进行对话 | ||||
|  | ||||
| @@ -8,12 +8,32 @@ | ||||
|    :caption: 新手入门 | ||||
|  | ||||
|    get_started/overview.md | ||||
|    get_started/install.md | ||||
|    get_started/quickstart.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|    :caption: 教程 | ||||
|  | ||||
|    tutorials/action.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :caption: 切换语言 | ||||
|  | ||||
|    switch_language.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|    :caption: API 参考 | ||||
|  | ||||
|    autoapi/lagent/actions/index | ||||
|    autoapi/lagent/agents/index | ||||
|    autoapi/lagent/llms/index | ||||
|    autoapi/lagent/utils/index | ||||
|    autoapi/lagent/schema/index | ||||
|    autoapi/lagent/version/index | ||||
|  | ||||
|  | ||||
| 导引 | ||||
| ================== | ||||
|  | ||||
|   | ||||
							
								
								
									
										394
									
								
								docs/zh_cn/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										394
									
								
								docs/zh_cn/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,394 @@ | ||||
| # 动作 | ||||
|  | ||||
| 动作,也被称为工具,提供了一套LLM驱动的智能体用来与真实世界交互并执行复杂任务的函数。 | ||||
|  | ||||
| ## 基本概念 | ||||
|  | ||||
| ### 工具 & 工具包 | ||||
|  | ||||
| 有两种类型的工具: | ||||
|  | ||||
| * 简单工具: 只提供一个API接口供调用。 | ||||
| * 工具包: 实现多个API接口,承担不同的子任务。 | ||||
|  | ||||
| ### 工具描述 | ||||
|  | ||||
| 在Lagent中,工具描述是一个刻画工具调用方式的字典,能够被LLM观察并用于决策。 | ||||
|  | ||||
| 对于简单工具,描述可按如下格式声明: | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'bold',  # 工具名称 | ||||
|     'description': 'a function used to make text bold',  # 介绍工具的功能 | ||||
|     'parameters': [  # 这个工具所需要的参数列表 | ||||
|         { | ||||
|             'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|         } | ||||
|     ], | ||||
|     'required': ['text'],  # 指定必需的参数名 | ||||
| } | ||||
| ``` | ||||
| 在某些情况下,可能还包含 `return_data`,`parameter_description` 字段,分别描述返回内容及参数传递格式。 | ||||
|  | ||||
| ```{attention} | ||||
| `parameter_description` 通常被动作的解析器自动插入到工具描述中,这部分将在[接口设计](#id6)中进行介绍。 | ||||
| ``` | ||||
|  | ||||
| 对于工具包,描述非常相似,但嵌套了子方法 | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'PhraseEmphasis',  # 工具包的名字 | ||||
|     'description': 'a toolkit which provides different styles of text emphasis',  # 介绍工具包的功能 | ||||
|     'api_list': [ | ||||
|         { | ||||
|             'name': 'bold', | ||||
|             'description': 'make text bold', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         }, | ||||
|         { | ||||
|             'name': 'italic', | ||||
|             'description': 'make text italic', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         } | ||||
|     ] | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 将函数转换为工具 | ||||
|  | ||||
| 对于已定义好的函数,无需人工添加额外的描述。在 Lagent 中,我们提供了一个修饰器 `tool_api`,它可以通过自动解析函数的类型提示和文档字符串来生成描述字典,并将其绑定到属性 `api_description`。 | ||||
|  | ||||
| ```python | ||||
| from lagent import tool_api | ||||
|  | ||||
| @tool_api | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         str: bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text']} | ||||
| ``` | ||||
|  | ||||
| 一旦启用 `returns_named_value`,您应当声明返回值的名称,这将被处理成一个新的字段 `return_data`: | ||||
|  | ||||
| ```python | ||||
| @tool_api(returns_named_value=True) | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         bold_text (str): bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text'], | ||||
|  'return_data': [{'name': 'bold_text', | ||||
|    'description': 'bold text', | ||||
|    'type': 'STRING'}]} | ||||
| ``` | ||||
|  | ||||
| 有时工具可能返回一个 `dict` 或 `tuple`,如果你想在 `return_data` 中详细说明每个成员的含义而不是把它们当作一个整体,设置 `explode_return=True` 并在文档字符串的 Returns 部分中罗列它们。 | ||||
|  | ||||
| ```python | ||||
| @tool_api(explode_return=True) | ||||
| def list_args(a: str, b: int, c: float = 0.0) -> dict: | ||||
|     """Return arguments in dict format | ||||
|  | ||||
|     Args: | ||||
|         a (str): a | ||||
|         b (int): b | ||||
|         c (float): c | ||||
|  | ||||
|     Returns: | ||||
|         dict: input arguments | ||||
|             - a (str): a | ||||
|             - b (int): b | ||||
|             - c: c | ||||
|     """ | ||||
|     return {'a': a, 'b': b, 'c': c} | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'list_args', | ||||
|  'description': 'Return arguments in dict format', | ||||
|  'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'}, | ||||
|   {'name': 'b', 'type': 'NUMBER', 'description': 'b'}, | ||||
|   {'name': 'c', 'type': 'FLOAT', 'description': 'c'}], | ||||
|  'required': ['a', 'b'], | ||||
|  'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'}, | ||||
|   {'name': 'b', 'description': 'b', 'type': 'NUMBER'}, | ||||
|   {'name': 'c', 'description': 'c'}]} | ||||
| ``` | ||||
|  | ||||
| ```{warning} | ||||
| 目前仅支持 Google 格式的 Python 文档字符串。 | ||||
| ``` | ||||
|  | ||||
| ## 接口设计 | ||||
|  | ||||
| `BaseAction(description=None, parser=JsonParser, enable=True)` 是所有动作应该继承的基类,它接收三个初始化参数: | ||||
|  | ||||
| * **description**:一个工具描述的字典,用于设置实例属性 `description`。通常不需要显式地传递这个参数,因为 `BaseAction` 的元类将查找被 `tool_api` 装饰的方法,并组装它们的 `api_description` 构造一个类属性 `__tool_description__`,如果实例化时 `description` 为空,那么该实例属性将置为 `__tool_description__`。 | ||||
| * **parser**:`BaseParser` 类,用于实例化一个动作解析器校验 `description` 所描述的工具的参数。例如,`JsonParser` 会要求模型在调用工具时传入一个 JSON 格式字符串或者 Python 字典,为了让 LLM 感知到该指令,它会在 `description` 中插入一个 `parameter_description` 字段。 | ||||
|  | ||||
|     ```python | ||||
|     from lagent import BaseAction | ||||
|  | ||||
|     action = BaseAction( | ||||
|         { | ||||
|             'name': 'bold', | ||||
|             'description': 'a function used to make text bold', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         } | ||||
|     ) | ||||
|     action.description | ||||
|     ``` | ||||
|  | ||||
|     ```python | ||||
|     {'name': 'bold', | ||||
|      'description': 'a function used to make text bold', | ||||
|      'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input content'}], | ||||
|      'required': ['text'], | ||||
|      'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'} | ||||
|     ``` | ||||
|  | ||||
| * **enable**: 指明该动作是否生效。 | ||||
|  | ||||
| ### 自定义动作 | ||||
|  | ||||
| 一个简单工具必须实现 `run` 方法,而工具包则应当避免将各子API名称定义为该保留字段。 | ||||
|  | ||||
| ```python | ||||
| class Bold(BaseAction): | ||||
|      | ||||
|     @tool_api | ||||
|     def run(self, text: str): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
| class PhraseEmphasis(BaseAction): | ||||
|     """a toolkit which provides different styles of text emphasis""" | ||||
|  | ||||
|     @tool_api | ||||
|     def bold(self, text): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
|     @tool_api | ||||
|     def italic(self, text): | ||||
|         """make text italic | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: italic text | ||||
|         """ | ||||
|         return '*' + text + '*' | ||||
|  | ||||
| # 查看默认工具描述 | ||||
| # Bold.__tool_description__, PhraseEmphasis.__tool_description__ | ||||
| ``` | ||||
|  | ||||
| ### 自动注册 | ||||
|  | ||||
| 任何 `BaseAction` 的子类都会自动被注册。你可以使用 `list_tools()` 和 `get_tool()` 来查看所有工具类并通过工具名进行初始化。 | ||||
|  | ||||
| ```python | ||||
| from lagent import list_tools, get_tool | ||||
|  | ||||
| list_tools() | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| ['BaseAction', | ||||
|  'InvalidAction', | ||||
|  'NoAction', | ||||
|  'FinishAction', | ||||
|  'ArxivSearch', | ||||
|  'BINGMap', | ||||
|  'GoogleScholar', | ||||
|  'GoogleSearch', | ||||
|  'IPythonInterpreter', | ||||
|  'PPT', | ||||
|  'PythonInterpreter', | ||||
|  'Bold', | ||||
|  'PhraseEmphasis'] | ||||
| ``` | ||||
|  | ||||
| 创建一个 `PhraseEmphasis` 对象。 | ||||
|  | ||||
| ```python | ||||
| action = get_tool('PhraseEmphasis') | ||||
| action.description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'PhraseEmphasis', | ||||
|  'description': 'a toolkit which provides different styles of text emphasis', | ||||
|  'api_list': [{'name': 'bold', | ||||
|    'description': 'make text bold', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|   {'name': 'italic', | ||||
|    'description': 'make text italic', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]} | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## 工具调用 | ||||
|  | ||||
| ### 执行工具 | ||||
|  | ||||
| `Action` 的 `__call__` 方法需要传入两个参数 | ||||
|  | ||||
| * `inputs`: 其类型与动作绑定的 `BaseParser` 相关,通常是由大语言模型生成的字符串。 | ||||
|   + `JsonParser`: 允许传入 JSON 格式字符串或 Python 字典。 | ||||
|   + `TupleParser`: 允许传入字面量为元组的字符串或 Python 元组。 | ||||
| * `name`: 调用哪个 API,默认为 `run`。 | ||||
|  | ||||
| 工具会返回一个封装了调用细节的 `ActionReturn` 对象。 | ||||
|  | ||||
| * `args`: 一个字典,表示该动作的入参。 | ||||
| * `type`: 动作名称。 | ||||
| * `result`: 以字典为成员的列表,每个字典包含两个键——'type' 和 'content',发生异常时该字段为 `None`。 | ||||
| * `errmsg`: 错误信息,默认为 `None`。 | ||||
|  | ||||
| 以下是一个例子: | ||||
|  | ||||
| ```python | ||||
| from lagent import IPythonInterpreter, TupleParser | ||||
|  | ||||
| action1 = IPythonInterpreter() | ||||
| ret = action1('{"command": "import math;math.sqrt(100)"}') | ||||
| print(ret.result) | ||||
| ret = action1({'command': 'import math;math.sqrt(100)'}) | ||||
| print(ret.result) | ||||
|  | ||||
| action2 = IPythonInterpreter(parser=TupleParser) | ||||
| ret = action2('("import math;math.sqrt(100)", )') | ||||
| print(ret.result) | ||||
| ret = action2(('import math;math.sqrt(100)',)) | ||||
| print(ret.result) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
|  | ||||
| ### 动态触发 | ||||
|  | ||||
| Lagent 提供 `ActionExecutor` 接口管理多个工具,它会将工具包的 `api_list` 平展并将各 API 更名为 `{tool_name}.{api_name}`。 | ||||
|  | ||||
| ```python | ||||
| from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter | ||||
|  | ||||
| executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()]) | ||||
| executor.get_actions_info()  # 该结果会作为LLM系统提示词的一部分 | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'name': 'ArxivSearch.get_arxiv_article_information', | ||||
|   'description': 'Run Arxiv search and get the article meta information.', | ||||
|   'parameters': [{'name': 'query', | ||||
|     'type': 'STRING', | ||||
|     'description': 'the content of search query'}], | ||||
|   'required': ['query'], | ||||
|   'return_data': [{'name': 'content', | ||||
|     'description': 'a list of 3 arxiv search papers', | ||||
|     'type': 'STRING'}], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|  {'name': 'IPythonInterpreter', | ||||
|   'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.", | ||||
|   'parameters': [{'name': 'command', | ||||
|     'type': 'STRING', | ||||
|     'description': 'Python code'}, | ||||
|    {'name': 'timeout', | ||||
|     'type': 'NUMBER', | ||||
|     'description': 'Upper bound of waiting time for Python script execution.'}], | ||||
|   'required': ['command'], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}] | ||||
| ``` | ||||
|  | ||||
| 通过动作执行器来触发一个工具 | ||||
|  | ||||
| ```python | ||||
| ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}') | ||||
| ret.result | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
							
								
								
									
										333
									
								
								examples/internlm2_agent_web_demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										333
									
								
								examples/internlm2_agent_web_demo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,333 @@ | ||||
| import copy | ||||
| import hashlib | ||||
| import json | ||||
| import os | ||||
|  | ||||
| import streamlit as st | ||||
|  | ||||
| from lagent.actions import ActionExecutor, ArxivSearch, GoogleScholar, IPythonInterpreter | ||||
| from lagent.agents.internlm2_agent import (INTERPRETER_CN, META_INS, PLUGIN_CN, | ||||
|                                            Internlm2Agent, Interlm2Protocol) | ||||
| from lagent.llms.lmdepoly_wrapper import LMDeployClient | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
| from lagent.schema import AgentStatusCode | ||||
|  | ||||
| # from streamlit.logger import get_logger | ||||
|  | ||||
|  | ||||
| class SessionState: | ||||
|  | ||||
|     def init_state(self): | ||||
|         """Initialize session state variables.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|  | ||||
|         action_list = [ | ||||
|             GoogleScholar( | ||||
|                 api_key=('a558de7dee10146326ca86fbaa0736b' | ||||
|                          'dd947c9e646cd3f14da5aff177d6b2ff0')), | ||||
|             ArxivSearch(), | ||||
|         ] | ||||
|         st.session_state['plugin_map'] = { | ||||
|             action.name: action | ||||
|             for action in action_list | ||||
|         } | ||||
|         st.session_state['model_map'] = {} | ||||
|         st.session_state['model_selected'] = None | ||||
|         st.session_state['plugin_actions'] = set() | ||||
|         st.session_state['history'] = [] | ||||
|  | ||||
|     def clear_state(self): | ||||
|         """Clear the existing session state.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|         st.session_state['model_selected'] = None | ||||
|         st.session_state['file'] = set() | ||||
|         if 'chatbot' in st.session_state: | ||||
|             st.session_state['chatbot']._session_history = [] | ||||
|  | ||||
|  | ||||
| class StreamlitUI: | ||||
|  | ||||
|     def __init__(self, session_state: SessionState): | ||||
|         self.init_streamlit() | ||||
|         self.session_state = session_state | ||||
|  | ||||
|     def init_streamlit(self): | ||||
|         """Initialize Streamlit's UI settings.""" | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|         st.sidebar.title('模型控制') | ||||
|         st.session_state['file'] = set() | ||||
|         st.session_state['ip'] = None | ||||
|  | ||||
|     def setup_sidebar(self): | ||||
|         """Setup the sidebar for model and plugin selection.""" | ||||
|         model_name = st.sidebar.selectbox('模型选择:', options=['internlm']) | ||||
|         meta_prompt = st.sidebar.text_area('系统提示词', value=META_INS) | ||||
|         da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN) | ||||
|         plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN) | ||||
|         model_ip = st.sidebar.text_input('模型IP:', value='10.140.0.220:23333') | ||||
|         if model_name != st.session_state[ | ||||
|                 'model_selected'] or st.session_state['ip'] != model_ip: | ||||
|             st.session_state['ip'] = model_ip | ||||
|             model = self.init_model(model_name, model_ip) | ||||
|             self.session_state.clear_state() | ||||
|             st.session_state['model_selected'] = model_name | ||||
|             if 'chatbot' in st.session_state: | ||||
|                 del st.session_state['chatbot'] | ||||
|         else: | ||||
|             model = st.session_state['model_map'][model_name] | ||||
|  | ||||
|         plugin_name = st.sidebar.multiselect( | ||||
|             '插件选择', | ||||
|             options=list(st.session_state['plugin_map'].keys()), | ||||
|             default=[], | ||||
|         ) | ||||
|         da_flag = st.sidebar.checkbox( | ||||
|             '数据分析', | ||||
|             value=False, | ||||
|         ) | ||||
|         plugin_action = [ | ||||
|             st.session_state['plugin_map'][name] for name in plugin_name | ||||
|         ] | ||||
|  | ||||
|         if 'chatbot' in st.session_state: | ||||
|             if len(plugin_action) > 0: | ||||
|                 st.session_state['chatbot']._action_executor = ActionExecutor( | ||||
|                     actions=plugin_action) | ||||
|             else: | ||||
|                 st.session_state['chatbot']._action_executor = None | ||||
|             if da_flag: | ||||
|                 st.session_state[ | ||||
|                     'chatbot']._interpreter_executor = ActionExecutor( | ||||
|                         actions=[IPythonInterpreter()]) | ||||
|             else: | ||||
|                 st.session_state['chatbot']._interpreter_executor = None | ||||
|             st.session_state['chatbot']._protocol._meta_template = meta_prompt | ||||
|             st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt | ||||
|             st.session_state[ | ||||
|                 'chatbot']._protocol.interpreter_prompt = da_prompt | ||||
|         if st.sidebar.button('清空对话', key='clear'): | ||||
|             self.session_state.clear_state() | ||||
|         uploaded_file = st.sidebar.file_uploader('上传文件') | ||||
|  | ||||
|         return model_name, model, plugin_action, uploaded_file, model_ip | ||||
|  | ||||
|     def init_model(self, option, ip=None): | ||||
|         """Initialize the model based on the selected option.""" | ||||
|         model_url = f'http://{ip}' | ||||
|         st.session_state['model_map'][option] = LMDeployClient( | ||||
|             path='internlm2-chat-20b', | ||||
|             url=model_url, | ||||
|             meta_template=META, | ||||
|             top_p=0.8, | ||||
|             top_k=100, | ||||
|             temperature=0, | ||||
|             repetition_penalty=1.0, | ||||
|             stop_words=['<|im_end|>']) | ||||
|         return st.session_state['model_map'][option] | ||||
|  | ||||
|     def initialize_chatbot(self, model, plugin_action): | ||||
|         """Initialize the chatbot with the given model and plugin actions.""" | ||||
|         return Internlm2Agent( | ||||
|             llm=model, | ||||
|             protocol=Interlm2Protocol( | ||||
|                 tool=dict( | ||||
|                     begin='{start_token}{name}\n', | ||||
|                     start_token='<|action_start|>', | ||||
|                     name_map=dict( | ||||
|                         plugin='<|plugin|>', interpreter='<|interpreter|>'), | ||||
|                     belong='assistant', | ||||
|                     end='<|action_end|>\n', | ||||
|                 ), ), | ||||
|         ) | ||||
|  | ||||
|     def render_user(self, prompt: str): | ||||
|         with st.chat_message('user'): | ||||
|             st.markdown(prompt) | ||||
|  | ||||
|     def render_assistant(self, agent_return): | ||||
|         with st.chat_message('assistant'): | ||||
|             for action in agent_return.actions: | ||||
|                 if (action) and (action.type != 'FinishAction'): | ||||
|                     self.render_action(action) | ||||
|             st.markdown(agent_return.response) | ||||
|  | ||||
|     def render_plugin_args(self, action): | ||||
|         action_name = action.type | ||||
|         args = action.args | ||||
|         import json | ||||
|         parameter_dict = dict(name=action_name, parameters=args) | ||||
|         parameter_str = '```json\n' + json.dumps( | ||||
|             parameter_dict, indent=4, ensure_ascii=False) + '\n```' | ||||
|         st.markdown(parameter_str) | ||||
|  | ||||
|     def render_interpreter_args(self, action): | ||||
|         st.info(action.type) | ||||
|         st.markdown(action.args['text']) | ||||
|  | ||||
|     def render_action(self, action): | ||||
|         st.markdown(action.thought) | ||||
|         if action.type == 'IPythonInterpreter': | ||||
|             self.render_interpreter_args(action) | ||||
|         elif action.type == 'FinishAction': | ||||
|             pass | ||||
|         else: | ||||
|             self.render_plugin_args(action) | ||||
|         self.render_action_results(action) | ||||
|  | ||||
|     def render_action_results(self, action): | ||||
|         """Render the results of action, including text, images, videos, and | ||||
|         audios.""" | ||||
|         if (isinstance(action.result, dict)): | ||||
|             if 'text' in action.result: | ||||
|                 st.markdown('```\n' + action.result['text'] + '\n```') | ||||
|             if 'image' in action.result: | ||||
|                 # image_path = action.result['image'] | ||||
|                 for image_path in action.result['image']: | ||||
|                     image_data = open(image_path, 'rb').read() | ||||
|                     st.image(image_data, caption='Generated Image') | ||||
|             if 'video' in action.result: | ||||
|                 video_data = action.result['video'] | ||||
|                 video_data = open(video_data, 'rb').read() | ||||
|                 st.video(video_data) | ||||
|             if 'audio' in action.result: | ||||
|                 audio_data = action.result['audio'] | ||||
|                 audio_data = open(audio_data, 'rb').read() | ||||
|                 st.audio(audio_data) | ||||
|         elif isinstance(action.result, list): | ||||
|             for item in action.result: | ||||
|                 if item['type'] == 'text': | ||||
|                     st.markdown('```\n' + item['content'] + '\n```') | ||||
|                 elif item['type'] == 'image': | ||||
|                     image_data = open(item['content'], 'rb').read() | ||||
|                     st.image(image_data, caption='Generated Image') | ||||
|                 elif item['type'] == 'video': | ||||
|                     video_data = open(item['content'], 'rb').read() | ||||
|                     st.video(video_data) | ||||
|                 elif item['type'] == 'audio': | ||||
|                     audio_data = open(item['content'], 'rb').read() | ||||
|                     st.audio(audio_data) | ||||
|         if action.errmsg: | ||||
|             st.error(action.errmsg) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # logger = get_logger(__name__) | ||||
|     # Initialize Streamlit UI and setup sidebar | ||||
|     if 'ui' not in st.session_state: | ||||
|         session_state = SessionState() | ||||
|         session_state.init_state() | ||||
|         st.session_state['ui'] = StreamlitUI(session_state) | ||||
|  | ||||
|     else: | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|     _, model, plugin_action, uploaded_file, _ = st.session_state[ | ||||
|         'ui'].setup_sidebar() | ||||
|  | ||||
|     # Initialize chatbot if it is not already initialized | ||||
|     # or if the model has changed | ||||
|     if 'chatbot' not in st.session_state or model != st.session_state[ | ||||
|             'chatbot']._llm: | ||||
|         st.session_state['chatbot'] = st.session_state[ | ||||
|             'ui'].initialize_chatbot(model, plugin_action) | ||||
|         st.session_state['session_history'] = [] | ||||
|  | ||||
|     for prompt, agent_return in zip(st.session_state['user'], | ||||
|                                     st.session_state['assistant']): | ||||
|         st.session_state['ui'].render_user(prompt) | ||||
|         st.session_state['ui'].render_assistant(agent_return) | ||||
|  | ||||
|     if user_input := st.chat_input(''): | ||||
|         with st.container(): | ||||
|             st.session_state['ui'].render_user(user_input) | ||||
|         st.session_state['user'].append(user_input) | ||||
|         # Add file uploader to sidebar | ||||
|         if (uploaded_file | ||||
|                 and uploaded_file.name not in st.session_state['file']): | ||||
|  | ||||
|             st.session_state['file'].add(uploaded_file.name) | ||||
|             file_bytes = uploaded_file.read() | ||||
|             file_type = uploaded_file.type | ||||
|             if 'image' in file_type: | ||||
|                 st.image(file_bytes, caption='Uploaded Image') | ||||
|             elif 'video' in file_type: | ||||
|                 st.video(file_bytes, caption='Uploaded Video') | ||||
|             elif 'audio' in file_type: | ||||
|                 st.audio(file_bytes, caption='Uploaded Audio') | ||||
|             # Save the file to a temporary location and get the path | ||||
|  | ||||
|             postfix = uploaded_file.name.split('.')[-1] | ||||
|             # prefix = str(uuid.uuid4()) | ||||
|             prefix = hashlib.md5(file_bytes).hexdigest() | ||||
|             filename = f'{prefix}.{postfix}' | ||||
|             file_path = os.path.join(root_dir, filename) | ||||
|             with open(file_path, 'wb') as tmpfile: | ||||
|                 tmpfile.write(file_bytes) | ||||
|             file_size = os.stat(file_path).st_size / 1024 / 1024 | ||||
|             file_size = f'{round(file_size, 2)} MB' | ||||
|             # st.write(f'File saved at: {file_path}') | ||||
|             user_input = [ | ||||
|                 dict(role='user', content=user_input), | ||||
|                 dict( | ||||
|                     role='user', | ||||
|                     content=json.dumps(dict(path=file_path, size=file_size)), | ||||
|                     name='file') | ||||
|             ] | ||||
|         if isinstance(user_input, str): | ||||
|             user_input = [dict(role='user', content=user_input)] | ||||
|         st.session_state['last_status'] = AgentStatusCode.SESSION_READY | ||||
|         for agent_return in st.session_state['chatbot'].stream_chat( | ||||
|                 st.session_state['session_history'] + user_input): | ||||
|             if agent_return.state == AgentStatusCode.PLUGIN_RETURN: | ||||
|                 with st.container(): | ||||
|                     st.session_state['ui'].render_plugin_args( | ||||
|                         agent_return.actions[-1]) | ||||
|                     st.session_state['ui'].render_action_results( | ||||
|                         agent_return.actions[-1]) | ||||
|             elif agent_return.state == AgentStatusCode.CODE_RETURN: | ||||
|                 with st.container(): | ||||
|                     st.session_state['ui'].render_action_results( | ||||
|                         agent_return.actions[-1]) | ||||
|             elif (agent_return.state == AgentStatusCode.STREAM_ING | ||||
|                   or agent_return.state == AgentStatusCode.CODING): | ||||
|                 # st.markdown(agent_return.response) | ||||
|                 # 清除占位符的当前内容,并显示新内容 | ||||
|                 with st.container(): | ||||
|                     if agent_return.state != st.session_state['last_status']: | ||||
|                         st.session_state['temp'] = '' | ||||
|                         placeholder = st.empty() | ||||
|                         st.session_state['placeholder'] = placeholder | ||||
|                     if isinstance(agent_return.response, dict): | ||||
|                         action = f"\n\n {agent_return.response['name']}: \n\n" | ||||
|                         action_input = agent_return.response['parameters'] | ||||
|                         if agent_return.response['name'] == 'IPythonInterpreter': | ||||
|                             action_input = action_input['command'] | ||||
|                         response = action + action_input | ||||
|                     else: | ||||
|                         response = agent_return.response | ||||
|                     st.session_state['temp'] = response | ||||
|                     st.session_state['placeholder'].markdown( | ||||
|                         st.session_state['temp']) | ||||
|             elif agent_return.state == AgentStatusCode.END: | ||||
|                 st.session_state['session_history'] += (user_input + agent_return.inner_steps) | ||||
|                 agent_return = copy.deepcopy(agent_return) | ||||
|                 agent_return.response = st.session_state['temp'] | ||||
|                 st.session_state['assistant'].append( | ||||
|                     copy.deepcopy(agent_return)) | ||||
|             st.session_state['last_status'] = agent_return.state | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
|     root_dir = os.path.join(root_dir, 'tmp_dir') | ||||
|     os.makedirs(root_dir, exist_ok=True) | ||||
|     main() | ||||
| @@ -1,11 +1,61 @@ | ||||
| from typing import Type | ||||
|  | ||||
| from .action_executor import ActionExecutor | ||||
| from .base_action import BaseAction | ||||
| from .arxiv_search import ArxivSearch | ||||
| from .base_action import TOOL_REGISTRY, BaseAction, tool_api | ||||
| from .bing_map import BINGMap | ||||
| from .builtin_actions import FinishAction, InvalidAction, NoAction | ||||
| from .google_scholar_search import GoogleScholar | ||||
| from .google_search import GoogleSearch | ||||
| from .llm_qa import LLMQA | ||||
| from .ipython_interpreter import IPythonInterpreter | ||||
| from .parser import BaseParser, JsonParser, TupleParser | ||||
| from .ppt import PPT | ||||
| from .python_interpreter import PythonInterpreter | ||||
|  | ||||
| __all__ = [ | ||||
|     'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction', | ||||
|     'FinishAction', 'GoogleSearch', 'PythonInterpreter', 'LLMQA' | ||||
|     'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction', | ||||
|     'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch', | ||||
|     'GoogleScholar', 'IPythonInterpreter', 'PythonInterpreter', 'PPT', | ||||
|     'BaseParser', 'JsonParser', 'TupleParser', 'tool_api', 'list_tools', | ||||
|     'get_tool_cls', 'get_tool' | ||||
| ] | ||||
|  | ||||
|  | ||||
| def list_tools(with_class: bool = False): | ||||
|     """List available tools | ||||
|  | ||||
|     Args: | ||||
|         with_class (bool): whether to return the action class along  | ||||
|             with its name. Defaults to ``False``. | ||||
|  | ||||
|     Returns: | ||||
|         list: all action names | ||||
|     """ | ||||
|     return list(TOOL_REGISTRY.items()) if with_class else list( | ||||
|         TOOL_REGISTRY.keys()) | ||||
|  | ||||
|  | ||||
| def get_tool_cls(specifier: str) -> Type[BaseAction]: | ||||
|     """Get the action class | ||||
|  | ||||
|     Args: | ||||
|         specifier (:class:`str`): tool name | ||||
|  | ||||
|     Returns: | ||||
|         Type[BaseAction]: action class | ||||
|     """ | ||||
|     return TOOL_REGISTRY.get_class(specifier) | ||||
|  | ||||
|  | ||||
| def get_tool(specifier: str, *args, **kwargs) -> BaseAction: | ||||
|     """Instantiate an action | ||||
|  | ||||
|     Args: | ||||
|         specifier (str): tool name | ||||
|         args: positional arguments passed to the action's ``__init__`` method | ||||
|         kwargs: keyword arguments passed to the action's ``__init__`` method | ||||
|  | ||||
|     Returns: | ||||
|         :class:`BaseAction`: action object | ||||
|     """ | ||||
|     return TOOL_REGISTRY.get(specifier, *args, **kwargs) | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from typing import Any, Dict, List, Union | ||||
| from typing import Dict, List, Union | ||||
|  | ||||
| from lagent.schema import ActionReturn, ActionValidCode | ||||
| from .base_action import BaseAction | ||||
| @@ -39,14 +39,20 @@ class ActionExecutor: | ||||
|         self.no_action = no_action | ||||
|         self.finish_action = finish_action | ||||
|  | ||||
|     def get_actions_info(self, only_enable: bool = True) -> Dict: | ||||
|         if only_enable: | ||||
|             return { | ||||
|                 k: v.description | ||||
|                 for k, v in self.actions.items() if v.enable | ||||
|             } | ||||
|         else: | ||||
|             return {k: v.description for k, v in self.actions.items()} | ||||
|     def get_actions_info(self) -> List[Dict]: | ||||
|         actions = [] | ||||
|         for action_name, action in self.actions.items(): | ||||
|             if not action.enable: | ||||
|                 continue | ||||
|             if action.is_toolkit: | ||||
|                 for api in action.description['api_list']: | ||||
|                     api_desc = api.copy() | ||||
|                     api_desc['name'] = f"{action_name}.{api_desc['name']}" | ||||
|                     actions.append(api_desc) | ||||
|             else: | ||||
|                 action_desc = action.description.copy() | ||||
|                 actions.append(action_desc) | ||||
|         return actions | ||||
|  | ||||
|     def is_valid(self, name: str): | ||||
|         return name in self.actions and self.actions[name].enable | ||||
| @@ -66,19 +72,17 @@ class ActionExecutor: | ||||
|         if name in self.actions: | ||||
|             del self.actions[name] | ||||
|  | ||||
|     def __call__(self, name: str, command: Any) -> ActionReturn: | ||||
|         if isinstance(command, str): | ||||
|             args, kwargs = (command, ), {} | ||||
|         else: | ||||
|             args, kwargs = (), command | ||||
|         if not self.is_valid(name): | ||||
|     def __call__(self, name: str, command: str) -> ActionReturn: | ||||
|         action_name, api_name = ( | ||||
|             name.split('.') if '.' in name else (name, 'run')) | ||||
|         if not self.is_valid(action_name): | ||||
|             if name == self.no_action.name: | ||||
|                 action_return = self.no_action.run(*args, **kwargs) | ||||
|                 action_return = self.no_action(command) | ||||
|             elif name == self.finish_action.name: | ||||
|                 action_return = self.finish_action.run(*args, **kwargs) | ||||
|                 action_return = self.finish_action(command) | ||||
|             else: | ||||
|                 action_return = self.invalid_action(*args, **kwargs) | ||||
|                 action_return = self.invalid_action(command) | ||||
|         else: | ||||
|             action_return = self.actions[name].run(*args, **kwargs) | ||||
|             action_return = self.actions[action_name](command, api_name) | ||||
|             action_return.valid = ActionValidCode.OPEN | ||||
|         return action_return | ||||
|   | ||||
							
								
								
									
										56
									
								
								lagent/actions/arxiv_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								lagent/actions/arxiv_search.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| from typing import Optional, Type | ||||
|  | ||||
| import arxiv | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
|  | ||||
|  | ||||
| class ArxivSearch(BaseAction): | ||||
|     """Search information from Arxiv.org. \ | ||||
| Useful for when you need to answer questions about Physics, Mathematics, \ | ||||
| Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \ | ||||
| Electrical Engineering, and Economics from scientific articles on arxiv.org. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  top_k_results: int = 3, | ||||
|                  max_query_len: int = 300, | ||||
|                  doc_content_chars_max: int = 1500, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         self.top_k_results = top_k_results | ||||
|         self.max_query_len = max_query_len | ||||
|         self.doc_content_chars_max = doc_content_chars_max | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_arxiv_article_information(self, query: str) -> dict: | ||||
|         """Run Arxiv search and get the article meta information. | ||||
|  | ||||
|         Args: | ||||
|             query (:class:`str`): the content of search query | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: article information | ||||
|                 * content (str): a list of 3 arxiv search papers | ||||
|         """ | ||||
|         try: | ||||
|             results = arxiv.Search(  # type: ignore | ||||
|                 query[:self.max_query_len], | ||||
|                 max_results=self.top_k_results).results() | ||||
|         except Exception as exc: | ||||
|             return ActionReturn( | ||||
|                 errmsg=f'Arxiv exception: {exc}', | ||||
|                 state=ActionStatusCode.HTTP_ERROR) | ||||
|         docs = [ | ||||
|             f'Published: {result.updated.date()}\nTitle: {result.title}\n' | ||||
|             f'Authors: {", ".join(a.name for a in result.authors)}\n' | ||||
|             f'Summary: {result.summary[:self.doc_content_chars_max]}' | ||||
|             for result in results | ||||
|         ] | ||||
|         if docs: | ||||
|             return {'content': '\n\n'.join(docs)} | ||||
|         return {'content': 'No good Arxiv Result was found'} | ||||
| @@ -1,57 +1,364 @@ | ||||
| from typing import Optional | ||||
| import inspect | ||||
| import logging | ||||
| import re | ||||
| from abc import ABCMeta | ||||
| from copy import deepcopy | ||||
| from functools import wraps | ||||
| from typing import Annotated, Callable, Optional, Type, get_args, get_origin | ||||
|  | ||||
| from lagent.schema import ActionReturn | ||||
| from class_registry import AutoRegister, ClassRegistry | ||||
| from griffe import Docstring | ||||
| from griffe.enumerations import DocstringSectionKind | ||||
|  | ||||
| from ..schema import ActionReturn, ActionStatusCode | ||||
| from .parser import BaseParser, JsonParser, ParseError | ||||
|  | ||||
| logging.getLogger('griffe').setLevel(logging.ERROR) | ||||
|  | ||||
| TOOL_REGISTRY = ClassRegistry('__tool_name__', unique=True) | ||||
|  | ||||
|  | ||||
| class BaseAction: | ||||
| def tool_api(func: Optional[Callable] = None, | ||||
|              *, | ||||
|              explode_return: bool = False, | ||||
|              returns_named_value: bool = False, | ||||
|              **kwargs): | ||||
|     """Turn functions into tools. It will parse typehints as well as docstrings  | ||||
|     to build the tool description and attach it to functions via an attribute  | ||||
|     ``api_description``. | ||||
|      | ||||
|     Examples: | ||||
|      | ||||
|         .. code-block:: python | ||||
|              | ||||
|             # typehints has higher priority than docstrings | ||||
|             from typing import Annotated | ||||
|              | ||||
|             @tool_api | ||||
|             def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): | ||||
|                 '''Add operation | ||||
|                  | ||||
|                 Args: | ||||
|                     x (int): a | ||||
|                     y (int): b | ||||
|                 ''' | ||||
|                 return a + b | ||||
|              | ||||
|             print(add.api_description) | ||||
|              | ||||
|     Args: | ||||
|         func (Optional[Callable]): function to decorate. Defaults to ``None``. | ||||
|         explode_return (bool): whether to flatten the dictionary or tuple return  | ||||
|             as the ``return_data`` field. When enabled, it is recommended to  | ||||
|             annotate the member in docstrings. Defaults to ``False``. | ||||
|              | ||||
|             .. code-block:: python | ||||
|                  | ||||
|                 @tool_api(explode_return=True) | ||||
|                 def foo(a, b): | ||||
|                     '''A simple function | ||||
|                      | ||||
|                     Args: | ||||
|                         a (int): a | ||||
|                         b (int): b | ||||
|                      | ||||
|                     Returns: | ||||
|                         dict: information of inputs | ||||
|                             * x: value of a | ||||
|                             * y: value of b | ||||
|                     ''' | ||||
|                     return {'x': a, 'y': b} | ||||
|                      | ||||
|                 print(foo.api_description) | ||||
|              | ||||
|         returns_named_value (bool): whether to parse ``thing: Description`` in  | ||||
|             returns sections as a name and description, rather than a type and  | ||||
|             description. When true, type must be wrapped in parentheses:  | ||||
|             ``(int): Description``. When false, parentheses are optional but  | ||||
|             the items cannot be named: ``int: Description``. Defaults to ``False``. | ||||
|              | ||||
|     Returns: | ||||
|         Callable: wrapped function or partial decorator | ||||
|  | ||||
|     Important: | ||||
|         ``return_data`` field will be added to ``api_description`` only | ||||
|         when ``explode_return`` or ``returns_named_value`` is enabled. | ||||
|     """ | ||||
|  | ||||
|     def _detect_type(string): | ||||
|         field_type = 'STRING' | ||||
|         if 'list' in string: | ||||
|             field_type = 'Array' | ||||
|         elif 'str' not in string: | ||||
|             if 'float' in string: | ||||
|                 field_type = 'FLOAT' | ||||
|             elif 'int' in string: | ||||
|                 field_type = 'NUMBER' | ||||
|             elif 'bool' in string: | ||||
|                 field_type = 'BOOLEAN' | ||||
|         return field_type | ||||
|  | ||||
|     def _explode(desc): | ||||
|         kvs = [] | ||||
|         desc = '\nArgs:\n' + '\n'.join([ | ||||
|             '    ' + item.lstrip(' -+*#.') | ||||
|             for item in desc.split('\n')[1:] if item.strip() | ||||
|         ]) | ||||
|         docs = Docstring(desc).parse('google') | ||||
|         if not docs: | ||||
|             return kvs | ||||
|         if docs[0].kind is DocstringSectionKind.parameters: | ||||
|             for d in docs[0].value: | ||||
|                 d = d.as_dict() | ||||
|                 if not d['annotation']: | ||||
|                     d.pop('annotation') | ||||
|                 else: | ||||
|                     d['type'] = _detect_type(d.pop('annotation').lower()) | ||||
|                 kvs.append(d) | ||||
|         return kvs | ||||
|  | ||||
|     def _parse_tool(function): | ||||
|         # remove rst syntax | ||||
|         docs = Docstring( | ||||
|             re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( | ||||
|                 'google', returns_named_value=returns_named_value, **kwargs) | ||||
|         desc = dict( | ||||
|             name=function.__name__, | ||||
|             description=docs[0].value | ||||
|             if docs[0].kind is DocstringSectionKind.text else '', | ||||
|             parameters=[], | ||||
|             required=[], | ||||
|         ) | ||||
|         args_doc, returns_doc = {}, [] | ||||
|         for doc in docs: | ||||
|             if doc.kind is DocstringSectionKind.parameters: | ||||
|                 for d in doc.value: | ||||
|                     d = d.as_dict() | ||||
|                     d['type'] = _detect_type(d.pop('annotation').lower()) | ||||
|                     args_doc[d['name']] = d | ||||
|             if doc.kind is DocstringSectionKind.returns: | ||||
|                 for d in doc.value: | ||||
|                     d = d.as_dict() | ||||
|                     if not d['name']: | ||||
|                         d.pop('name') | ||||
|                     if not d['annotation']: | ||||
|                         d.pop('annotation') | ||||
|                     else: | ||||
|                         d['type'] = _detect_type(d.pop('annotation').lower()) | ||||
|                     returns_doc.append(d) | ||||
|  | ||||
|         sig = inspect.signature(function) | ||||
|         for name, param in sig.parameters.items(): | ||||
|             if name == 'self': | ||||
|                 continue | ||||
|             parameter = dict( | ||||
|                 name=param.name, | ||||
|                 type='STRING', | ||||
|                 description=args_doc.get(param.name, | ||||
|                                          {}).get('description', '')) | ||||
|             annotation = param.annotation | ||||
|             if annotation is inspect.Signature.empty: | ||||
|                 parameter['type'] = args_doc.get(param.name, | ||||
|                                                  {}).get('type', 'STRING') | ||||
|             else: | ||||
|                 if get_origin(annotation) is Annotated: | ||||
|                     annotation, info = get_args(annotation) | ||||
|                     if info: | ||||
|                         parameter['description'] = info | ||||
|                 while get_origin(annotation): | ||||
|                     annotation = get_args(annotation) | ||||
|                 parameter['type'] = _detect_type(str(annotation)) | ||||
|             desc['parameters'].append(parameter) | ||||
|             if param.default is inspect.Signature.empty: | ||||
|                 desc['required'].append(param.name) | ||||
|  | ||||
|         return_data = [] | ||||
|         if explode_return: | ||||
|             return_data = _explode(returns_doc[0]['description']) | ||||
|         elif returns_named_value: | ||||
|             return_data = returns_doc | ||||
|         if return_data: | ||||
|             desc['return_data'] = return_data | ||||
|         return desc | ||||
|  | ||||
|     if callable(func): | ||||
|  | ||||
|         @wraps(func) | ||||
|         def wrapper(self, *args, **kwargs): | ||||
|             return func(self, *args, **kwargs) | ||||
|  | ||||
|         wrapper.api_description = _parse_tool(func) | ||||
|         return wrapper | ||||
|  | ||||
|     def decorate(func): | ||||
|  | ||||
|         @wraps(func) | ||||
|         def wrapper(self, *args, **kwargs): | ||||
|             return func(self, *args, **kwargs) | ||||
|  | ||||
|         wrapper.api_description = _parse_tool(func) | ||||
|         return wrapper | ||||
|  | ||||
|     return decorate | ||||
|  | ||||
|  | ||||
| class ToolMeta(ABCMeta): | ||||
|     """Metaclass of tools""" | ||||
|  | ||||
|     def __new__(mcs, name, base, attrs): | ||||
|         is_toolkit, tool_desc = True, dict( | ||||
|             name=attrs.setdefault('__tool_name__', name), | ||||
|             description=Docstring(attrs.get('__doc__', | ||||
|                                             '')).parse('google')[0].value) | ||||
|         for key, value in attrs.items(): | ||||
|             if callable(value) and hasattr(value, 'api_description'): | ||||
|                 api_desc = getattr(value, 'api_description') | ||||
|                 if key == 'run': | ||||
|                     tool_desc['parameters'] = api_desc['parameters'] | ||||
|                     tool_desc['required'] = api_desc['required'] | ||||
|                     if api_desc['description']: | ||||
|                         tool_desc['description'] = api_desc['description'] | ||||
|                     if api_desc.get('return_data'): | ||||
|                         tool_desc['return_data'] = api_desc['return_data'] | ||||
|                     is_toolkit = False | ||||
|                     break | ||||
|                 tool_desc.setdefault('api_list', []).append(api_desc) | ||||
|         attrs['_is_toolkit'] = is_toolkit | ||||
|         attrs['__tool_description__'] = tool_desc | ||||
|         return super().__new__(mcs, name, base, attrs) | ||||
|  | ||||
|  | ||||
| class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)): | ||||
|     """Base class for all actions. | ||||
|  | ||||
|     Args: | ||||
|         description (str, optional): The description of the action. Defaults to | ||||
|             None. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class name. Defaults to None. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|         description (:class:`Optional[dict]`): The description of the action. | ||||
|             Defaults to ``None``. | ||||
|         parser (:class:`Type[BaseParser]`): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (:class:`bool`): Whether the action is enabled. Defaults to | ||||
|             ``True``. | ||||
|              | ||||
|     Examples: | ||||
|  | ||||
|         * simple tool | ||||
|  | ||||
|         .. code-block:: python | ||||
|          | ||||
|             class Bold(BaseAction): | ||||
|                 '''Make text bold''' | ||||
|              | ||||
|                 @tool_api | ||||
|                 def run(self, text: str): | ||||
|                     ''' | ||||
|                     Args: | ||||
|                         text (str): input text | ||||
|                          | ||||
|                     Returns: | ||||
|                         str: bold text | ||||
|                     ''' | ||||
|                     return '**' + text + '**' | ||||
|  | ||||
|             action = Bold() | ||||
|  | ||||
|         * toolkit with multiple APIs | ||||
|  | ||||
|         .. code-block:: python | ||||
|          | ||||
|             class Calculator(BaseAction): | ||||
|                 '''Calculator''' | ||||
|                  | ||||
|                 @tool_api | ||||
|                 def add(self, a, b): | ||||
|                     '''Add operation | ||||
|                      | ||||
|                     Args: | ||||
|                         a (int): augend | ||||
|                         b (int): addend | ||||
|                      | ||||
|                     Returns: | ||||
|                         int: sum | ||||
|                     ''' | ||||
|                     return a + b | ||||
|                  | ||||
|                 @tool_api  | ||||
|                 def sub(self, a, b): | ||||
|                     '''Subtraction operation | ||||
|                      | ||||
|                     Args: | ||||
|                         a (int): minuend | ||||
|                         b (int): subtrahend | ||||
|                          | ||||
|                     Returns: | ||||
|                         int: difference | ||||
|                     ''' | ||||
|                     return a - b | ||||
|  | ||||
|             action = Calculator() | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  description: Optional[str] = None, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None) -> None: | ||||
|         if name is None: | ||||
|             name = self.__class__.__name__ | ||||
|         self._name = name | ||||
|         self._description = description | ||||
|         self._disable_description = disable_description | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         self._description = deepcopy(description or self.__tool_description__) | ||||
|         self._name = self._description['name'] | ||||
|         self._parser = parser(self) | ||||
|         self._enable = enable | ||||
|  | ||||
|     def __call__(self, *args, **kwargs) -> ActionReturn: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.name}:{self.description}' | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.__repr__() | ||||
|  | ||||
|     def run(self, *args, **kwargs) -> ActionReturn: | ||||
|         return self.__call__(*args, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def enable(self): | ||||
|         return self._enable | ||||
|     def __call__(self, inputs: str, name='run') -> ActionReturn: | ||||
|         fallback_args = {'inputs': inputs, 'name': name} | ||||
|         if not hasattr(self, name): | ||||
|             return ActionReturn( | ||||
|                 fallback_args, | ||||
|                 type=self.name, | ||||
|                 errmsg=f'invalid API: {name}', | ||||
|                 state=ActionStatusCode.API_ERROR) | ||||
|         try: | ||||
|             inputs = self._parser.parse_inputs(inputs, name) | ||||
|         except ParseError as exc: | ||||
|             return ActionReturn( | ||||
|                 fallback_args, | ||||
|                 type=self.name, | ||||
|                 errmsg=exc.err_msg, | ||||
|                 state=ActionStatusCode.ARGS_ERROR) | ||||
|         try: | ||||
|             outputs = getattr(self, name)(**inputs) | ||||
|         except Exception as exc: | ||||
|             return ActionReturn( | ||||
|                 inputs, | ||||
|                 type=self.name, | ||||
|                 errmsg=str(exc), | ||||
|                 state=ActionStatusCode.API_ERROR) | ||||
|         if isinstance(outputs, ActionReturn): | ||||
|             action_return = outputs | ||||
|             if not action_return.args: | ||||
|                 action_return.args = inputs | ||||
|             if not action_return.type: | ||||
|                 action_return.type = self.name | ||||
|         else: | ||||
|             result = self._parser.parse_outputs(outputs) | ||||
|             action_return = ActionReturn(inputs, type=self.name, result=result) | ||||
|         return action_return | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return self._name | ||||
|  | ||||
|     @property | ||||
|     def description(self): | ||||
|         if self.enable: | ||||
|             return self._description | ||||
|         else: | ||||
|             return self._disable_description | ||||
|     def enable(self): | ||||
|         return self._enable | ||||
|  | ||||
|     @property | ||||
|     def is_toolkit(self): | ||||
|         return self._is_toolkit | ||||
|  | ||||
|     @property | ||||
|     def description(self) -> dict: | ||||
|         """Description of the tool""" | ||||
|         return self._description | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.description}' | ||||
|  | ||||
|     __str__ = __repr__ | ||||
|   | ||||
							
								
								
									
										145
									
								
								lagent/actions/bing_map.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								lagent/actions/bing_map.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,145 @@ | ||||
| import json | ||||
| import os | ||||
| from typing import Optional, Type | ||||
|  | ||||
| import requests | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
|  | ||||
|  | ||||
| class BINGMap(BaseAction): | ||||
|     """BING Map plugin for looking up map information""" | ||||
|  | ||||
|     def __init__(self, | ||||
|                  key: Optional[str] = None, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True) -> None: | ||||
|         super().__init__(description, parser, enable) | ||||
|         key = os.environ.get('BING_MAP_KEY', key) | ||||
|         if key is None: | ||||
|             raise ValueError( | ||||
|                 'Please set BING Map API key either in the environment ' | ||||
|                 'as BING_MAP_KEY or pass it as `key` parameter.') | ||||
|         self.key = key | ||||
|         self.base_url = 'http://dev.virtualearth.net/REST/V1/' | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_distance(self, start: str, end: str) -> dict: | ||||
|         """Get the distance between two locations in km. | ||||
|  | ||||
|         Args: | ||||
|             start (:class:`str`): The start location | ||||
|             end (:class:`str`): The end location | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: distance information | ||||
|                 * distance (str): the distance in km. | ||||
|         """ | ||||
|         # Request URL | ||||
|         url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key | ||||
|         # GET request | ||||
|         r = requests.get(url) | ||||
|         # TODO check request status? | ||||
|         data = json.loads(r.text) | ||||
|         # Extract route information | ||||
|         route = data['resourceSets'][0]['resources'][0] | ||||
|         # Extract distance in miles | ||||
|         distance = route['travelDistance'] | ||||
|         return dict(distance=distance) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_route(self, start: str, end: str) -> dict: | ||||
|         """Get the route between two locations in km. | ||||
|  | ||||
|         Args: | ||||
|             start (:class:`str`): The start location | ||||
|             end (:class:`str`): The end location | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: route information | ||||
|                 * route (list): the route, a list of actions. | ||||
|         """ | ||||
|         # Request URL | ||||
|         url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key | ||||
|         # GET request | ||||
|         r = requests.get(url) | ||||
|         data = json.loads(r.text) | ||||
|         # Extract route information | ||||
|         route = data['resourceSets'][0]['resources'][0] | ||||
|         itinerary = route['routeLegs'][0]['itineraryItems'] | ||||
|         # Extract route text information | ||||
|         route_text = [] | ||||
|         for item in itinerary: | ||||
|             if 'instruction' in item: | ||||
|                 route_text.append(item['instruction']['text']) | ||||
|         return dict(route=route_text) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_coordinates(self, location: str) -> dict: | ||||
|         """Get the coordinates of a location. | ||||
|  | ||||
|         Args: | ||||
|             location (:class:`str`): the location need to get coordinates. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: coordinates information         | ||||
|                 * latitude (float): the latitude of the location. | ||||
|                 * longitude (float): the longitude of the location. | ||||
|         """ | ||||
|         url = self.base_url + 'Locations' | ||||
|         params = {'query': location, 'key': self.key} | ||||
|         response = requests.get(url, params=params) | ||||
|         json_data = response.json() | ||||
|         coordinates = json_data['resourceSets'][0]['resources'][0]['point'][ | ||||
|             'coordinates'] | ||||
|         return dict(latitude=coordinates[0], longitude=coordinates[1]) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def search_nearby(self, | ||||
|                       search_term: str, | ||||
|                       places: str = 'unknown', | ||||
|                       latitude: float = 0.0, | ||||
|                       longitude: float = 0.0, | ||||
|                       radius: int = 5000) -> dict:  #  radius in meters | ||||
|         """Search for places nearby a location, within a given radius, and \ | ||||
| return the results into a list. You can use either the places name or the \ | ||||
| latitude and longitude. | ||||
|  | ||||
|         Args: | ||||
|             search_term (:class:`str`): the place name. | ||||
|             places (:class:`str`): the name of the location. Defaults to ``'unknown'``. | ||||
|             latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``. | ||||
|             longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``. | ||||
|             radius (:class:`int`): radius in meters. Defaults to ``5000``. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: places information | ||||
|                 * places (list): the list of places, each place is a dict with name and address, at most 5 places. | ||||
|         """ | ||||
|         url = self.base_url + 'LocalSearch' | ||||
|         if places != 'unknown': | ||||
|             pos = self.get_coordinates(**{'location': places}) | ||||
|             latitude, longitude = pos[1]['latitude'], pos[1]['longitude'] | ||||
|         # Build the request query string | ||||
|         params = { | ||||
|             'query': search_term, | ||||
|             'userLocation': f'{latitude},{longitude}', | ||||
|             'radius': radius, | ||||
|             'key': self.key | ||||
|         } | ||||
|         # Make the request | ||||
|         response = requests.get(url, params=params) | ||||
|         # Parse the response | ||||
|         response_data = json.loads(response.content) | ||||
|         # Get the results | ||||
|         results = response_data['resourceSets'][0]['resources'] | ||||
|         addresses = [] | ||||
|         for result in results: | ||||
|             name = result['name'] | ||||
|             address = result['Address']['formattedAddress'] | ||||
|             addresses.append(dict(name=name, address=address)) | ||||
|             if len(addresses) == 5: | ||||
|                 break | ||||
|         return dict(place=addresses) | ||||
| @@ -1,6 +1,7 @@ | ||||
| from typing import Optional | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode | ||||
|  | ||||
|  | ||||
| @@ -19,12 +20,13 @@ class InvalidAction(BaseAction): | ||||
|     def __init__(self, | ||||
|                  err_msg: | ||||
|                  str = 'The action is invalid, please check the action name.', | ||||
|                  **kwargs) -> None: | ||||
|  | ||||
|         super().__init__(enable=False, **kwargs) | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser=BaseParser) -> None: | ||||
|         super().__init__(description, parser, enable=False) | ||||
|         self._err_msg = err_msg | ||||
|  | ||||
|     def __call__(self, err_msg: Optional[str] = None): | ||||
|     @tool_api | ||||
|     def run(self, err_msg: Optional[str] = None) -> ActionReturn: | ||||
|         """Return the error message. | ||||
|  | ||||
|         Args: | ||||
| @@ -35,7 +37,7 @@ class InvalidAction(BaseAction): | ||||
|         action_return = ActionReturn( | ||||
|             url=None, | ||||
|             args=dict(text=err_msg), | ||||
|             errmsg=err_msg if err_msg else self._err_msg, | ||||
|             errmsg=err_msg or self._err_msg, | ||||
|             type=self.name, | ||||
|             valid=ActionValidCode.INVALID, | ||||
|             state=ActionStatusCode.API_ERROR) | ||||
| @@ -51,12 +53,15 @@ class NoAction(BaseAction): | ||||
|             'Please follow the format'. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, err_msg: str = 'Please follow the format', **kwargs): | ||||
|  | ||||
|         super().__init__(enable=False, **kwargs) | ||||
|     def __init__(self, | ||||
|                  err_msg: str = 'Please follow the format', | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser=BaseParser): | ||||
|         super().__init__(description, parser, enable=False) | ||||
|         self._err_msg = err_msg | ||||
|  | ||||
|     def __call__(self, err_msg: Optional[str] = None): | ||||
|     @tool_api | ||||
|     def run(self, err_msg: Optional[str] = None) -> ActionReturn: | ||||
|         """Return the error message. | ||||
|  | ||||
|         Args: | ||||
| @@ -71,7 +76,7 @@ class NoAction(BaseAction): | ||||
|             url=None, | ||||
|             args=dict(text=err_msg), | ||||
|             type=self.name, | ||||
|             errmsg=err_msg if err_msg else self._err_msg, | ||||
|             errmsg=err_msg or self._err_msg, | ||||
|             valid=ActionValidCode.INVALID, | ||||
|             state=ActionStatusCode.API_ERROR) | ||||
|         return action_return | ||||
| @@ -81,7 +86,11 @@ class FinishAction(BaseAction): | ||||
|     """This is a finish action class, which is used to return the final | ||||
|     result.""" | ||||
|  | ||||
|     def __call__(self, response: str) -> ActionReturn: | ||||
|     def __init__(self, description: Optional[dict] = None, parser=BaseParser): | ||||
|         super().__init__(description, parser, enable=True) | ||||
|  | ||||
|     @tool_api | ||||
|     def run(self, response: str) -> ActionReturn: | ||||
|         """Return the final result. | ||||
|  | ||||
|         Args: | ||||
| @@ -93,7 +102,7 @@ class FinishAction(BaseAction): | ||||
|         action_return = ActionReturn( | ||||
|             url=None, | ||||
|             args=dict(text=response), | ||||
|             result=dict(text=response), | ||||
|             result=[dict(type='text', content=response)], | ||||
|             type=self.name, | ||||
|             valid=ActionValidCode.FINISH, | ||||
|             state=ActionStatusCode.SUCCESS) | ||||
|   | ||||
							
								
								
									
										267
									
								
								lagent/actions/google_scholar_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								lagent/actions/google_scholar_search.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,267 @@ | ||||
| import os | ||||
| from typing import Optional, Type | ||||
|  | ||||
| from serpapi import GoogleSearch | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
| from .parser import BaseParser, JsonParser | ||||
|  | ||||
|  | ||||
| class GoogleScholar(BaseAction): | ||||
|     """Plugin for google scholar search | ||||
|  | ||||
|     Args: | ||||
|         api_key (str): API KEY to use serper google search API, | ||||
|             You can create a free API key at https://serper.dev. | ||||
|         description (dict): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  api_key: Optional[str] = None, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         api_key = os.environ.get('SERPER_API_KEY', api_key) | ||||
|         if api_key is None: | ||||
|             raise ValueError( | ||||
|                 'Please set Serper API key either in the environment ' | ||||
|                 'as SERPER_API_KEY or pass it as `api_key` parameter.') | ||||
|         self.api_key = api_key | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def search_google_scholar( | ||||
|         self, | ||||
|         query: str, | ||||
|         cites: Optional[str] = None, | ||||
|         as_ylo: Optional[int] = None, | ||||
|         as_yhi: Optional[int] = None, | ||||
|         scisbd: Optional[int] = None, | ||||
|         cluster: Optional[str] = None, | ||||
|         hl: Optional[str] = None, | ||||
|         lr: Optional[str] = None, | ||||
|         start: Optional[int] = None, | ||||
|         num: Optional[int] = None, | ||||
|         as_sdt: Optional[str] = None, | ||||
|         safe: Optional[str] = None, | ||||
|         filter: Optional[str] = None, | ||||
|         as_vis: Optional[str] = None, | ||||
|     ) -> dict: | ||||
|         """Search for scholarly articles based on a query according to the google scholar | ||||
|  | ||||
|         Args: | ||||
|             query (str): The query to search for. | ||||
|             cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches. | ||||
|             as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted). | ||||
|             as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted). | ||||
|             scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything. | ||||
|             cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches. | ||||
|             hl (Optional[str]): The language to use for the Google Scholar search. | ||||
|             lr (Optional[str]): One or multiple languages to limit the search to. | ||||
|             start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.) | ||||
|             num (Optional[int]): The maximum number of results to return, limited to 20. | ||||
|             as_sdt (Optional[str]): Can be used either as a search type or a filter. | ||||
|             safe (Optional[str]): The level of filtering for adult content. | ||||
|             filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off. | ||||
|             as_vis (Optional[str]): Defines whether to include citations or not. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: article information | ||||
|                 - title: a list of the titles of the three selected papers | ||||
|                 - cited_by: a list of the citation numbers of the three selected papers | ||||
|                 - organic_id: a list of the organic results' ids of the three selected papers | ||||
|                 - pub_info: publication information of selected papers | ||||
|         """ | ||||
|         params = { | ||||
|             'q': query, | ||||
|             'engine': 'google_scholar', | ||||
|             'api_key': self.api_key, | ||||
|             'cites': cites, | ||||
|             'as_ylo': as_ylo, | ||||
|             'as_yhi': as_yhi, | ||||
|             'scisbd': scisbd, | ||||
|             'cluster': cluster, | ||||
|             'hl': hl, | ||||
|             'lr': lr, | ||||
|             'start': start, | ||||
|             'num': num, | ||||
|             'as_sdt': as_sdt, | ||||
|             'safe': safe, | ||||
|             'filter': filter, | ||||
|             'as_vis': as_vis | ||||
|         } | ||||
|         search = GoogleSearch(params) | ||||
|         try: | ||||
|             r = search.get_dict() | ||||
|             results = r['organic_results'] | ||||
|             title = [] | ||||
|             snippets = [] | ||||
|             cited_by = [] | ||||
|             organic_id = [] | ||||
|             pub_info = [] | ||||
|             for item in results[:3]: | ||||
|                 title.append(item['title']) | ||||
|                 pub_info.append(item['publication_info']['summary']) | ||||
|                 citation = item['inline_links'].get('cited_by', {'total': ''}) | ||||
|                 cited_by.append(citation['total']) | ||||
|                 snippets.append(item['snippet']) | ||||
|                 organic_id.append(item['result_id']) | ||||
|             return dict( | ||||
|                 title=title, | ||||
|                 cited_by=cited_by, | ||||
|                 organic_id=organic_id, | ||||
|                 snippets=snippets) | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_author_information(self, | ||||
|                                author_id: str, | ||||
|                                hl: Optional[str] = None, | ||||
|                                view_op: Optional[str] = None, | ||||
|                                sort: Optional[str] = None, | ||||
|                                citation_id: Optional[str] = None, | ||||
|                                start: Optional[int] = None, | ||||
|                                num: Optional[int] = None, | ||||
|                                no_cache: Optional[bool] = None, | ||||
|                                async_req: Optional[bool] = None, | ||||
|                                output: Optional[str] = None) -> dict: | ||||
|         """Search for an author's information by author's id provided by get_author_id. | ||||
|  | ||||
|         Args: | ||||
|             author_id (str): Required. The ID of an author. | ||||
|             hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'. | ||||
|             view_op (Optional[str]): Used for viewing specific parts of a page. | ||||
|             sort (Optional[str]): Used for sorting and refining articles. | ||||
|             citation_id (Optional[str]): Used for retrieving individual article citation. | ||||
|             start (Optional[int]): Defines the result offset. Default is 0. | ||||
|             num (Optional[int]): Defines the number of results to return. Default is 20. | ||||
|             no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False. | ||||
|             async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False. | ||||
|             output (Optional[str]): Defines the final output you want. Default is 'json'. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: author information | ||||
|                 * name: author's name | ||||
|                 * affliation: the affliation of the author | ||||
|                 * articles: at most 3 articles by the author | ||||
|                 * website: the author's homepage url | ||||
|         """ | ||||
|         params = { | ||||
|             'engine': 'google_scholar_author', | ||||
|             'author_id': author_id, | ||||
|             'api_key': self.api_key, | ||||
|             'hl': hl, | ||||
|             'view_op': view_op, | ||||
|             'sort': sort, | ||||
|             'citation_id': citation_id, | ||||
|             'start': start, | ||||
|             'num': num, | ||||
|             'no_cache': no_cache, | ||||
|             'async': async_req, | ||||
|             'output': output | ||||
|         } | ||||
|         try: | ||||
|             search = GoogleSearch(params) | ||||
|             results = search.get_dict() | ||||
|             author = results['author'] | ||||
|             articles = results.get('articles', []) | ||||
|             return dict( | ||||
|                 name=author['name'], | ||||
|                 affiliations=author.get('affiliations', ''), | ||||
|                 website=author.get('website', ''), | ||||
|                 articles=[ | ||||
|                     dict(title=article['title'], authors=article['authors']) | ||||
|                     for article in articles[:3] | ||||
|                 ]) | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_citation_format(self, | ||||
|                             q: str, | ||||
|                             no_cache: Optional[bool] = None, | ||||
|                             async_: Optional[bool] = None, | ||||
|                             output: Optional[str] = 'json') -> dict: | ||||
|         """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar. | ||||
|  | ||||
|         Args: | ||||
|             q (str): ID of an individual Google Scholar organic search result. | ||||
|             no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None. | ||||
|             async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None. | ||||
|             output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: citation format | ||||
|                 * authors: the authors of the article | ||||
|                 * citation: the citation format of the article | ||||
|         """ | ||||
|         params = { | ||||
|             'q': q, | ||||
|             'engine': 'google_scholar_cite', | ||||
|             'api_key': self.api_key, | ||||
|             'no_cache': no_cache, | ||||
|             'async': async_, | ||||
|             'output': output | ||||
|         } | ||||
|         try: | ||||
|             search = GoogleSearch(params) | ||||
|             results = search.get_dict() | ||||
|             citation = results['citations'] | ||||
|             citation_info = citation[0]['snippet'] | ||||
|             return citation_info | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_author_id(self, | ||||
|                       mauthors: str, | ||||
|                       hl: Optional[str] = 'en', | ||||
|                       after_author: Optional[str] = None, | ||||
|                       before_author: Optional[str] = None, | ||||
|                       no_cache: Optional[bool] = False, | ||||
|                       _async: Optional[bool] = False, | ||||
|                       output: Optional[str] = 'json') -> dict: | ||||
|         """The getAuthorId function is used to get the author's id by his or her name. | ||||
|  | ||||
|         Args: | ||||
|             mauthors (str): Defines the author you want to search for. | ||||
|             hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'. | ||||
|             after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None. | ||||
|             before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None. | ||||
|             no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False. | ||||
|             _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False. | ||||
|             output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: author id | ||||
|                 * author_id: the author_id of the author | ||||
|         """ | ||||
|         params = { | ||||
|             'mauthors': mauthors, | ||||
|             'engine': 'google_scholar_profiles', | ||||
|             'api_key': self.api_key, | ||||
|             'hl': hl, | ||||
|             'after_author': after_author, | ||||
|             'before_author': before_author, | ||||
|             'no_cache': no_cache, | ||||
|             'async': _async, | ||||
|             'output': output | ||||
|         } | ||||
|         try: | ||||
|             search = GoogleSearch(params) | ||||
|             results = search.get_dict() | ||||
|             profile = results['profiles'] | ||||
|             author_info = dict(author_id=profile[0]['author_id']) | ||||
|             return author_info | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
| @@ -1,15 +1,11 @@ | ||||
| import os | ||||
| from typing import List, Optional, Tuple, Union | ||||
| from typing import List, Optional, Tuple, Type, Union | ||||
|  | ||||
| import requests | ||||
|  | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
| from .base_action import BaseAction | ||||
|  | ||||
| DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。 | ||||
| 当你需要对于一个特定问题找到简短明了的回答时,可以使用它。 | ||||
| 输入应该是一个搜索查询。 | ||||
| """ | ||||
| from .base_action import BaseAction, tool_api | ||||
| from .parser import BaseParser, JsonParser | ||||
|  | ||||
|  | ||||
| class GoogleSearch(BaseAction): | ||||
| @@ -28,15 +24,10 @@ class GoogleSearch(BaseAction): | ||||
|         timeout (int): Upper bound of waiting time for a serper request. | ||||
|         search_type (str): Serper API support ['search', 'images', 'news', | ||||
|             'places'] types of search, currently we only support 'search'. | ||||
|         k (int): select first k results in the search results as response. | ||||
|         description (str): The description of the action. Defaults to | ||||
|             None. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class name. Defaults to None. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|         description (dict): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool): Whether the action is enabled. Defaults to ``True``. | ||||
|     """ | ||||
|     result_key_for_type = { | ||||
|         'news': 'news', | ||||
| @@ -49,36 +40,29 @@ class GoogleSearch(BaseAction): | ||||
|                  api_key: Optional[str] = None, | ||||
|                  timeout: int = 5, | ||||
|                  search_type: str = 'search', | ||||
|                  k: int = 10, | ||||
|                  description: str = DEFAULT_DESCRIPTION, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None) -> None: | ||||
|         super().__init__(description, name, enable, disable_description) | ||||
|  | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         api_key = os.environ.get('SERPER_API_KEY', api_key) | ||||
|         if api_key is None: | ||||
|             raise ValueError( | ||||
|                 'Please set Serper API key either in the environment ' | ||||
|                 ' as SERPER_API_KEY or pass it as `api_key` parameter.') | ||||
|                 'as SERPER_API_KEY or pass it as `api_key` parameter.') | ||||
|         self.api_key = api_key | ||||
|         self.timeout = timeout | ||||
|         self.search_type = search_type | ||||
|         self.k = k | ||||
|  | ||||
|     def __call__(self, query: str) -> ActionReturn: | ||||
|         """Return the search response. | ||||
|  | ||||
|     @tool_api | ||||
|     def run(self, query: str, k: int = 10) -> ActionReturn: | ||||
|         """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。 | ||||
|          | ||||
|         Args: | ||||
|             query (str): The search content. | ||||
|  | ||||
|         Returns: | ||||
|             ActionReturn: The action return. | ||||
|             query (str): the search content | ||||
|             k (int): select first k results in the search results as response | ||||
|         """ | ||||
|  | ||||
|         tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|         status_code, response = self._search( | ||||
|             query, search_type=self.search_type, k=self.k) | ||||
|         tool_return = ActionReturn(type=self.name) | ||||
|         status_code, response = self._search(query, k=k) | ||||
|         # convert search results to ToolReturn format | ||||
|         if status_code == -1: | ||||
|             tool_return.errmsg = response | ||||
| @@ -139,7 +123,7 @@ class GoogleSearch(BaseAction): | ||||
|  | ||||
|     def _search(self, | ||||
|                 search_term: str, | ||||
|                 search_type: str = 'search', | ||||
|                 search_type: Optional[str] = None, | ||||
|                 **kwargs) -> Tuple[int, Union[dict, str]]: | ||||
|         """HTTP requests to Serper API. | ||||
|  | ||||
| @@ -166,7 +150,7 @@ class GoogleSearch(BaseAction): | ||||
|         } | ||||
|         try: | ||||
|             response = requests.post( | ||||
|                 f'https://google.serper.dev/{search_type}', | ||||
|                 f'https://google.serper.dev/{search_type or self.search_type}', | ||||
|                 headers=headers, | ||||
|                 params=params, | ||||
|                 timeout=self.timeout) | ||||
|   | ||||
							
								
								
									
										296
									
								
								lagent/actions/ipython_interpreter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										296
									
								
								lagent/actions/ipython_interpreter.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,296 @@ | ||||
| import base64 | ||||
| import io | ||||
| import logging | ||||
| import os | ||||
| import queue | ||||
| import re | ||||
| import signal | ||||
| import sys | ||||
| import traceback | ||||
| import uuid | ||||
| from typing import Optional, Tuple, Type | ||||
|  | ||||
| import json5 | ||||
| import PIL.Image | ||||
| from jupyter_client import KernelManager | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
|  | ||||
| START_CODE = """ | ||||
| def input(*args, **kwargs): | ||||
|     raise NotImplementedError('Python input() function is disabled.') | ||||
|  | ||||
| get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') | ||||
| {} | ||||
| """  # noqa | ||||
|  | ||||
|  | ||||
| class TimeoutError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class IPythonInterpreter(BaseAction): | ||||
|     """A IPython executor that can execute Python scripts in a jupyter manner. | ||||
|  | ||||
|     Args: | ||||
|         timeout (int): Upper bound of waiting time for Python script execution. | ||||
|             Defaults to 20. | ||||
|         user_data_dir (str, optional): Specified the user data directory for files | ||||
|             loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. | ||||
|             Defaults to `ENV`. | ||||
|         work_dir (str, optional): Specify which directory to save output images to.  | ||||
|             Defaults to ``'./work_dir/tmp_dir'``. | ||||
|         description (dict): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to ``True``. | ||||
|     """ | ||||
|  | ||||
|     _KERNEL_CLIENTS = {} | ||||
|  | ||||
|     def __init__(self, | ||||
|                  timeout: int = 20, | ||||
|                  user_data_dir: str = 'ENV', | ||||
|                  work_dir='./work_dir/tmp_dir', | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|  | ||||
|         self.timeout = timeout | ||||
|         if user_data_dir == 'ENV': | ||||
|             user_data_dir = os.environ.get('USER_DATA_DIR', '') | ||||
|  | ||||
|         if user_data_dir: | ||||
|             user_data_dir = os.path.dirname(user_data_dir) | ||||
|             user_data_dir = f"import os\nos.chdir('{user_data_dir}')" | ||||
|         self.user_data_dir = user_data_dir | ||||
|         self._initialized = False | ||||
|         self.work_dir = work_dir | ||||
|         if not os.path.exists(self.work_dir): | ||||
|             os.makedirs(self.work_dir, exist_ok=True) | ||||
|  | ||||
|     @staticmethod | ||||
|     def start_kernel(): | ||||
|         # start the kernel and manager | ||||
|         km = KernelManager() | ||||
|         km.start_kernel() | ||||
|         kc = km.client() | ||||
|         return km, kc | ||||
|  | ||||
|     def initialize(self): | ||||
|         if self._initialized: | ||||
|             return | ||||
|         pid = os.getpid() | ||||
|         if pid not in self._KERNEL_CLIENTS: | ||||
|             self._KERNEL_CLIENTS[pid] = self.start_kernel() | ||||
|         self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] | ||||
|         self._initialized = True | ||||
|         self._call(START_CODE.format(self.user_data_dir), None) | ||||
|  | ||||
|     def reset(self): | ||||
|         if not self._initialized: | ||||
|             self.initialize() | ||||
|         else: | ||||
|             code = "get_ipython().run_line_magic('reset', '-f')\n" + \ | ||||
|                 START_CODE.format(self.user_data_dir) | ||||
|             self._call(code, None) | ||||
|  | ||||
|     def _call(self, | ||||
|               command: str, | ||||
|               timeout: Optional[int] = None) -> Tuple[str, bool]: | ||||
|         self.initialize() | ||||
|         command = extract_code(command) | ||||
|  | ||||
|         # check previous remaining result | ||||
|         while True: | ||||
|             try: | ||||
|                 msg = self.kernel_client.get_iopub_msg(timeout=5) | ||||
|                 msg_type = msg['msg_type'] | ||||
|                 if msg_type == 'status': | ||||
|                     if msg['content'].get('execution_state') == 'idle': | ||||
|                         break | ||||
|             except queue.Empty: | ||||
|                 # assume no result | ||||
|                 break | ||||
|  | ||||
|         self.kernel_client.execute(command) | ||||
|  | ||||
|         def _inner_call(): | ||||
|             result = '' | ||||
|             images = [] | ||||
|             succeed = True | ||||
|             image_idx = 0 | ||||
|  | ||||
|             while True: | ||||
|                 text = '' | ||||
|                 image = '' | ||||
|                 finished = False | ||||
|                 msg_type = 'error' | ||||
|                 try: | ||||
|                     msg = self.kernel_client.get_iopub_msg(timeout=20) | ||||
|                     msg_type = msg['msg_type'] | ||||
|                     if msg_type == 'status': | ||||
|                         if msg['content'].get('execution_state') == 'idle': | ||||
|                             finished = True | ||||
|                     elif msg_type == 'execute_result': | ||||
|                         text = msg['content']['data'].get('text/plain', '') | ||||
|                         if 'image/png' in msg['content']['data']: | ||||
|                             image_b64 = msg['content']['data']['image/png'] | ||||
|                             image_url = publish_image_to_local( | ||||
|                                 image_b64, self.work_dir) | ||||
|                             image_idx += 1 | ||||
|                             image = '' % (image_idx, image_url) | ||||
|  | ||||
|                     elif msg_type == 'display_data': | ||||
|                         if 'image/png' in msg['content']['data']: | ||||
|                             image_b64 = msg['content']['data']['image/png'] | ||||
|                             image_url = publish_image_to_local( | ||||
|                                 image_b64, self.work_dir) | ||||
|                             image_idx += 1 | ||||
|                             image = '' % (image_idx, image_url) | ||||
|  | ||||
|                         else: | ||||
|                             text = msg['content']['data'].get('text/plain', '') | ||||
|                     elif msg_type == 'stream': | ||||
|                         msg_type = msg['content']['name']  # stdout, stderr | ||||
|                         text = msg['content']['text'] | ||||
|                     elif msg_type == 'error': | ||||
|                         succeed = False | ||||
|                         text = escape_ansi('\n'.join( | ||||
|                             msg['content']['traceback'])) | ||||
|                         if 'M6_CODE_INTERPRETER_TIMEOUT' in text: | ||||
|                             text = f'Timeout. No response after {timeout} seconds.'  # noqa | ||||
|                 except queue.Empty: | ||||
|                     # stop current task in case break next input. | ||||
|                     self.kernel_manager.interrupt_kernel() | ||||
|                     succeed = False | ||||
|                     text = f'Timeout. No response after {timeout} seconds.' | ||||
|                     finished = True | ||||
|                 except Exception: | ||||
|                     succeed = False | ||||
|                     msg = ''.join(traceback.format_exception(*sys.exc_info())) | ||||
|                     # text = 'The code interpreter encountered an unexpected error.'  # noqa | ||||
|                     text = msg | ||||
|                     logging.warning(msg) | ||||
|                     finished = True | ||||
|                 if text: | ||||
|                     # result += f'\n\n{msg_type}:\n\n```\n{text}\n```' | ||||
|                     result += f'{text}' | ||||
|  | ||||
|                 if image: | ||||
|                     images.append(image_url) | ||||
|                 if finished: | ||||
|                     return succeed, dict(text=result, image=images) | ||||
|  | ||||
|         try: | ||||
|             if timeout: | ||||
|  | ||||
|                 def handler(signum, frame): | ||||
|                     raise TimeoutError() | ||||
|  | ||||
|                 signal.signal(signal.SIGALRM, handler) | ||||
|                 signal.alarm(timeout) | ||||
|             succeed, result = _inner_call() | ||||
|         except TimeoutError: | ||||
|             succeed = False | ||||
|             text = 'The code interpreter encountered an unexpected error.' | ||||
|             result = f'\n\nerror:\n\n```\n{text}\n```' | ||||
|         finally: | ||||
|             if timeout: | ||||
|                 signal.alarm(0) | ||||
|  | ||||
|         # result = result.strip('\n') | ||||
|         return succeed, result | ||||
|  | ||||
|     @tool_api | ||||
|     def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: | ||||
|         """When you send a message containing Python code to python, it will be \ | ||||
| executed in a stateful Jupyter notebook environment. python will respond with \ | ||||
| the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' \ | ||||
| can be used to save and persist user files. Internet access for this session is \ | ||||
| disabled. Do not make external web requests or API calls as they will fail. | ||||
|  | ||||
|         Args: | ||||
|             command (:class:`str`): Python code | ||||
|             timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. | ||||
|         """ | ||||
|         tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|         tool_return.args = dict(text=command) | ||||
|         succeed, result = self._call(command, timeout) | ||||
|         if succeed: | ||||
|             text = result['text'] | ||||
|             image = result.get('image', []) | ||||
|             resp = [dict(type="text", content=text)] | ||||
|             if image: | ||||
|                 resp.extend([dict(type="image", content=im) for im in image]) | ||||
|             tool_return.result = resp | ||||
|             # tool_return.result = dict( | ||||
|             #     text=result['text'], image=result.get('image', [])[0]) | ||||
|             tool_return.state = ActionStatusCode.SUCCESS | ||||
|         else: | ||||
|             tool_return.errmsg = result.get("text", "") if isinstance( | ||||
|                 result, dict) else result | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|         return tool_return | ||||
|  | ||||
|  | ||||
| def extract_code(text): | ||||
|     # Match triple backtick blocks first | ||||
|     triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) | ||||
|     # Match single backtick blocks second | ||||
|     single_match = re.search(r'`([^`]*)`', text, re.DOTALL) | ||||
|     if triple_match: | ||||
|         text = triple_match.group(1) | ||||
|     elif single_match: | ||||
|         text = single_match.group(1) | ||||
|     else: | ||||
|         try: | ||||
|             text = json5.loads(text)['code'] | ||||
|         except Exception: | ||||
|             pass | ||||
|     # If no code blocks found, return original text | ||||
|     return text | ||||
|  | ||||
|  | ||||
| def escape_ansi(line): | ||||
|     ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') | ||||
|     return ansi_escape.sub('', line) | ||||
|  | ||||
|  | ||||
| def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): | ||||
|     image_file = str(uuid.uuid4()) + '.png' | ||||
|     local_image_file = os.path.join(work_dir, image_file) | ||||
|  | ||||
|     png_bytes = base64.b64decode(image_base64) | ||||
|     assert isinstance(png_bytes, bytes) | ||||
|     bytes_io = io.BytesIO(png_bytes) | ||||
|     PIL.Image.open(bytes_io).save(local_image_file, 'png') | ||||
|  | ||||
|     return local_image_file | ||||
|  | ||||
|  | ||||
| # local test for code interpreter | ||||
| def get_multiline_input(hint): | ||||
|     print(hint) | ||||
|     print('// Press ENTER to make a new line. Press CTRL-D to end input.') | ||||
|     lines = [] | ||||
|     while True: | ||||
|         try: | ||||
|             line = input() | ||||
|         except EOFError:  # CTRL-D | ||||
|             break | ||||
|         lines.append(line) | ||||
|     print('// Input received.') | ||||
|     if lines: | ||||
|         return '\n'.join(lines) | ||||
|     else: | ||||
|         return '' | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     code_interpreter = IPythonInterpreter() | ||||
|     while True: | ||||
|         print(code_interpreter(get_multiline_input('Enter python code:'))) | ||||
| @@ -1,56 +0,0 @@ | ||||
| from typing import Optional, Union | ||||
|  | ||||
| from lagent.llms.base_api import BaseAPIModel | ||||
| from lagent.llms.base_llm import BaseModel | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
| from .base_action import BaseAction | ||||
|  | ||||
| DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。 | ||||
| 当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。 | ||||
| """ | ||||
|  | ||||
|  | ||||
| class LLMQA(BaseAction): | ||||
|     """An LLM Wrapper as BaseAction type. | ||||
|  | ||||
|     Args: | ||||
|         llm (BaseModel or BaseAPIModel): a LLM service which can chat. | ||||
|         description (str): The description of the action. Defaults to | ||||
|             None. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class name. Defaults to None. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  llm: Union[BaseModel, BaseAPIModel], | ||||
|                  description: str = DEFAULT_DESCRIPTION, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None) -> None: | ||||
|         super().__init__(description, name, enable, disable_description) | ||||
|  | ||||
|         self._llm = llm | ||||
|  | ||||
|     def __call__(self, query: str) -> ActionReturn: | ||||
|         """Return the QA response. | ||||
|  | ||||
|         Args: | ||||
|             query (str): The query content. | ||||
|  | ||||
|         Returns: | ||||
|             ActionReturn: The action return. | ||||
|         """ | ||||
|  | ||||
|         tool_return = ActionReturn(url=None, args=None) | ||||
|         try: | ||||
|             response = self._llm.generate_from_template(query, 512) | ||||
|             tool_return.result = dict(text=str(response)) | ||||
|             tool_return.state = ActionStatusCode.SUCCESS | ||||
|         except Exception as e: | ||||
|             tool_return.result = dict(text=str(e)) | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|         return tool_return | ||||
							
								
								
									
										139
									
								
								lagent/actions/parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								lagent/actions/parser.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| import json | ||||
| import re | ||||
| from ast import literal_eval | ||||
| from typing import Any, List, Union | ||||
|  | ||||
|  | ||||
| class ParseError(Exception): | ||||
|     """Parsing exception class""" | ||||
|  | ||||
|     def __init__(self, err_msg: str): | ||||
|         self.err_msg = err_msg | ||||
|  | ||||
|  | ||||
| class BaseParser: | ||||
|     """Base parser to process inputs and outputs of actions. | ||||
|  | ||||
|     Args: | ||||
|         action (:class:`BaseAction`): action to validate | ||||
|  | ||||
|     Attributes: | ||||
|         PARAMETER_DESCRIPTION (:class:`str`): declare the input format which | ||||
|             LLMs should follow when generating arguments for decided tools. | ||||
|     """ | ||||
|  | ||||
|     PARAMETER_DESCRIPTION: str = '' | ||||
|  | ||||
|     def __init__(self, action): | ||||
|         self.action = action | ||||
|         self._api2param = {} | ||||
|         self._api2required = {} | ||||
|         # perform basic argument validation | ||||
|         if action.description: | ||||
|             for api in action.description.get('api_list', | ||||
|                                               [action.description]): | ||||
|                 name = (f'{action.name}.{api["name"]}' | ||||
|                         if self.action.is_toolkit else api['name']) | ||||
|                 required_parameters = set(api['required']) | ||||
|                 all_parameters = {j['name'] for j in api['parameters']} | ||||
|                 if not required_parameters.issubset(all_parameters): | ||||
|                     raise ValueError( | ||||
|                         f'unknown parameters for function "{name}": ' | ||||
|                         f'{required_parameters - all_parameters}') | ||||
|                 if self.PARAMETER_DESCRIPTION: | ||||
|                     api['parameter_description'] = self.PARAMETER_DESCRIPTION | ||||
|                 api_name = api['name'] if self.action.is_toolkit else 'run' | ||||
|                 self._api2param[api_name] = api['parameters'] | ||||
|                 self._api2required[api_name] = api['required'] | ||||
|  | ||||
|     def parse_inputs(self, inputs: str, name: str = 'run') -> dict: | ||||
|         """parse inputs LLMs generate for the action | ||||
|  | ||||
|         Args: | ||||
|             inputs (:class:`str`): input string extracted from responses | ||||
|              | ||||
|         Returns: | ||||
|             :class:`dict`: processed input | ||||
|         """ | ||||
|         inputs = {self._api2param[name][0]['name']: inputs} | ||||
|         return inputs | ||||
|  | ||||
|     def parse_outputs(self, outputs: Any) -> List[dict]: | ||||
|         """parser outputs returned by the action | ||||
|  | ||||
|         Args: | ||||
|             outputs (:class:`Any`): raw output of the action | ||||
|  | ||||
|         Returns: | ||||
|             :class:`List[dict]`: processed output of which each member is a  | ||||
|                 dictionary with two keys - 'type' and 'content'. | ||||
|         """ | ||||
|         if isinstance(outputs, dict): | ||||
|             outputs = json.dumps(outputs, ensure_ascii=False) | ||||
|         elif not isinstance(outputs, str): | ||||
|             outputs = str(outputs) | ||||
|         return [{'type': 'text', 'content': outputs}] | ||||
|  | ||||
|  | ||||
| class JsonParser(BaseParser): | ||||
|     """Json parser to convert input string into a dictionary. | ||||
|  | ||||
|     Args: | ||||
|         action (:class:`BaseAction`): action to validate | ||||
|     """ | ||||
|  | ||||
|     PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称' | ||||
|  | ||||
|     def parse_inputs(self, | ||||
|                      inputs: Union[str, dict], | ||||
|                      name: str = 'run') -> dict: | ||||
|         if not isinstance(inputs, dict): | ||||
|             try: | ||||
|                 match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs, | ||||
|                                   re.S) | ||||
|                 if match: | ||||
|                     inputs = match.group(2).strip() | ||||
|                 inputs = json.loads(inputs) | ||||
|             except json.JSONDecodeError as exc: | ||||
|                 raise ParseError(f'invalid json format: {inputs}') from exc | ||||
|         input_keys = set(inputs) | ||||
|         all_keys = {param['name'] for param in self._api2param[name]} | ||||
|         if not input_keys.issubset(all_keys): | ||||
|             raise ParseError(f'unknown arguments: {input_keys - all_keys}') | ||||
|         required_keys = set(self._api2required[name]) | ||||
|         if not input_keys.issuperset(required_keys): | ||||
|             raise ParseError( | ||||
|                 f'missing required arguments: {required_keys - input_keys}') | ||||
|         return inputs | ||||
|  | ||||
|  | ||||
| class TupleParser(BaseParser): | ||||
|     """Tuple parser to convert input string into a tuple. | ||||
|  | ||||
|     Args: | ||||
|         action (:class:`BaseAction`): action to validate | ||||
|     """ | ||||
|  | ||||
|     PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Tuple格式 (arg1, arg2, arg3) 传参,且参数是有序的' | ||||
|  | ||||
|     def parse_inputs(self, | ||||
|                      inputs: Union[str, tuple], | ||||
|                      name: str = 'run') -> dict: | ||||
|         if not isinstance(inputs, tuple): | ||||
|             try: | ||||
|                 inputs = literal_eval(inputs) | ||||
|             except Exception as exc: | ||||
|                 raise ParseError(f'invalid tuple format: {inputs}') from exc | ||||
|         if len(inputs) < len(self._api2required[name]): | ||||
|             raise ParseError( | ||||
|                 f'API takes {len(self._api2required[name])} required positional ' | ||||
|                 f'arguments but {len(inputs)} were given') | ||||
|         if len(inputs) > len(self._api2param[name]): | ||||
|             raise ParseError( | ||||
|                 f'API takes {len(self._api2param[name])} positional arguments ' | ||||
|                 f'but {len(inputs)} were given') | ||||
|         inputs = { | ||||
|             self._api2param[name][i]['name']: item | ||||
|             for i, item in enumerate(inputs) | ||||
|         } | ||||
|         return inputs | ||||
							
								
								
									
										157
									
								
								lagent/actions/ppt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								lagent/actions/ppt.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,157 @@ | ||||
| from typing import Dict, Optional, Type | ||||
|  | ||||
| from pptx import Presentation | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
|  | ||||
| THEME_MAPPING = { | ||||
|     'Default': { | ||||
|         'template': None, | ||||
|         'title': 'Title Slide', | ||||
|         'single': 'Title and Content', | ||||
|         'two': 'Tow content', | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| class PPT(BaseAction): | ||||
|     """Plugin to create ppt slides with text, paragraph, images in good looking styles""" | ||||
|  | ||||
|     def __init__(self, | ||||
|                  theme_mapping: Optional[Dict[str, dict]] = None, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         self.theme_mapping = theme_mapping or THEME_MAPPING | ||||
|         self.pointer = None | ||||
|         self.location = None | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def create_file(self, theme: str, abs_location: str) -> dict: | ||||
|         """Create a pptx file with specific themes | ||||
|  | ||||
|         Args: | ||||
|             theme (:class:`str`): the theme used | ||||
|             abs_location (:class:`str`): the ppt file's absolute location | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         self.location = abs_location | ||||
|         try: | ||||
|             self.pointer = Presentation(self.theme_mapping[theme]['template']) | ||||
|             self.pointer.slide_master.name = theme | ||||
|             # print('created') | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|         return dict(status='created a ppt file.') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def add_first_page(self, title: str, subtitle: str) -> dict: | ||||
|         """Add the first page of ppt. | ||||
|  | ||||
|         Args: | ||||
|             title (:class:`str`): the title of ppt | ||||
|             subtitle (:class:`str`): the subtitle of ppt | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         layout_name = self.theme_mapping[ | ||||
|             self.pointer.slide_master.name]['title'] | ||||
|         layout = next(i for i in self.pointer.slide_master.slide_layouts | ||||
|                       if i.name == layout_name) | ||||
|         slide = self.pointer.slides.add_slide(layout) | ||||
|         ph_title, ph_subtitle = slide.placeholders | ||||
|         ph_title.text = title | ||||
|         if subtitle: | ||||
|             ph_subtitle.text = subtitle | ||||
|         return dict(status='added page') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def add_text_page(self, title: str, bullet_items: str) -> dict: | ||||
|         """Add text page of ppt | ||||
|  | ||||
|         Args: | ||||
|             title (:class:`str`): the title of the page | ||||
|             bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         layout_name = self.theme_mapping[ | ||||
|             self.pointer.slide_master.name]['single'] | ||||
|         layout = next(i for i in self.pointer.slide_master.slide_layouts | ||||
|                       if i.name == layout_name) | ||||
|         slide = self.pointer.slides.add_slide(layout) | ||||
|         ph_title, ph_body = slide.placeholders | ||||
|         ph_title.text = title | ||||
|         ph = ph_body | ||||
|         tf = ph.text_frame | ||||
|         for i, item in enumerate(bullet_items.split('[SPAN]')): | ||||
|             if i == 0: | ||||
|                 p = tf.paragraphs[0] | ||||
|             else: | ||||
|                 p = tf.add_paragraph() | ||||
|             p.text = item.strip() | ||||
|             p.level = 0 | ||||
|         return dict(status='added page') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def add_text_image_page(self, title: str, bullet_items: str, | ||||
|                             image: str) -> dict: | ||||
|         """Add a text page with one image. Image should be a path | ||||
|  | ||||
|         Args: | ||||
|             title (:class:`str`): the title of the page | ||||
|             bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. | ||||
|             image (:class:`str`): the path of the image | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         layout_name = self.theme_mapping[self.pointer.slide_master.name]['two'] | ||||
|         layout = next(i for i in self.pointer.slide_master.slide_layouts | ||||
|                       if i.name == layout_name) | ||||
|         slide = self.pointer.slides.add_slide(layout) | ||||
|         ph_title, ph_body1, ph_body2 = slide.placeholders | ||||
|         ph_title.text = title | ||||
|         ph = ph_body2 | ||||
|         image_pil = image.to_pil() | ||||
|         left = ph.left | ||||
|         width = ph.width | ||||
|         height = int(width / image_pil.width * image_pil.height) | ||||
|         top = (ph.top + (ph.top + ph.height)) // 2 - height // 2 | ||||
|         slide.shapes.add_picture(image.to_path(), left, top, width, height) | ||||
|  | ||||
|         ph = ph_body1 | ||||
|         tf = ph.text_frame | ||||
|         for i, item in enumerate(bullet_items.split('[SPAN]')): | ||||
|             if i == 0: | ||||
|                 p = tf.paragraphs[0] | ||||
|             else: | ||||
|                 p = tf.add_paragraph() | ||||
|             p.text = item.strip() | ||||
|             p.level = 0 | ||||
|  | ||||
|         return dict(status='added page') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def submit_file(self) -> dict: | ||||
|         """When all steps done, YOU MUST use submit_file() to submit your work. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx') | ||||
|         # self.pointer.save(file_path) | ||||
|         # retreival_url = upload_file(file_path) | ||||
|         self.pointer.save(self.location) | ||||
|         return dict(status=f'submitted. view ppt at {self.location}') | ||||
| @@ -1,11 +1,12 @@ | ||||
| import copy | ||||
| import io | ||||
| from contextlib import redirect_stdout | ||||
| from typing import Any, Optional | ||||
| from typing import Any, Optional, Type | ||||
|  | ||||
| from func_timeout import FunctionTimedOut, func_set_timeout | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
|  | ||||
|  | ||||
| @@ -29,72 +30,70 @@ class GenericRuntime: | ||||
|         return eval(expr, self._global_vars) | ||||
|  | ||||
|  | ||||
| DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数, | ||||
| 函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: | ||||
| ```python | ||||
| # import 依赖包 | ||||
| import xxx | ||||
| def solution(): | ||||
|     # 初始化一些变量 | ||||
|     variable_names_with_real_meaning = xxx | ||||
|     # 步骤一 | ||||
|     mid_variable = func(variable_names_with_real_meaning) | ||||
|     # 步骤 x | ||||
|     mid_variable = func(mid_variable) | ||||
|     # 最后结果 | ||||
|     final_answer =  func(mid_variable) | ||||
|     return final_answer | ||||
| ```""" | ||||
|  | ||||
|  | ||||
| class PythonInterpreter(BaseAction): | ||||
|     """A Python executor that can execute Python scripts. | ||||
|  | ||||
|     Args: | ||||
|         description (str): The description of the action. Defaults to | ||||
|             DEFAULT_DESCRIPTION. | ||||
|         answer_symbol (str, Optional): the answer symbol from LLM | ||||
|         answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``. | ||||
|         answer_expr (str, Optional): the answer function name of the Python | ||||
|             script. Default to 'solution()'. | ||||
|         answer_from_stdout (boolean): whether the execution results is from | ||||
|             stdout. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class nameDefaults to None. | ||||
|             script. Defaults to ``'solution()'``. | ||||
|         answer_from_stdout (boolean, Optional): whether the execution results is from | ||||
|             stdout. Defaults to ``False``. | ||||
|         timeout (int, Optional): Upper bound of waiting time for Python script execution. | ||||
|             Defaults to ``20``. | ||||
|         description (dict, Optional): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|         timeout (int): Upper bound of waiting time for Python script execution. | ||||
|             ``True``. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  description: str = DEFAULT_DESCRIPTION, | ||||
|                  answer_symbol: Optional[str] = None, | ||||
|                  answer_expr: Optional[str] = 'solution()', | ||||
|                  answer_from_stdout: bool = False, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None, | ||||
|                  timeout: int = 20) -> None: | ||||
|         super().__init__(description, name, enable, disable_description) | ||||
|  | ||||
|                  timeout: int = 20, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True) -> None: | ||||
|         super().__init__(description, parser, enable) | ||||
|         self.answer_symbol = answer_symbol | ||||
|         self.answer_expr = answer_expr | ||||
|         self.answer_from_stdout = answer_from_stdout | ||||
|         self.timeout = timeout | ||||
|  | ||||
|     def __call__(self, command: str) -> ActionReturn: | ||||
|     @tool_api | ||||
|     def run(self, command: str) -> ActionReturn: | ||||
|         """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: | ||||
|         ```python | ||||
|         # import 依赖包 | ||||
|         import xxx | ||||
|         def solution(): | ||||
|             # 初始化一些变量 | ||||
|             variable_names_with_real_meaning = xxx | ||||
|             # 步骤一 | ||||
|             mid_variable = func(variable_names_with_real_meaning) | ||||
|             # 步骤 x | ||||
|             mid_variable = func(mid_variable) | ||||
|             # 最后结果 | ||||
|             final_answer =  func(mid_variable) | ||||
|             return final_answer | ||||
|         ``` | ||||
|  | ||||
|         Args: | ||||
|             command (:class:`str`): Python code snippet | ||||
|         """ | ||||
|         self.runtime = GenericRuntime() | ||||
|         try: | ||||
|             tool_return = func_set_timeout(self.timeout)(self._call)(command) | ||||
|         except FunctionTimedOut as e: | ||||
|             tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|             tool_return = ActionReturn(type=self.name) | ||||
|             tool_return.errmsg = repr(e) | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|         return tool_return | ||||
|  | ||||
|     def _call(self, command: str) -> ActionReturn: | ||||
|         tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|         tool_return = ActionReturn(type=self.name) | ||||
|         try: | ||||
|             if '```python' in command: | ||||
|                 command = command.split('```python')[1].split('```')[0] | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| # flake8: noqa | ||||
| import ast | ||||
| import copy | ||||
| import platform | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
|  | ||||
| @@ -220,21 +219,20 @@ class AutoGPTProtocol: | ||||
|             dict(role='user', content=self.triggering_prompt)) | ||||
|         return formatted_data | ||||
|  | ||||
|     def format_response(self, action_return): | ||||
|     def format_response(self, action_return) -> dict: | ||||
|         """format the final response at current step. | ||||
|  | ||||
|         Args: | ||||
|             action_return (ActionReturn): return value of the current action. | ||||
|  | ||||
|         Returns: | ||||
|             str: the final response at current step. | ||||
|             dict: the final response at current step. | ||||
|         """ | ||||
|         if action_return.state == ActionStatusCode.SUCCESS: | ||||
|             response = action_return.result['text'] | ||||
|             response = f'Command {action_return.type} returned: {response}' | ||||
|             response = f'Command {action_return.type} returned: {response.format_result()}' | ||||
|         else: | ||||
|             response = action_return.errmsg | ||||
|         return response | ||||
|         return dict(role='system', content=response) | ||||
|  | ||||
|  | ||||
| class AutoGPT(BaseAgent): | ||||
| @@ -261,30 +259,26 @@ class AutoGPT(BaseAgent): | ||||
|         super().__init__( | ||||
|             llm=llm, action_executor=action_executor, protocol=protocol) | ||||
|  | ||||
|     def chat(self, goal: str) -> AgentReturn: | ||||
|         self._inner_history = [] | ||||
|     def chat(self, goal: str, **kwargs) -> AgentReturn: | ||||
|         inner_history = [] | ||||
|         agent_return = AgentReturn() | ||||
|         default_response = 'Sorry that I cannot answer your question.' | ||||
|         for _ in range(self.max_turn): | ||||
|             prompt = self._protocol.format( | ||||
|                 goal=goal, | ||||
|                 inner_history=self._inner_history, | ||||
|                 inner_history=inner_history, | ||||
|                 action_executor=self._action_executor) | ||||
|             response = self._llm.generate_from_template(prompt, 512) | ||||
|             self._inner_history.append( | ||||
|                 dict(role='assistant', content=response)) | ||||
|             response = self._llm.chat(prompt, **kwargs) | ||||
|             inner_history.append(dict(role='assistant', content=response)) | ||||
|             action, action_input = self._protocol.parse( | ||||
|                 response, self._action_executor) | ||||
|             action_return: ActionReturn = self._action_executor( | ||||
|                 action, action_input) | ||||
|             agent_return.actions.append(action_return) | ||||
|             if action_return.type == self._action_executor.finish_action.name: | ||||
|                 agent_return.response = action_return.result['text'] | ||||
|                 agent_return.response = action_return.format_result() | ||||
|                 return agent_return | ||||
|             self._inner_history.append( | ||||
|                 dict( | ||||
|                     role='system', | ||||
|                     content=self._protocol.format_response(action_return))) | ||||
|         agent_return.inner_steps = copy.deepcopy(self._inner_history) | ||||
|             inner_history.append(self._protocol.format_response(action_return)) | ||||
|         agent_return.inner_steps = inner_history | ||||
|         agent_return.response = default_response | ||||
|         return agent_return | ||||
|   | ||||
| @@ -1,5 +1,3 @@ | ||||
| from typing import List | ||||
|  | ||||
| from lagent.actions import ActionExecutor | ||||
| from lagent.actions.base_action import BaseAction | ||||
| from lagent.llms.base_llm import BaseModel | ||||
| @@ -19,8 +17,6 @@ class BaseAgent: | ||||
|  | ||||
|     def __init__(self, llm: BaseModel, action_executor: ActionExecutor, | ||||
|                  protocol: object) -> None: | ||||
|  | ||||
|         self._session_history = [] | ||||
|         self._llm = llm | ||||
|         self._action_executor = action_executor | ||||
|         self._protocol = protocol | ||||
| @@ -41,9 +37,5 @@ class BaseAgent: | ||||
|         """ | ||||
|         self._action_executor.del_action(name) | ||||
|  | ||||
|     def chat(self, message: str) -> AgentReturn: | ||||
|     def chat(self, message: str, **kwargs) -> AgentReturn: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     def session_history(self) -> List: | ||||
|         return self._session_history | ||||
|   | ||||
							
								
								
									
										363
									
								
								lagent/agents/internlm2_agent.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										363
									
								
								lagent/agents/internlm2_agent.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,363 @@ | ||||
| import json | ||||
| import logging | ||||
| from copy import deepcopy | ||||
| from typing import Dict, List, Union, Optional | ||||
|  | ||||
| from ilagent.schema import AgentReturn, AgentStatusCode | ||||
|  | ||||
| from lagent import BaseAgent | ||||
| from lagent.actions import ActionExecutor | ||||
| from lagent.llms import BaseAPIModel, BaseModel | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
|  | ||||
| API_PREFIX = ( | ||||
|     "This is the subfunction for tool '{tool_name}', you can use this tool. " | ||||
|     'The description of this function is: \n{description}') | ||||
|  | ||||
| META_INS = ('You are InternLM, a large language model trained by PJLab. ' | ||||
|             'Answer as concisely as possible. ' | ||||
|             '当开启工具以及代码时,根据需求选择合适的工具进行调用') | ||||
|  | ||||
| INTERPRETER_CN = ('你现在可以通过如下格式向 Jupyter Notebook 发送并执行代码:' | ||||
|                   '\n<|action_start|><|interpreter|>```python\n\n代码\n\n```\n' | ||||
|                   '\n当遇到以下问题时,请使用上述格式调用 Jupyter Notebook 去解决,并根据执行结果做出友好的回复:\n' | ||||
|                   '1. 文件操作和数据导入,比如处理CSV、JSON等格式文件\n' | ||||
|                   '2. 数据分析或处理,比如数据操作或图像绘制如折线图、柱状图等\n' | ||||
|                   '3. 数学相关的问题。当遇到数学问题时,你需要分析题目,并给出代码去解决这个题目') | ||||
|  | ||||
| PLUGIN_CN = ( | ||||
|     '你可以使用如下工具:' | ||||
|     '\n{prompt}\n' | ||||
|     '当你需要使用工具时,你可以使用如下格式:\n' | ||||
|     '<|action_start|><|plugin|>{{"name": "工具名称", "parameters": {{参数}}}}\n' | ||||
|     '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' | ||||
|     '同时注意你可以使用的工具,不要随意捏造!') | ||||
|  | ||||
|  | ||||
| class Interlm2Protocol: | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         meta_prompt: str=META_INS, | ||||
|         interpreter_prompt: str=INTERPRETER_CN, | ||||
|         plugin_prompt: str=PLUGIN_CN, | ||||
|         few_shot: Optional[List]=None, | ||||
|         language: Dict=dict( | ||||
|             begin='', | ||||
|             end='', | ||||
|             belong='assistant', | ||||
|         ), | ||||
|         tool: Dict=dict( | ||||
|             begin='{start_token}{name}\n', | ||||
|             start_token='<|action_start|>', | ||||
|             name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'), | ||||
|             belong='assistant', | ||||
|             end='<|action_end|>\n', | ||||
|         ), | ||||
|         execute: Dict = dict( | ||||
|             role='execute', begin='', end='', fallback_role='environment'), | ||||
|     ) -> None: | ||||
|         self.meta_prompt = meta_prompt | ||||
|         self.interpreter_prompt = interpreter_prompt | ||||
|         self.plugin_prompt = plugin_prompt | ||||
|         self.roles_cfg = dict(tool=tool, language=language) | ||||
|         self.language = language | ||||
|         self.execute = execute | ||||
|         self.tool = tool | ||||
|         self.few_shot = few_shot | ||||
|  | ||||
|     def format_sub_role(self, messages: List[Dict]) -> List[Dict]: | ||||
|  | ||||
|         def format_interpreter(message): | ||||
|             if isinstance(message['content'], dict): | ||||
|                 # assert message['content']['name'] == 'IPythonInterpreter' | ||||
|                 return dict( | ||||
|                     role=message['role'], | ||||
|                     name=message['name'], | ||||
|                     content=message['content']['parameters']['command']) | ||||
|             else: | ||||
|                 return message | ||||
|  | ||||
|         def format_plugin(message): | ||||
|             if isinstance(message['content'], dict): | ||||
|                 return dict( | ||||
|                     role=message['role'], | ||||
|                     name=message['name'], | ||||
|                     content=json.dumps(message['content'])) | ||||
|             else: | ||||
|                 return message | ||||
|  | ||||
|         new_message = list() | ||||
|         for message in messages: | ||||
|             if message['role'] in [ | ||||
|                     'assistant', 'user', 'system', 'environment' | ||||
|             ]: | ||||
|                 new_message.append(message) | ||||
|                 continue | ||||
|             role_cfg = self.roles_cfg[message['role']] | ||||
|             begin = role_cfg['begin'] | ||||
|             if message['role'] == 'tool': | ||||
|                 if message['name'] == 'interpreter': | ||||
|                     message = format_interpreter(message) | ||||
|                 elif message['name'] == 'plugin': | ||||
|                     message = format_plugin(message) | ||||
|                 else: | ||||
|                     raise NotImplementedError | ||||
|                 begin = role_cfg['begin'].format( | ||||
|                     start_token=role_cfg.get('start_token', ''), | ||||
|                     name=role_cfg.get('name_map', {}).get(message['name'], '')) | ||||
|             new_content = begin + message['content'] + role_cfg['end'] | ||||
|             if role_cfg.get('fallback_role'): | ||||
|                 new_message.append( | ||||
|                     dict(role=role_cfg['fallback_role'], content=new_content)) | ||||
|             elif role_cfg.get('belong'): | ||||
|                 if new_message[-1]['role'] != role_cfg.get('belong'): | ||||
|                     new_message.append( | ||||
|                         dict(role=role_cfg.get('belong'), content=new_content)) | ||||
|                 else: | ||||
|                     new_message[-1]['content'] += new_content | ||||
|             else: | ||||
|                 new_message.append( | ||||
|                     dict(role=message['role'], content=new_content)) | ||||
|  | ||||
|         return new_message | ||||
|  | ||||
|     def format(self, | ||||
|                inner_step: List[Dict], | ||||
|                plugin_executor: ActionExecutor = None, | ||||
|                interpreter_executor: ActionExecutor = None, | ||||
|                **kwargs) -> list: | ||||
|         formatted = [] | ||||
|         if self.meta_prompt: | ||||
|             formatted.append(dict(role='system', content=self.meta_prompt)) | ||||
|         if interpreter_executor and self.interpreter_prompt: | ||||
|             interpreter_info = interpreter_executor.get_actions_info()[0] | ||||
|             interpreter_prompt = self.interpreter_prompt.format( | ||||
|                 code_prompt=interpreter_info['description']) | ||||
|             formatted.append( | ||||
|                 dict( | ||||
|                     role='system', | ||||
|                     content=interpreter_prompt, | ||||
|                     name='interpreter')) | ||||
|         if plugin_executor and plugin_executor.actions and self.plugin_prompt: | ||||
|             plugin_descriptions = [] | ||||
|             for api_info in plugin_executor.get_actions_info(): | ||||
|                 plugin = deepcopy(api_info) | ||||
|                 if isinstance(api_info, dict): | ||||
|                     tool_name = api_info['name'].split('.')[0] | ||||
|                     plugin['description'] = API_PREFIX.format( | ||||
|                         tool_name=tool_name, description=plugin['description']) | ||||
|                 plugin_descriptions.append(plugin) | ||||
|             plugin_prompt = self.plugin_prompt.format( | ||||
|                 prompt=json.dumps( | ||||
|                     plugin_descriptions, ensure_ascii=False, indent=4)) | ||||
|             formatted.append( | ||||
|                 dict(role='system', content=plugin_prompt, name='plugin')) | ||||
|         if self.few_shot: | ||||
|             for few_shot in self.few_shot: | ||||
|                 formatted += self.format_sub_role(few_shot) | ||||
|         formatted += self.format_sub_role(inner_step) | ||||
|         return formatted | ||||
|  | ||||
|     def parse(self, message, plugin_executor: ActionExecutor, | ||||
|               interpreter_executor: ActionExecutor): | ||||
|         if self.language['begin']: | ||||
|             message = message.split(self.language['begin'])[-1] | ||||
|         if self.tool['name_map']['plugin'] in message: | ||||
|             message, action = message.split( | ||||
|                 f"{self.tool['start_token']}{self.tool['name_map']['plugin']}") | ||||
|             action = action.split(self.tool['end'].strip())[0] | ||||
|             return 'plugin', message, action | ||||
|         if self.tool['name_map']['interpreter'] in message: | ||||
|             message, code = message.split( | ||||
|                 f"{self.tool['start_token']}" | ||||
|                 f"{self.tool['name_map']['interpreter']}") | ||||
|             code = code.split(self.tool['end'].strip())[0].strip() | ||||
|             return 'interpreter', message, dict( | ||||
|                 name=interpreter_executor.action_names()[0], | ||||
|                 parameters=dict(command=code)) | ||||
|         return None, message, None | ||||
|  | ||||
|     def format_response(self, action_return, name) -> dict: | ||||
|         if action_return.state == ActionStatusCode.SUCCESS: | ||||
|             response = action_return.format_result() | ||||
|         else: | ||||
|             response = action_return.errmsg | ||||
|         content = self.execute['begin'] + response + self.execute['end'] | ||||
|         if self.execute.get('fallback_role'): | ||||
|             return dict( | ||||
|                 role=self.execute['fallback_role'], content=content, name=name) | ||||
|         elif self.execute.get('belong'): | ||||
|             return dict( | ||||
|                 role=self.execute['belong'], content=content, name=name) | ||||
|         return dict(role=self.execute['role'], content=response, name=name) | ||||
|  | ||||
|  | ||||
| class Internlm2Agent(BaseAgent): | ||||
|  | ||||
|     def __init__(self, | ||||
|                  llm: Union[BaseModel, BaseAPIModel], | ||||
|                  plugin_executor: ActionExecutor = None, | ||||
|                  interpreter_executor: ActionExecutor = None, | ||||
|                  protocol=Interlm2Protocol(), | ||||
|                  max_turn: int = 3) -> None: | ||||
|         self.max_turn = max_turn | ||||
|         self._interpreter_executor = interpreter_executor | ||||
|         super().__init__( | ||||
|             llm=llm, action_executor=plugin_executor, protocol=protocol) | ||||
|  | ||||
|     def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             message = dict(role='user', content=message) | ||||
|         if isinstance(message, dict): | ||||
|             message = [message] | ||||
|         inner_history = message[:] | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|         for _ in range(self.max_turn): | ||||
|             # list of dict | ||||
|             prompt = self._protocol.format( | ||||
|                 inner_step=inner_history, | ||||
|                 plugin_executor=self._action_executor, | ||||
|                 interpreter_executor=self._interpreter_executor, | ||||
|             ) | ||||
|             response = self._llm.chat(prompt, **kwargs) | ||||
|             name, language, action = self._protocol.parse( | ||||
|                 message=response, | ||||
|                 plugin_executor=self._action_executor, | ||||
|                 interpreter_executor=self._interpreter_executor, | ||||
|             ) | ||||
|             if name: | ||||
|                 if name == 'plugin': | ||||
|                     if self._action_executor: | ||||
|                         executor = self._action_executor | ||||
|                     else: | ||||
|                         logging.info(msg='No plugin is instantiated!') | ||||
|                         continue | ||||
|                     try: | ||||
|                         action = json.loads(action) | ||||
|                     except Exception as e: | ||||
|                         logging.info( | ||||
|                             msg=f'Invaild action {e}') | ||||
|                         continue | ||||
|                 elif name == 'interpreter': | ||||
|                     if self._interpreter_executor: | ||||
|                         executor = self._interpreter_executor | ||||
|                     else: | ||||
|                         logging.info(msg='No interpreter is instantiated!') | ||||
|                         continue | ||||
|                 else: | ||||
|                     logging.info( | ||||
|                         msg=(f"Invalid name '{name}'. Currently only 'plugin' " | ||||
|                              "and 'interpreter' are supported.")) | ||||
|                     continue | ||||
|                 action_return: ActionReturn = executor(action['name'], | ||||
|                                                        action['parameters']) | ||||
|                 action_return.thought = language | ||||
|                 agent_return.actions.append(action_return) | ||||
|             inner_history.append(dict(role='language', content=language)) | ||||
|             if not name or action_return.type == executor.finish_action.name: | ||||
|                 agent_return.response = language | ||||
|                 agent_return.state = AgentStatusCode.END | ||||
|                 break | ||||
|             else: | ||||
|                 inner_history.append( | ||||
|                     dict(role='tool', content=action, name=name)) | ||||
|                 inner_history.append( | ||||
|                     self._protocol.format_response(action_return, name=name)) | ||||
|             yield agent_return | ||||
|         agent_return.inner_steps = inner_history[offset:] | ||||
|         yield agent_return | ||||
|  | ||||
|     def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             message = dict(role='user', content=message) | ||||
|         if isinstance(message, dict): | ||||
|             message = [message] | ||||
|         inner_history = message[:] | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|         last_agent_state = AgentStatusCode.SESSION_READY | ||||
|         for _ in range(self.max_turn): | ||||
|             # list of dict | ||||
|             prompt = self._protocol.format( | ||||
|                 inner_step=inner_history, | ||||
|                 plugin_executor=self._action_executor, | ||||
|                 interpreter_executor=self._interpreter_executor, | ||||
|             ) | ||||
|             response = '' | ||||
|             for model_state, res, _ in self._llm.stream_chat( | ||||
|                     prompt, **kwargs): | ||||
|                 response = res | ||||
|                 if model_state.value < 0: | ||||
|                     agent_return.state = model_state | ||||
|                     yield deepcopy(agent_return) | ||||
|                     return | ||||
|                 else: | ||||
|                     name, language, action = self._protocol.parse( | ||||
|                         message=response, | ||||
|                         plugin_executor=self._action_executor, | ||||
|                         interpreter_executor=self._interpreter_executor, | ||||
|                     ) | ||||
|                     if name: | ||||
|                         if model_state == AgentStatusCode.END: | ||||
|                             agent_state = last_agent_state + 1 | ||||
|                             if name == 'plugin': | ||||
|                                 if self._action_executor: | ||||
|                                     executor = self._action_executor | ||||
|                                 else: | ||||
|                                     logging.info( | ||||
|                                         msg='No plugin is instantiated!') | ||||
|                                     continue | ||||
|                                 try: | ||||
|                                     action = json.loads(action) | ||||
|                                 except Exception as e: | ||||
|                                     logging.info( | ||||
|                                         msg=f'Invaild action {e}') | ||||
|                                     continue | ||||
|                             elif name == 'interpreter': | ||||
|                                 if self._interpreter_executor: | ||||
|                                     executor = self._interpreter_executor | ||||
|                                 else: | ||||
|                                     logging.info( | ||||
|                                         msg='No interpreter is instantiated!') | ||||
|                                     continue | ||||
|                             agent_return.state = agent_state | ||||
|                             agent_return.response = action | ||||
|                         else: | ||||
|                             agent_state = ( | ||||
|                                 AgentStatusCode.PLUGIN_START if name | ||||
|                                 == 'plugin' else AgentStatusCode.CODING) | ||||
|                             if agent_state != last_agent_state: | ||||
|                                 # agent_return.state = agent_state | ||||
|                                 agent_return.response = language | ||||
|                                 yield deepcopy(agent_return) | ||||
|                             agent_return.state = agent_state | ||||
|                             agent_return.response = action | ||||
|                     else: | ||||
|                         agent_state = AgentStatusCode.STREAM_ING | ||||
|                         agent_return.state = agent_state | ||||
|                         agent_return.response = language | ||||
|                     last_agent_state = agent_state | ||||
|                     yield deepcopy(agent_return) | ||||
|             if name: | ||||
|                 action_return: ActionReturn = executor(action['name'], | ||||
|                                                        action['parameters']) | ||||
|                 action_return.thought = language | ||||
|                 agent_return.actions.append(action_return) | ||||
|             inner_history.append(dict(role='language', content=language)) | ||||
|             if not name or action_return.type == executor.finish_action.name: | ||||
|                 agent_return.response = language | ||||
|                 agent_return.state = AgentStatusCode.END | ||||
|                 break | ||||
|             else: | ||||
|                 inner_history.append( | ||||
|                     dict(role='tool', content=action, name=name)) | ||||
|                 inner_history.append( | ||||
|                     self._protocol.format_response(action_return, name=name)) | ||||
|                 agent_state += 1 | ||||
|                 agent_return.state = agent_state | ||||
|                 yield agent_return | ||||
|         agent_return.inner_steps = deepcopy(inner_history[offset:]) | ||||
|         agent_return.state = AgentStatusCode.END | ||||
|         yield agent_return | ||||
| @@ -1,4 +1,3 @@ | ||||
| import copy | ||||
| from typing import Dict, List, Tuple, Union | ||||
|  | ||||
| from lagent.actions import ActionExecutor | ||||
| @@ -43,7 +42,7 @@ To use a tool, please use the following format: | ||||
| The response after utilizing tools should using the following format: | ||||
| ``` | ||||
| {response}the results after call the tool. | ||||
| `` | ||||
| ``` | ||||
| If you already know the answer, or you do not need to use tools, | ||||
| please using the following format to reply: | ||||
| ``` | ||||
| @@ -170,20 +169,22 @@ class ReActProtocol: | ||||
|         action_input = arg_match[-1] | ||||
|         return thought, action.strip(), action_input.strip().strip('"') | ||||
|  | ||||
|     def format_response(self, action_return: ActionReturn) -> str: | ||||
|     def format_response(self, action_return: ActionReturn) -> dict: | ||||
|         """format the final response at current step. | ||||
|  | ||||
|         Args: | ||||
|             action_return (ActionReturn): return value of the current action. | ||||
|  | ||||
|         Returns: | ||||
|             str: the final response at current step. | ||||
|             dict: the final response at current step. | ||||
|         """ | ||||
|         if action_return.state == ActionStatusCode.SUCCESS: | ||||
|             response = action_return.result['text'] | ||||
|             response = action_return.format_result() | ||||
|         else: | ||||
|             response = action_return.errmsg | ||||
|         return self.response['begin'] + response + self.response['end'] | ||||
|         return dict( | ||||
|             role='system', | ||||
|             content=self.response['begin'] + response + self.response['end']) | ||||
|  | ||||
|  | ||||
| class ReAct(BaseAgent): | ||||
| @@ -210,20 +211,27 @@ class ReAct(BaseAgent): | ||||
|         super().__init__( | ||||
|             llm=llm, action_executor=action_executor, protocol=protocol) | ||||
|  | ||||
|     def chat(self, message: str) -> AgentReturn: | ||||
|         self._inner_history = [] | ||||
|         self._inner_history.append(dict(role='user', content=message)) | ||||
|     def chat(self, message: Union[str, dict, List[dict]], | ||||
|              **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             inner_history = [dict(role='user', content=message)] | ||||
|         elif isinstance(message, dict): | ||||
|             inner_history = [message] | ||||
|         elif isinstance(message, list): | ||||
|             inner_history = message[:] | ||||
|         else: | ||||
|             raise TypeError(f'unsupported type: {type(message)}') | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|         default_response = 'Sorry that I cannot answer your question.' | ||||
|         for turn in range(self.max_turn): | ||||
|             prompt = self._protocol.format( | ||||
|                 chat_history=self.session_history, | ||||
|                 inner_step=self._inner_history, | ||||
|                 chat_history=[], | ||||
|                 inner_step=inner_history, | ||||
|                 action_executor=self._action_executor, | ||||
|                 force_stop=(turn == self.max_turn - 1)) | ||||
|             response = self._llm.generate_from_template(prompt, 512) | ||||
|             self._inner_history.append( | ||||
|                 dict(role='assistant', content=response)) | ||||
|             response = self._llm.chat(prompt, **kwargs) | ||||
|             inner_history.append(dict(role='assistant', content=response)) | ||||
|             thought, action, action_input = self._protocol.parse( | ||||
|                 response, self._action_executor) | ||||
|             action_return: ActionReturn = self._action_executor( | ||||
| @@ -231,17 +239,10 @@ class ReAct(BaseAgent): | ||||
|             action_return.thought = thought | ||||
|             agent_return.actions.append(action_return) | ||||
|             if action_return.type == self._action_executor.finish_action.name: | ||||
|                 agent_return.response = action_return.result['text'] | ||||
|                 agent_return.response = action_return.format_result() | ||||
|                 break | ||||
|             self._inner_history.append( | ||||
|                 dict( | ||||
|                     role='system', | ||||
|                     content=self._protocol.format_response(action_return))) | ||||
|             inner_history.append(self._protocol.format_response(action_return)) | ||||
|         else: | ||||
|             agent_return.response = default_response | ||||
|         agent_return.inner_steps = copy.deepcopy(self._inner_history) | ||||
|         # only append the user and final response | ||||
|         self._session_history.append(dict(role='user', content=message)) | ||||
|         self._session_history.append( | ||||
|             dict(role='assistant', content=agent_return.response)) | ||||
|         agent_return.inner_steps = inner_history[offset:] | ||||
|         return agent_return | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| import copy | ||||
| import re | ||||
| import warnings | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
| @@ -192,7 +191,7 @@ class ReWOOProtocol: | ||||
|         worker_log = '' | ||||
|         for thought, action_return in zip(thought_list, action_return_list): | ||||
|             if action_return.state == ActionStatusCode.SUCCESS: | ||||
|                 action_resp = action_return.result['text'] | ||||
|                 action_resp = action_return.format_result() | ||||
|             else: | ||||
|                 action_resp = action_return.errmsg | ||||
|             worker_response = self.worker_prompt.format( | ||||
| @@ -227,9 +226,17 @@ class ReWOO(BaseAgent): | ||||
|  | ||||
|         self.max_turn = max_turn | ||||
|  | ||||
|     def chat(self, message: str) -> AgentReturn: | ||||
|         self._inner_history = [] | ||||
|         self._inner_history.append(dict(role='user', content=message)) | ||||
|     def chat(self, message: Union[str, dict, List[dict]], | ||||
|              **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             inner_history = [dict(role='user', content=message)] | ||||
|         elif isinstance(message, dict): | ||||
|             inner_history = [message] | ||||
|         elif isinstance(message, list): | ||||
|             inner_history = message[:] | ||||
|         else: | ||||
|             raise TypeError(f'unsupported type: {type(message)}') | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|  | ||||
|         # planner | ||||
| @@ -237,13 +244,12 @@ class ReWOO(BaseAgent): | ||||
|         reformat_request = '' | ||||
|         while turn_id < self.max_turn: | ||||
|             planner_prompt = self._protocol.format_planner( | ||||
|                 chat_history=self.session_history, | ||||
|                 inner_step=self._inner_history, | ||||
|                 chat_history=[], | ||||
|                 inner_step=inner_history, | ||||
|                 action_executor=self._action_executor, | ||||
|                 reformat_request=reformat_request) | ||||
|             response = self._llm.generate_from_template(planner_prompt, 512) | ||||
|             self._inner_history.append( | ||||
|                 dict(role='assistant', content=response)) | ||||
|             response = self._llm.chat(planner_prompt, **kwargs) | ||||
|             inner_history.append(dict(role='assistant', content=response)) | ||||
|             try: | ||||
|                 thoughts, actions, actions_input = self._protocol.parse_worker( | ||||
|                     response) | ||||
| @@ -267,18 +273,17 @@ class ReWOO(BaseAgent): | ||||
|             for prev_ptr in prev_ptrs: | ||||
|                 ptr_num = int(prev_ptr.strip('#E')) - 1  # start from 0 | ||||
|                 actions_input[action_id] = actions_input[action_id].replace( | ||||
|                     prev_ptr, action_responses[ptr_num].result['text']) | ||||
|                     prev_ptr, action_responses[ptr_num].format_result()) | ||||
|             action_return: ActionReturn = self._action_executor( | ||||
|                 actions[action_id], actions_input[action_id]) | ||||
|             action_responses.append(action_return) | ||||
|  | ||||
|         solver_prompt, worker_log = self._protocol.format_solver( | ||||
|             message, thoughts, action_responses) | ||||
|         self._inner_history.append(dict(role='system', content=worker_log)) | ||||
|         inner_history.append(dict(role='system', content=worker_log)) | ||||
|  | ||||
|         final_response = self._llm.generate_from_template(solver_prompt, 512) | ||||
|         self._inner_history.append( | ||||
|             dict(role='assistant', content=final_response)) | ||||
|         agent_return.inner_steps = copy.deepcopy(self._inner_history) | ||||
|         final_response = self._llm.chat(solver_prompt, **kwargs) | ||||
|         inner_history.append(dict(role='assistant', content=final_response)) | ||||
|         agent_return.inner_steps = inner_history[offset:] | ||||
|         agent_return.response = final_response | ||||
|         return agent_return | ||||
|   | ||||
| @@ -8,7 +8,3 @@ __all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI'] | ||||
| if is_module_exist('transformers'): | ||||
|     from .huggingface import HFTransformer, HFTransformerCasualLM  # noqa: F401 | ||||
|     __all__.extend(['HFTransformer', 'HFTransformerCasualLM']) | ||||
|  | ||||
| if is_module_exist('lmdeploy'): | ||||
|     from .lmdeploy import TritonClient, TurboMind  # noqa: F401 | ||||
|     __all__.extend(['TritonClient', 'TurboMind']) | ||||
|   | ||||
| @@ -27,7 +27,7 @@ class APITemplateParser: | ||||
|                     'role in meta prompt must be unique!' | ||||
|                 self.roles[item['role']] = item.copy() | ||||
|  | ||||
|     def parse_template(self, dialog: List[Union[str, List]]): | ||||
|     def __call__(self, dialog: List[Union[str, List]]): | ||||
|         """Parse the intermidate prompt template, and wrap it with meta | ||||
|         template if applicable. When the meta template is set and the input is | ||||
|         a list, the return value will be a list containing the full | ||||
| @@ -155,7 +155,14 @@ class BaseAPIModel(BaseModel): | ||||
|                  retry: int = 2, | ||||
|                  max_seq_len: int = 2048, | ||||
|                  template_parser: 'APITemplateParser' = APITemplateParser, | ||||
|                  meta_template: Optional[Dict] = None): | ||||
|                  meta_template: Optional[Dict] = None, | ||||
|                  *, | ||||
|                  max_out_len: int = 512, | ||||
|                  top_p: float = 0.8, | ||||
|                  top_k: float = None, | ||||
|                  temperature: float = 0.8, | ||||
|                  repetition_penalty: float = 0.0, | ||||
|                  stop_words: Union[List[str], str] = None): | ||||
|         self.model_type = model_type | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.meta_template = meta_template | ||||
| @@ -165,53 +172,21 @@ class BaseAPIModel(BaseModel): | ||||
|         if template_parser: | ||||
|             self.template_parser = template_parser(meta_template) | ||||
|  | ||||
|     @abstractclassmethod | ||||
|     def generate(self, inputs, max_out_len: int) -> List[str]: | ||||
|         """Generate results given a list of inputs. | ||||
|         self.gen_params = dict( | ||||
|             max_out_len=max_out_len, | ||||
|             top_p=top_p, | ||||
|             top_k=top_k, | ||||
|             temperature=temperature, | ||||
|             repetition_penalty=repetition_penalty, | ||||
|             stop_words=stop_words) | ||||
|  | ||||
|         Args: | ||||
|             inputs (List[str or list]): A list of strings or PromptDicts. | ||||
|                 The PromptDict should be organized in OpenCompass' | ||||
|                 API format. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|  | ||||
|         Returns: | ||||
|             List[str]: A list of generated strings. | ||||
|         """ | ||||
|  | ||||
|     def get_token_len(self, prompt: str) -> int: | ||||
|         """Get lengths of the tokenized string. Only English and Chinese | ||||
|         characters are counted for now. Users are encouraged to override this | ||||
|         method if more accurate length is needed. | ||||
|  | ||||
|         Args: | ||||
|             prompt (str): Input string. | ||||
|  | ||||
|         Returns: | ||||
|             int: Length of the input tokens | ||||
|         """ | ||||
|  | ||||
|         english_parts = re.findall(r'[A-Za-z0-9]+', prompt) | ||||
|         chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) | ||||
|  | ||||
|         # Count English words | ||||
|         english_count = sum(len(part.split()) for part in english_parts) | ||||
|  | ||||
|         # Count Chinese words | ||||
|         chinese_count = sum(len(part) for part in chinese_parts) | ||||
|  | ||||
|         return english_count + chinese_count | ||||
|  | ||||
|     def wait(self): | ||||
|     def _wait(self): | ||||
|         """Wait till the next query can be sent. | ||||
|  | ||||
|         Applicable in both single-thread and multi-thread environments. | ||||
|         """ | ||||
|         return self.token_bucket.get_token() | ||||
|  | ||||
|     def to(self, device): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class TokenBucket: | ||||
|     """A token bucket for rate limiting. | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| from abc import abstractclassmethod | ||||
| from copy import copy | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
| from warnings import warn | ||||
|  | ||||
|  | ||||
| class LMTemplateParser: | ||||
| @@ -21,7 +23,7 @@ class LMTemplateParser: | ||||
|                     'role in meta prompt must be unique!' | ||||
|                 self.roles[item['role']] = item.copy() | ||||
|  | ||||
|     def parse_template(self, dialog) -> str: | ||||
|     def __call__(self, dialog) -> str: | ||||
|         """Parse a prompt template, and wrap it with meta template if | ||||
|         applicable. | ||||
|  | ||||
| @@ -111,12 +113,17 @@ class BaseModel: | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  max_seq_len: int = 2048, | ||||
|                  tokenizer_only: bool = False, | ||||
|                  template_parser: 'LMTemplateParser' = LMTemplateParser, | ||||
|                  meta_template: Optional[List[Dict]] = None): | ||||
|                  meta_template: Optional[List[Dict]] = None, | ||||
|                  *, | ||||
|                  max_tokens: int = 512, | ||||
|                  top_p: float = 0.8, | ||||
|                  top_k: float = None, | ||||
|                  temperature: float = 0.8, | ||||
|                  repetition_penalty: float = 1.0, | ||||
|                  stop_words: Union[List[str], str] = None): | ||||
|         self.path = path | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.tokenizer_only = tokenizer_only | ||||
|         # meta template | ||||
|         self.template_parser = template_parser(meta_template) | ||||
| @@ -124,41 +131,99 @@ class BaseModel: | ||||
|         if meta_template and 'eos_token_id' in meta_template: | ||||
|             self.eos_token_id = meta_template['eos_token_id'] | ||||
|  | ||||
|     @abstractclassmethod | ||||
|     def generate(self, inputs: List[str], max_out_len: int) -> List[str]: | ||||
|         """Generate results given a list of inputs. | ||||
|         self.gen_params = dict( | ||||
|             max_tokens=max_tokens, | ||||
|             top_p=top_p, | ||||
|             top_k=top_k, | ||||
|             temperature=temperature, | ||||
|             repetition_penalty=repetition_penalty, | ||||
|             stop_words=stop_words) | ||||
|  | ||||
|     def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: | ||||
|         """Generate results given a str (or list of) inputs. | ||||
|  | ||||
|         Args: | ||||
|             inputs (List[str]): A list of strings. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|             inputs (Union[str, List[str]]): | ||||
|             gen_params (dict): The input params for generation. | ||||
|  | ||||
|         Returns: | ||||
|             List[str]: A list of generated strings. | ||||
|         """ | ||||
|             Union[str, List[str]]: A (list of) generated strings. | ||||
|  | ||||
|     def parse_template(self, dialog) -> str: | ||||
|         """Parse a prompt template, and wrap it with meta template if | ||||
|         applicable. | ||||
|         eg. | ||||
|             batched = True | ||||
|             if isinstance(inputs, str): | ||||
|                 inputs = [inputs] | ||||
|                 batched = False | ||||
|             response = [''] | ||||
|             if batched: | ||||
|                 return response | ||||
|             return response[0] | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def stream_generate(self, inputs: str, **gen_params) -> List[str]: | ||||
|         """Generate results as streaming given a str inputs. | ||||
|  | ||||
|         Args: | ||||
|             dialog (List[str or PromptList]): A prompt | ||||
|                 template (potentially before being wrapped by meta template). | ||||
|             mode (str): Parsing mode. Choices are 'ppl' and 'gen'. | ||||
|             inputs (str): | ||||
|             gen_params (dict): The input params for generation. | ||||
|  | ||||
|         Returns: | ||||
|             str: The final string. | ||||
|             str: A generated string. | ||||
|         """ | ||||
|         return self.template_parser.parse_template(dialog) | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def generate_from_template(self, templates, max_out_len: int, **kwargs): | ||||
|     def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params): | ||||
|         """Generate completion from a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             templates (List[PromptType]): A list of templates. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|             inputs (Union[List[dict], List[List[dict]]]): | ||||
|             gen_params (dict): The input params for generation. | ||||
|         Returns: | ||||
|         """ | ||||
|         inputs = self.parse_template(templates) | ||||
|         return self.generate(inputs, max_out_len=max_out_len, **kwargs) | ||||
|         if isinstance(inputs[0], list): | ||||
|             inputs = list() | ||||
|             for msg in inputs: | ||||
|                 inputs.append(self.template_parser(msg)) | ||||
|         else: | ||||
|             inputs = self.template_parser(inputs) | ||||
|         return self.generate(inputs, **gen_params) | ||||
|  | ||||
|     def to(self, device): | ||||
|         self.model.to(device) | ||||
|     def generate_from_template( | ||||
|             self,  | ||||
|             inputs: Union[List[dict], List[List[dict]]],  | ||||
|             **gen_params | ||||
|         ): | ||||
|         warn( | ||||
|             "This function will be deprecated after three months and will be replaced." | ||||
|             "Please use `.chat()`",  | ||||
|             DeprecationWarning, 2) | ||||
|         return self.chat(inputs, **gen_params) | ||||
|  | ||||
|     def stream_chat(self, inputs: List[dict], **gen_params): | ||||
|         """Generate results as streaming given a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[List[dict]): | ||||
|             gen_params (dict): The input params for generation. | ||||
|         Returns: | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def tokenize(self, prompts: Union[str, List[str], List[dict], | ||||
|                                       List[List[dict]]]): | ||||
|         """Tokenize the input prompts. | ||||
|  | ||||
|         Args: | ||||
|             prompts(str | List[str]): user's prompt, or a batch prompts | ||||
|  | ||||
|         Returns: | ||||
|             Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token | ||||
|             ids, ids' length and requested output length | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def update_gen_params(self, **kwargs): | ||||
|         gen_params = copy(self.gen_params) | ||||
|         gen_params.update(kwargs) | ||||
|         return gen_params | ||||
|   | ||||
| @@ -1,15 +1,19 @@ | ||||
| from typing import Dict, List, Optional | ||||
|  | ||||
| import torch | ||||
| import copy | ||||
| import warnings | ||||
| import logging | ||||
| from typing import Dict, List, Optional, Union | ||||
| from dataclasses import asdict | ||||
|  | ||||
| from .base_llm import BaseModel | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class HFTransformer(BaseModel): | ||||
|     """Model wrapper around HuggingFace general models. | ||||
|  | ||||
|     Adapted from OpenCompass (https://github.com/InternLM/opencompass | ||||
|     /blob/main/opencompass/models/huggingface.py) | ||||
|     Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ | ||||
|         chat/web_demo.py) | ||||
|  | ||||
|     Args: | ||||
|         path (str): The name or path to HuggingFace's model. | ||||
| @@ -25,52 +29,40 @@ class HFTransformer(BaseModel): | ||||
|         meta_template (Dict, optional): The model's meta prompt | ||||
|             template if needed, in case the requirement of injecting or | ||||
|             wrapping of any meta instructions. | ||||
|         extract_pred_after_decode (bool): Whether to extract the prediction | ||||
|             string from the decoded output string, instead of extract the | ||||
|             prediction tokens before decoding. Defaults to False. | ||||
|         batch_padding (bool): If False, inference with be performed in for-loop | ||||
|             without batch padding. | ||||
|  | ||||
|     Note: | ||||
|         About ``extract_pred_after_decode``: Commonly, we should extract the | ||||
|         the prediction tokens before decoding. But for some tokenizers using | ||||
|         ``sentencepiece``, like LLaMA,  this behavior may change the number of | ||||
|         whitespaces, which is harmful for Python programming tasks. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|             self, | ||||
|             path: str, | ||||
|             max_seq_len: int = 2048, | ||||
|             tokenizer_path: Optional[str] = None, | ||||
|             tokenizer_kwargs: dict = dict(), | ||||
|             tokenizer_only: bool = False, | ||||
|             model_kwargs: dict = dict(device_map='auto'), | ||||
|             meta_template: Optional[Dict] = [ | ||||
|                 dict(role='system', begin='<|System|>:', end='\n'), | ||||
|                 dict(role='user', begin='<|User|>:', end='\n'), | ||||
|                 dict( | ||||
|                     role='assistant', | ||||
|                     begin='<|Bot|>:', | ||||
|                     end='<eoa>\n', | ||||
|                     generate=True) | ||||
|             ],  # default meta template for InternLM-7b | ||||
|             extract_pred_after_decode: bool = False, | ||||
|             batch_padding: bool = False): | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  tokenizer_path: Optional[str] = None, | ||||
|                  tokenizer_kwargs: dict = dict(), | ||||
|                  tokenizer_only: bool = False, | ||||
|                  model_kwargs: dict = dict(device_map='auto'), | ||||
|                  meta_template: Optional[Dict] = None, | ||||
|                  **kwargs): | ||||
|         super().__init__( | ||||
|             path=path, | ||||
|             max_seq_len=max_seq_len, | ||||
|             tokenizer_only=tokenizer_only, | ||||
|             meta_template=meta_template) | ||||
|             meta_template=meta_template, | ||||
|             **kwargs) | ||||
|  | ||||
|         self._load_tokenizer( | ||||
|             path=path, | ||||
|             tokenizer_path=tokenizer_path, | ||||
|             tokenizer_kwargs=tokenizer_kwargs) | ||||
|         self.batch_padding = batch_padding | ||||
|         self.extract_pred_after_decode = extract_pred_after_decode | ||||
|         if not tokenizer_only: | ||||
|             self._load_model(path=path, model_kwargs=model_kwargs) | ||||
|  | ||||
|         from transformers.generation.utils import (LogitsProcessorList, | ||||
|                                                    StoppingCriteriaList) | ||||
|         self.logits_processor = LogitsProcessorList() | ||||
|         self.stopping_criteria = StoppingCriteriaList() | ||||
|         self.prefix_allowed_tokens_fn = None | ||||
|  | ||||
|         stop_words_id = [] | ||||
|         for sw in self.gen_params.get('stop_words', []): | ||||
|             stop_words_id.append(self.tokenizer(sw)['input_ids'][1]) | ||||
|         self.additional_eos_token_id = stop_words_id | ||||
|  | ||||
|     def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], | ||||
|                         tokenizer_kwargs: dict): | ||||
|         from transformers import AutoTokenizer | ||||
| @@ -82,57 +74,182 @@ class HFTransformer(BaseModel): | ||||
|             self.tokenizer.pad_token = self.tokenizer.eos_token | ||||
|  | ||||
|     def _load_model(self, path: str, model_kwargs: dict): | ||||
|         import torch | ||||
|         from transformers import AutoModel | ||||
|         model_kwargs.setdefault('torch_dtype', torch.float16) | ||||
|         self.model = AutoModel.from_pretrained( | ||||
|             path, trust_remote_code=True, **model_kwargs) | ||||
|         self.model.eval() | ||||
|  | ||||
|     def generate(self, inputs: List[str], max_out_len: int, | ||||
|                  **kwargs) -> List[str]: | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|         if self.extract_pred_after_decode: | ||||
|             prompt_lens = [len(input_) for input_ in inputs] | ||||
|     def tokenize(self, inputs: str): | ||||
|         assert isinstance(inputs, str) | ||||
|         inputs = self.tokenizer( | ||||
|             inputs, return_tensors='pt', return_length=True) | ||||
|         return inputs['input_ids'].tolist() | ||||
|  | ||||
|         input_ids = self.tokenizer( | ||||
|             inputs, truncation=True, | ||||
|             max_length=self.max_seq_len - max_out_len)['input_ids'] | ||||
|         input_ids = torch.tensor(input_ids, device=self.model.device) | ||||
|         outputs = self.model.generate( | ||||
|             input_ids=input_ids, max_new_tokens=max_out_len, **kwargs) | ||||
|     def generate( | ||||
|         self, | ||||
|         inputs: List[str], | ||||
|         do_sample=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         for chunk in self.stream_generate(inputs, do_sample, **kwargs): | ||||
|             response = chunk | ||||
|         return response | ||||
|  | ||||
|         if not self.extract_pred_after_decode: | ||||
|             outputs = outputs[:, input_ids.shape[1]:] | ||||
|     def stream_generate( | ||||
|         self, | ||||
|         inputs: List[str], | ||||
|         do_sample=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         import torch | ||||
|         from torch import nn | ||||
|         with torch.no_grad(): | ||||
|             batched = True | ||||
|             if isinstance(inputs, str): | ||||
|                 inputs = [inputs] | ||||
|                 batched = False | ||||
|             # import pdb; pdb.set_trace() | ||||
|             inputs = self.tokenizer( | ||||
|                 inputs, padding=True, return_tensors='pt', return_length=True) | ||||
|             input_length = inputs['length'] | ||||
|             for k, v in inputs.items(): | ||||
|                 inputs[k] = v.cuda() | ||||
|             input_ids = inputs['input_ids'] | ||||
|             attention_mask = inputs['attention_mask'] | ||||
|             batch_size, input_ids_seq_length = input_ids.shape[ | ||||
|                 0], input_ids.shape[-1]  # noqa: F841  # pylint: disable=W0612 | ||||
|             generation_config = self.model.generation_config | ||||
|             generation_config = copy.deepcopy(generation_config) | ||||
|             new_gen_params = self.update_gen_params(**kwargs) | ||||
|             generation_config.update(**new_gen_params) | ||||
|             generation_config.update(**kwargs) | ||||
|             model_kwargs = generation_config.to_dict() | ||||
|             model_kwargs['attention_mask'] = attention_mask | ||||
|             _, eos_token_id = (  # noqa: F841  # pylint: disable=W0612 | ||||
|                 generation_config.bos_token_id, | ||||
|                 generation_config.eos_token_id, | ||||
|             ) | ||||
|             if isinstance(eos_token_id, int): | ||||
|                 eos_token_id = [eos_token_id] | ||||
|             if self.additional_eos_token_id is not None: | ||||
|                 eos_token_id.extend(self.additional_eos_token_id) | ||||
|             eos_token_id_tensor = torch.tensor(eos_token_id).to( | ||||
|                 input_ids.device) if eos_token_id is not None else None | ||||
|             has_default_max_length = ( | ||||
|                 kwargs.get('max_length') is None | ||||
|                 and generation_config.max_length is not None) | ||||
|             if has_default_max_length and generation_config.max_new_tokens is None: | ||||
|                 warnings.warn( | ||||
|                     f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " | ||||
|                     'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we' | ||||
|                     ' recommend using `max_new_tokens` to control the maximum length of the generation.', | ||||
|                     UserWarning, | ||||
|                 ) | ||||
|             elif generation_config.max_new_tokens is not None: | ||||
|                 generation_config.max_length = ( | ||||
|                     generation_config.max_new_tokens + input_ids_seq_length) | ||||
|                 if not has_default_max_length: | ||||
|                     logger.warn(  # pylint: disable=W4902 | ||||
|                         f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=' | ||||
|                         f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ' | ||||
|                         'Please refer to the documentation for more information. ' | ||||
|                         '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)', | ||||
|                         UserWarning, | ||||
|                     ) | ||||
|  | ||||
|         decodeds = self.tokenizer.batch_decode( | ||||
|             outputs, skip_special_tokens=True) | ||||
|         if self.extract_pred_after_decode: | ||||
|             decodeds = [ | ||||
|                 token[len_:] for token, len_ in zip(decodeds, prompt_lens) | ||||
|             ] | ||||
|             if input_ids_seq_length >= generation_config.max_length: | ||||
|                 input_ids_string = 'input_ids' | ||||
|                 logger.warning( | ||||
|                     f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to' | ||||
|                     f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider' | ||||
|                     ' increasing `max_new_tokens`.') | ||||
|  | ||||
|         return decodeds[0] | ||||
|             # 2. Set generation parameters if not already defined | ||||
|             logits_processor = self.logits_processor | ||||
|             stopping_criteria = self.stopping_criteria | ||||
|  | ||||
|     def generate_from_template(self, templates, max_out_len: int, **kwargs): | ||||
|         """Generate completion from a list of templates. | ||||
|             logits_processor = self.model._get_logits_processor( | ||||
|                 generation_config=generation_config, | ||||
|                 input_ids_seq_length=input_ids_seq_length, | ||||
|                 encoder_input_ids=input_ids, | ||||
|                 prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, | ||||
|                 logits_processor=logits_processor, | ||||
|             ) | ||||
|  | ||||
|         Args: | ||||
|             templates (List[PromptType]): A list of templates. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|         """ | ||||
|         inputs = self.parse_template(templates) | ||||
|         response = self.generate(inputs, max_out_len=max_out_len, **kwargs) | ||||
|         end_token = self.template_parser.meta_template[0]['end'].strip() | ||||
|         # return response.replace( | ||||
|         #     self.template_parser.roles['assistant']['end'].strip(), | ||||
|         #     '').strip() | ||||
|         return response.split(end_token.strip())[0] | ||||
|             stopping_criteria = self.model._get_stopping_criteria( | ||||
|                 generation_config=generation_config, | ||||
|                 stopping_criteria=stopping_criteria) | ||||
|             logits_warper = self.model._get_logits_warper(generation_config) | ||||
|  | ||||
|             unfinished_sequences = input_ids.new(batch_size).fill_(1) | ||||
|             scores = None | ||||
|             while True: | ||||
|                 model_inputs = self.model.prepare_inputs_for_generation( | ||||
|                     input_ids, **model_kwargs) | ||||
|                 # forward pass to get next token | ||||
|                 outputs = self.model( | ||||
|                     **model_inputs, | ||||
|                     return_dict=True, | ||||
|                     output_attentions=False, | ||||
|                     output_hidden_states=False, | ||||
|                 ) | ||||
|  | ||||
|                 next_token_logits = outputs.logits[:, -1, :] | ||||
|  | ||||
|                 # pre-process distribution | ||||
|                 next_token_scores = logits_processor(input_ids, | ||||
|                                                      next_token_logits) | ||||
|                 next_token_scores = logits_warper(input_ids, next_token_scores) | ||||
|  | ||||
|                 # sample | ||||
|                 probs = nn.functional.softmax(next_token_scores, dim=-1) | ||||
|                 if do_sample: | ||||
|                     next_tokens = torch.multinomial( | ||||
|                         probs, num_samples=1).squeeze(1) | ||||
|                 else: | ||||
|                     next_tokens = torch.argmax(probs, dim=-1) | ||||
|  | ||||
|                 # update generated ids, model inputs, and length for next step | ||||
|                 input_ids = torch.cat([input_ids, next_tokens[:, None]], | ||||
|                                       dim=-1) | ||||
|                 model_kwargs = self.model._update_model_kwargs_for_generation( | ||||
|                     outputs, model_kwargs, is_encoder_decoder=False) | ||||
|                 unfinished_sequences = unfinished_sequences.mul( | ||||
|                     next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( | ||||
|                         eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) | ||||
|                 # output_token_ids = input_ids.cpu()[:, input_length:].tolist() | ||||
|                 output_token_ids = input_ids.cpu().tolist() | ||||
|                 for i in range(len(output_token_ids)): | ||||
|                     output_token_ids[i] = output_token_ids[i][:][ | ||||
|                         input_length[i]:] | ||||
|                     # Find the first occurrence of an EOS token in the sequence | ||||
|                     first_eos_idx = next( | ||||
|                         (idx | ||||
|                          for idx, token_id in enumerate(output_token_ids[i]) | ||||
|                          if token_id in eos_token_id), None) | ||||
|                     # If an EOS token is found, only the previous part of it is retained | ||||
|                     if first_eos_idx is not None: | ||||
|                         output_token_ids[i] = output_token_ids[ | ||||
|                             i][:first_eos_idx] | ||||
|  | ||||
|                 response = self.tokenizer.batch_decode(output_token_ids) | ||||
|                 # print(response) | ||||
|                 if not batched: | ||||
|                     yield response[0] | ||||
|                 else: | ||||
|                     yield response | ||||
|                 # stop when each sentence is finished, or if we exceed the maximum length | ||||
|                 if (unfinished_sequences.max() == 0 | ||||
|                         or stopping_criteria(input_ids, scores)): | ||||
|                     break | ||||
|  | ||||
|  | ||||
| class HFTransformerCasualLM(HFTransformer): | ||||
|  | ||||
|     def _load_model(self, path: str, model_kwargs: dict): | ||||
|         import torch | ||||
|         from transformers import AutoModelForCausalLM | ||||
|         model_kwargs.setdefault('torch_dtype', torch.float16) | ||||
|         self.model = AutoModelForCausalLM.from_pretrained( | ||||
|   | ||||
| @@ -1,143 +0,0 @@ | ||||
| import dataclasses | ||||
| import os.path as osp | ||||
| import random | ||||
|  | ||||
| import lmdeploy.turbomind.chat as tm_chat | ||||
| from lmdeploy import turbomind as tm | ||||
| from lmdeploy.serve.turbomind.chatbot import Chatbot, Session, get_logger | ||||
| from lmdeploy.tokenizer import Tokenizer | ||||
|  | ||||
| from .base_llm import BaseModel | ||||
|  | ||||
|  | ||||
| class TritonClient(Chatbot, BaseModel): | ||||
|  | ||||
|     def __init__(self, meta_template=None, **kwargs): | ||||
|         """TritonClient is a wrapper of TritonClient for LLM. | ||||
|  | ||||
|         Args: | ||||
|             model_name (str): the name of the model | ||||
|             max_out_len (int): the expected generated token numbers | ||||
|             log_level (str): log level | ||||
|         """ | ||||
|         BaseModel.__init__(self, meta_template=meta_template, path=None) | ||||
|         Chatbot.__init__(self, **kwargs) | ||||
|  | ||||
|     def generate(self, | ||||
|                  prompt: str, | ||||
|                  session_id: int = 2967, | ||||
|                  request_id: str = '', | ||||
|                  max_out_len: int = None, | ||||
|                  sequence_start: bool = True, | ||||
|                  sequence_end: bool = True, | ||||
|                  *args, | ||||
|                  **kwargs): | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             session_id (int): the identical id of a session | ||||
|             prompt (str): user's prompt in this round conversation | ||||
|             request_id (str): the identical id of this round conversation | ||||
|             max_out_len (int): the expected generated token numbers | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|  | ||||
|         Returns: | ||||
|             tuple(Status, str, int): status, text/chat completion, | ||||
|             generated token number | ||||
|         """ | ||||
|         assert isinstance(session_id, int), \ | ||||
|             f'INT session id is required, but got {type(session_id)}' | ||||
|  | ||||
|         logger = get_logger(log_level=self.log_level) | ||||
|         logger.info(f'session {session_id}, request_id {request_id}, ' | ||||
|                     f'max_out_len {max_out_len}') | ||||
|  | ||||
|         if self._session is None: | ||||
|             sequence_start = True | ||||
|             self._session = Session(session_id=session_id) | ||||
|         elif self._session.status == 0: | ||||
|             logger.error(f'session {session_id} has been ended. Please set ' | ||||
|                          f'`sequence_start` be True if you want to restart it') | ||||
|             return '' | ||||
|  | ||||
|         self._session.status = 1 | ||||
|         self._session.request_id = request_id | ||||
|         self._session.response = '' | ||||
|  | ||||
|         status, res, _ = None, '', 0 | ||||
|         for status, res, _ in self._stream_infer(self._session, prompt, | ||||
|                                                  max_out_len, sequence_start, | ||||
|                                                  sequence_end): | ||||
|             if status.value < 0: | ||||
|                 break | ||||
|         if status.value == 0: | ||||
|             self._session.histories = \ | ||||
|                 self._session.histories + self._session.prompt + \ | ||||
|                 self._session.response | ||||
|             return res | ||||
|         else: | ||||
|             return '' | ||||
|  | ||||
|     def generate_from_template(self, templates, max_out_len: int, **kwargs): | ||||
|         """Generate completion from a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             templates (List[PromptType]): A list of templates. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|         """ | ||||
|         inputs = self.parse_template(templates) | ||||
|         response = self.generate(inputs, max_out_len=max_out_len, **kwargs) | ||||
|         # The return of tuibomind contains <eoa>, here we hard code removes it. | ||||
|         response = response.replace( | ||||
|             self.template_parser.roles['assistant']['end'].strip(), | ||||
|             '').strip() | ||||
|         return response | ||||
|  | ||||
|  | ||||
| class TurboMind(BaseModel): | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  max_seq_len: int = 8192, | ||||
|                  tokenizer_only: bool = False, | ||||
|                  meta_template=None, | ||||
|                  tp=1, | ||||
|                  **kwargs): | ||||
|         super().__init__( | ||||
|             path=path, | ||||
|             max_seq_len=max_seq_len, | ||||
|             tokenizer_only=tokenizer_only, | ||||
|             meta_template=meta_template) | ||||
|         tokenizer_model_path = osp.join(path, 'triton_models', 'tokenizer') | ||||
|         self.tokenizer = Tokenizer(tokenizer_model_path) | ||||
|         self.tm_model = tm.TurboMind( | ||||
|             path, eos_id=self.tokenizer.eos_token_id, tp=tp) | ||||
|         self.generator = self.tm_model.create_instance() | ||||
|  | ||||
|         model_name = self.tm_model.model_name | ||||
|         self.model = tm_chat.MODELS.get(model_name)( | ||||
|             capability='completion', **kwargs) | ||||
|         self._session_id = 0 | ||||
|  | ||||
|     def generate(self, prompt, **kwargs): | ||||
|         seed = random.getrandbits(64) | ||||
|         input_ids = self.tokenizer.encode(prompt) | ||||
|         gen_param = tm_chat.get_gen_param( | ||||
|             'completion', self.model.sampling_param, step=0, nth_round=1) | ||||
|         response_size = 0 | ||||
|         self._session_id = (self._session_id + 1) % 100000 | ||||
|         for outputs in self.generator.stream_infer( | ||||
|                 session_id=self._session_id, | ||||
|                 input_ids=[input_ids], | ||||
|                 stream_output=False, | ||||
|                 **dataclasses.asdict(gen_param), | ||||
|                 ignore_eos=False, | ||||
|                 random_seed=seed): | ||||
|             res, tokens = outputs[0] | ||||
|             # decode res | ||||
|             response = self.tokenizer.decode( | ||||
|                 res.tolist(), offset=response_size) | ||||
|             response = tm_chat.valid_str(response) | ||||
|         return response | ||||
							
								
								
									
										383
									
								
								lagent/llms/lmdepoly_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										383
									
								
								lagent/llms/lmdepoly_wrapper.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,383 @@ | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
| from lagent.llms.base_llm import BaseModel | ||||
| from lagent.schema import AgentStatusCode | ||||
| from lagent.utils.util import filter_suffix | ||||
|  | ||||
|  | ||||
| class TritonClient(BaseModel): | ||||
|     """TritonClient is a wrapper of TritonClient for LLM. | ||||
|  | ||||
|     Args: | ||||
|         tritonserver_addr (str): the address in format "ip:port" of | ||||
|             triton inference server | ||||
|         model_name (str): the name of the model | ||||
|         session_len (int): the context size | ||||
|         max_tokens (int): the expected generated token numbers | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  tritonserver_addr: str, | ||||
|                  model_name: str, | ||||
|                  session_len: int = 32768, | ||||
|                  log_level: str = 'WARNING', | ||||
|                  **kwargs): | ||||
|         super().__init__(path=None, **kwargs) | ||||
|         from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode | ||||
|         self.state_map = { | ||||
|             StatusCode.TRITON_STREAM_END: AgentStatusCode.END, | ||||
|             StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR, | ||||
|             StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED, | ||||
|             StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING, | ||||
|             StatusCode.TRITON_SESSION_OUT_OF_LIMIT: | ||||
|             AgentStatusCode.SESSION_OUT_OF_LIMIT, | ||||
|             StatusCode.TRITON_SESSION_INVALID_ARG: | ||||
|             AgentStatusCode.SESSION_INVALID_ARG, | ||||
|             StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY | ||||
|         } | ||||
|         self.chatbot = Chatbot( | ||||
|             tritonserver_addr=tritonserver_addr, | ||||
|             model_name=model_name, | ||||
|             session_len=session_len, | ||||
|             log_level=log_level, | ||||
|             **kwargs) | ||||
|  | ||||
|     def generate(self, | ||||
|                  inputs: Union[str, List[str]], | ||||
|                  session_id: int = 2967, | ||||
|                  request_id: str = '', | ||||
|                  max_tokens: int = 512, | ||||
|                  sequence_start: bool = True, | ||||
|                  sequence_end: bool = True, | ||||
|                  **kwargs): | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (str, List[str]): user's prompt(s) in this round | ||||
|             session_id (int): the identical id of a session | ||||
|             request_id (str): the identical id of this round conversation | ||||
|             max_tokens (int): the expected generated token numbers | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|  | ||||
|         Returns: | ||||
|             (a list of/batched) text/chat completion | ||||
|         """ | ||||
|         from lmdeploy.serve.turbomind.chatbot import Session, get_logger | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|         prompt = inputs | ||||
|  | ||||
|         assert isinstance(session_id, int), \ | ||||
|             f'INT session id is required, but got {type(session_id)}' | ||||
|  | ||||
|         logger = get_logger(log_level=self.chatbot.log_level) | ||||
|         logger.info(f'session {session_id}, request_id {request_id}, ' | ||||
|                     f'max_out_len {max_tokens}') | ||||
|  | ||||
|         if self.chatbot._session is None: | ||||
|             sequence_start = True | ||||
|             self.chatbot._session = Session(session_id=session_id) | ||||
|         elif self.chatbot._session.status == 0: | ||||
|             logger.error(f'session {session_id} has been ended. Please set ' | ||||
|                          f'`sequence_start` be True if you want to restart it') | ||||
|             return '' | ||||
|  | ||||
|         self.chatbot._session.status = 1 | ||||
|         self.chatbot._session.request_id = request_id | ||||
|         self.chatbot._session.response = '' | ||||
|  | ||||
|         self.chatbot.cfg = self._update_gen_params( | ||||
|             max_tokens=max_tokens, **kwargs) | ||||
|  | ||||
|         status, res, _ = None, '', 0 | ||||
|         for status, res, _ in self.chatbot._stream_infer( | ||||
|                 self.chatbot._session, prompt, max_tokens, sequence_start, | ||||
|                 sequence_end): | ||||
|             if status.value < 0: | ||||
|                 break | ||||
|         if status.value == 0: | ||||
|             self.chatbot._session.histories = ( | ||||
|                 self.chatbot._session.histories + | ||||
|                 self.chatbot._session.prompt + self.chatbot._session.response) | ||||
|             # remove stop_words | ||||
|             res = filter_suffix(res, self.gen_params.get('stop_words')) | ||||
|             return res | ||||
|         else: | ||||
|             return '' | ||||
|  | ||||
|     def stream_chat(self, | ||||
|                     inputs: List[dict], | ||||
|                     session_id: int = 2967, | ||||
|                     request_id: str = '', | ||||
|                     max_tokens: int = 512, | ||||
|                     sequence_start: bool = True, | ||||
|                     sequence_end: bool = True, | ||||
|                     **kwargs): | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             session_id (int): the identical id of a session | ||||
|             inputs (List[dict]): user's inputs in this round conversation | ||||
|             request_id (str): the identical id of this round conversation | ||||
|             max_tokens (int): the expected generated token numbers | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|  | ||||
|         Returns: | ||||
|             tuple(Status, str, int): status, text/chat completion, | ||||
|             generated token number | ||||
|         """ | ||||
|         from lmdeploy.serve.turbomind.chatbot import (Session, StatusCode, | ||||
|                                                       get_logger) | ||||
|         assert isinstance(session_id, int), \ | ||||
|             f'INT session id is required, but got {type(session_id)}' | ||||
|  | ||||
|         logger = get_logger(log_level=self.chatbot.log_level) | ||||
|         logger.info(f'session {session_id}, request_id {request_id}, ' | ||||
|                     f'max_out_len {max_tokens}') | ||||
|  | ||||
|         if self.chatbot._session is None: | ||||
|             sequence_start = True | ||||
|             self.chatbot._session = Session(session_id=session_id) | ||||
|         elif self.chatbot._session.status == 0: | ||||
|             logger.error(f'session {session_id} has been ended. Please set ' | ||||
|                          f'`sequence_start` be True if you want to restart it') | ||||
|             return '' | ||||
|  | ||||
|         self.chatbot._session.status = 1 | ||||
|         self.chatbot._session.request_id = request_id | ||||
|         self.chatbot._session.response = '' | ||||
|  | ||||
|         self.chatbot.cfg = self._update_gen_params( | ||||
|             max_tokens=max_tokens, **kwargs) | ||||
|         prompt = self.template_parser(inputs) | ||||
|  | ||||
|         status, res, _ = None, '', 0 | ||||
|         for status, res, _ in self.chatbot._stream_infer( | ||||
|                 self.chatbot._session, prompt, max_tokens, sequence_start, | ||||
|                 sequence_end): | ||||
|             if status == StatusCode.TRITON_STREAM_END:  # remove stop_words | ||||
|                 res = filter_suffix(res, self.gen_params.get('stop_words')) | ||||
|             if status.value < 0: | ||||
|                 break | ||||
|             else: | ||||
|                 yield self.state_map.get(status), res, _ | ||||
|         if status.value == 0: | ||||
|             self.chatbot._session.histories = ( | ||||
|                 self.chatbot._session.histories + | ||||
|                 self.chatbot._session.prompt + self.chatbot._session.response) | ||||
|             yield self.state_map.get(status), res, _ | ||||
|         else: | ||||
|             return '' | ||||
|  | ||||
|     def _update_gen_params(self, **kwargs): | ||||
|         import mmengine | ||||
|         new_gen_params = self.update_gen_params(**kwargs) | ||||
|         self.gen_params['stop_words'] = new_gen_params.pop('stop_words') | ||||
|         stop_words = self.chatbot._stop_words( | ||||
|             self.gen_params.get('stop_words')) | ||||
|         cfg = mmengine.Config( | ||||
|             dict( | ||||
|                 session_len=self.chatbot.model.session_len, | ||||
|                 stop_words=stop_words, | ||||
|                 bad_words=self.chatbot.cfg.bad_words, | ||||
|                 **new_gen_params)) | ||||
|         return cfg | ||||
|  | ||||
|  | ||||
| class LMDeployPipeline(BaseModel): | ||||
|     """ | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|             It could be one of the following options: | ||||
|                     - i) A local directory path of a turbomind model which is | ||||
|                         converted by `lmdeploy convert` command or download  | ||||
|                         from ii) and iii). | ||||
|                     - ii) The model_id of a lmdeploy-quantized model hosted | ||||
|                         inside a model repo on huggingface.co, such as | ||||
|                         "InternLM/internlm-chat-20b-4bit", | ||||
|                         "lmdeploy/llama2-chat-70b-4bit", etc. | ||||
|                     - iii) The model_id of a model hosted inside a model repo | ||||
|                         on huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|                         "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||||
|                         and so on. | ||||
|         model_name (str): needed when model_path is a pytorch model on | ||||
|             huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|             "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. | ||||
|         tp (int): | ||||
|         pipeline_cfg (dict): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  model_name: Optional[str] = None, | ||||
|                  tp: int = 1, | ||||
|                  pipeline_cfg=dict(), | ||||
|                  **kwargs): | ||||
|  | ||||
|         super().__init__(path=path, **kwargs) | ||||
|         from lmdeploy import pipeline | ||||
|         self.model = pipeline( | ||||
|             model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg) | ||||
|  | ||||
|     def generate(self, | ||||
|                  inputs: Union[str, List[str]], | ||||
|                  do_preprocess=None, | ||||
|                  **kwargs): | ||||
|         batched = True | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|             batched = False | ||||
|         prompt = inputs | ||||
|         gen_params = self.update_gen_params(**kwargs) | ||||
|         response = self.model.batch_infer( | ||||
|             prompt, do_preprocess=do_preprocess, **gen_params) | ||||
|         response = [resp.text for resp in response] | ||||
|         # remove stop_words | ||||
|         response = filter_suffix(response, self.gen_params.get('stop_words')) | ||||
|         if batched: | ||||
|             return response | ||||
|         return response[0] | ||||
|  | ||||
|  | ||||
| class LMDeployServer(BaseModel): | ||||
|     """ | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|             It could be one of the following options: | ||||
|                 - i) A local directory path of a turbomind model which is | ||||
|                     converted by `lmdeploy convert` command or download from | ||||
|                     ii) and iii). | ||||
|                 - ii) The model_id of a lmdeploy-quantized model hosted | ||||
|                     inside a model repo on huggingface.co, such as | ||||
|                     "InternLM/internlm-chat-20b-4bit", | ||||
|                     "lmdeploy/llama2-chat-70b-4bit", etc. | ||||
|                 - iii) The model_id of a model hosted inside a model repo | ||||
|                     on huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|                     "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||||
|                     and so on. | ||||
|         model_name (str): needed when model_path is a pytorch model on | ||||
|             huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|             "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. | ||||
|         server_name (str): host ip for serving | ||||
|         server_port (int): server port | ||||
|         tp (int): | ||||
|         log_level (str): set log level whose value among | ||||
|             [CRITICAL, ERROR, WARNING, INFO, DEBUG] | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  model_name: Optional[str] = None, | ||||
|                  server_name: str = '0.0.0.0', | ||||
|                  server_port: int = 23333, | ||||
|                  tp: int = 1, | ||||
|                  log_level: str = 'WARNING', | ||||
|                  serve_cfg=dict(), | ||||
|                  **kwargs): | ||||
|         super().__init__(path=path, **kwargs) | ||||
|         # TODO get_logger issue in multi processing | ||||
|         import lmdeploy | ||||
|         self.client = lmdeploy.serve( | ||||
|             model_path=self.path, | ||||
|             model_name=model_name, | ||||
|             server_name=server_name, | ||||
|             server_port=server_port, | ||||
|             tp=tp, | ||||
|             log_level=log_level, | ||||
|             **serve_cfg) | ||||
|  | ||||
|     def generate(self, | ||||
|                  inputs: Union[str, List[str]], | ||||
|                  session_id: int = 2967, | ||||
|                  sequence_start: bool = True, | ||||
|                  sequence_end: bool = True, | ||||
|                  ignore_eos: bool = False, | ||||
|                  timeout: int = 30, | ||||
|                  **kwargs) -> List[str]: | ||||
|         batched = True | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|             batched = False | ||||
|  | ||||
|         gen_params = self.update_gen_params(**kwargs) | ||||
|  | ||||
|         resp = [''] * len(inputs) | ||||
|         for text in self.client.completions_v1( | ||||
|                 self.path, | ||||
|                 inputs, | ||||
|                 session_id=session_id, | ||||
|                 sequence_start=sequence_start, | ||||
|                 sequence_end=sequence_end, | ||||
|                 stream=False, | ||||
|                 ignore_eos=ignore_eos, | ||||
|                 timeout=timeout, | ||||
|                 **gen_params): | ||||
|             resp = [ | ||||
|                 resp[i] + item['text'] | ||||
|                 for i, item in enumerate(text['choices']) | ||||
|             ] | ||||
|         # remove stop_words | ||||
|         resp = filter_suffix(resp, self.gen_params.get('stop_words')) | ||||
|         if not batched: | ||||
|             return resp[0] | ||||
|         return resp | ||||
|  | ||||
|     def stream_chat(self, | ||||
|                     inputs: List[dict], | ||||
|                     session_id=0, | ||||
|                     sequence_start: bool = True, | ||||
|                     sequence_end: bool = True, | ||||
|                     stream: bool = True, | ||||
|                     ignore_eos: bool = False, | ||||
|                     timeout: int = 30, | ||||
|                     **kwargs): | ||||
|  | ||||
|         gen_params = self.update_gen_params(**kwargs) | ||||
|         prompt = self.template_parser(inputs) | ||||
|  | ||||
|         resp = '' | ||||
|         finished = False | ||||
|         stop_words = self.gen_params.get('stop_words') | ||||
|         for text in self.client.completions_v1( | ||||
|                 self.path, | ||||
|                 prompt, | ||||
|                 session_id=session_id, | ||||
|                 sequence_start=sequence_start, | ||||
|                 sequence_end=sequence_end, | ||||
|                 stream=stream, | ||||
|                 ignore_eos=ignore_eos, | ||||
|                 timeout=timeout, | ||||
|                 **gen_params): | ||||
|             resp += text['choices'][0]['text'] | ||||
|             if not resp: | ||||
|                 continue | ||||
|             # remove stop_words | ||||
|             for sw in stop_words: | ||||
|                 if sw in resp: | ||||
|                     resp = filter_suffix(resp, stop_words) | ||||
|                     finished = True | ||||
|                     break | ||||
|             yield AgentStatusCode.STREAM_ING, resp, None | ||||
|             if finished: | ||||
|                 break | ||||
|         yield AgentStatusCode.END, resp, None | ||||
|  | ||||
|  | ||||
| class LMDeployClient(LMDeployServer): | ||||
|     """ | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|         url (str): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, path: str, url: str, **kwargs): | ||||
|         BaseModel.__init__(self, path=path, **kwargs) | ||||
|         from lmdeploy.serve.openai.api_client import APIClient | ||||
|         self.client = APIClient(url) | ||||
| @@ -1,7 +1,7 @@ | ||||
| import json | ||||
| import os | ||||
| import time | ||||
| # from concurrent.futures import ThreadPoolExecutor | ||||
| from concurrent.futures import ThreadPoolExecutor, wait | ||||
| from logging import getLogger | ||||
| from threading import Lock | ||||
| from typing import Dict, List, Optional, Union | ||||
| @@ -38,9 +38,8 @@ class GPTAPI(BaseAPIModel): | ||||
|             wrapping of any meta instructions. | ||||
|         openai_api_base (str): The base url of OpenAI's API. Defaults to | ||||
|             'https://api.openai.com/v1/chat/completions'. | ||||
|         temperature (float, optional): What sampling temperature to use. | ||||
|             If not None, will override the temperature in the `generate()` | ||||
|             call. Defaults to None. | ||||
|         gen_params: Default generation configuration which could be overrided | ||||
|             on the fly of generation. | ||||
|     """ | ||||
|  | ||||
|     is_api: bool = True | ||||
| @@ -58,16 +57,15 @@ class GPTAPI(BaseAPIModel): | ||||
|                      dict(role='assistant', api_role='assistant') | ||||
|                  ], | ||||
|                  openai_api_base: str = OPENAI_API_BASE, | ||||
|                  temperature: Optional[float] = None): | ||||
|  | ||||
|                  **gen_params): | ||||
|         super().__init__( | ||||
|             model_type=model_type, | ||||
|             max_seq_len=max_seq_len, | ||||
|             meta_template=meta_template, | ||||
|             query_per_second=query_per_second, | ||||
|             retry=retry) | ||||
|             retry=retry, | ||||
|             **gen_params) | ||||
|         self.logger = getLogger(__name__) | ||||
|         self.temperature = temperature | ||||
|  | ||||
|         if isinstance(key, str): | ||||
|             self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] | ||||
| @@ -97,67 +95,57 @@ class GPTAPI(BaseAPIModel): | ||||
|             context_window = 8192 | ||||
|         self.context_window = context_window | ||||
|  | ||||
|     def generate( | ||||
|     def chat( | ||||
|         self, | ||||
|         inputs: Union[List, str], | ||||
|         max_out_len: int = 512, | ||||
|         temperature: float = 0.7, | ||||
|     ) -> List[str]: | ||||
|         """Generate results given a list of inputs. | ||||
|         inputs: Union[List[dict], List[List[dict]]], | ||||
|         **gen_params, | ||||
|     ) -> Union[str, List[str]]: | ||||
|         """Generate responses given the contexts. | ||||
|  | ||||
|         Args: | ||||
|             inputs (List[str or List]): A list of strings or PromptDicts. | ||||
|                 The PromptDict should be organized in OpenCompass' | ||||
|                 API format. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|             temperature (float): What sampling temperature to use, | ||||
|                 between 0 and 2. Higher values like 0.8 will make the output | ||||
|                 more random, while lower values like 0.2 will make it more | ||||
|                 focused and deterministic. Defaults to 0.7. | ||||
|             inputs (Union[List[dict], List[List[dict]]]): a list of messages  | ||||
|                 or list of lists of messages | ||||
|             gen_params: additional generation configuration | ||||
|  | ||||
|         Returns: | ||||
|             List[str]: A list of generated strings. | ||||
|             Union[str, List[str]]: generated string(s) | ||||
|         """ | ||||
|         if self.temperature is not None: | ||||
|             temperature = self.temperature | ||||
|         return self._generate(inputs, max_out_len, temperature) | ||||
|         assert isinstance(inputs, list) | ||||
|         if isinstance(inputs[0], dict): | ||||
|             inputs = [inputs] | ||||
|         gen_params = {**self.gen_params, **gen_params} | ||||
|         with ThreadPoolExecutor(max_workers=20) as executor: | ||||
|             tasks = [ | ||||
|                 executor.submit(self._chat, messages, **gen_params) | ||||
|                 for messages in inputs | ||||
|             ] | ||||
|         wait(tasks) | ||||
|         ret = [task.result() for task in tasks] | ||||
|         return ret[0] if isinstance(inputs[0], dict) else ret | ||||
|  | ||||
|     def _generate(self, input: str or List, max_out_len: int, | ||||
|                   temperature: float) -> str: | ||||
|         """Generate results given a list of inputs. | ||||
|     def _chat(self, messages: List[dict], **gen_params) -> str: | ||||
|         """Generate completion from a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             inputs (str or List): A string or PromptDict. | ||||
|                 The PromptDict should be organized in OpenCompass' | ||||
|                 API format. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|             temperature (float): What sampling temperature to use, | ||||
|                 between 0 and 2. Higher values like 0.8 will make the output | ||||
|                 more random, while lower values like 0.2 will make it more | ||||
|                 focused and deterministic. | ||||
|             messages (List[dict]): a list of prompt dictionaries | ||||
|             gen_params: additional generation configuration | ||||
|  | ||||
|         Returns: | ||||
|             str: The generated string. | ||||
|         """ | ||||
|         assert isinstance(input, (str, list, dict)) | ||||
|  | ||||
|         if isinstance(input, str): | ||||
|             messages = [{'role': 'user', 'content': input}] | ||||
|         elif isinstance(input, dict): | ||||
|             messages = [input] | ||||
|         else: | ||||
|             messages = input | ||||
|         assert isinstance(messages, list) | ||||
|         gen_params = gen_params.copy() | ||||
|  | ||||
|         # Hold out 100 tokens due to potential errors in tiktoken calculation | ||||
|         max_out_len = min( | ||||
|             max_out_len, | ||||
|             self.context_window - self.get_token_len(str(input)) - 100) | ||||
|             gen_params.pop('max_out_len'), | ||||
|             self.context_window - len(self.tokenize(str(input))) - 100) | ||||
|         if max_out_len <= 0: | ||||
|             return '' | ||||
|  | ||||
|         max_num_retries = 0 | ||||
|         while max_num_retries < self.retry: | ||||
|             self.wait() | ||||
|             self._wait() | ||||
|  | ||||
|             with Lock(): | ||||
|                 if len(self.invalid_keys) == len(self.keys): | ||||
| @@ -192,8 +180,9 @@ class GPTAPI(BaseAPIModel): | ||||
|                     messages=messages, | ||||
|                     max_tokens=max_out_len, | ||||
|                     n=1, | ||||
|                     stop=None, | ||||
|                     temperature=temperature, | ||||
|                     stop=gen_params.pop('stop_words'), | ||||
|                     frequency_penalty=gen_params.pop('repetition_penalty'), | ||||
|                     **gen_params, | ||||
|                 ) | ||||
|                 raw_response = requests.post( | ||||
|                     self.url, headers=header, data=json.dumps(data)) | ||||
| @@ -225,18 +214,16 @@ class GPTAPI(BaseAPIModel): | ||||
|                            f'{max_num_retries} times. Check the logs for ' | ||||
|                            'details.') | ||||
|  | ||||
|     def get_token_len(self, prompt: str) -> int: | ||||
|         """Get lengths of the tokenized string. Only English and Chinese | ||||
|         characters are counted for now. Users are encouraged to override this | ||||
|         method if more accurate length is needed. | ||||
|     def tokenize(self, prompt: str) -> list: | ||||
|         """Tokenize the input prompt. | ||||
|  | ||||
|         Args: | ||||
|             prompt (str): Input string. | ||||
|  | ||||
|         Returns: | ||||
|             int: Length of the input tokens | ||||
|             list: token ids | ||||
|         """ | ||||
|         import tiktoken | ||||
|         self.tiktoken = tiktoken | ||||
|         enc = self.tiktoken.encoding_for_model(self.model_type) | ||||
|         return len(enc.encode(prompt)) | ||||
|         return enc.encode(prompt) | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from enum import Enum | ||||
| from typing import Dict, List, Optional, Union | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
| from lagent.utils import is_module_exist | ||||
|  | ||||
| @@ -33,47 +33,49 @@ class ActionValidCode(int, Enum): | ||||
|  | ||||
| @dataclass | ||||
| class ActionReturn: | ||||
|     args: Dict | ||||
|     args: Optional[dict] = None | ||||
|     url: Optional[str] = None | ||||
|     type: Optional[str] = None | ||||
|     result: Optional[str] = None | ||||
|     result: Optional[List[dict]] = None | ||||
|     errmsg: Optional[str] = None | ||||
|     state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS | ||||
|     thought: Optional[str] = None | ||||
|     valid: Optional[ActionValidCode] = ActionValidCode.OPEN | ||||
|  | ||||
|     def format_result(self) -> str: | ||||
|         """Concatenate items in result""" | ||||
|         result = [] | ||||
|         for item in self.result or []: | ||||
|             if item['type'] == 'text': | ||||
|                 result.append(item['content']) | ||||
|             else: | ||||
|                 result.append(f"[{item['type']}]({item['content']})") | ||||
|         result = '\n'.join(result) | ||||
|         return result | ||||
|  | ||||
| class AgentStatusCode(Enum): | ||||
|     END = 0  # end of streaming | ||||
|  | ||||
| # 需要集成int,如此asdict可以把AgentStatusCode 转换成 int | ||||
| class AgentStatusCode(int, Enum): | ||||
|     END = 0  # end of streaming 返回本次history | ||||
|     STREAM_ING = 1  # response is in streaming | ||||
|     SERVER_ERR = -1  # triton server's error | ||||
|     SESSION_CLOSED = -2  # session has been closed | ||||
|     SESSION_OUT_OF_LIMIT = -3  # request length out of limit | ||||
|     CMD = 2  # return command | ||||
|     PLUGIN_START = 3  # start tool | ||||
|     PLUGIN_END = 4  # finish tool | ||||
|     PLUGIN_RETURN = 5  # finish tool | ||||
|  | ||||
|     CODING = 6  # start python | ||||
|     CODE_END = 7  # end python | ||||
|     CODE_RETURN = 8  # python return | ||||
|     SESSION_INVALID_ARG = -4  # invalid argument | ||||
|     SESSION_READY = 3  # session is ready for inference | ||||
|     SESSION_READY = 2  # session is ready for inference | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class AgentReturn: | ||||
|     state: Union[AgentStatusCode, int] = AgentStatusCode.END | ||||
|     actions: List[ActionReturn] = field(default_factory=list) | ||||
|     response: str = '' | ||||
|     inner_steps: List = field(default_factory=list) | ||||
|     errmsg: Optional[str] = None | ||||
|  | ||||
|  | ||||
| if is_module_exist('lmdeploy'): | ||||
|     from lmdeploy.serve.turbomind.chatbot import StatusCode | ||||
|     STATE_MAP = { | ||||
|         StatusCode.TRITON_STREAM_END: AgentStatusCode.END, | ||||
|         StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR, | ||||
|         StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED, | ||||
|         StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING, | ||||
|         StatusCode.TRITON_SESSION_OUT_OF_LIMIT: | ||||
|         AgentStatusCode.SESSION_OUT_OF_LIMIT, | ||||
|         StatusCode.TRITON_SESSION_INVALID_ARG: | ||||
|         AgentStatusCode.SESSION_INVALID_ARG, | ||||
|         StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY | ||||
|     } | ||||
| else: | ||||
|     STATE_MAP = {} | ||||
|   | ||||
							
								
								
									
										30
									
								
								lagent/utils/util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								lagent/utils/util.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
|  | ||||
| def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] = None) -> str: | ||||
|     """Filter response with suffixes. | ||||
|  | ||||
|     Args: | ||||
|         response (Union[str, List[str]]): generated responses by LLMs. | ||||
|         suffixes (str): a list of suffixes to be deleted. | ||||
|  | ||||
|     Return: | ||||
|         str: a clean response. | ||||
|     """ | ||||
|     if suffixes is None: | ||||
|         return response | ||||
|     batched = True | ||||
|     if isinstance(response, str): | ||||
|         response = [response] | ||||
|         batched = False | ||||
|     processed = [] | ||||
|     for resp in response: | ||||
|         for item in suffixes: | ||||
|             # if response.endswith(item): | ||||
|             #     response = response[:len(response) - len(item)] | ||||
|             if item in resp: | ||||
|                 resp = resp.split(item)[0] | ||||
|         processed.append(resp) | ||||
|     if not batched: | ||||
|         return processed[0] | ||||
|     return processed | ||||
| @@ -1,9 +1,13 @@ | ||||
| docutils==0.16.0 | ||||
| docutils==0.18.1 | ||||
| markdown>=3.4.0 | ||||
| myst-parser | ||||
| -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme | ||||
| sphinx==4.0.2 | ||||
| myst-nb | ||||
| # -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme | ||||
| # sphinx==4.0.2 | ||||
| sphinx==6.1.0 | ||||
| sphinx-tabs | ||||
| sphinx_copybutton | ||||
| sphinx_markdown_tables>=0.0.16 | ||||
| sphinx-rtd-theme==1.3.0 | ||||
| tabulate | ||||
| astroid<3.0.0 | ||||
| sphinx-autoapi | ||||
|   | ||||
| @@ -3,3 +3,9 @@ func_timeout | ||||
| jsonschema | ||||
| requests | ||||
| tiktoken | ||||
| griffe | ||||
| phx-class-registry | ||||
| jupyter | ||||
| jupyter_client | ||||
| json5 | ||||
| pillow | ||||
|   | ||||
		Reference in New Issue
	
	Block a user