Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1eb10a3467 | ||
|
|
b5533b036d | ||
|
|
39e00f25df | ||
|
|
183ef77a84 | ||
|
|
a2a0d91b29 | ||
|
|
85853da263 | ||
|
|
487919476b | ||
|
|
1292ea7dc7 | ||
|
|
016ee62cc6 | ||
|
|
1b5fd996d3 | ||
|
|
799c6a32c4 | ||
|
|
bdd6a9b4df | ||
|
|
a53bad2077 | ||
|
|
54e5b615f4 | ||
|
|
73de598d42 | ||
|
|
adefd97a27 | ||
|
|
fdaacb87a4 | ||
|
|
c42c884900 | ||
|
|
30f77daa07 | ||
|
|
e46e59cc6a | ||
|
|
740c3f960b | ||
|
|
93162b27b1 | ||
|
|
b49397aa93 | ||
|
|
06bedec5a9 |
@@ -2,7 +2,11 @@ version: 2
|
||||
|
||||
formats: all
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.10"
|
||||
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
|
||||
@@ -83,7 +83,7 @@ Below is an example of running ReWOO with GPT-3.5
|
||||
```python
|
||||
# Import necessary modules and classes from the "lagent" library.
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, LLMQA
|
||||
from lagent.actions import ActionExecutor, GoogleSearch
|
||||
from lagent.llms import GPTAPI
|
||||
|
||||
# Initialize the Language Model (llm) and provide your API key.
|
||||
@@ -92,14 +92,11 @@ llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
|
||||
# Initialize the Google Search tool and provide your API key.
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# Initialize the LLMQA tool using the Language Model (llm).
|
||||
llmqa_tool = LLMQA(llm)
|
||||
|
||||
# Create a chatbot by configuring the ReWOO agent.
|
||||
chatbot = ReWOO(
|
||||
llm=llm, # Provide the Language Model instance.
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, llmqa_tool] # Specify the actions the chatbot can perform.
|
||||
actions=[search_tool] # Specify the actions the chatbot can perform.
|
||||
),
|
||||
)
|
||||
|
||||
@@ -154,6 +151,7 @@ response = chatbot.chat(
|
||||
print(response.response) # Output the response generated by the chatbot.
|
||||
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
|
||||
```
|
||||
|
||||
### All Thanks To Our Contributors:
|
||||
<a href="https://github.com/InternLM/lagent/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=InternLM/lagent" />
|
||||
|
||||
14
docs/en/_templates/autoapi/index.rst
Normal file
14
docs/en/_templates/autoapi/index.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
API Reference
|
||||
=============
|
||||
|
||||
This page contains auto-generated API reference documentation.
|
||||
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for page in pages %}
|
||||
{% if page.top_level_object and page.display %}
|
||||
{{ page.include_path }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
112
docs/en/_templates/autoapi/python/module.rst
Normal file
112
docs/en/_templates/autoapi/python/module.rst
Normal file
@@ -0,0 +1,112 @@
|
||||
{% if not obj.display %}
|
||||
:orphan:
|
||||
|
||||
{% endif %}
|
||||
:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}`
|
||||
=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }}
|
||||
|
||||
.. py:module:: {{ obj.name }}
|
||||
|
||||
{% if obj.docstring %}
|
||||
.. autoapi-nested-parse::
|
||||
|
||||
{{ obj.docstring|indent(3) }}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% block subpackages %}
|
||||
{% set visible_subpackages = obj.subpackages|selectattr("display")|list %}
|
||||
{% if visible_subpackages %}
|
||||
Subpackages
|
||||
-----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for subpackage in visible_subpackages %}
|
||||
{{ subpackage.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block submodules %}
|
||||
{% set visible_submodules = obj.submodules|selectattr("display")|list %}
|
||||
{% if visible_submodules %}
|
||||
Submodules
|
||||
----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 1
|
||||
|
||||
{% for submodule in visible_submodules %}
|
||||
{{ submodule.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% if obj.type is equalto("package") %}
|
||||
{% set visible_children = obj.children|selectattr("display")|list %}
|
||||
{% else %}
|
||||
{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %}
|
||||
{% endif %}
|
||||
{% if visible_children %}
|
||||
{{ obj.type|title }} Contents
|
||||
{{ "-" * obj.type|length }}---------
|
||||
|
||||
{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %}
|
||||
{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %}
|
||||
{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %}
|
||||
{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %}
|
||||
{% block classes scoped %}
|
||||
{% if visible_classes %}
|
||||
Classes
|
||||
~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for klass in visible_classes %}
|
||||
{{ klass.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block functions scoped %}
|
||||
{% if visible_functions %}
|
||||
Functions
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for function in visible_functions %}
|
||||
{{ function.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block attributes scoped %}
|
||||
{% if visible_attributes %}
|
||||
Attributes
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for attribute in visible_attributes %}
|
||||
{{ attribute.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% endif %}
|
||||
{% for obj_item in visible_children %}
|
||||
{{ obj_item.render()|indent(0) }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
110
docs/en/conf.py
110
docs/en/conf.py
@@ -11,17 +11,16 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import pytorch_sphinx_theme
|
||||
|
||||
sys.path.insert(0, os.path.abspath('../../'))
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Lagent'
|
||||
copyright = '2020-2030, InternLM'
|
||||
author = 'InternLM'
|
||||
language = 'en'
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../lagent/version.py'
|
||||
@@ -36,97 +35,74 @@ release = __version__
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx_rtd_theme',
|
||||
'myst_nb',
|
||||
'autoapi.extension',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx_copybutton',
|
||||
'myst_parser',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.autodoc.typehints',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx_tabs.tabs',
|
||||
]
|
||||
|
||||
nb_output_stderr = 'remove-warn'
|
||||
autodoc_typehints = 'description'
|
||||
autosummary_generate = True # Turn on sphinx.ext.autosummary
|
||||
|
||||
# Ignore >>> when copying code
|
||||
copybutton_prompt_text = r'>>> |\.\.\. '
|
||||
copybutton_prompt_is_regexp = True
|
||||
|
||||
myst_enable_extensions = ['colon_fence']
|
||||
# sphinx-autoapi configuration
|
||||
autoapi_dirs = ['../../lagent']
|
||||
autoapi_options = [
|
||||
'members',
|
||||
'undoc-members',
|
||||
'show-inheritance',
|
||||
'show-module-summary',
|
||||
]
|
||||
autoapi_ignore = ['*migrations*', '*command.py', '*cli.py']
|
||||
autoapi_template_dir = '_templates/autoapi'
|
||||
autoapi_add_toctree_entry = False
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = {
|
||||
'.rst': 'restructuredtext',
|
||||
'.md': 'markdown',
|
||||
}
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
# html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = 'pytorch_sphinx_theme'
|
||||
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme_options = {
|
||||
'menu': [
|
||||
{
|
||||
'name': 'GitHub',
|
||||
'url': 'https://github.com/InternLM/lagent'
|
||||
},
|
||||
],
|
||||
# Specify the language of shared menu
|
||||
'menu_lang': 'en'
|
||||
'navigation_depth': 3,
|
||||
'titles_only': False,
|
||||
'style_nav_header_background': '#4fabab',
|
||||
}
|
||||
|
||||
language = 'en'
|
||||
html_context = {
|
||||
'display_github': True,
|
||||
'github_host': 'github.com',
|
||||
'github_user': 'InternLM',
|
||||
'github_repo': 'lagent',
|
||||
'github_version': 'main',
|
||||
'conf_py_path': '/docs/en/',
|
||||
}
|
||||
html_title = 'Lagent'
|
||||
html_logo = '../imgs/lagent_logo.png'
|
||||
html_favicon = '../imgs/lagent_icon.png'
|
||||
|
||||
master_doc = 'index'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
# so a file named 'default.css' will overwrite the builtin 'default.css'.
|
||||
html_static_path = ['_static']
|
||||
|
||||
html_css_files = [
|
||||
'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css',
|
||||
'css/readthedocs.css'
|
||||
]
|
||||
html_js_files = [
|
||||
'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js',
|
||||
'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js',
|
||||
'js/collapsed.js',
|
||||
'js/table.js',
|
||||
]
|
||||
|
||||
myst_heading_anchors = 4
|
||||
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'numpy': ('https://numpy.org/doc/stable', None),
|
||||
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||
}
|
||||
def custom_skip(app, what, name, obj, skip, options):
|
||||
if what in ['data', 'function', 'class'] and re.search('logger', name):
|
||||
skip = True
|
||||
return skip
|
||||
|
||||
|
||||
def builder_inited_handler(app):
|
||||
pass
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.connect('builder-inited', builder_inited_handler)
|
||||
def setup(sphinx):
|
||||
sphinx.connect('autoapi-skip-member', custom_skip)
|
||||
|
||||
19
docs/en/get_started/install.md
Normal file
19
docs/en/get_started/install.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Installation
|
||||
|
||||
## With pip
|
||||
|
||||
Install with pip (Recommended).
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
## From source
|
||||
|
||||
Optionally, you could also build Lagent from source in case you want to modify the code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
```
|
||||
@@ -1,4 +1,4 @@
|
||||
# OVERVIEW
|
||||
# Overview
|
||||
|
||||
This chapter introduces you to the framework of Lagent, and provides links to detailed tutorials about Lagent.
|
||||
|
||||
@@ -18,6 +18,6 @@ Lagent consists of 3 main parts, agents, llms, and actions.
|
||||
|
||||
Here is a detailed step-by-step guide to learn more about Lagent:
|
||||
|
||||
1. For installation instructions, please see [README](../README.md).
|
||||
1. For installation instructions, please see [README](https://github.com/InternLM/lagent/blob/main/README.md).
|
||||
|
||||
2. We provide several examples to build agents with Lagent in [examples](examples/) by simply run `python examples/react_example.py`.
|
||||
2. We provide several examples to build agents with Lagent in [examples](https://github.com/InternLM/lagent/tree/main/examples) by simply run `python examples/react_example.py`.
|
||||
|
||||
89
docs/en/get_started/quickstart.md
Normal file
89
docs/en/get_started/quickstart.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Quickstart
|
||||
|
||||
Using Lagent, you can easily build agents with just a few lines of code.
|
||||
|
||||
## Run a ReWOO agent with GPT-3.5
|
||||
|
||||
Below is an example of running ReWOO with GPT-3.5
|
||||
|
||||
```python
|
||||
# Import necessary modules and classes from the "lagent" library.
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.actions import ActionExecutor, GoogleSearch
|
||||
from lagent.llms import GPTAPI
|
||||
|
||||
# Initialize the Language Model (llm) and provide your API key.
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
|
||||
|
||||
# Initialize the Google Search tool and provide your API key.
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# Create a chatbot by configuring the ReWOO agent.
|
||||
chatbot = ReWOO(
|
||||
llm=llm, # Provide the Language Model instance.
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool] # Specify the actions the chatbot can perform.
|
||||
),
|
||||
)
|
||||
|
||||
# Ask the chatbot a question and store the response.
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
|
||||
# Print the chatbot's response.
|
||||
print(response.response) # Output the response generated by the chatbot.
|
||||
```
|
||||
|
||||
```python
|
||||
>>> Film director.
|
||||
```
|
||||
|
||||
## Run a ReAct agent with InternLM
|
||||
|
||||
NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first.
|
||||
|
||||
```python
|
||||
# Import necessary modules and classes from the "lagent" library.
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
|
||||
from lagent.llms import HFTransformer
|
||||
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
|
||||
# Initialize the HFTransformer-based Language Model (llm) and
|
||||
# provide the model name.
|
||||
llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META)
|
||||
|
||||
# Initialize the Google Search tool and provide your API key.
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# Initialize the Python Interpreter tool.
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
# Create a chatbot by configuring the ReAct agent.
|
||||
# Specify the actions the chatbot can perform.
|
||||
chatbot = ReAct(
|
||||
llm=llm, # Provide the Language Model instance.
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
# Ask the chatbot a mathematical question in LaTeX format.
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$')
|
||||
|
||||
# Print the chatbot's response.
|
||||
print(response.response) # Output the response generated by the chatbot.
|
||||
```
|
||||
|
||||
```python
|
||||
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
|
||||
```
|
||||
|
||||
## Run ReAct Web Demo
|
||||
|
||||
```python
|
||||
# You need to install streamlit first
|
||||
# pip install streamlit
|
||||
streamlit run examples/react_web_demo.py
|
||||
```
|
||||
|
||||
Then you can chat through the UI shown as below
|
||||

|
||||
@@ -8,13 +8,31 @@ You can switch between English and Chinese in the lower-left corner of the layou
|
||||
:caption: Get Started
|
||||
|
||||
get_started/overview.md
|
||||
get_started/install.md
|
||||
get_started/quickstart.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Tutorials
|
||||
|
||||
tutorials/action.md
|
||||
|
||||
.. toctree::
|
||||
:caption: Switch Language
|
||||
|
||||
switch_language.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: API Reference
|
||||
|
||||
autoapi/lagent/actions/index
|
||||
autoapi/lagent/agents/index
|
||||
autoapi/lagent/llms/index
|
||||
autoapi/lagent/utils/index
|
||||
autoapi/lagent/schema/index
|
||||
autoapi/lagent/version/index
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
396
docs/en/tutorials/action.md
Normal file
396
docs/en/tutorials/action.md
Normal file
@@ -0,0 +1,396 @@
|
||||
# Action
|
||||
|
||||
Actions, also called **tools**, provide a suite of functions LLM-driven agents can use to interact with the real world and perform complex tasks.
|
||||
|
||||
## Basic Concepts
|
||||
|
||||
### Tool & Toolkit
|
||||
|
||||
There are two categories of tools:
|
||||
|
||||
* tool: provide only one API to call.
|
||||
* toolkit: implement multiple APIs that undertake different sub-tasks.
|
||||
|
||||
### Tool Description
|
||||
|
||||
In Lagent, the tool description is a dictionary containing the action's core information of usage, observed by LLMs for decision-making.
|
||||
|
||||
For simple tools, the description can be created as follows
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'bold', # name of the tool
|
||||
'description': 'a function used to make text bold', # introduce the tool's function
|
||||
'parameters': [ # a list of parameters the tool take.
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text'], # specify names of parameters required
|
||||
}
|
||||
```
|
||||
|
||||
In some situations there may be optional `return_data`, `parameter_description` keys describing the returns and argument passing format respectively.
|
||||
|
||||
```{attention}
|
||||
`parameter_description` is usually inserted into the tool description automatically by the action's parser. It will be introduced in [Interface Design](#interface-design) .
|
||||
```
|
||||
|
||||
|
||||
For toolkits, the description is very similar but nest submethods
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'PhraseEmphasis', # name of the toolkit
|
||||
'description': 'a toolkit which provides different styles of text emphasis', # introduce the tool's function
|
||||
'api_list': [
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
},
|
||||
{
|
||||
'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Make Functions Tools
|
||||
|
||||
It's not necessary to prepare an extra description for a defined function. In Lagent we provide a decorator `tool_api` which can conveniently turn a function into a tool by automatically parsing the function's typehints and dosctrings to generate the description dictionary and binding it to an attribute `api_description`.
|
||||
|
||||
```python
|
||||
from lagent import tool_api
|
||||
|
||||
@tool_api
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text']}
|
||||
```
|
||||
|
||||
Once `returns_named_value` is enabled you should declare the name of the return data, which will be processed to form a new field `return_data`:
|
||||
|
||||
```python
|
||||
@tool_api(returns_named_value=True)
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
bold_text (str): bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'return_data': [{'name': 'bold_text',
|
||||
'description': 'bold text',
|
||||
'type': 'STRING'}]}
|
||||
```
|
||||
|
||||
Sometimes the tool may return a `dict` or `tuple`, and you want to elaborate each member in `return_data` rather than take them as a whole. Set `explode_return=True` and list them in the return part of docstrings.
|
||||
|
||||
```python
|
||||
@tool_api(explode_return=True)
|
||||
def list_args(a: str, b: int, c: float = 0.0) -> dict:
|
||||
"""Return arguments in dict format
|
||||
|
||||
Args:
|
||||
a (str): a
|
||||
b (int): b
|
||||
c (float): c
|
||||
|
||||
Returns:
|
||||
dict: input arguments
|
||||
- a (str): a
|
||||
- b (int): b
|
||||
- c: c
|
||||
"""
|
||||
return {'a': a, 'b': b, 'c': c}
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'list_args',
|
||||
'description': 'Return arguments in dict format',
|
||||
'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'},
|
||||
{'name': 'b', 'type': 'NUMBER', 'description': 'b'},
|
||||
{'name': 'c', 'type': 'FLOAT', 'description': 'c'}],
|
||||
'required': ['a', 'b'],
|
||||
'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'},
|
||||
{'name': 'b', 'description': 'b', 'type': 'NUMBER'},
|
||||
{'name': 'c', 'description': 'c'}]}
|
||||
```
|
||||
|
||||
```{warning}
|
||||
Only Google style Python docstrings is currently supported.
|
||||
```
|
||||
|
||||
## Interface Design
|
||||
|
||||
`BaseAction(description=None, parser=JsonParser, enable=True)` is the base class all actions should inherit from. It takes three initialization arguments
|
||||
|
||||
* **description**: a tool description dictionary, used set instance attribute `description`. Mostly you don't need explicitly pass this argument since the meta class of `BaseAction` will search methods decorated by `tool_api` and assemble their `api_description` as a class attribute `__tool_description__`, and if the initial `description` is left null, then `__tool_description__` will be copied as `description`.
|
||||
* **parser**: `BaseParser` class. It will instantialize a parser used to validate the arguments of APIs in `description`.
|
||||
|
||||
For example, `JsonParser` requires arguments passed in the format of JSON or `dict`. To make LLMs aware of this, It inserts a field `parameter_description` into the `description`.
|
||||
|
||||
```python
|
||||
from lagent import BaseAction
|
||||
|
||||
action = BaseAction(
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
)
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input content'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}
|
||||
```
|
||||
* **enable**: specify whether the tool is available.
|
||||
|
||||
### Custom Action
|
||||
|
||||
A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word.
|
||||
|
||||
```python
|
||||
class Bold(BaseAction):
|
||||
|
||||
@tool_api
|
||||
def run(self, text: str):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
class PhraseEmphasis(BaseAction):
|
||||
"""a toolkit which provides different styles of text emphasis"""
|
||||
|
||||
@tool_api
|
||||
def bold(self, text):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
@tool_api
|
||||
def italic(self, text):
|
||||
"""make text italic
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: italic text
|
||||
"""
|
||||
return '*' + text + '*'
|
||||
|
||||
# Inspect the default description
|
||||
# Bold.__tool_description__, PhraseEmphasis.__tool_description__
|
||||
```
|
||||
|
||||
### Auto-registration
|
||||
|
||||
Any subclass of `BaseAction` will be registered automatically. You can use `list_tools()` and `get_tool()` to view all tools and initialize by name.
|
||||
|
||||
```python
|
||||
from lagent import list_tools, get_tool
|
||||
|
||||
list_tools()
|
||||
```
|
||||
|
||||
```python
|
||||
['BaseAction',
|
||||
'InvalidAction',
|
||||
'NoAction',
|
||||
'FinishAction',
|
||||
'ArxivSearch',
|
||||
'BINGMap',
|
||||
'GoogleScholar',
|
||||
'GoogleSearch',
|
||||
'IPythonInterpreter',
|
||||
'PPT',
|
||||
'PythonInterpreter',
|
||||
'Bold',
|
||||
'PhraseEmphasis']
|
||||
```
|
||||
Create a `PhraseEmphasis` object
|
||||
|
||||
```python
|
||||
action = get_tool('PhraseEmphasis')
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'PhraseEmphasis',
|
||||
'description': 'a toolkit which provides different styles of text emphasis',
|
||||
'api_list': [{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]}
|
||||
```
|
||||
|
||||
|
||||
## Tool Calling
|
||||
|
||||
### Run a Tool
|
||||
|
||||
`__call__` method of `Action` takes two arguments
|
||||
|
||||
* `inputs`: It depends on the action's parser. Often a string in specific formats generated by LLMs.
|
||||
+ `JsonParser`: Allow passing arguements in the format of JSON string or Python `dict`.
|
||||
+ `TupleParser`: Allow passing arguments in the format of tuple string format or Python `tuple`.
|
||||
* `name`: Which API to call. Default is `run`.
|
||||
|
||||
It returns an `ActionReturn` object which encapsulates calling details
|
||||
|
||||
* `args`: Dictionary of action inputs.
|
||||
* `type`: Action name.
|
||||
* `result`: List of dicts. Each contains two keys: 'type' and 'content'. when errors occur, it is `None`.
|
||||
* `errmsg`: Error message. Default is `None`.
|
||||
|
||||
Below is an example
|
||||
|
||||
```python
|
||||
from lagent import IPythonInterpreter, TupleParser
|
||||
|
||||
action1 = IPythonInterpreter()
|
||||
ret = action1('{"command": "import math;math.sqrt(100)"}')
|
||||
print(ret.result)
|
||||
ret = action1({'command': 'import math;math.sqrt(100)'})
|
||||
print(ret.result)
|
||||
|
||||
action2 = IPythonInterpreter(parser=TupleParser)
|
||||
ret = action2('("import math;math.sqrt(100)", )')
|
||||
print(ret.result)
|
||||
ret = action2(('import math;math.sqrt(100)',))
|
||||
print(ret.result)
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
|
||||
### Dynamic Invocation
|
||||
|
||||
Lagent provides an `ActionExecutor` to manage multiple tools. It will flatten `api_list` of toolkits and rename each `{tool_name}.{api_name}`.
|
||||
|
||||
```python
|
||||
from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
|
||||
executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()])
|
||||
executor.get_actions_info() # This information is fed to LLMs as the tool meta prompt
|
||||
```
|
||||
|
||||
```python
|
||||
[{'name': 'ArxivSearch.get_arxiv_article_information',
|
||||
'description': 'Run Arxiv search and get the article meta information.',
|
||||
'parameters': [{'name': 'query',
|
||||
'type': 'STRING',
|
||||
'description': 'the content of search query'}],
|
||||
'required': ['query'],
|
||||
'return_data': [{'name': 'content',
|
||||
'description': 'a list of 3 arxiv search papers',
|
||||
'type': 'STRING'}],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'IPythonInterpreter',
|
||||
'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.",
|
||||
'parameters': [{'name': 'command',
|
||||
'type': 'STRING',
|
||||
'description': 'Python code'},
|
||||
{'name': 'timeout',
|
||||
'type': 'NUMBER',
|
||||
'description': 'Upper bound of waiting time for Python script execution.'}],
|
||||
'required': ['command'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]
|
||||
```
|
||||
|
||||
Trigger an action through the executor
|
||||
|
||||
```python
|
||||
ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}')
|
||||
ret.result
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
14
docs/zh_cn/_templates/autoapi/index.rst
Normal file
14
docs/zh_cn/_templates/autoapi/index.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
API Reference
|
||||
=============
|
||||
|
||||
This page contains auto-generated API reference documentation.
|
||||
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for page in pages %}
|
||||
{% if page.top_level_object and page.display %}
|
||||
{{ page.include_path }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
112
docs/zh_cn/_templates/autoapi/python/module.rst
Normal file
112
docs/zh_cn/_templates/autoapi/python/module.rst
Normal file
@@ -0,0 +1,112 @@
|
||||
{% if not obj.display %}
|
||||
:orphan:
|
||||
|
||||
{% endif %}
|
||||
:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}`
|
||||
=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }}
|
||||
|
||||
.. py:module:: {{ obj.name }}
|
||||
|
||||
{% if obj.docstring %}
|
||||
.. autoapi-nested-parse::
|
||||
|
||||
{{ obj.docstring|indent(3) }}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% block subpackages %}
|
||||
{% set visible_subpackages = obj.subpackages|selectattr("display")|list %}
|
||||
{% if visible_subpackages %}
|
||||
Subpackages
|
||||
-----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for subpackage in visible_subpackages %}
|
||||
{{ subpackage.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block submodules %}
|
||||
{% set visible_submodules = obj.submodules|selectattr("display")|list %}
|
||||
{% if visible_submodules %}
|
||||
Submodules
|
||||
----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 1
|
||||
|
||||
{% for submodule in visible_submodules %}
|
||||
{{ submodule.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% if obj.type is equalto("package") %}
|
||||
{% set visible_children = obj.children|selectattr("display")|list %}
|
||||
{% else %}
|
||||
{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %}
|
||||
{% endif %}
|
||||
{% if visible_children %}
|
||||
{{ obj.type|title }} Contents
|
||||
{{ "-" * obj.type|length }}---------
|
||||
|
||||
{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %}
|
||||
{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %}
|
||||
{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %}
|
||||
{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %}
|
||||
{% block classes scoped %}
|
||||
{% if visible_classes %}
|
||||
Classes
|
||||
~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for klass in visible_classes %}
|
||||
{{ klass.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block functions scoped %}
|
||||
{% if visible_functions %}
|
||||
Functions
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for function in visible_functions %}
|
||||
{{ function.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block attributes scoped %}
|
||||
{% if visible_attributes %}
|
||||
Attributes
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for attribute in visible_attributes %}
|
||||
{{ attribute.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% endif %}
|
||||
{% for obj_item in visible_children %}
|
||||
{{ obj_item.render()|indent(0) }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
@@ -11,18 +11,16 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import sys
|
||||
|
||||
import pytorch_sphinx_theme
|
||||
|
||||
sys.path.insert(0, os.path.abspath('../../'))
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Lagent'
|
||||
copyright = '2020-2030, InternLM'
|
||||
author = 'InternLM'
|
||||
language = 'zh_CN'
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../lagent/version.py'
|
||||
@@ -37,97 +35,74 @@ release = __version__
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx_rtd_theme',
|
||||
'myst_nb',
|
||||
'autoapi.extension',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx_copybutton',
|
||||
'myst_parser',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.autodoc.typehints',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx_tabs.tabs',
|
||||
]
|
||||
|
||||
nb_output_stderr = 'remove-warn'
|
||||
autodoc_typehints = 'description'
|
||||
|
||||
autosummary_generate = True # Turn on sphinx.ext.autosummary
|
||||
# Ignore >>> when copying code
|
||||
copybutton_prompt_text = r'>>> |\.\.\. '
|
||||
copybutton_prompt_is_regexp = True
|
||||
|
||||
myst_enable_extensions = ['colon_fence']
|
||||
# sphinx-autoapi configuration
|
||||
autoapi_dirs = ['../../lagent']
|
||||
autoapi_options = [
|
||||
'members',
|
||||
'undoc-members',
|
||||
'show-inheritance',
|
||||
'show-module-summary',
|
||||
]
|
||||
autoapi_ignore = ['*migrations*', '*command.py', '*cli.py']
|
||||
autoapi_template_dir = '_templates/autoapi'
|
||||
autoapi_add_toctree_entry = False
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = {
|
||||
'.rst': 'restructuredtext',
|
||||
'.md': 'markdown',
|
||||
}
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
# html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = 'pytorch_sphinx_theme'
|
||||
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme_options = {
|
||||
'menu': [
|
||||
{
|
||||
'name': 'GitHub',
|
||||
'url': 'https://github.com/InternLM/lagent'
|
||||
},
|
||||
],
|
||||
# Specify the language of shared menu
|
||||
'menu_lang': 'cn',
|
||||
'navigation_depth': 3,
|
||||
'titles_only': False,
|
||||
'style_nav_header_background': '#4fabab',
|
||||
}
|
||||
|
||||
language = 'zh_CN'
|
||||
html_context = {
|
||||
'display_github': True,
|
||||
'github_host': 'github.com',
|
||||
'github_user': 'InternLM',
|
||||
'github_repo': 'lagent',
|
||||
'github_version': 'main',
|
||||
'conf_py_path': '/docs/en/',
|
||||
}
|
||||
html_title = 'Lagent'
|
||||
html_logo = '../imgs/lagent_logo.png'
|
||||
html_favicon = '../imgs/lagent_icon.png'
|
||||
|
||||
master_doc = 'index'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
# so a file named 'default.css' will overwrite the builtin 'default.css'.
|
||||
html_static_path = ['_static']
|
||||
html_css_files = [
|
||||
'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css',
|
||||
'css/readthedocs.css'
|
||||
]
|
||||
html_js_files = [
|
||||
'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js',
|
||||
'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js',
|
||||
'js/collapsed.js',
|
||||
'js/table.js',
|
||||
]
|
||||
|
||||
myst_heading_anchors = 4
|
||||
|
||||
# Configuration for intersphinx
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'numpy': ('https://numpy.org/doc/stable', None),
|
||||
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||
}
|
||||
|
||||
|
||||
def builder_inited_handler(app):
|
||||
subprocess.run(['./cp_origin_docs.sh'])
|
||||
def custom_skip(app, what, name, obj, skip, options):
|
||||
if what in ['data', 'function', 'class'] and re.search('logger', name):
|
||||
skip = True
|
||||
return skip
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.connect('builder-inited', builder_inited_handler)
|
||||
def setup(sphinx):
|
||||
sphinx.connect('autoapi-skip-member', custom_skip)
|
||||
|
||||
19
docs/zh_cn/get_started/install.md
Normal file
19
docs/zh_cn/get_started/install.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# 安装方式
|
||||
|
||||
## pip安装
|
||||
|
||||
推荐使用 pip 安装
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
## 源码安装
|
||||
|
||||
如需修改部分功能,可以从源码构建 Lagent
|
||||
|
||||
```bash
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
```
|
||||
23
docs/zh_cn/get_started/overview.md
Normal file
23
docs/zh_cn/get_started/overview.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# 总览
|
||||
|
||||
本章节将介绍 Lagent 的架构,并提供 Lagent 详细教程的链接。
|
||||
|
||||
## Lagent 是什么
|
||||
|
||||
Lagent 是一个开源的 LLM 智能体框架,允许使用者快速将一个大语言模型转换成智能体,并提供一些典型工具来激发大语言模型的潜能。Lagent 框架图如下:
|
||||
|
||||

|
||||
|
||||
Lagent 包含三个主要模块:agents,llms 和 actions。
|
||||
|
||||
- **agents** 实现了多种智能体,如 ReAct,AutoGPT。
|
||||
- **llms** 支持多种大语言模型,包括在 HuggingFace 上托管的开源模型(Llama-2, InterLM)及 GPT3.5/4 等闭源模型。
|
||||
- **actions** 包含一系列工具,并提供工具执行器来统一管理。
|
||||
|
||||
## 如何使用
|
||||
|
||||
以下是帮助您了解关于 Lagent 更多信息的详细教程:
|
||||
|
||||
1. 安装请参考 [README](https://github.com/InternLM/lagent/blob/main/README.md).
|
||||
|
||||
2. 一些构建智能体的实例 [examples](https://github.com/InternLM/lagent/tree/main/examples),直接运行脚本即可,如 `python examples/react_example.py`.
|
||||
87
docs/zh_cn/get_started/quickstart.md
Normal file
87
docs/zh_cn/get_started/quickstart.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# 快速上手
|
||||
|
||||
借助 Lagent 仅需几行代码就能构建大语言模型智能体。
|
||||
|
||||
## GPT-3.5 驱动的 ReWOO 智能体
|
||||
|
||||
下面是使用 GPT-3.5 运行 ReWOO 的示例
|
||||
|
||||
```python
|
||||
# 从 Lagent 导入必要的模块和类
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.actions import ActionExecutor, GoogleSearch
|
||||
from lagent.llms import GPTAPI
|
||||
|
||||
# 初始化 LLM,你可能需要提供 API 密钥
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
|
||||
|
||||
# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# 配置 ReWOO 智能体,创建聊天机器人
|
||||
chatbot = ReWOO(
|
||||
llm=llm, # 大语言模型实例
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool] # 指定智能体可以调用的工具
|
||||
),
|
||||
)
|
||||
|
||||
# 询问问题并获取回复
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
|
||||
# 打印回复
|
||||
print(response.response)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> Film director.
|
||||
```
|
||||
|
||||
## InterLM 驱动的 ReAct 智能体
|
||||
|
||||
注意,如果你想使用 HuggingFace 模型,请先运行 `pip install -e .[all]`
|
||||
|
||||
```python
|
||||
# 从 Lagent 导入必要的模块和类
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
|
||||
from lagent.llms import HFTransformer
|
||||
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
|
||||
# 初始化 HFTransformer 模型
|
||||
llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META)
|
||||
|
||||
# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# 初始化 Python 代码解释其
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
# 配置 ReAct 智能体,创建聊天机器人
|
||||
chatbot = ReAct(
|
||||
llm=llm, # 大语言模型实例
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]), # 指定智能体可以调用的工具
|
||||
)
|
||||
# 询问LaTeX格式的数学问题
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$')
|
||||
|
||||
# 打印回复
|
||||
print(response.response)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
|
||||
```
|
||||
|
||||
## 启动 ReAct 网页 App
|
||||
|
||||
```python
|
||||
# 你需要先安装 streamlit
|
||||
# pip install streamlit
|
||||
streamlit run examples/react_web_demo.py
|
||||
```
|
||||
|
||||
然后你可以通过下图所示UI界面进行对话
|
||||

|
||||
@@ -8,12 +8,32 @@
|
||||
:caption: 新手入门
|
||||
|
||||
get_started/overview.md
|
||||
get_started/install.md
|
||||
get_started/quickstart.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: 教程
|
||||
|
||||
tutorials/action.md
|
||||
|
||||
.. toctree::
|
||||
:caption: 切换语言
|
||||
|
||||
switch_language.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: API 参考
|
||||
|
||||
autoapi/lagent/actions/index
|
||||
autoapi/lagent/agents/index
|
||||
autoapi/lagent/llms/index
|
||||
autoapi/lagent/utils/index
|
||||
autoapi/lagent/schema/index
|
||||
autoapi/lagent/version/index
|
||||
|
||||
|
||||
导引
|
||||
==================
|
||||
|
||||
|
||||
394
docs/zh_cn/tutorials/action.md
Normal file
394
docs/zh_cn/tutorials/action.md
Normal file
@@ -0,0 +1,394 @@
|
||||
# 动作
|
||||
|
||||
动作,也被称为工具,提供了一套LLM驱动的智能体用来与真实世界交互并执行复杂任务的函数。
|
||||
|
||||
## 基本概念
|
||||
|
||||
### 工具 & 工具包
|
||||
|
||||
有两种类型的工具:
|
||||
|
||||
* 简单工具: 只提供一个API接口供调用。
|
||||
* 工具包: 实现多个API接口,承担不同的子任务。
|
||||
|
||||
### 工具描述
|
||||
|
||||
在Lagent中,工具描述是一个刻画工具调用方式的字典,能够被LLM观察并用于决策。
|
||||
|
||||
对于简单工具,描述可按如下格式声明:
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'bold', # 工具名称
|
||||
'description': 'a function used to make text bold', # 介绍工具的功能
|
||||
'parameters': [ # 这个工具所需要的参数列表
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text'], # 指定必需的参数名
|
||||
}
|
||||
```
|
||||
在某些情况下,可能还包含 `return_data`,`parameter_description` 字段,分别描述返回内容及参数传递格式。
|
||||
|
||||
```{attention}
|
||||
`parameter_description` 通常被动作的解析器自动插入到工具描述中,这部分将在[接口设计](#id6)中进行介绍。
|
||||
```
|
||||
|
||||
对于工具包,描述非常相似,但嵌套了子方法
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'PhraseEmphasis', # 工具包的名字
|
||||
'description': 'a toolkit which provides different styles of text emphasis', # 介绍工具包的功能
|
||||
'api_list': [
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
},
|
||||
{
|
||||
'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 将函数转换为工具
|
||||
|
||||
对于已定义好的函数,无需人工添加额外的描述。在 Lagent 中,我们提供了一个修饰器 `tool_api`,它可以通过自动解析函数的类型提示和文档字符串来生成描述字典,并将其绑定到属性 `api_description`。
|
||||
|
||||
```python
|
||||
from lagent import tool_api
|
||||
|
||||
@tool_api
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text']}
|
||||
```
|
||||
|
||||
一旦启用 `returns_named_value`,您应当声明返回值的名称,这将被处理成一个新的字段 `return_data`:
|
||||
|
||||
```python
|
||||
@tool_api(returns_named_value=True)
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
bold_text (str): bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'return_data': [{'name': 'bold_text',
|
||||
'description': 'bold text',
|
||||
'type': 'STRING'}]}
|
||||
```
|
||||
|
||||
有时工具可能返回一个 `dict` 或 `tuple`,如果你想在 `return_data` 中详细说明每个成员的含义而不是把它们当作一个整体,设置 `explode_return=True` 并在文档字符串的 Returns 部分中罗列它们。
|
||||
|
||||
```python
|
||||
@tool_api(explode_return=True)
|
||||
def list_args(a: str, b: int, c: float = 0.0) -> dict:
|
||||
"""Return arguments in dict format
|
||||
|
||||
Args:
|
||||
a (str): a
|
||||
b (int): b
|
||||
c (float): c
|
||||
|
||||
Returns:
|
||||
dict: input arguments
|
||||
- a (str): a
|
||||
- b (int): b
|
||||
- c: c
|
||||
"""
|
||||
return {'a': a, 'b': b, 'c': c}
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'list_args',
|
||||
'description': 'Return arguments in dict format',
|
||||
'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'},
|
||||
{'name': 'b', 'type': 'NUMBER', 'description': 'b'},
|
||||
{'name': 'c', 'type': 'FLOAT', 'description': 'c'}],
|
||||
'required': ['a', 'b'],
|
||||
'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'},
|
||||
{'name': 'b', 'description': 'b', 'type': 'NUMBER'},
|
||||
{'name': 'c', 'description': 'c'}]}
|
||||
```
|
||||
|
||||
```{warning}
|
||||
目前仅支持 Google 格式的 Python 文档字符串。
|
||||
```
|
||||
|
||||
## 接口设计
|
||||
|
||||
`BaseAction(description=None, parser=JsonParser, enable=True)` 是所有动作应该继承的基类,它接收三个初始化参数:
|
||||
|
||||
* **description**:一个工具描述的字典,用于设置实例属性 `description`。通常不需要显式地传递这个参数,因为 `BaseAction` 的元类将查找被 `tool_api` 装饰的方法,并组装它们的 `api_description` 构造一个类属性 `__tool_description__`,如果实例化时 `description` 为空,那么该实例属性将置为 `__tool_description__`。
|
||||
* **parser**:`BaseParser` 类,用于实例化一个动作解析器校验 `description` 所描述的工具的参数。例如,`JsonParser` 会要求模型在调用工具时传入一个 JSON 格式字符串或者 Python 字典,为了让 LLM 感知到该指令,它会在 `description` 中插入一个 `parameter_description` 字段。
|
||||
|
||||
```python
|
||||
from lagent import BaseAction
|
||||
|
||||
action = BaseAction(
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
)
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input content'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}
|
||||
```
|
||||
|
||||
* **enable**: 指明该动作是否生效。
|
||||
|
||||
### 自定义动作
|
||||
|
||||
一个简单工具必须实现 `run` 方法,而工具包则应当避免将各子API名称定义为该保留字段。
|
||||
|
||||
```python
|
||||
class Bold(BaseAction):
|
||||
|
||||
@tool_api
|
||||
def run(self, text: str):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
class PhraseEmphasis(BaseAction):
|
||||
"""a toolkit which provides different styles of text emphasis"""
|
||||
|
||||
@tool_api
|
||||
def bold(self, text):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
@tool_api
|
||||
def italic(self, text):
|
||||
"""make text italic
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: italic text
|
||||
"""
|
||||
return '*' + text + '*'
|
||||
|
||||
# 查看默认工具描述
|
||||
# Bold.__tool_description__, PhraseEmphasis.__tool_description__
|
||||
```
|
||||
|
||||
### 自动注册
|
||||
|
||||
任何 `BaseAction` 的子类都会自动被注册。你可以使用 `list_tools()` 和 `get_tool()` 来查看所有工具类并通过工具名进行初始化。
|
||||
|
||||
```python
|
||||
from lagent import list_tools, get_tool
|
||||
|
||||
list_tools()
|
||||
```
|
||||
|
||||
```python
|
||||
['BaseAction',
|
||||
'InvalidAction',
|
||||
'NoAction',
|
||||
'FinishAction',
|
||||
'ArxivSearch',
|
||||
'BINGMap',
|
||||
'GoogleScholar',
|
||||
'GoogleSearch',
|
||||
'IPythonInterpreter',
|
||||
'PPT',
|
||||
'PythonInterpreter',
|
||||
'Bold',
|
||||
'PhraseEmphasis']
|
||||
```
|
||||
|
||||
创建一个 `PhraseEmphasis` 对象。
|
||||
|
||||
```python
|
||||
action = get_tool('PhraseEmphasis')
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'PhraseEmphasis',
|
||||
'description': 'a toolkit which provides different styles of text emphasis',
|
||||
'api_list': [{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]}
|
||||
```
|
||||
|
||||
|
||||
## 工具调用
|
||||
|
||||
### 执行工具
|
||||
|
||||
`Action` 的 `__call__` 方法需要传入两个参数
|
||||
|
||||
* `inputs`: 其类型与动作绑定的 `BaseParser` 相关,通常是由大语言模型生成的字符串。
|
||||
+ `JsonParser`: 允许传入 JSON 格式字符串或 Python 字典。
|
||||
+ `TupleParser`: 允许传入字面量为元组的字符串或 Python 元组。
|
||||
* `name`: 调用哪个 API,默认为 `run`。
|
||||
|
||||
工具会返回一个封装了调用细节的 `ActionReturn` 对象。
|
||||
|
||||
* `args`: 一个字典,表示该动作的入参。
|
||||
* `type`: 动作名称。
|
||||
* `result`: 以字典为成员的列表,每个字典包含两个键——'type' 和 'content',发生异常时该字段为 `None`。
|
||||
* `errmsg`: 错误信息,默认为 `None`。
|
||||
|
||||
以下是一个例子:
|
||||
|
||||
```python
|
||||
from lagent import IPythonInterpreter, TupleParser
|
||||
|
||||
action1 = IPythonInterpreter()
|
||||
ret = action1('{"command": "import math;math.sqrt(100)"}')
|
||||
print(ret.result)
|
||||
ret = action1({'command': 'import math;math.sqrt(100)'})
|
||||
print(ret.result)
|
||||
|
||||
action2 = IPythonInterpreter(parser=TupleParser)
|
||||
ret = action2('("import math;math.sqrt(100)", )')
|
||||
print(ret.result)
|
||||
ret = action2(('import math;math.sqrt(100)',))
|
||||
print(ret.result)
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
|
||||
### 动态触发
|
||||
|
||||
Lagent 提供 `ActionExecutor` 接口管理多个工具,它会将工具包的 `api_list` 平展并将各 API 更名为 `{tool_name}.{api_name}`。
|
||||
|
||||
```python
|
||||
from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
|
||||
executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()])
|
||||
executor.get_actions_info() # 该结果会作为LLM系统提示词的一部分
|
||||
```
|
||||
|
||||
```python
|
||||
[{'name': 'ArxivSearch.get_arxiv_article_information',
|
||||
'description': 'Run Arxiv search and get the article meta information.',
|
||||
'parameters': [{'name': 'query',
|
||||
'type': 'STRING',
|
||||
'description': 'the content of search query'}],
|
||||
'required': ['query'],
|
||||
'return_data': [{'name': 'content',
|
||||
'description': 'a list of 3 arxiv search papers',
|
||||
'type': 'STRING'}],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'IPythonInterpreter',
|
||||
'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.",
|
||||
'parameters': [{'name': 'command',
|
||||
'type': 'STRING',
|
||||
'description': 'Python code'},
|
||||
{'name': 'timeout',
|
||||
'type': 'NUMBER',
|
||||
'description': 'Upper bound of waiting time for Python script execution.'}],
|
||||
'required': ['command'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]
|
||||
```
|
||||
|
||||
通过动作执行器来触发一个工具
|
||||
|
||||
```python
|
||||
ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}')
|
||||
ret.result
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
333
examples/internlm2_agent_web_demo.py
Normal file
333
examples/internlm2_agent_web_demo.py
Normal file
@@ -0,0 +1,333 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from lagent.actions import ActionExecutor, ArxivSearch, GoogleScholar, IPythonInterpreter
|
||||
from lagent.agents.internlm2_agent import (INTERPRETER_CN, META_INS, PLUGIN_CN,
|
||||
Internlm2Agent, Interlm2Protocol)
|
||||
from lagent.llms.lmdepoly_wrapper import LMDeployClient
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
from lagent.schema import AgentStatusCode
|
||||
|
||||
# from streamlit.logger import get_logger
|
||||
|
||||
|
||||
class SessionState:
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize session state variables."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
|
||||
action_list = [
|
||||
GoogleScholar(
|
||||
api_key=('a558de7dee10146326ca86fbaa0736b'
|
||||
'dd947c9e646cd3f14da5aff177d6b2ff0')),
|
||||
ArxivSearch(),
|
||||
]
|
||||
st.session_state['plugin_map'] = {
|
||||
action.name: action
|
||||
for action in action_list
|
||||
}
|
||||
st.session_state['model_map'] = {}
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['plugin_actions'] = set()
|
||||
st.session_state['history'] = []
|
||||
|
||||
def clear_state(self):
|
||||
"""Clear the existing session state."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['file'] = set()
|
||||
if 'chatbot' in st.session_state:
|
||||
st.session_state['chatbot']._session_history = []
|
||||
|
||||
|
||||
class StreamlitUI:
|
||||
|
||||
def __init__(self, session_state: SessionState):
|
||||
self.init_streamlit()
|
||||
self.session_state = session_state
|
||||
|
||||
def init_streamlit(self):
|
||||
"""Initialize Streamlit's UI settings."""
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
st.sidebar.title('模型控制')
|
||||
st.session_state['file'] = set()
|
||||
st.session_state['ip'] = None
|
||||
|
||||
def setup_sidebar(self):
|
||||
"""Setup the sidebar for model and plugin selection."""
|
||||
model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
||||
meta_prompt = st.sidebar.text_area('系统提示词', value=META_INS)
|
||||
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
|
||||
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
|
||||
model_ip = st.sidebar.text_input('模型IP:', value='10.140.0.220:23333')
|
||||
if model_name != st.session_state[
|
||||
'model_selected'] or st.session_state['ip'] != model_ip:
|
||||
st.session_state['ip'] = model_ip
|
||||
model = self.init_model(model_name, model_ip)
|
||||
self.session_state.clear_state()
|
||||
st.session_state['model_selected'] = model_name
|
||||
if 'chatbot' in st.session_state:
|
||||
del st.session_state['chatbot']
|
||||
else:
|
||||
model = st.session_state['model_map'][model_name]
|
||||
|
||||
plugin_name = st.sidebar.multiselect(
|
||||
'插件选择',
|
||||
options=list(st.session_state['plugin_map'].keys()),
|
||||
default=[],
|
||||
)
|
||||
da_flag = st.sidebar.checkbox(
|
||||
'数据分析',
|
||||
value=False,
|
||||
)
|
||||
plugin_action = [
|
||||
st.session_state['plugin_map'][name] for name in plugin_name
|
||||
]
|
||||
|
||||
if 'chatbot' in st.session_state:
|
||||
if len(plugin_action) > 0:
|
||||
st.session_state['chatbot']._action_executor = ActionExecutor(
|
||||
actions=plugin_action)
|
||||
else:
|
||||
st.session_state['chatbot']._action_executor = None
|
||||
if da_flag:
|
||||
st.session_state[
|
||||
'chatbot']._interpreter_executor = ActionExecutor(
|
||||
actions=[IPythonInterpreter()])
|
||||
else:
|
||||
st.session_state['chatbot']._interpreter_executor = None
|
||||
st.session_state['chatbot']._protocol._meta_template = meta_prompt
|
||||
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
|
||||
st.session_state[
|
||||
'chatbot']._protocol.interpreter_prompt = da_prompt
|
||||
if st.sidebar.button('清空对话', key='clear'):
|
||||
self.session_state.clear_state()
|
||||
uploaded_file = st.sidebar.file_uploader('上传文件')
|
||||
|
||||
return model_name, model, plugin_action, uploaded_file, model_ip
|
||||
|
||||
def init_model(self, option, ip=None):
|
||||
"""Initialize the model based on the selected option."""
|
||||
model_url = f'http://{ip}'
|
||||
st.session_state['model_map'][option] = LMDeployClient(
|
||||
path='internlm2-chat-20b',
|
||||
url=model_url,
|
||||
meta_template=META,
|
||||
top_p=0.8,
|
||||
top_k=100,
|
||||
temperature=0,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
return st.session_state['model_map'][option]
|
||||
|
||||
def initialize_chatbot(self, model, plugin_action):
|
||||
"""Initialize the chatbot with the given model and plugin actions."""
|
||||
return Internlm2Agent(
|
||||
llm=model,
|
||||
protocol=Interlm2Protocol(
|
||||
tool=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(
|
||||
plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
), ),
|
||||
)
|
||||
|
||||
def render_user(self, prompt: str):
|
||||
with st.chat_message('user'):
|
||||
st.markdown(prompt)
|
||||
|
||||
def render_assistant(self, agent_return):
|
||||
with st.chat_message('assistant'):
|
||||
for action in agent_return.actions:
|
||||
if (action) and (action.type != 'FinishAction'):
|
||||
self.render_action(action)
|
||||
st.markdown(agent_return.response)
|
||||
|
||||
def render_plugin_args(self, action):
|
||||
action_name = action.type
|
||||
args = action.args
|
||||
import json
|
||||
parameter_dict = dict(name=action_name, parameters=args)
|
||||
parameter_str = '```json\n' + json.dumps(
|
||||
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
|
||||
st.markdown(parameter_str)
|
||||
|
||||
def render_interpreter_args(self, action):
|
||||
st.info(action.type)
|
||||
st.markdown(action.args['text'])
|
||||
|
||||
def render_action(self, action):
|
||||
st.markdown(action.thought)
|
||||
if action.type == 'IPythonInterpreter':
|
||||
self.render_interpreter_args(action)
|
||||
elif action.type == 'FinishAction':
|
||||
pass
|
||||
else:
|
||||
self.render_plugin_args(action)
|
||||
self.render_action_results(action)
|
||||
|
||||
def render_action_results(self, action):
|
||||
"""Render the results of action, including text, images, videos, and
|
||||
audios."""
|
||||
if (isinstance(action.result, dict)):
|
||||
if 'text' in action.result:
|
||||
st.markdown('```\n' + action.result['text'] + '\n```')
|
||||
if 'image' in action.result:
|
||||
# image_path = action.result['image']
|
||||
for image_path in action.result['image']:
|
||||
image_data = open(image_path, 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
if 'video' in action.result:
|
||||
video_data = action.result['video']
|
||||
video_data = open(video_data, 'rb').read()
|
||||
st.video(video_data)
|
||||
if 'audio' in action.result:
|
||||
audio_data = action.result['audio']
|
||||
audio_data = open(audio_data, 'rb').read()
|
||||
st.audio(audio_data)
|
||||
elif isinstance(action.result, list):
|
||||
for item in action.result:
|
||||
if item['type'] == 'text':
|
||||
st.markdown('```\n' + item['content'] + '\n```')
|
||||
elif item['type'] == 'image':
|
||||
image_data = open(item['content'], 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
elif item['type'] == 'video':
|
||||
video_data = open(item['content'], 'rb').read()
|
||||
st.video(video_data)
|
||||
elif item['type'] == 'audio':
|
||||
audio_data = open(item['content'], 'rb').read()
|
||||
st.audio(audio_data)
|
||||
if action.errmsg:
|
||||
st.error(action.errmsg)
|
||||
|
||||
|
||||
def main():
|
||||
# logger = get_logger(__name__)
|
||||
# Initialize Streamlit UI and setup sidebar
|
||||
if 'ui' not in st.session_state:
|
||||
session_state = SessionState()
|
||||
session_state.init_state()
|
||||
st.session_state['ui'] = StreamlitUI(session_state)
|
||||
|
||||
else:
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
_, model, plugin_action, uploaded_file, _ = st.session_state[
|
||||
'ui'].setup_sidebar()
|
||||
|
||||
# Initialize chatbot if it is not already initialized
|
||||
# or if the model has changed
|
||||
if 'chatbot' not in st.session_state or model != st.session_state[
|
||||
'chatbot']._llm:
|
||||
st.session_state['chatbot'] = st.session_state[
|
||||
'ui'].initialize_chatbot(model, plugin_action)
|
||||
st.session_state['session_history'] = []
|
||||
|
||||
for prompt, agent_return in zip(st.session_state['user'],
|
||||
st.session_state['assistant']):
|
||||
st.session_state['ui'].render_user(prompt)
|
||||
st.session_state['ui'].render_assistant(agent_return)
|
||||
|
||||
if user_input := st.chat_input(''):
|
||||
with st.container():
|
||||
st.session_state['ui'].render_user(user_input)
|
||||
st.session_state['user'].append(user_input)
|
||||
# Add file uploader to sidebar
|
||||
if (uploaded_file
|
||||
and uploaded_file.name not in st.session_state['file']):
|
||||
|
||||
st.session_state['file'].add(uploaded_file.name)
|
||||
file_bytes = uploaded_file.read()
|
||||
file_type = uploaded_file.type
|
||||
if 'image' in file_type:
|
||||
st.image(file_bytes, caption='Uploaded Image')
|
||||
elif 'video' in file_type:
|
||||
st.video(file_bytes, caption='Uploaded Video')
|
||||
elif 'audio' in file_type:
|
||||
st.audio(file_bytes, caption='Uploaded Audio')
|
||||
# Save the file to a temporary location and get the path
|
||||
|
||||
postfix = uploaded_file.name.split('.')[-1]
|
||||
# prefix = str(uuid.uuid4())
|
||||
prefix = hashlib.md5(file_bytes).hexdigest()
|
||||
filename = f'{prefix}.{postfix}'
|
||||
file_path = os.path.join(root_dir, filename)
|
||||
with open(file_path, 'wb') as tmpfile:
|
||||
tmpfile.write(file_bytes)
|
||||
file_size = os.stat(file_path).st_size / 1024 / 1024
|
||||
file_size = f'{round(file_size, 2)} MB'
|
||||
# st.write(f'File saved at: {file_path}')
|
||||
user_input = [
|
||||
dict(role='user', content=user_input),
|
||||
dict(
|
||||
role='user',
|
||||
content=json.dumps(dict(path=file_path, size=file_size)),
|
||||
name='file')
|
||||
]
|
||||
if isinstance(user_input, str):
|
||||
user_input = [dict(role='user', content=user_input)]
|
||||
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
|
||||
for agent_return in st.session_state['chatbot'].stream_chat(
|
||||
st.session_state['session_history'] + user_input):
|
||||
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_plugin_args(
|
||||
agent_return.actions[-1])
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif agent_return.state == AgentStatusCode.CODE_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif (agent_return.state == AgentStatusCode.STREAM_ING
|
||||
or agent_return.state == AgentStatusCode.CODING):
|
||||
# st.markdown(agent_return.response)
|
||||
# 清除占位符的当前内容,并显示新内容
|
||||
with st.container():
|
||||
if agent_return.state != st.session_state['last_status']:
|
||||
st.session_state['temp'] = ''
|
||||
placeholder = st.empty()
|
||||
st.session_state['placeholder'] = placeholder
|
||||
if isinstance(agent_return.response, dict):
|
||||
action = f"\n\n {agent_return.response['name']}: \n\n"
|
||||
action_input = agent_return.response['parameters']
|
||||
if agent_return.response['name'] == 'IPythonInterpreter':
|
||||
action_input = action_input['command']
|
||||
response = action + action_input
|
||||
else:
|
||||
response = agent_return.response
|
||||
st.session_state['temp'] = response
|
||||
st.session_state['placeholder'].markdown(
|
||||
st.session_state['temp'])
|
||||
elif agent_return.state == AgentStatusCode.END:
|
||||
st.session_state['session_history'] += (user_input + agent_return.inner_steps)
|
||||
agent_return = copy.deepcopy(agent_return)
|
||||
agent_return.response = st.session_state['temp']
|
||||
st.session_state['assistant'].append(
|
||||
copy.deepcopy(agent_return))
|
||||
st.session_state['last_status'] = agent_return.state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
root_dir = os.path.join(root_dir, 'tmp_dir')
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
main()
|
||||
@@ -1,11 +1,61 @@
|
||||
from typing import Type
|
||||
|
||||
from .action_executor import ActionExecutor
|
||||
from .base_action import BaseAction
|
||||
from .arxiv_search import ArxivSearch
|
||||
from .base_action import TOOL_REGISTRY, BaseAction, tool_api
|
||||
from .bing_map import BINGMap
|
||||
from .builtin_actions import FinishAction, InvalidAction, NoAction
|
||||
from .google_scholar_search import GoogleScholar
|
||||
from .google_search import GoogleSearch
|
||||
from .llm_qa import LLMQA
|
||||
from .ipython_interpreter import IPythonInterpreter
|
||||
from .parser import BaseParser, JsonParser, TupleParser
|
||||
from .ppt import PPT
|
||||
from .python_interpreter import PythonInterpreter
|
||||
|
||||
__all__ = [
|
||||
'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction',
|
||||
'FinishAction', 'GoogleSearch', 'PythonInterpreter', 'LLMQA'
|
||||
'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction',
|
||||
'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch',
|
||||
'GoogleScholar', 'IPythonInterpreter', 'PythonInterpreter', 'PPT',
|
||||
'BaseParser', 'JsonParser', 'TupleParser', 'tool_api', 'list_tools',
|
||||
'get_tool_cls', 'get_tool'
|
||||
]
|
||||
|
||||
|
||||
def list_tools(with_class: bool = False):
|
||||
"""List available tools
|
||||
|
||||
Args:
|
||||
with_class (bool): whether to return the action class along
|
||||
with its name. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
list: all action names
|
||||
"""
|
||||
return list(TOOL_REGISTRY.items()) if with_class else list(
|
||||
TOOL_REGISTRY.keys())
|
||||
|
||||
|
||||
def get_tool_cls(specifier: str) -> Type[BaseAction]:
|
||||
"""Get the action class
|
||||
|
||||
Args:
|
||||
specifier (:class:`str`): tool name
|
||||
|
||||
Returns:
|
||||
Type[BaseAction]: action class
|
||||
"""
|
||||
return TOOL_REGISTRY.get_class(specifier)
|
||||
|
||||
|
||||
def get_tool(specifier: str, *args, **kwargs) -> BaseAction:
|
||||
"""Instantiate an action
|
||||
|
||||
Args:
|
||||
specifier (str): tool name
|
||||
args: positional arguments passed to the action's ``__init__`` method
|
||||
kwargs: keyword arguments passed to the action's ``__init__`` method
|
||||
|
||||
Returns:
|
||||
:class:`BaseAction`: action object
|
||||
"""
|
||||
return TOOL_REGISTRY.get(specifier, *args, **kwargs)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from lagent.schema import ActionReturn, ActionValidCode
|
||||
from .base_action import BaseAction
|
||||
@@ -39,14 +39,20 @@ class ActionExecutor:
|
||||
self.no_action = no_action
|
||||
self.finish_action = finish_action
|
||||
|
||||
def get_actions_info(self, only_enable: bool = True) -> Dict:
|
||||
if only_enable:
|
||||
return {
|
||||
k: v.description
|
||||
for k, v in self.actions.items() if v.enable
|
||||
}
|
||||
else:
|
||||
return {k: v.description for k, v in self.actions.items()}
|
||||
def get_actions_info(self) -> List[Dict]:
|
||||
actions = []
|
||||
for action_name, action in self.actions.items():
|
||||
if not action.enable:
|
||||
continue
|
||||
if action.is_toolkit:
|
||||
for api in action.description['api_list']:
|
||||
api_desc = api.copy()
|
||||
api_desc['name'] = f"{action_name}.{api_desc['name']}"
|
||||
actions.append(api_desc)
|
||||
else:
|
||||
action_desc = action.description.copy()
|
||||
actions.append(action_desc)
|
||||
return actions
|
||||
|
||||
def is_valid(self, name: str):
|
||||
return name in self.actions and self.actions[name].enable
|
||||
@@ -66,19 +72,17 @@ class ActionExecutor:
|
||||
if name in self.actions:
|
||||
del self.actions[name]
|
||||
|
||||
def __call__(self, name: str, command: Any) -> ActionReturn:
|
||||
if isinstance(command, str):
|
||||
args, kwargs = (command, ), {}
|
||||
else:
|
||||
args, kwargs = (), command
|
||||
if not self.is_valid(name):
|
||||
def __call__(self, name: str, command: str) -> ActionReturn:
|
||||
action_name, api_name = (
|
||||
name.split('.') if '.' in name else (name, 'run'))
|
||||
if not self.is_valid(action_name):
|
||||
if name == self.no_action.name:
|
||||
action_return = self.no_action.run(*args, **kwargs)
|
||||
action_return = self.no_action(command)
|
||||
elif name == self.finish_action.name:
|
||||
action_return = self.finish_action.run(*args, **kwargs)
|
||||
action_return = self.finish_action(command)
|
||||
else:
|
||||
action_return = self.invalid_action(*args, **kwargs)
|
||||
action_return = self.invalid_action(command)
|
||||
else:
|
||||
action_return = self.actions[name].run(*args, **kwargs)
|
||||
action_return = self.actions[action_name](command, api_name)
|
||||
action_return.valid = ActionValidCode.OPEN
|
||||
return action_return
|
||||
|
||||
56
lagent/actions/arxiv_search.py
Normal file
56
lagent/actions/arxiv_search.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
import arxiv
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
|
||||
class ArxivSearch(BaseAction):
|
||||
"""Search information from Arxiv.org. \
|
||||
Useful for when you need to answer questions about Physics, Mathematics, \
|
||||
Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
|
||||
Electrical Engineering, and Economics from scientific articles on arxiv.org.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
top_k_results: int = 3,
|
||||
max_query_len: int = 300,
|
||||
doc_content_chars_max: int = 1500,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
self.top_k_results = top_k_results
|
||||
self.max_query_len = max_query_len
|
||||
self.doc_content_chars_max = doc_content_chars_max
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_arxiv_article_information(self, query: str) -> dict:
|
||||
"""Run Arxiv search and get the article meta information.
|
||||
|
||||
Args:
|
||||
query (:class:`str`): the content of search query
|
||||
|
||||
Returns:
|
||||
:class:`dict`: article information
|
||||
* content (str): a list of 3 arxiv search papers
|
||||
"""
|
||||
try:
|
||||
results = arxiv.Search( # type: ignore
|
||||
query[:self.max_query_len],
|
||||
max_results=self.top_k_results).results()
|
||||
except Exception as exc:
|
||||
return ActionReturn(
|
||||
errmsg=f'Arxiv exception: {exc}',
|
||||
state=ActionStatusCode.HTTP_ERROR)
|
||||
docs = [
|
||||
f'Published: {result.updated.date()}\nTitle: {result.title}\n'
|
||||
f'Authors: {", ".join(a.name for a in result.authors)}\n'
|
||||
f'Summary: {result.summary[:self.doc_content_chars_max]}'
|
||||
for result in results
|
||||
]
|
||||
if docs:
|
||||
return {'content': '\n\n'.join(docs)}
|
||||
return {'content': 'No good Arxiv Result was found'}
|
||||
@@ -1,57 +1,364 @@
|
||||
from typing import Optional
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from abc import ABCMeta
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import Annotated, Callable, Optional, Type, get_args, get_origin
|
||||
|
||||
from lagent.schema import ActionReturn
|
||||
from class_registry import AutoRegister, ClassRegistry
|
||||
from griffe import Docstring
|
||||
from griffe.enumerations import DocstringSectionKind
|
||||
|
||||
from ..schema import ActionReturn, ActionStatusCode
|
||||
from .parser import BaseParser, JsonParser, ParseError
|
||||
|
||||
logging.getLogger('griffe').setLevel(logging.ERROR)
|
||||
|
||||
TOOL_REGISTRY = ClassRegistry('__tool_name__', unique=True)
|
||||
|
||||
|
||||
class BaseAction:
|
||||
def tool_api(func: Optional[Callable] = None,
|
||||
*,
|
||||
explode_return: bool = False,
|
||||
returns_named_value: bool = False,
|
||||
**kwargs):
|
||||
"""Turn functions into tools. It will parse typehints as well as docstrings
|
||||
to build the tool description and attach it to functions via an attribute
|
||||
``api_description``.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# typehints has higher priority than docstrings
|
||||
from typing import Annotated
|
||||
|
||||
@tool_api
|
||||
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
|
||||
'''Add operation
|
||||
|
||||
Args:
|
||||
x (int): a
|
||||
y (int): b
|
||||
'''
|
||||
return a + b
|
||||
|
||||
print(add.api_description)
|
||||
|
||||
Args:
|
||||
func (Optional[Callable]): function to decorate. Defaults to ``None``.
|
||||
explode_return (bool): whether to flatten the dictionary or tuple return
|
||||
as the ``return_data`` field. When enabled, it is recommended to
|
||||
annotate the member in docstrings. Defaults to ``False``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def foo(a, b):
|
||||
'''A simple function
|
||||
|
||||
Args:
|
||||
a (int): a
|
||||
b (int): b
|
||||
|
||||
Returns:
|
||||
dict: information of inputs
|
||||
* x: value of a
|
||||
* y: value of b
|
||||
'''
|
||||
return {'x': a, 'y': b}
|
||||
|
||||
print(foo.api_description)
|
||||
|
||||
returns_named_value (bool): whether to parse ``thing: Description`` in
|
||||
returns sections as a name and description, rather than a type and
|
||||
description. When true, type must be wrapped in parentheses:
|
||||
``(int): Description``. When false, parentheses are optional but
|
||||
the items cannot be named: ``int: Description``. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
Callable: wrapped function or partial decorator
|
||||
|
||||
Important:
|
||||
``return_data`` field will be added to ``api_description`` only
|
||||
when ``explode_return`` or ``returns_named_value`` is enabled.
|
||||
"""
|
||||
|
||||
def _detect_type(string):
|
||||
field_type = 'STRING'
|
||||
if 'list' in string:
|
||||
field_type = 'Array'
|
||||
elif 'str' not in string:
|
||||
if 'float' in string:
|
||||
field_type = 'FLOAT'
|
||||
elif 'int' in string:
|
||||
field_type = 'NUMBER'
|
||||
elif 'bool' in string:
|
||||
field_type = 'BOOLEAN'
|
||||
return field_type
|
||||
|
||||
def _explode(desc):
|
||||
kvs = []
|
||||
desc = '\nArgs:\n' + '\n'.join([
|
||||
' ' + item.lstrip(' -+*#.')
|
||||
for item in desc.split('\n')[1:] if item.strip()
|
||||
])
|
||||
docs = Docstring(desc).parse('google')
|
||||
if not docs:
|
||||
return kvs
|
||||
if docs[0].kind is DocstringSectionKind.parameters:
|
||||
for d in docs[0].value:
|
||||
d = d.as_dict()
|
||||
if not d['annotation']:
|
||||
d.pop('annotation')
|
||||
else:
|
||||
d['type'] = _detect_type(d.pop('annotation').lower())
|
||||
kvs.append(d)
|
||||
return kvs
|
||||
|
||||
def _parse_tool(function):
|
||||
# remove rst syntax
|
||||
docs = Docstring(
|
||||
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
|
||||
'google', returns_named_value=returns_named_value, **kwargs)
|
||||
desc = dict(
|
||||
name=function.__name__,
|
||||
description=docs[0].value
|
||||
if docs[0].kind is DocstringSectionKind.text else '',
|
||||
parameters=[],
|
||||
required=[],
|
||||
)
|
||||
args_doc, returns_doc = {}, []
|
||||
for doc in docs:
|
||||
if doc.kind is DocstringSectionKind.parameters:
|
||||
for d in doc.value:
|
||||
d = d.as_dict()
|
||||
d['type'] = _detect_type(d.pop('annotation').lower())
|
||||
args_doc[d['name']] = d
|
||||
if doc.kind is DocstringSectionKind.returns:
|
||||
for d in doc.value:
|
||||
d = d.as_dict()
|
||||
if not d['name']:
|
||||
d.pop('name')
|
||||
if not d['annotation']:
|
||||
d.pop('annotation')
|
||||
else:
|
||||
d['type'] = _detect_type(d.pop('annotation').lower())
|
||||
returns_doc.append(d)
|
||||
|
||||
sig = inspect.signature(function)
|
||||
for name, param in sig.parameters.items():
|
||||
if name == 'self':
|
||||
continue
|
||||
parameter = dict(
|
||||
name=param.name,
|
||||
type='STRING',
|
||||
description=args_doc.get(param.name,
|
||||
{}).get('description', ''))
|
||||
annotation = param.annotation
|
||||
if annotation is inspect.Signature.empty:
|
||||
parameter['type'] = args_doc.get(param.name,
|
||||
{}).get('type', 'STRING')
|
||||
else:
|
||||
if get_origin(annotation) is Annotated:
|
||||
annotation, info = get_args(annotation)
|
||||
if info:
|
||||
parameter['description'] = info
|
||||
while get_origin(annotation):
|
||||
annotation = get_args(annotation)
|
||||
parameter['type'] = _detect_type(str(annotation))
|
||||
desc['parameters'].append(parameter)
|
||||
if param.default is inspect.Signature.empty:
|
||||
desc['required'].append(param.name)
|
||||
|
||||
return_data = []
|
||||
if explode_return:
|
||||
return_data = _explode(returns_doc[0]['description'])
|
||||
elif returns_named_value:
|
||||
return_data = returns_doc
|
||||
if return_data:
|
||||
desc['return_data'] = return_data
|
||||
return desc
|
||||
|
||||
if callable(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
wrapper.api_description = _parse_tool(func)
|
||||
return wrapper
|
||||
|
||||
def decorate(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
wrapper.api_description = _parse_tool(func)
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
class ToolMeta(ABCMeta):
|
||||
"""Metaclass of tools"""
|
||||
|
||||
def __new__(mcs, name, base, attrs):
|
||||
is_toolkit, tool_desc = True, dict(
|
||||
name=attrs.setdefault('__tool_name__', name),
|
||||
description=Docstring(attrs.get('__doc__',
|
||||
'')).parse('google')[0].value)
|
||||
for key, value in attrs.items():
|
||||
if callable(value) and hasattr(value, 'api_description'):
|
||||
api_desc = getattr(value, 'api_description')
|
||||
if key == 'run':
|
||||
tool_desc['parameters'] = api_desc['parameters']
|
||||
tool_desc['required'] = api_desc['required']
|
||||
if api_desc['description']:
|
||||
tool_desc['description'] = api_desc['description']
|
||||
if api_desc.get('return_data'):
|
||||
tool_desc['return_data'] = api_desc['return_data']
|
||||
is_toolkit = False
|
||||
break
|
||||
tool_desc.setdefault('api_list', []).append(api_desc)
|
||||
attrs['_is_toolkit'] = is_toolkit
|
||||
attrs['__tool_description__'] = tool_desc
|
||||
return super().__new__(mcs, name, base, attrs)
|
||||
|
||||
|
||||
class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)):
|
||||
"""Base class for all actions.
|
||||
|
||||
Args:
|
||||
description (str, optional): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class name. Defaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
description (:class:`Optional[dict]`): The description of the action.
|
||||
Defaults to ``None``.
|
||||
parser (:class:`Type[BaseParser]`): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (:class:`bool`): Whether the action is enabled. Defaults to
|
||||
``True``.
|
||||
|
||||
Examples:
|
||||
|
||||
* simple tool
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Bold(BaseAction):
|
||||
'''Make text bold'''
|
||||
|
||||
@tool_api
|
||||
def run(self, text: str):
|
||||
'''
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
'''
|
||||
return '**' + text + '**'
|
||||
|
||||
action = Bold()
|
||||
|
||||
* toolkit with multiple APIs
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Calculator(BaseAction):
|
||||
'''Calculator'''
|
||||
|
||||
@tool_api
|
||||
def add(self, a, b):
|
||||
'''Add operation
|
||||
|
||||
Args:
|
||||
a (int): augend
|
||||
b (int): addend
|
||||
|
||||
Returns:
|
||||
int: sum
|
||||
'''
|
||||
return a + b
|
||||
|
||||
@tool_api
|
||||
def sub(self, a, b):
|
||||
'''Subtraction operation
|
||||
|
||||
Args:
|
||||
a (int): minuend
|
||||
b (int): subtrahend
|
||||
|
||||
Returns:
|
||||
int: difference
|
||||
'''
|
||||
return a - b
|
||||
|
||||
action = Calculator()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
description: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
if name is None:
|
||||
name = self.__class__.__name__
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._disable_description = disable_description
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
self._description = deepcopy(description or self.__tool_description__)
|
||||
self._name = self._description['name']
|
||||
self._parser = parser(self)
|
||||
self._enable = enable
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ActionReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}:{self.description}'
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def run(self, *args, **kwargs) -> ActionReturn:
|
||||
return self.__call__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
def __call__(self, inputs: str, name='run') -> ActionReturn:
|
||||
fallback_args = {'inputs': inputs, 'name': name}
|
||||
if not hasattr(self, name):
|
||||
return ActionReturn(
|
||||
fallback_args,
|
||||
type=self.name,
|
||||
errmsg=f'invalid API: {name}',
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
try:
|
||||
inputs = self._parser.parse_inputs(inputs, name)
|
||||
except ParseError as exc:
|
||||
return ActionReturn(
|
||||
fallback_args,
|
||||
type=self.name,
|
||||
errmsg=exc.err_msg,
|
||||
state=ActionStatusCode.ARGS_ERROR)
|
||||
try:
|
||||
outputs = getattr(self, name)(**inputs)
|
||||
except Exception as exc:
|
||||
return ActionReturn(
|
||||
inputs,
|
||||
type=self.name,
|
||||
errmsg=str(exc),
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
if isinstance(outputs, ActionReturn):
|
||||
action_return = outputs
|
||||
if not action_return.args:
|
||||
action_return.args = inputs
|
||||
if not action_return.type:
|
||||
action_return.type = self.name
|
||||
else:
|
||||
result = self._parser.parse_outputs(outputs)
|
||||
action_return = ActionReturn(inputs, type=self.name, result=result)
|
||||
return action_return
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if self.enable:
|
||||
return self._description
|
||||
else:
|
||||
return self._disable_description
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
@property
|
||||
def is_toolkit(self):
|
||||
return self._is_toolkit
|
||||
|
||||
@property
|
||||
def description(self) -> dict:
|
||||
"""Description of the tool"""
|
||||
return self._description
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.description}'
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
145
lagent/actions/bing_map.py
Normal file
145
lagent/actions/bing_map.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Type
|
||||
|
||||
import requests
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
|
||||
|
||||
class BINGMap(BaseAction):
|
||||
"""BING Map plugin for looking up map information"""
|
||||
|
||||
def __init__(self,
|
||||
key: Optional[str] = None,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True) -> None:
|
||||
super().__init__(description, parser, enable)
|
||||
key = os.environ.get('BING_MAP_KEY', key)
|
||||
if key is None:
|
||||
raise ValueError(
|
||||
'Please set BING Map API key either in the environment '
|
||||
'as BING_MAP_KEY or pass it as `key` parameter.')
|
||||
self.key = key
|
||||
self.base_url = 'http://dev.virtualearth.net/REST/V1/'
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_distance(self, start: str, end: str) -> dict:
|
||||
"""Get the distance between two locations in km.
|
||||
|
||||
Args:
|
||||
start (:class:`str`): The start location
|
||||
end (:class:`str`): The end location
|
||||
|
||||
Returns:
|
||||
:class:`dict`: distance information
|
||||
* distance (str): the distance in km.
|
||||
"""
|
||||
# Request URL
|
||||
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
||||
# GET request
|
||||
r = requests.get(url)
|
||||
# TODO check request status?
|
||||
data = json.loads(r.text)
|
||||
# Extract route information
|
||||
route = data['resourceSets'][0]['resources'][0]
|
||||
# Extract distance in miles
|
||||
distance = route['travelDistance']
|
||||
return dict(distance=distance)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_route(self, start: str, end: str) -> dict:
|
||||
"""Get the route between two locations in km.
|
||||
|
||||
Args:
|
||||
start (:class:`str`): The start location
|
||||
end (:class:`str`): The end location
|
||||
|
||||
Returns:
|
||||
:class:`dict`: route information
|
||||
* route (list): the route, a list of actions.
|
||||
"""
|
||||
# Request URL
|
||||
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
||||
# GET request
|
||||
r = requests.get(url)
|
||||
data = json.loads(r.text)
|
||||
# Extract route information
|
||||
route = data['resourceSets'][0]['resources'][0]
|
||||
itinerary = route['routeLegs'][0]['itineraryItems']
|
||||
# Extract route text information
|
||||
route_text = []
|
||||
for item in itinerary:
|
||||
if 'instruction' in item:
|
||||
route_text.append(item['instruction']['text'])
|
||||
return dict(route=route_text)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_coordinates(self, location: str) -> dict:
|
||||
"""Get the coordinates of a location.
|
||||
|
||||
Args:
|
||||
location (:class:`str`): the location need to get coordinates.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: coordinates information
|
||||
* latitude (float): the latitude of the location.
|
||||
* longitude (float): the longitude of the location.
|
||||
"""
|
||||
url = self.base_url + 'Locations'
|
||||
params = {'query': location, 'key': self.key}
|
||||
response = requests.get(url, params=params)
|
||||
json_data = response.json()
|
||||
coordinates = json_data['resourceSets'][0]['resources'][0]['point'][
|
||||
'coordinates']
|
||||
return dict(latitude=coordinates[0], longitude=coordinates[1])
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def search_nearby(self,
|
||||
search_term: str,
|
||||
places: str = 'unknown',
|
||||
latitude: float = 0.0,
|
||||
longitude: float = 0.0,
|
||||
radius: int = 5000) -> dict: # radius in meters
|
||||
"""Search for places nearby a location, within a given radius, and \
|
||||
return the results into a list. You can use either the places name or the \
|
||||
latitude and longitude.
|
||||
|
||||
Args:
|
||||
search_term (:class:`str`): the place name.
|
||||
places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
|
||||
latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
|
||||
longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
|
||||
radius (:class:`int`): radius in meters. Defaults to ``5000``.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: places information
|
||||
* places (list): the list of places, each place is a dict with name and address, at most 5 places.
|
||||
"""
|
||||
url = self.base_url + 'LocalSearch'
|
||||
if places != 'unknown':
|
||||
pos = self.get_coordinates(**{'location': places})
|
||||
latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
|
||||
# Build the request query string
|
||||
params = {
|
||||
'query': search_term,
|
||||
'userLocation': f'{latitude},{longitude}',
|
||||
'radius': radius,
|
||||
'key': self.key
|
||||
}
|
||||
# Make the request
|
||||
response = requests.get(url, params=params)
|
||||
# Parse the response
|
||||
response_data = json.loads(response.content)
|
||||
# Get the results
|
||||
results = response_data['resourceSets'][0]['resources']
|
||||
addresses = []
|
||||
for result in results:
|
||||
name = result['name']
|
||||
address = result['Address']['formattedAddress']
|
||||
addresses.append(dict(name=name, address=address))
|
||||
if len(addresses) == 5:
|
||||
break
|
||||
return dict(place=addresses)
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
|
||||
|
||||
|
||||
@@ -19,12 +20,13 @@ class InvalidAction(BaseAction):
|
||||
def __init__(self,
|
||||
err_msg:
|
||||
str = 'The action is invalid, please check the action name.',
|
||||
**kwargs) -> None:
|
||||
|
||||
super().__init__(enable=False, **kwargs)
|
||||
description: Optional[dict] = None,
|
||||
parser=BaseParser) -> None:
|
||||
super().__init__(description, parser, enable=False)
|
||||
self._err_msg = err_msg
|
||||
|
||||
def __call__(self, err_msg: Optional[str] = None):
|
||||
@tool_api
|
||||
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
|
||||
"""Return the error message.
|
||||
|
||||
Args:
|
||||
@@ -35,7 +37,7 @@ class InvalidAction(BaseAction):
|
||||
action_return = ActionReturn(
|
||||
url=None,
|
||||
args=dict(text=err_msg),
|
||||
errmsg=err_msg if err_msg else self._err_msg,
|
||||
errmsg=err_msg or self._err_msg,
|
||||
type=self.name,
|
||||
valid=ActionValidCode.INVALID,
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
@@ -51,12 +53,15 @@ class NoAction(BaseAction):
|
||||
'Please follow the format'.
|
||||
"""
|
||||
|
||||
def __init__(self, err_msg: str = 'Please follow the format', **kwargs):
|
||||
|
||||
super().__init__(enable=False, **kwargs)
|
||||
def __init__(self,
|
||||
err_msg: str = 'Please follow the format',
|
||||
description: Optional[dict] = None,
|
||||
parser=BaseParser):
|
||||
super().__init__(description, parser, enable=False)
|
||||
self._err_msg = err_msg
|
||||
|
||||
def __call__(self, err_msg: Optional[str] = None):
|
||||
@tool_api
|
||||
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
|
||||
"""Return the error message.
|
||||
|
||||
Args:
|
||||
@@ -71,7 +76,7 @@ class NoAction(BaseAction):
|
||||
url=None,
|
||||
args=dict(text=err_msg),
|
||||
type=self.name,
|
||||
errmsg=err_msg if err_msg else self._err_msg,
|
||||
errmsg=err_msg or self._err_msg,
|
||||
valid=ActionValidCode.INVALID,
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
return action_return
|
||||
@@ -81,7 +86,11 @@ class FinishAction(BaseAction):
|
||||
"""This is a finish action class, which is used to return the final
|
||||
result."""
|
||||
|
||||
def __call__(self, response: str) -> ActionReturn:
|
||||
def __init__(self, description: Optional[dict] = None, parser=BaseParser):
|
||||
super().__init__(description, parser, enable=True)
|
||||
|
||||
@tool_api
|
||||
def run(self, response: str) -> ActionReturn:
|
||||
"""Return the final result.
|
||||
|
||||
Args:
|
||||
@@ -93,7 +102,7 @@ class FinishAction(BaseAction):
|
||||
action_return = ActionReturn(
|
||||
url=None,
|
||||
args=dict(text=response),
|
||||
result=dict(text=response),
|
||||
result=[dict(type='text', content=response)],
|
||||
type=self.name,
|
||||
valid=ActionValidCode.FINISH,
|
||||
state=ActionStatusCode.SUCCESS)
|
||||
|
||||
267
lagent/actions/google_scholar_search.py
Normal file
267
lagent/actions/google_scholar_search.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import os
|
||||
from typing import Optional, Type
|
||||
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .parser import BaseParser, JsonParser
|
||||
|
||||
|
||||
class GoogleScholar(BaseAction):
|
||||
"""Plugin for google scholar search
|
||||
|
||||
Args:
|
||||
api_key (str): API KEY to use serper google search API,
|
||||
You can create a free API key at https://serper.dev.
|
||||
description (dict): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
api_key: Optional[str] = None,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
api_key = os.environ.get('SERPER_API_KEY', api_key)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
'Please set Serper API key either in the environment '
|
||||
'as SERPER_API_KEY or pass it as `api_key` parameter.')
|
||||
self.api_key = api_key
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def search_google_scholar(
|
||||
self,
|
||||
query: str,
|
||||
cites: Optional[str] = None,
|
||||
as_ylo: Optional[int] = None,
|
||||
as_yhi: Optional[int] = None,
|
||||
scisbd: Optional[int] = None,
|
||||
cluster: Optional[str] = None,
|
||||
hl: Optional[str] = None,
|
||||
lr: Optional[str] = None,
|
||||
start: Optional[int] = None,
|
||||
num: Optional[int] = None,
|
||||
as_sdt: Optional[str] = None,
|
||||
safe: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
as_vis: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Search for scholarly articles based on a query according to the google scholar
|
||||
|
||||
Args:
|
||||
query (str): The query to search for.
|
||||
cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
|
||||
as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
|
||||
as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
|
||||
scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
|
||||
cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
|
||||
hl (Optional[str]): The language to use for the Google Scholar search.
|
||||
lr (Optional[str]): One or multiple languages to limit the search to.
|
||||
start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
|
||||
num (Optional[int]): The maximum number of results to return, limited to 20.
|
||||
as_sdt (Optional[str]): Can be used either as a search type or a filter.
|
||||
safe (Optional[str]): The level of filtering for adult content.
|
||||
filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
|
||||
as_vis (Optional[str]): Defines whether to include citations or not.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: article information
|
||||
- title: a list of the titles of the three selected papers
|
||||
- cited_by: a list of the citation numbers of the three selected papers
|
||||
- organic_id: a list of the organic results' ids of the three selected papers
|
||||
- pub_info: publication information of selected papers
|
||||
"""
|
||||
params = {
|
||||
'q': query,
|
||||
'engine': 'google_scholar',
|
||||
'api_key': self.api_key,
|
||||
'cites': cites,
|
||||
'as_ylo': as_ylo,
|
||||
'as_yhi': as_yhi,
|
||||
'scisbd': scisbd,
|
||||
'cluster': cluster,
|
||||
'hl': hl,
|
||||
'lr': lr,
|
||||
'start': start,
|
||||
'num': num,
|
||||
'as_sdt': as_sdt,
|
||||
'safe': safe,
|
||||
'filter': filter,
|
||||
'as_vis': as_vis
|
||||
}
|
||||
search = GoogleSearch(params)
|
||||
try:
|
||||
r = search.get_dict()
|
||||
results = r['organic_results']
|
||||
title = []
|
||||
snippets = []
|
||||
cited_by = []
|
||||
organic_id = []
|
||||
pub_info = []
|
||||
for item in results[:3]:
|
||||
title.append(item['title'])
|
||||
pub_info.append(item['publication_info']['summary'])
|
||||
citation = item['inline_links'].get('cited_by', {'total': ''})
|
||||
cited_by.append(citation['total'])
|
||||
snippets.append(item['snippet'])
|
||||
organic_id.append(item['result_id'])
|
||||
return dict(
|
||||
title=title,
|
||||
cited_by=cited_by,
|
||||
organic_id=organic_id,
|
||||
snippets=snippets)
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_author_information(self,
|
||||
author_id: str,
|
||||
hl: Optional[str] = None,
|
||||
view_op: Optional[str] = None,
|
||||
sort: Optional[str] = None,
|
||||
citation_id: Optional[str] = None,
|
||||
start: Optional[int] = None,
|
||||
num: Optional[int] = None,
|
||||
no_cache: Optional[bool] = None,
|
||||
async_req: Optional[bool] = None,
|
||||
output: Optional[str] = None) -> dict:
|
||||
"""Search for an author's information by author's id provided by get_author_id.
|
||||
|
||||
Args:
|
||||
author_id (str): Required. The ID of an author.
|
||||
hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
|
||||
view_op (Optional[str]): Used for viewing specific parts of a page.
|
||||
sort (Optional[str]): Used for sorting and refining articles.
|
||||
citation_id (Optional[str]): Used for retrieving individual article citation.
|
||||
start (Optional[int]): Defines the result offset. Default is 0.
|
||||
num (Optional[int]): Defines the number of results to return. Default is 20.
|
||||
no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
|
||||
async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
|
||||
output (Optional[str]): Defines the final output you want. Default is 'json'.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: author information
|
||||
* name: author's name
|
||||
* affliation: the affliation of the author
|
||||
* articles: at most 3 articles by the author
|
||||
* website: the author's homepage url
|
||||
"""
|
||||
params = {
|
||||
'engine': 'google_scholar_author',
|
||||
'author_id': author_id,
|
||||
'api_key': self.api_key,
|
||||
'hl': hl,
|
||||
'view_op': view_op,
|
||||
'sort': sort,
|
||||
'citation_id': citation_id,
|
||||
'start': start,
|
||||
'num': num,
|
||||
'no_cache': no_cache,
|
||||
'async': async_req,
|
||||
'output': output
|
||||
}
|
||||
try:
|
||||
search = GoogleSearch(params)
|
||||
results = search.get_dict()
|
||||
author = results['author']
|
||||
articles = results.get('articles', [])
|
||||
return dict(
|
||||
name=author['name'],
|
||||
affiliations=author.get('affiliations', ''),
|
||||
website=author.get('website', ''),
|
||||
articles=[
|
||||
dict(title=article['title'], authors=article['authors'])
|
||||
for article in articles[:3]
|
||||
])
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_citation_format(self,
|
||||
q: str,
|
||||
no_cache: Optional[bool] = None,
|
||||
async_: Optional[bool] = None,
|
||||
output: Optional[str] = 'json') -> dict:
|
||||
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
|
||||
|
||||
Args:
|
||||
q (str): ID of an individual Google Scholar organic search result.
|
||||
no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
|
||||
async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
|
||||
output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: citation format
|
||||
* authors: the authors of the article
|
||||
* citation: the citation format of the article
|
||||
"""
|
||||
params = {
|
||||
'q': q,
|
||||
'engine': 'google_scholar_cite',
|
||||
'api_key': self.api_key,
|
||||
'no_cache': no_cache,
|
||||
'async': async_,
|
||||
'output': output
|
||||
}
|
||||
try:
|
||||
search = GoogleSearch(params)
|
||||
results = search.get_dict()
|
||||
citation = results['citations']
|
||||
citation_info = citation[0]['snippet']
|
||||
return citation_info
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_author_id(self,
|
||||
mauthors: str,
|
||||
hl: Optional[str] = 'en',
|
||||
after_author: Optional[str] = None,
|
||||
before_author: Optional[str] = None,
|
||||
no_cache: Optional[bool] = False,
|
||||
_async: Optional[bool] = False,
|
||||
output: Optional[str] = 'json') -> dict:
|
||||
"""The getAuthorId function is used to get the author's id by his or her name.
|
||||
|
||||
Args:
|
||||
mauthors (str): Defines the author you want to search for.
|
||||
hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
|
||||
after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
|
||||
before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
|
||||
no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
|
||||
_async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
|
||||
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: author id
|
||||
* author_id: the author_id of the author
|
||||
"""
|
||||
params = {
|
||||
'mauthors': mauthors,
|
||||
'engine': 'google_scholar_profiles',
|
||||
'api_key': self.api_key,
|
||||
'hl': hl,
|
||||
'after_author': after_author,
|
||||
'before_author': before_author,
|
||||
'no_cache': no_cache,
|
||||
'async': _async,
|
||||
'output': output
|
||||
}
|
||||
try:
|
||||
search = GoogleSearch(params)
|
||||
results = search.get_dict()
|
||||
profile = results['profiles']
|
||||
author_info = dict(author_id=profile[0]['author_id'])
|
||||
return author_info
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
@@ -1,15 +1,11 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import requests
|
||||
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction
|
||||
|
||||
DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。
|
||||
当你需要对于一个特定问题找到简短明了的回答时,可以使用它。
|
||||
输入应该是一个搜索查询。
|
||||
"""
|
||||
from .base_action import BaseAction, tool_api
|
||||
from .parser import BaseParser, JsonParser
|
||||
|
||||
|
||||
class GoogleSearch(BaseAction):
|
||||
@@ -28,15 +24,10 @@ class GoogleSearch(BaseAction):
|
||||
timeout (int): Upper bound of waiting time for a serper request.
|
||||
search_type (str): Serper API support ['search', 'images', 'news',
|
||||
'places'] types of search, currently we only support 'search'.
|
||||
k (int): select first k results in the search results as response.
|
||||
description (str): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class name. Defaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
description (dict): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool): Whether the action is enabled. Defaults to ``True``.
|
||||
"""
|
||||
result_key_for_type = {
|
||||
'news': 'news',
|
||||
@@ -49,36 +40,29 @@ class GoogleSearch(BaseAction):
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 5,
|
||||
search_type: str = 'search',
|
||||
k: int = 10,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
api_key = os.environ.get('SERPER_API_KEY', api_key)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
'Please set Serper API key either in the environment '
|
||||
' as SERPER_API_KEY or pass it as `api_key` parameter.')
|
||||
'as SERPER_API_KEY or pass it as `api_key` parameter.')
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.search_type = search_type
|
||||
self.k = k
|
||||
|
||||
def __call__(self, query: str) -> ActionReturn:
|
||||
"""Return the search response.
|
||||
|
||||
@tool_api
|
||||
def run(self, query: str, k: int = 10) -> ActionReturn:
|
||||
"""一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
|
||||
|
||||
Args:
|
||||
query (str): The search content.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
query (str): the search content
|
||||
k (int): select first k results in the search results as response
|
||||
"""
|
||||
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
status_code, response = self._search(
|
||||
query, search_type=self.search_type, k=self.k)
|
||||
tool_return = ActionReturn(type=self.name)
|
||||
status_code, response = self._search(query, k=k)
|
||||
# convert search results to ToolReturn format
|
||||
if status_code == -1:
|
||||
tool_return.errmsg = response
|
||||
@@ -139,7 +123,7 @@ class GoogleSearch(BaseAction):
|
||||
|
||||
def _search(self,
|
||||
search_term: str,
|
||||
search_type: str = 'search',
|
||||
search_type: Optional[str] = None,
|
||||
**kwargs) -> Tuple[int, Union[dict, str]]:
|
||||
"""HTTP requests to Serper API.
|
||||
|
||||
@@ -166,7 +150,7 @@ class GoogleSearch(BaseAction):
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
f'https://google.serper.dev/{search_type}',
|
||||
f'https://google.serper.dev/{search_type or self.search_type}',
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=self.timeout)
|
||||
|
||||
296
lagent/actions/ipython_interpreter.py
Normal file
296
lagent/actions/ipython_interpreter.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import json5
|
||||
import PIL.Image
|
||||
from jupyter_client import KernelManager
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
START_CODE = """
|
||||
def input(*args, **kwargs):
|
||||
raise NotImplementedError('Python input() function is disabled.')
|
||||
|
||||
get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!')
|
||||
{}
|
||||
""" # noqa
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class IPythonInterpreter(BaseAction):
|
||||
"""A IPython executor that can execute Python scripts in a jupyter manner.
|
||||
|
||||
Args:
|
||||
timeout (int): Upper bound of waiting time for Python script execution.
|
||||
Defaults to 20.
|
||||
user_data_dir (str, optional): Specified the user data directory for files
|
||||
loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
|
||||
Defaults to `ENV`.
|
||||
work_dir (str, optional): Specify which directory to save output images to.
|
||||
Defaults to ``'./work_dir/tmp_dir'``.
|
||||
description (dict): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to ``True``.
|
||||
"""
|
||||
|
||||
_KERNEL_CLIENTS = {}
|
||||
|
||||
def __init__(self,
|
||||
timeout: int = 20,
|
||||
user_data_dir: str = 'ENV',
|
||||
work_dir='./work_dir/tmp_dir',
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
|
||||
self.timeout = timeout
|
||||
if user_data_dir == 'ENV':
|
||||
user_data_dir = os.environ.get('USER_DATA_DIR', '')
|
||||
|
||||
if user_data_dir:
|
||||
user_data_dir = os.path.dirname(user_data_dir)
|
||||
user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
|
||||
self.user_data_dir = user_data_dir
|
||||
self._initialized = False
|
||||
self.work_dir = work_dir
|
||||
if not os.path.exists(self.work_dir):
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def start_kernel():
|
||||
# start the kernel and manager
|
||||
km = KernelManager()
|
||||
km.start_kernel()
|
||||
kc = km.client()
|
||||
return km, kc
|
||||
|
||||
def initialize(self):
|
||||
if self._initialized:
|
||||
return
|
||||
pid = os.getpid()
|
||||
if pid not in self._KERNEL_CLIENTS:
|
||||
self._KERNEL_CLIENTS[pid] = self.start_kernel()
|
||||
self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid]
|
||||
self._initialized = True
|
||||
self._call(START_CODE.format(self.user_data_dir), None)
|
||||
|
||||
def reset(self):
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
else:
|
||||
code = "get_ipython().run_line_magic('reset', '-f')\n" + \
|
||||
START_CODE.format(self.user_data_dir)
|
||||
self._call(code, None)
|
||||
|
||||
def _call(self,
|
||||
command: str,
|
||||
timeout: Optional[int] = None) -> Tuple[str, bool]:
|
||||
self.initialize()
|
||||
command = extract_code(command)
|
||||
|
||||
# check previous remaining result
|
||||
while True:
|
||||
try:
|
||||
msg = self.kernel_client.get_iopub_msg(timeout=5)
|
||||
msg_type = msg['msg_type']
|
||||
if msg_type == 'status':
|
||||
if msg['content'].get('execution_state') == 'idle':
|
||||
break
|
||||
except queue.Empty:
|
||||
# assume no result
|
||||
break
|
||||
|
||||
self.kernel_client.execute(command)
|
||||
|
||||
def _inner_call():
|
||||
result = ''
|
||||
images = []
|
||||
succeed = True
|
||||
image_idx = 0
|
||||
|
||||
while True:
|
||||
text = ''
|
||||
image = ''
|
||||
finished = False
|
||||
msg_type = 'error'
|
||||
try:
|
||||
msg = self.kernel_client.get_iopub_msg(timeout=20)
|
||||
msg_type = msg['msg_type']
|
||||
if msg_type == 'status':
|
||||
if msg['content'].get('execution_state') == 'idle':
|
||||
finished = True
|
||||
elif msg_type == 'execute_result':
|
||||
text = msg['content']['data'].get('text/plain', '')
|
||||
if 'image/png' in msg['content']['data']:
|
||||
image_b64 = msg['content']['data']['image/png']
|
||||
image_url = publish_image_to_local(
|
||||
image_b64, self.work_dir)
|
||||
image_idx += 1
|
||||
image = '' % (image_idx, image_url)
|
||||
|
||||
elif msg_type == 'display_data':
|
||||
if 'image/png' in msg['content']['data']:
|
||||
image_b64 = msg['content']['data']['image/png']
|
||||
image_url = publish_image_to_local(
|
||||
image_b64, self.work_dir)
|
||||
image_idx += 1
|
||||
image = '' % (image_idx, image_url)
|
||||
|
||||
else:
|
||||
text = msg['content']['data'].get('text/plain', '')
|
||||
elif msg_type == 'stream':
|
||||
msg_type = msg['content']['name'] # stdout, stderr
|
||||
text = msg['content']['text']
|
||||
elif msg_type == 'error':
|
||||
succeed = False
|
||||
text = escape_ansi('\n'.join(
|
||||
msg['content']['traceback']))
|
||||
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
|
||||
text = f'Timeout. No response after {timeout} seconds.' # noqa
|
||||
except queue.Empty:
|
||||
# stop current task in case break next input.
|
||||
self.kernel_manager.interrupt_kernel()
|
||||
succeed = False
|
||||
text = f'Timeout. No response after {timeout} seconds.'
|
||||
finished = True
|
||||
except Exception:
|
||||
succeed = False
|
||||
msg = ''.join(traceback.format_exception(*sys.exc_info()))
|
||||
# text = 'The code interpreter encountered an unexpected error.' # noqa
|
||||
text = msg
|
||||
logging.warning(msg)
|
||||
finished = True
|
||||
if text:
|
||||
# result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
|
||||
result += f'{text}'
|
||||
|
||||
if image:
|
||||
images.append(image_url)
|
||||
if finished:
|
||||
return succeed, dict(text=result, image=images)
|
||||
|
||||
try:
|
||||
if timeout:
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
succeed, result = _inner_call()
|
||||
except TimeoutError:
|
||||
succeed = False
|
||||
text = 'The code interpreter encountered an unexpected error.'
|
||||
result = f'\n\nerror:\n\n```\n{text}\n```'
|
||||
finally:
|
||||
if timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
# result = result.strip('\n')
|
||||
return succeed, result
|
||||
|
||||
@tool_api
|
||||
def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
|
||||
"""When you send a message containing Python code to python, it will be \
|
||||
executed in a stateful Jupyter notebook environment. python will respond with \
|
||||
the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' \
|
||||
can be used to save and persist user files. Internet access for this session is \
|
||||
disabled. Do not make external web requests or API calls as they will fail.
|
||||
|
||||
Args:
|
||||
command (:class:`str`): Python code
|
||||
timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
|
||||
"""
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return.args = dict(text=command)
|
||||
succeed, result = self._call(command, timeout)
|
||||
if succeed:
|
||||
text = result['text']
|
||||
image = result.get('image', [])
|
||||
resp = [dict(type="text", content=text)]
|
||||
if image:
|
||||
resp.extend([dict(type="image", content=im) for im in image])
|
||||
tool_return.result = resp
|
||||
# tool_return.result = dict(
|
||||
# text=result['text'], image=result.get('image', [])[0])
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
else:
|
||||
tool_return.errmsg = result.get("text", "") if isinstance(
|
||||
result, dict) else result
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
|
||||
|
||||
def extract_code(text):
|
||||
# Match triple backtick blocks first
|
||||
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
||||
# Match single backtick blocks second
|
||||
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
||||
if triple_match:
|
||||
text = triple_match.group(1)
|
||||
elif single_match:
|
||||
text = single_match.group(1)
|
||||
else:
|
||||
try:
|
||||
text = json5.loads(text)['code']
|
||||
except Exception:
|
||||
pass
|
||||
# If no code blocks found, return original text
|
||||
return text
|
||||
|
||||
|
||||
def escape_ansi(line):
|
||||
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', line)
|
||||
|
||||
|
||||
def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
|
||||
image_file = str(uuid.uuid4()) + '.png'
|
||||
local_image_file = os.path.join(work_dir, image_file)
|
||||
|
||||
png_bytes = base64.b64decode(image_base64)
|
||||
assert isinstance(png_bytes, bytes)
|
||||
bytes_io = io.BytesIO(png_bytes)
|
||||
PIL.Image.open(bytes_io).save(local_image_file, 'png')
|
||||
|
||||
return local_image_file
|
||||
|
||||
|
||||
# local test for code interpreter
|
||||
def get_multiline_input(hint):
|
||||
print(hint)
|
||||
print('// Press ENTER to make a new line. Press CTRL-D to end input.')
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
except EOFError: # CTRL-D
|
||||
break
|
||||
lines.append(line)
|
||||
print('// Input received.')
|
||||
if lines:
|
||||
return '\n'.join(lines)
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
code_interpreter = IPythonInterpreter()
|
||||
while True:
|
||||
print(code_interpreter(get_multiline_input('Enter python code:')))
|
||||
@@ -1,56 +0,0 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from lagent.llms.base_api import BaseAPIModel
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction
|
||||
|
||||
DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。
|
||||
当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。
|
||||
"""
|
||||
|
||||
|
||||
class LLMQA(BaseAction):
|
||||
"""An LLM Wrapper as BaseAction type.
|
||||
|
||||
Args:
|
||||
llm (BaseModel or BaseAPIModel): a LLM service which can chat.
|
||||
description (str): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class name. Defaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
self._llm = llm
|
||||
|
||||
def __call__(self, query: str) -> ActionReturn:
|
||||
"""Return the QA response.
|
||||
|
||||
Args:
|
||||
query (str): The query content.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
"""
|
||||
|
||||
tool_return = ActionReturn(url=None, args=None)
|
||||
try:
|
||||
response = self._llm.generate_from_template(query, 512)
|
||||
tool_return.result = dict(text=str(response))
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
except Exception as e:
|
||||
tool_return.result = dict(text=str(e))
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
139
lagent/actions/parser.py
Normal file
139
lagent/actions/parser.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import json
|
||||
import re
|
||||
from ast import literal_eval
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Parsing exception class"""
|
||||
|
||||
def __init__(self, err_msg: str):
|
||||
self.err_msg = err_msg
|
||||
|
||||
|
||||
class BaseParser:
|
||||
"""Base parser to process inputs and outputs of actions.
|
||||
|
||||
Args:
|
||||
action (:class:`BaseAction`): action to validate
|
||||
|
||||
Attributes:
|
||||
PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
|
||||
LLMs should follow when generating arguments for decided tools.
|
||||
"""
|
||||
|
||||
PARAMETER_DESCRIPTION: str = ''
|
||||
|
||||
def __init__(self, action):
|
||||
self.action = action
|
||||
self._api2param = {}
|
||||
self._api2required = {}
|
||||
# perform basic argument validation
|
||||
if action.description:
|
||||
for api in action.description.get('api_list',
|
||||
[action.description]):
|
||||
name = (f'{action.name}.{api["name"]}'
|
||||
if self.action.is_toolkit else api['name'])
|
||||
required_parameters = set(api['required'])
|
||||
all_parameters = {j['name'] for j in api['parameters']}
|
||||
if not required_parameters.issubset(all_parameters):
|
||||
raise ValueError(
|
||||
f'unknown parameters for function "{name}": '
|
||||
f'{required_parameters - all_parameters}')
|
||||
if self.PARAMETER_DESCRIPTION:
|
||||
api['parameter_description'] = self.PARAMETER_DESCRIPTION
|
||||
api_name = api['name'] if self.action.is_toolkit else 'run'
|
||||
self._api2param[api_name] = api['parameters']
|
||||
self._api2required[api_name] = api['required']
|
||||
|
||||
def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
|
||||
"""parse inputs LLMs generate for the action
|
||||
|
||||
Args:
|
||||
inputs (:class:`str`): input string extracted from responses
|
||||
|
||||
Returns:
|
||||
:class:`dict`: processed input
|
||||
"""
|
||||
inputs = {self._api2param[name][0]['name']: inputs}
|
||||
return inputs
|
||||
|
||||
def parse_outputs(self, outputs: Any) -> List[dict]:
|
||||
"""parser outputs returned by the action
|
||||
|
||||
Args:
|
||||
outputs (:class:`Any`): raw output of the action
|
||||
|
||||
Returns:
|
||||
:class:`List[dict]`: processed output of which each member is a
|
||||
dictionary with two keys - 'type' and 'content'.
|
||||
"""
|
||||
if isinstance(outputs, dict):
|
||||
outputs = json.dumps(outputs, ensure_ascii=False)
|
||||
elif not isinstance(outputs, str):
|
||||
outputs = str(outputs)
|
||||
return [{'type': 'text', 'content': outputs}]
|
||||
|
||||
|
||||
class JsonParser(BaseParser):
|
||||
"""Json parser to convert input string into a dictionary.
|
||||
|
||||
Args:
|
||||
action (:class:`BaseAction`): action to validate
|
||||
"""
|
||||
|
||||
PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'
|
||||
|
||||
def parse_inputs(self,
|
||||
inputs: Union[str, dict],
|
||||
name: str = 'run') -> dict:
|
||||
if not isinstance(inputs, dict):
|
||||
try:
|
||||
match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
|
||||
re.S)
|
||||
if match:
|
||||
inputs = match.group(2).strip()
|
||||
inputs = json.loads(inputs)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ParseError(f'invalid json format: {inputs}') from exc
|
||||
input_keys = set(inputs)
|
||||
all_keys = {param['name'] for param in self._api2param[name]}
|
||||
if not input_keys.issubset(all_keys):
|
||||
raise ParseError(f'unknown arguments: {input_keys - all_keys}')
|
||||
required_keys = set(self._api2required[name])
|
||||
if not input_keys.issuperset(required_keys):
|
||||
raise ParseError(
|
||||
f'missing required arguments: {required_keys - input_keys}')
|
||||
return inputs
|
||||
|
||||
|
||||
class TupleParser(BaseParser):
|
||||
"""Tuple parser to convert input string into a tuple.
|
||||
|
||||
Args:
|
||||
action (:class:`BaseAction`): action to validate
|
||||
"""
|
||||
|
||||
PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Tuple格式 (arg1, arg2, arg3) 传参,且参数是有序的'
|
||||
|
||||
def parse_inputs(self,
|
||||
inputs: Union[str, tuple],
|
||||
name: str = 'run') -> dict:
|
||||
if not isinstance(inputs, tuple):
|
||||
try:
|
||||
inputs = literal_eval(inputs)
|
||||
except Exception as exc:
|
||||
raise ParseError(f'invalid tuple format: {inputs}') from exc
|
||||
if len(inputs) < len(self._api2required[name]):
|
||||
raise ParseError(
|
||||
f'API takes {len(self._api2required[name])} required positional '
|
||||
f'arguments but {len(inputs)} were given')
|
||||
if len(inputs) > len(self._api2param[name]):
|
||||
raise ParseError(
|
||||
f'API takes {len(self._api2param[name])} positional arguments '
|
||||
f'but {len(inputs)} were given')
|
||||
inputs = {
|
||||
self._api2param[name][i]['name']: item
|
||||
for i, item in enumerate(inputs)
|
||||
}
|
||||
return inputs
|
||||
157
lagent/actions/ppt.py
Normal file
157
lagent/actions/ppt.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from pptx import Presentation
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
|
||||
THEME_MAPPING = {
|
||||
'Default': {
|
||||
'template': None,
|
||||
'title': 'Title Slide',
|
||||
'single': 'Title and Content',
|
||||
'two': 'Tow content',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PPT(BaseAction):
|
||||
"""Plugin to create ppt slides with text, paragraph, images in good looking styles"""
|
||||
|
||||
def __init__(self,
|
||||
theme_mapping: Optional[Dict[str, dict]] = None,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
self.theme_mapping = theme_mapping or THEME_MAPPING
|
||||
self.pointer = None
|
||||
self.location = None
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def create_file(self, theme: str, abs_location: str) -> dict:
|
||||
"""Create a pptx file with specific themes
|
||||
|
||||
Args:
|
||||
theme (:class:`str`): the theme used
|
||||
abs_location (:class:`str`): the ppt file's absolute location
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
self.location = abs_location
|
||||
try:
|
||||
self.pointer = Presentation(self.theme_mapping[theme]['template'])
|
||||
self.pointer.slide_master.name = theme
|
||||
# print('created')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return dict(status='created a ppt file.')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def add_first_page(self, title: str, subtitle: str) -> dict:
|
||||
"""Add the first page of ppt.
|
||||
|
||||
Args:
|
||||
title (:class:`str`): the title of ppt
|
||||
subtitle (:class:`str`): the subtitle of ppt
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
layout_name = self.theme_mapping[
|
||||
self.pointer.slide_master.name]['title']
|
||||
layout = next(i for i in self.pointer.slide_master.slide_layouts
|
||||
if i.name == layout_name)
|
||||
slide = self.pointer.slides.add_slide(layout)
|
||||
ph_title, ph_subtitle = slide.placeholders
|
||||
ph_title.text = title
|
||||
if subtitle:
|
||||
ph_subtitle.text = subtitle
|
||||
return dict(status='added page')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def add_text_page(self, title: str, bullet_items: str) -> dict:
|
||||
"""Add text page of ppt
|
||||
|
||||
Args:
|
||||
title (:class:`str`): the title of the page
|
||||
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
layout_name = self.theme_mapping[
|
||||
self.pointer.slide_master.name]['single']
|
||||
layout = next(i for i in self.pointer.slide_master.slide_layouts
|
||||
if i.name == layout_name)
|
||||
slide = self.pointer.slides.add_slide(layout)
|
||||
ph_title, ph_body = slide.placeholders
|
||||
ph_title.text = title
|
||||
ph = ph_body
|
||||
tf = ph.text_frame
|
||||
for i, item in enumerate(bullet_items.split('[SPAN]')):
|
||||
if i == 0:
|
||||
p = tf.paragraphs[0]
|
||||
else:
|
||||
p = tf.add_paragraph()
|
||||
p.text = item.strip()
|
||||
p.level = 0
|
||||
return dict(status='added page')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def add_text_image_page(self, title: str, bullet_items: str,
|
||||
image: str) -> dict:
|
||||
"""Add a text page with one image. Image should be a path
|
||||
|
||||
Args:
|
||||
title (:class:`str`): the title of the page
|
||||
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
||||
image (:class:`str`): the path of the image
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
|
||||
layout = next(i for i in self.pointer.slide_master.slide_layouts
|
||||
if i.name == layout_name)
|
||||
slide = self.pointer.slides.add_slide(layout)
|
||||
ph_title, ph_body1, ph_body2 = slide.placeholders
|
||||
ph_title.text = title
|
||||
ph = ph_body2
|
||||
image_pil = image.to_pil()
|
||||
left = ph.left
|
||||
width = ph.width
|
||||
height = int(width / image_pil.width * image_pil.height)
|
||||
top = (ph.top + (ph.top + ph.height)) // 2 - height // 2
|
||||
slide.shapes.add_picture(image.to_path(), left, top, width, height)
|
||||
|
||||
ph = ph_body1
|
||||
tf = ph.text_frame
|
||||
for i, item in enumerate(bullet_items.split('[SPAN]')):
|
||||
if i == 0:
|
||||
p = tf.paragraphs[0]
|
||||
else:
|
||||
p = tf.add_paragraph()
|
||||
p.text = item.strip()
|
||||
p.level = 0
|
||||
|
||||
return dict(status='added page')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def submit_file(self) -> dict:
|
||||
"""When all steps done, YOU MUST use submit_file() to submit your work.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
# file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx')
|
||||
# self.pointer.save(file_path)
|
||||
# retreival_url = upload_file(file_path)
|
||||
self.pointer.save(self.location)
|
||||
return dict(status=f'submitted. view ppt at {self.location}')
|
||||
@@ -1,11 +1,12 @@
|
||||
import copy
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from func_timeout import FunctionTimedOut, func_set_timeout
|
||||
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
|
||||
@@ -29,72 +30,70 @@ class GenericRuntime:
|
||||
return eval(expr, self._global_vars)
|
||||
|
||||
|
||||
DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数,
|
||||
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
||||
```python
|
||||
# import 依赖包
|
||||
import xxx
|
||||
def solution():
|
||||
# 初始化一些变量
|
||||
variable_names_with_real_meaning = xxx
|
||||
# 步骤一
|
||||
mid_variable = func(variable_names_with_real_meaning)
|
||||
# 步骤 x
|
||||
mid_variable = func(mid_variable)
|
||||
# 最后结果
|
||||
final_answer = func(mid_variable)
|
||||
return final_answer
|
||||
```"""
|
||||
|
||||
|
||||
class PythonInterpreter(BaseAction):
|
||||
"""A Python executor that can execute Python scripts.
|
||||
|
||||
Args:
|
||||
description (str): The description of the action. Defaults to
|
||||
DEFAULT_DESCRIPTION.
|
||||
answer_symbol (str, Optional): the answer symbol from LLM
|
||||
answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
|
||||
answer_expr (str, Optional): the answer function name of the Python
|
||||
script. Default to 'solution()'.
|
||||
answer_from_stdout (boolean): whether the execution results is from
|
||||
stdout.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
script. Defaults to ``'solution()'``.
|
||||
answer_from_stdout (boolean, Optional): whether the execution results is from
|
||||
stdout. Defaults to ``False``.
|
||||
timeout (int, Optional): Upper bound of waiting time for Python script execution.
|
||||
Defaults to ``20``.
|
||||
description (dict, Optional): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
timeout (int): Upper bound of waiting time for Python script execution.
|
||||
``True``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
answer_symbol: Optional[str] = None,
|
||||
answer_expr: Optional[str] = 'solution()',
|
||||
answer_from_stdout: bool = False,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None,
|
||||
timeout: int = 20) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
timeout: int = 20,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True) -> None:
|
||||
super().__init__(description, parser, enable)
|
||||
self.answer_symbol = answer_symbol
|
||||
self.answer_expr = answer_expr
|
||||
self.answer_from_stdout = answer_from_stdout
|
||||
self.timeout = timeout
|
||||
|
||||
def __call__(self, command: str) -> ActionReturn:
|
||||
@tool_api
|
||||
def run(self, command: str) -> ActionReturn:
|
||||
"""用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
||||
```python
|
||||
# import 依赖包
|
||||
import xxx
|
||||
def solution():
|
||||
# 初始化一些变量
|
||||
variable_names_with_real_meaning = xxx
|
||||
# 步骤一
|
||||
mid_variable = func(variable_names_with_real_meaning)
|
||||
# 步骤 x
|
||||
mid_variable = func(mid_variable)
|
||||
# 最后结果
|
||||
final_answer = func(mid_variable)
|
||||
return final_answer
|
||||
```
|
||||
|
||||
Args:
|
||||
command (:class:`str`): Python code snippet
|
||||
"""
|
||||
self.runtime = GenericRuntime()
|
||||
try:
|
||||
tool_return = func_set_timeout(self.timeout)(self._call)(command)
|
||||
except FunctionTimedOut as e:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return = ActionReturn(type=self.name)
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
|
||||
def _call(self, command: str) -> ActionReturn:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return = ActionReturn(type=self.name)
|
||||
try:
|
||||
if '```python' in command:
|
||||
command = command.split('```python')[1].split('```')[0]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# flake8: noqa
|
||||
import ast
|
||||
import copy
|
||||
import platform
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -220,21 +219,20 @@ class AutoGPTProtocol:
|
||||
dict(role='user', content=self.triggering_prompt))
|
||||
return formatted_data
|
||||
|
||||
def format_response(self, action_return):
|
||||
def format_response(self, action_return) -> dict:
|
||||
"""format the final response at current step.
|
||||
|
||||
Args:
|
||||
action_return (ActionReturn): return value of the current action.
|
||||
|
||||
Returns:
|
||||
str: the final response at current step.
|
||||
dict: the final response at current step.
|
||||
"""
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.result['text']
|
||||
response = f'Command {action_return.type} returned: {response}'
|
||||
response = f'Command {action_return.type} returned: {response.format_result()}'
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
return response
|
||||
return dict(role='system', content=response)
|
||||
|
||||
|
||||
class AutoGPT(BaseAgent):
|
||||
@@ -261,30 +259,26 @@ class AutoGPT(BaseAgent):
|
||||
super().__init__(
|
||||
llm=llm, action_executor=action_executor, protocol=protocol)
|
||||
|
||||
def chat(self, goal: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
def chat(self, goal: str, **kwargs) -> AgentReturn:
|
||||
inner_history = []
|
||||
agent_return = AgentReturn()
|
||||
default_response = 'Sorry that I cannot answer your question.'
|
||||
for _ in range(self.max_turn):
|
||||
prompt = self._protocol.format(
|
||||
goal=goal,
|
||||
inner_history=self._inner_history,
|
||||
inner_history=inner_history,
|
||||
action_executor=self._action_executor)
|
||||
response = self._llm.generate_from_template(prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
response = self._llm.chat(prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=response))
|
||||
action, action_input = self._protocol.parse(
|
||||
response, self._action_executor)
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
action, action_input)
|
||||
agent_return.actions.append(action_return)
|
||||
if action_return.type == self._action_executor.finish_action.name:
|
||||
agent_return.response = action_return.result['text']
|
||||
agent_return.response = action_return.format_result()
|
||||
return agent_return
|
||||
self._inner_history.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=self._protocol.format_response(action_return)))
|
||||
agent_return.inner_steps = copy.deepcopy(self._inner_history)
|
||||
inner_history.append(self._protocol.format_response(action_return))
|
||||
agent_return.inner_steps = inner_history
|
||||
agent_return.response = default_response
|
||||
return agent_return
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
@@ -19,8 +17,6 @@ class BaseAgent:
|
||||
|
||||
def __init__(self, llm: BaseModel, action_executor: ActionExecutor,
|
||||
protocol: object) -> None:
|
||||
|
||||
self._session_history = []
|
||||
self._llm = llm
|
||||
self._action_executor = action_executor
|
||||
self._protocol = protocol
|
||||
@@ -41,9 +37,5 @@ class BaseAgent:
|
||||
"""
|
||||
self._action_executor.del_action(name)
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
def chat(self, message: str, **kwargs) -> AgentReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def session_history(self) -> List:
|
||||
return self._session_history
|
||||
|
||||
363
lagent/agents/internlm2_agent.py
Normal file
363
lagent/agents/internlm2_agent.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from ilagent.schema import AgentReturn, AgentStatusCode
|
||||
|
||||
from lagent import BaseAgent
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.llms import BaseAPIModel, BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
API_PREFIX = (
|
||||
"This is the subfunction for tool '{tool_name}', you can use this tool. "
|
||||
'The description of this function is: \n{description}')
|
||||
|
||||
META_INS = ('You are InternLM, a large language model trained by PJLab. '
|
||||
'Answer as concisely as possible. '
|
||||
'当开启工具以及代码时,根据需求选择合适的工具进行调用')
|
||||
|
||||
INTERPRETER_CN = ('你现在可以通过如下格式向 Jupyter Notebook 发送并执行代码:'
|
||||
'\n<|action_start|><|interpreter|>```python\n\n代码\n\n```\n'
|
||||
'\n当遇到以下问题时,请使用上述格式调用 Jupyter Notebook 去解决,并根据执行结果做出友好的回复:\n'
|
||||
'1. 文件操作和数据导入,比如处理CSV、JSON等格式文件\n'
|
||||
'2. 数据分析或处理,比如数据操作或图像绘制如折线图、柱状图等\n'
|
||||
'3. 数学相关的问题。当遇到数学问题时,你需要分析题目,并给出代码去解决这个题目')
|
||||
|
||||
PLUGIN_CN = (
|
||||
'你可以使用如下工具:'
|
||||
'\n{prompt}\n'
|
||||
'当你需要使用工具时,你可以使用如下格式:\n'
|
||||
'<|action_start|><|plugin|>{{"name": "工具名称", "parameters": {{参数}}}}\n'
|
||||
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
|
||||
'同时注意你可以使用的工具,不要随意捏造!')
|
||||
|
||||
|
||||
class Interlm2Protocol:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta_prompt: str=META_INS,
|
||||
interpreter_prompt: str=INTERPRETER_CN,
|
||||
plugin_prompt: str=PLUGIN_CN,
|
||||
few_shot: Optional[List]=None,
|
||||
language: Dict=dict(
|
||||
begin='',
|
||||
end='',
|
||||
belong='assistant',
|
||||
),
|
||||
tool: Dict=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
),
|
||||
execute: Dict = dict(
|
||||
role='execute', begin='', end='', fallback_role='environment'),
|
||||
) -> None:
|
||||
self.meta_prompt = meta_prompt
|
||||
self.interpreter_prompt = interpreter_prompt
|
||||
self.plugin_prompt = plugin_prompt
|
||||
self.roles_cfg = dict(tool=tool, language=language)
|
||||
self.language = language
|
||||
self.execute = execute
|
||||
self.tool = tool
|
||||
self.few_shot = few_shot
|
||||
|
||||
def format_sub_role(self, messages: List[Dict]) -> List[Dict]:
|
||||
|
||||
def format_interpreter(message):
|
||||
if isinstance(message['content'], dict):
|
||||
# assert message['content']['name'] == 'IPythonInterpreter'
|
||||
return dict(
|
||||
role=message['role'],
|
||||
name=message['name'],
|
||||
content=message['content']['parameters']['command'])
|
||||
else:
|
||||
return message
|
||||
|
||||
def format_plugin(message):
|
||||
if isinstance(message['content'], dict):
|
||||
return dict(
|
||||
role=message['role'],
|
||||
name=message['name'],
|
||||
content=json.dumps(message['content']))
|
||||
else:
|
||||
return message
|
||||
|
||||
new_message = list()
|
||||
for message in messages:
|
||||
if message['role'] in [
|
||||
'assistant', 'user', 'system', 'environment'
|
||||
]:
|
||||
new_message.append(message)
|
||||
continue
|
||||
role_cfg = self.roles_cfg[message['role']]
|
||||
begin = role_cfg['begin']
|
||||
if message['role'] == 'tool':
|
||||
if message['name'] == 'interpreter':
|
||||
message = format_interpreter(message)
|
||||
elif message['name'] == 'plugin':
|
||||
message = format_plugin(message)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
begin = role_cfg['begin'].format(
|
||||
start_token=role_cfg.get('start_token', ''),
|
||||
name=role_cfg.get('name_map', {}).get(message['name'], ''))
|
||||
new_content = begin + message['content'] + role_cfg['end']
|
||||
if role_cfg.get('fallback_role'):
|
||||
new_message.append(
|
||||
dict(role=role_cfg['fallback_role'], content=new_content))
|
||||
elif role_cfg.get('belong'):
|
||||
if new_message[-1]['role'] != role_cfg.get('belong'):
|
||||
new_message.append(
|
||||
dict(role=role_cfg.get('belong'), content=new_content))
|
||||
else:
|
||||
new_message[-1]['content'] += new_content
|
||||
else:
|
||||
new_message.append(
|
||||
dict(role=message['role'], content=new_content))
|
||||
|
||||
return new_message
|
||||
|
||||
def format(self,
|
||||
inner_step: List[Dict],
|
||||
plugin_executor: ActionExecutor = None,
|
||||
interpreter_executor: ActionExecutor = None,
|
||||
**kwargs) -> list:
|
||||
formatted = []
|
||||
if self.meta_prompt:
|
||||
formatted.append(dict(role='system', content=self.meta_prompt))
|
||||
if interpreter_executor and self.interpreter_prompt:
|
||||
interpreter_info = interpreter_executor.get_actions_info()[0]
|
||||
interpreter_prompt = self.interpreter_prompt.format(
|
||||
code_prompt=interpreter_info['description'])
|
||||
formatted.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=interpreter_prompt,
|
||||
name='interpreter'))
|
||||
if plugin_executor and plugin_executor.actions and self.plugin_prompt:
|
||||
plugin_descriptions = []
|
||||
for api_info in plugin_executor.get_actions_info():
|
||||
plugin = deepcopy(api_info)
|
||||
if isinstance(api_info, dict):
|
||||
tool_name = api_info['name'].split('.')[0]
|
||||
plugin['description'] = API_PREFIX.format(
|
||||
tool_name=tool_name, description=plugin['description'])
|
||||
plugin_descriptions.append(plugin)
|
||||
plugin_prompt = self.plugin_prompt.format(
|
||||
prompt=json.dumps(
|
||||
plugin_descriptions, ensure_ascii=False, indent=4))
|
||||
formatted.append(
|
||||
dict(role='system', content=plugin_prompt, name='plugin'))
|
||||
if self.few_shot:
|
||||
for few_shot in self.few_shot:
|
||||
formatted += self.format_sub_role(few_shot)
|
||||
formatted += self.format_sub_role(inner_step)
|
||||
return formatted
|
||||
|
||||
def parse(self, message, plugin_executor: ActionExecutor,
|
||||
interpreter_executor: ActionExecutor):
|
||||
if self.language['begin']:
|
||||
message = message.split(self.language['begin'])[-1]
|
||||
if self.tool['name_map']['plugin'] in message:
|
||||
message, action = message.split(
|
||||
f"{self.tool['start_token']}{self.tool['name_map']['plugin']}")
|
||||
action = action.split(self.tool['end'].strip())[0]
|
||||
return 'plugin', message, action
|
||||
if self.tool['name_map']['interpreter'] in message:
|
||||
message, code = message.split(
|
||||
f"{self.tool['start_token']}"
|
||||
f"{self.tool['name_map']['interpreter']}")
|
||||
code = code.split(self.tool['end'].strip())[0].strip()
|
||||
return 'interpreter', message, dict(
|
||||
name=interpreter_executor.action_names()[0],
|
||||
parameters=dict(command=code))
|
||||
return None, message, None
|
||||
|
||||
def format_response(self, action_return, name) -> dict:
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.format_result()
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
content = self.execute['begin'] + response + self.execute['end']
|
||||
if self.execute.get('fallback_role'):
|
||||
return dict(
|
||||
role=self.execute['fallback_role'], content=content, name=name)
|
||||
elif self.execute.get('belong'):
|
||||
return dict(
|
||||
role=self.execute['belong'], content=content, name=name)
|
||||
return dict(role=self.execute['role'], content=response, name=name)
|
||||
|
||||
|
||||
class Internlm2Agent(BaseAgent):
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
plugin_executor: ActionExecutor = None,
|
||||
interpreter_executor: ActionExecutor = None,
|
||||
protocol=Interlm2Protocol(),
|
||||
max_turn: int = 3) -> None:
|
||||
self.max_turn = max_turn
|
||||
self._interpreter_executor = interpreter_executor
|
||||
super().__init__(
|
||||
llm=llm, action_executor=plugin_executor, protocol=protocol)
|
||||
|
||||
def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
message = dict(role='user', content=message)
|
||||
if isinstance(message, dict):
|
||||
message = [message]
|
||||
inner_history = message[:]
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
for _ in range(self.max_turn):
|
||||
# list of dict
|
||||
prompt = self._protocol.format(
|
||||
inner_step=inner_history,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
response = self._llm.chat(prompt, **kwargs)
|
||||
name, language, action = self._protocol.parse(
|
||||
message=response,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
if name:
|
||||
if name == 'plugin':
|
||||
if self._action_executor:
|
||||
executor = self._action_executor
|
||||
else:
|
||||
logging.info(msg='No plugin is instantiated!')
|
||||
continue
|
||||
try:
|
||||
action = json.loads(action)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
msg=f'Invaild action {e}')
|
||||
continue
|
||||
elif name == 'interpreter':
|
||||
if self._interpreter_executor:
|
||||
executor = self._interpreter_executor
|
||||
else:
|
||||
logging.info(msg='No interpreter is instantiated!')
|
||||
continue
|
||||
else:
|
||||
logging.info(
|
||||
msg=(f"Invalid name '{name}'. Currently only 'plugin' "
|
||||
"and 'interpreter' are supported."))
|
||||
continue
|
||||
action_return: ActionReturn = executor(action['name'],
|
||||
action['parameters'])
|
||||
action_return.thought = language
|
||||
agent_return.actions.append(action_return)
|
||||
inner_history.append(dict(role='language', content=language))
|
||||
if not name or action_return.type == executor.finish_action.name:
|
||||
agent_return.response = language
|
||||
agent_return.state = AgentStatusCode.END
|
||||
break
|
||||
else:
|
||||
inner_history.append(
|
||||
dict(role='tool', content=action, name=name))
|
||||
inner_history.append(
|
||||
self._protocol.format_response(action_return, name=name))
|
||||
yield agent_return
|
||||
agent_return.inner_steps = inner_history[offset:]
|
||||
yield agent_return
|
||||
|
||||
def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
message = dict(role='user', content=message)
|
||||
if isinstance(message, dict):
|
||||
message = [message]
|
||||
inner_history = message[:]
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
last_agent_state = AgentStatusCode.SESSION_READY
|
||||
for _ in range(self.max_turn):
|
||||
# list of dict
|
||||
prompt = self._protocol.format(
|
||||
inner_step=inner_history,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
response = ''
|
||||
for model_state, res, _ in self._llm.stream_chat(
|
||||
prompt, **kwargs):
|
||||
response = res
|
||||
if model_state.value < 0:
|
||||
agent_return.state = model_state
|
||||
yield deepcopy(agent_return)
|
||||
return
|
||||
else:
|
||||
name, language, action = self._protocol.parse(
|
||||
message=response,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
if name:
|
||||
if model_state == AgentStatusCode.END:
|
||||
agent_state = last_agent_state + 1
|
||||
if name == 'plugin':
|
||||
if self._action_executor:
|
||||
executor = self._action_executor
|
||||
else:
|
||||
logging.info(
|
||||
msg='No plugin is instantiated!')
|
||||
continue
|
||||
try:
|
||||
action = json.loads(action)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
msg=f'Invaild action {e}')
|
||||
continue
|
||||
elif name == 'interpreter':
|
||||
if self._interpreter_executor:
|
||||
executor = self._interpreter_executor
|
||||
else:
|
||||
logging.info(
|
||||
msg='No interpreter is instantiated!')
|
||||
continue
|
||||
agent_return.state = agent_state
|
||||
agent_return.response = action
|
||||
else:
|
||||
agent_state = (
|
||||
AgentStatusCode.PLUGIN_START if name
|
||||
== 'plugin' else AgentStatusCode.CODING)
|
||||
if agent_state != last_agent_state:
|
||||
# agent_return.state = agent_state
|
||||
agent_return.response = language
|
||||
yield deepcopy(agent_return)
|
||||
agent_return.state = agent_state
|
||||
agent_return.response = action
|
||||
else:
|
||||
agent_state = AgentStatusCode.STREAM_ING
|
||||
agent_return.state = agent_state
|
||||
agent_return.response = language
|
||||
last_agent_state = agent_state
|
||||
yield deepcopy(agent_return)
|
||||
if name:
|
||||
action_return: ActionReturn = executor(action['name'],
|
||||
action['parameters'])
|
||||
action_return.thought = language
|
||||
agent_return.actions.append(action_return)
|
||||
inner_history.append(dict(role='language', content=language))
|
||||
if not name or action_return.type == executor.finish_action.name:
|
||||
agent_return.response = language
|
||||
agent_return.state = AgentStatusCode.END
|
||||
break
|
||||
else:
|
||||
inner_history.append(
|
||||
dict(role='tool', content=action, name=name))
|
||||
inner_history.append(
|
||||
self._protocol.format_response(action_return, name=name))
|
||||
agent_state += 1
|
||||
agent_return.state = agent_state
|
||||
yield agent_return
|
||||
agent_return.inner_steps = deepcopy(inner_history[offset:])
|
||||
agent_return.state = AgentStatusCode.END
|
||||
yield agent_return
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
@@ -43,7 +42,7 @@ To use a tool, please use the following format:
|
||||
The response after utilizing tools should using the following format:
|
||||
```
|
||||
{response}the results after call the tool.
|
||||
``
|
||||
```
|
||||
If you already know the answer, or you do not need to use tools,
|
||||
please using the following format to reply:
|
||||
```
|
||||
@@ -170,20 +169,22 @@ class ReActProtocol:
|
||||
action_input = arg_match[-1]
|
||||
return thought, action.strip(), action_input.strip().strip('"')
|
||||
|
||||
def format_response(self, action_return: ActionReturn) -> str:
|
||||
def format_response(self, action_return: ActionReturn) -> dict:
|
||||
"""format the final response at current step.
|
||||
|
||||
Args:
|
||||
action_return (ActionReturn): return value of the current action.
|
||||
|
||||
Returns:
|
||||
str: the final response at current step.
|
||||
dict: the final response at current step.
|
||||
"""
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.result['text']
|
||||
response = action_return.format_result()
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
return self.response['begin'] + response + self.response['end']
|
||||
return dict(
|
||||
role='system',
|
||||
content=self.response['begin'] + response + self.response['end'])
|
||||
|
||||
|
||||
class ReAct(BaseAgent):
|
||||
@@ -210,20 +211,27 @@ class ReAct(BaseAgent):
|
||||
super().__init__(
|
||||
llm=llm, action_executor=action_executor, protocol=protocol)
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
self._inner_history.append(dict(role='user', content=message))
|
||||
def chat(self, message: Union[str, dict, List[dict]],
|
||||
**kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
inner_history = [dict(role='user', content=message)]
|
||||
elif isinstance(message, dict):
|
||||
inner_history = [message]
|
||||
elif isinstance(message, list):
|
||||
inner_history = message[:]
|
||||
else:
|
||||
raise TypeError(f'unsupported type: {type(message)}')
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
default_response = 'Sorry that I cannot answer your question.'
|
||||
for turn in range(self.max_turn):
|
||||
prompt = self._protocol.format(
|
||||
chat_history=self.session_history,
|
||||
inner_step=self._inner_history,
|
||||
chat_history=[],
|
||||
inner_step=inner_history,
|
||||
action_executor=self._action_executor,
|
||||
force_stop=(turn == self.max_turn - 1))
|
||||
response = self._llm.generate_from_template(prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
response = self._llm.chat(prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=response))
|
||||
thought, action, action_input = self._protocol.parse(
|
||||
response, self._action_executor)
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
@@ -231,17 +239,10 @@ class ReAct(BaseAgent):
|
||||
action_return.thought = thought
|
||||
agent_return.actions.append(action_return)
|
||||
if action_return.type == self._action_executor.finish_action.name:
|
||||
agent_return.response = action_return.result['text']
|
||||
agent_return.response = action_return.format_result()
|
||||
break
|
||||
self._inner_history.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=self._protocol.format_response(action_return)))
|
||||
inner_history.append(self._protocol.format_response(action_return))
|
||||
else:
|
||||
agent_return.response = default_response
|
||||
agent_return.inner_steps = copy.deepcopy(self._inner_history)
|
||||
# only append the user and final response
|
||||
self._session_history.append(dict(role='user', content=message))
|
||||
self._session_history.append(
|
||||
dict(role='assistant', content=agent_return.response))
|
||||
agent_return.inner_steps = inner_history[offset:]
|
||||
return agent_return
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@@ -192,7 +191,7 @@ class ReWOOProtocol:
|
||||
worker_log = ''
|
||||
for thought, action_return in zip(thought_list, action_return_list):
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
action_resp = action_return.result['text']
|
||||
action_resp = action_return.format_result()
|
||||
else:
|
||||
action_resp = action_return.errmsg
|
||||
worker_response = self.worker_prompt.format(
|
||||
@@ -227,9 +226,17 @@ class ReWOO(BaseAgent):
|
||||
|
||||
self.max_turn = max_turn
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
self._inner_history.append(dict(role='user', content=message))
|
||||
def chat(self, message: Union[str, dict, List[dict]],
|
||||
**kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
inner_history = [dict(role='user', content=message)]
|
||||
elif isinstance(message, dict):
|
||||
inner_history = [message]
|
||||
elif isinstance(message, list):
|
||||
inner_history = message[:]
|
||||
else:
|
||||
raise TypeError(f'unsupported type: {type(message)}')
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
|
||||
# planner
|
||||
@@ -237,13 +244,12 @@ class ReWOO(BaseAgent):
|
||||
reformat_request = ''
|
||||
while turn_id < self.max_turn:
|
||||
planner_prompt = self._protocol.format_planner(
|
||||
chat_history=self.session_history,
|
||||
inner_step=self._inner_history,
|
||||
chat_history=[],
|
||||
inner_step=inner_history,
|
||||
action_executor=self._action_executor,
|
||||
reformat_request=reformat_request)
|
||||
response = self._llm.generate_from_template(planner_prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
response = self._llm.chat(planner_prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=response))
|
||||
try:
|
||||
thoughts, actions, actions_input = self._protocol.parse_worker(
|
||||
response)
|
||||
@@ -267,18 +273,17 @@ class ReWOO(BaseAgent):
|
||||
for prev_ptr in prev_ptrs:
|
||||
ptr_num = int(prev_ptr.strip('#E')) - 1 # start from 0
|
||||
actions_input[action_id] = actions_input[action_id].replace(
|
||||
prev_ptr, action_responses[ptr_num].result['text'])
|
||||
prev_ptr, action_responses[ptr_num].format_result())
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
actions[action_id], actions_input[action_id])
|
||||
action_responses.append(action_return)
|
||||
|
||||
solver_prompt, worker_log = self._protocol.format_solver(
|
||||
message, thoughts, action_responses)
|
||||
self._inner_history.append(dict(role='system', content=worker_log))
|
||||
inner_history.append(dict(role='system', content=worker_log))
|
||||
|
||||
final_response = self._llm.generate_from_template(solver_prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=final_response))
|
||||
agent_return.inner_steps = copy.deepcopy(self._inner_history)
|
||||
final_response = self._llm.chat(solver_prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=final_response))
|
||||
agent_return.inner_steps = inner_history[offset:]
|
||||
agent_return.response = final_response
|
||||
return agent_return
|
||||
|
||||
@@ -8,7 +8,3 @@ __all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI']
|
||||
if is_module_exist('transformers'):
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401
|
||||
__all__.extend(['HFTransformer', 'HFTransformerCasualLM'])
|
||||
|
||||
if is_module_exist('lmdeploy'):
|
||||
from .lmdeploy import TritonClient, TurboMind # noqa: F401
|
||||
__all__.extend(['TritonClient', 'TurboMind'])
|
||||
|
||||
@@ -27,7 +27,7 @@ class APITemplateParser:
|
||||
'role in meta prompt must be unique!'
|
||||
self.roles[item['role']] = item.copy()
|
||||
|
||||
def parse_template(self, dialog: List[Union[str, List]]):
|
||||
def __call__(self, dialog: List[Union[str, List]]):
|
||||
"""Parse the intermidate prompt template, and wrap it with meta
|
||||
template if applicable. When the meta template is set and the input is
|
||||
a list, the return value will be a list containing the full
|
||||
@@ -155,7 +155,14 @@ class BaseAPIModel(BaseModel):
|
||||
retry: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
template_parser: 'APITemplateParser' = APITemplateParser,
|
||||
meta_template: Optional[Dict] = None):
|
||||
meta_template: Optional[Dict] = None,
|
||||
*,
|
||||
max_out_len: int = 512,
|
||||
top_p: float = 0.8,
|
||||
top_k: float = None,
|
||||
temperature: float = 0.8,
|
||||
repetition_penalty: float = 0.0,
|
||||
stop_words: Union[List[str], str] = None):
|
||||
self.model_type = model_type
|
||||
self.max_seq_len = max_seq_len
|
||||
self.meta_template = meta_template
|
||||
@@ -165,53 +172,21 @@ class BaseAPIModel(BaseModel):
|
||||
if template_parser:
|
||||
self.template_parser = template_parser(meta_template)
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs, max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
self.gen_params = dict(
|
||||
max_out_len=max_out_len,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
stop_words=stop_words)
|
||||
|
||||
Args:
|
||||
inputs (List[str or list]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
|
||||
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
|
||||
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
|
||||
|
||||
# Count English words
|
||||
english_count = sum(len(part.split()) for part in english_parts)
|
||||
|
||||
# Count Chinese words
|
||||
chinese_count = sum(len(part) for part in chinese_parts)
|
||||
|
||||
return english_count + chinese_count
|
||||
|
||||
def wait(self):
|
||||
def _wait(self):
|
||||
"""Wait till the next query can be sent.
|
||||
|
||||
Applicable in both single-thread and multi-thread environments.
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
||||
def to(self, device):
|
||||
pass
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""A token bucket for rate limiting.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from abc import abstractclassmethod
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from warnings import warn
|
||||
|
||||
|
||||
class LMTemplateParser:
|
||||
@@ -21,7 +23,7 @@ class LMTemplateParser:
|
||||
'role in meta prompt must be unique!'
|
||||
self.roles[item['role']] = item.copy()
|
||||
|
||||
def parse_template(self, dialog) -> str:
|
||||
def __call__(self, dialog) -> str:
|
||||
"""Parse a prompt template, and wrap it with meta template if
|
||||
applicable.
|
||||
|
||||
@@ -111,12 +113,17 @@ class BaseModel:
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_only: bool = False,
|
||||
template_parser: 'LMTemplateParser' = LMTemplateParser,
|
||||
meta_template: Optional[List[Dict]] = None):
|
||||
meta_template: Optional[List[Dict]] = None,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
top_p: float = 0.8,
|
||||
top_k: float = None,
|
||||
temperature: float = 0.8,
|
||||
repetition_penalty: float = 1.0,
|
||||
stop_words: Union[List[str], str] = None):
|
||||
self.path = path
|
||||
self.max_seq_len = max_seq_len
|
||||
self.tokenizer_only = tokenizer_only
|
||||
# meta template
|
||||
self.template_parser = template_parser(meta_template)
|
||||
@@ -124,41 +131,99 @@ class BaseModel:
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
self.gen_params = dict(
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
stop_words=stop_words)
|
||||
|
||||
def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
|
||||
"""Generate results given a str (or list of) inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of strings.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
inputs (Union[str, List[str]]):
|
||||
gen_params (dict): The input params for generation.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
Union[str, List[str]]: A (list of) generated strings.
|
||||
|
||||
def parse_template(self, dialog) -> str:
|
||||
"""Parse a prompt template, and wrap it with meta template if
|
||||
applicable.
|
||||
eg.
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
response = ['']
|
||||
if batched:
|
||||
return response
|
||||
return response[0]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def stream_generate(self, inputs: str, **gen_params) -> List[str]:
|
||||
"""Generate results as streaming given a str inputs.
|
||||
|
||||
Args:
|
||||
dialog (List[str or PromptList]): A prompt
|
||||
template (potentially before being wrapped by meta template).
|
||||
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
||||
inputs (str):
|
||||
gen_params (dict): The input params for generation.
|
||||
|
||||
Returns:
|
||||
str: The final string.
|
||||
str: A generated string.
|
||||
"""
|
||||
return self.template_parser.parse_template(dialog)
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_from_template(self, templates, max_out_len: int, **kwargs):
|
||||
def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params):
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
inputs (Union[List[dict], List[List[dict]]]):
|
||||
gen_params (dict): The input params for generation.
|
||||
Returns:
|
||||
"""
|
||||
inputs = self.parse_template(templates)
|
||||
return self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = list()
|
||||
for msg in inputs:
|
||||
inputs.append(self.template_parser(msg))
|
||||
else:
|
||||
inputs = self.template_parser(inputs)
|
||||
return self.generate(inputs, **gen_params)
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
def generate_from_template(
|
||||
self,
|
||||
inputs: Union[List[dict], List[List[dict]]],
|
||||
**gen_params
|
||||
):
|
||||
warn(
|
||||
"This function will be deprecated after three months and will be replaced."
|
||||
"Please use `.chat()`",
|
||||
DeprecationWarning, 2)
|
||||
return self.chat(inputs, **gen_params)
|
||||
|
||||
def stream_chat(self, inputs: List[dict], **gen_params):
|
||||
"""Generate results as streaming given a list of templates.
|
||||
|
||||
Args:
|
||||
inputs (Union[List[dict]):
|
||||
gen_params (dict): The input params for generation.
|
||||
Returns:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize(self, prompts: Union[str, List[str], List[dict],
|
||||
List[List[dict]]]):
|
||||
"""Tokenize the input prompts.
|
||||
|
||||
Args:
|
||||
prompts(str | List[str]): user's prompt, or a batch prompts
|
||||
|
||||
Returns:
|
||||
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
|
||||
ids, ids' length and requested output length
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_gen_params(self, **kwargs):
|
||||
gen_params = copy(self.gen_params)
|
||||
gen_params.update(kwargs)
|
||||
return gen_params
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import warnings
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
from dataclasses import asdict
|
||||
|
||||
from .base_llm import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFTransformer(BaseModel):
|
||||
"""Model wrapper around HuggingFace general models.
|
||||
|
||||
Adapted from OpenCompass (https://github.com/InternLM/opencompass
|
||||
/blob/main/opencompass/models/huggingface.py)
|
||||
Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/
|
||||
chat/web_demo.py)
|
||||
|
||||
Args:
|
||||
path (str): The name or path to HuggingFace's model.
|
||||
@@ -25,52 +29,40 @@ class HFTransformer(BaseModel):
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
extract_pred_after_decode (bool): Whether to extract the prediction
|
||||
string from the decoded output string, instead of extract the
|
||||
prediction tokens before decoding. Defaults to False.
|
||||
batch_padding (bool): If False, inference with be performed in for-loop
|
||||
without batch padding.
|
||||
|
||||
Note:
|
||||
About ``extract_pred_after_decode``: Commonly, we should extract the
|
||||
the prediction tokens before decoding. But for some tokenizers using
|
||||
``sentencepiece``, like LLaMA, this behavior may change the number of
|
||||
whitespaces, which is harmful for Python programming tasks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = [
|
||||
dict(role='system', begin='<|System|>:', end='\n'),
|
||||
dict(role='user', begin='<|User|>:', end='\n'),
|
||||
dict(
|
||||
role='assistant',
|
||||
begin='<|Bot|>:',
|
||||
end='<eoa>\n',
|
||||
generate=True)
|
||||
], # default meta template for InternLM-7b
|
||||
extract_pred_after_decode: bool = False,
|
||||
batch_padding: bool = False):
|
||||
def __init__(self,
|
||||
path: str,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = None,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
tokenizer_only=tokenizer_only,
|
||||
meta_template=meta_template)
|
||||
meta_template=meta_template,
|
||||
**kwargs)
|
||||
|
||||
self._load_tokenizer(
|
||||
path=path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
tokenizer_kwargs=tokenizer_kwargs)
|
||||
self.batch_padding = batch_padding
|
||||
self.extract_pred_after_decode = extract_pred_after_decode
|
||||
if not tokenizer_only:
|
||||
self._load_model(path=path, model_kwargs=model_kwargs)
|
||||
|
||||
from transformers.generation.utils import (LogitsProcessorList,
|
||||
StoppingCriteriaList)
|
||||
self.logits_processor = LogitsProcessorList()
|
||||
self.stopping_criteria = StoppingCriteriaList()
|
||||
self.prefix_allowed_tokens_fn = None
|
||||
|
||||
stop_words_id = []
|
||||
for sw in self.gen_params.get('stop_words', []):
|
||||
stop_words_id.append(self.tokenizer(sw)['input_ids'][1])
|
||||
self.additional_eos_token_id = stop_words_id
|
||||
|
||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
||||
tokenizer_kwargs: dict):
|
||||
from transformers import AutoTokenizer
|
||||
@@ -82,57 +74,182 @@ class HFTransformer(BaseModel):
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
path, trust_remote_code=True, **model_kwargs)
|
||||
self.model.eval()
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int,
|
||||
**kwargs) -> List[str]:
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
if self.extract_pred_after_decode:
|
||||
prompt_lens = [len(input_) for input_ in inputs]
|
||||
def tokenize(self, inputs: str):
|
||||
assert isinstance(inputs, str)
|
||||
inputs = self.tokenizer(
|
||||
inputs, return_tensors='pt', return_length=True)
|
||||
return inputs['input_ids'].tolist()
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
inputs, truncation=True,
|
||||
max_length=self.max_seq_len - max_out_len)['input_ids']
|
||||
input_ids = torch.tensor(input_ids, device=self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids, max_new_tokens=max_out_len, **kwargs)
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str],
|
||||
do_sample=True,
|
||||
**kwargs,
|
||||
):
|
||||
for chunk in self.stream_generate(inputs, do_sample, **kwargs):
|
||||
response = chunk
|
||||
return response
|
||||
|
||||
if not self.extract_pred_after_decode:
|
||||
outputs = outputs[:, input_ids.shape[1]:]
|
||||
def stream_generate(
|
||||
self,
|
||||
inputs: List[str],
|
||||
do_sample=True,
|
||||
**kwargs,
|
||||
):
|
||||
import torch
|
||||
from torch import nn
|
||||
with torch.no_grad():
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
# import pdb; pdb.set_trace()
|
||||
inputs = self.tokenizer(
|
||||
inputs, padding=True, return_tensors='pt', return_length=True)
|
||||
input_length = inputs['length']
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
input_ids = inputs['input_ids']
|
||||
attention_mask = inputs['attention_mask']
|
||||
batch_size, input_ids_seq_length = input_ids.shape[
|
||||
0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
|
||||
generation_config = self.model.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
new_gen_params = self.update_gen_params(**kwargs)
|
||||
generation_config.update(**new_gen_params)
|
||||
generation_config.update(**kwargs)
|
||||
model_kwargs = generation_config.to_dict()
|
||||
model_kwargs['attention_mask'] = attention_mask
|
||||
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if self.additional_eos_token_id is not None:
|
||||
eos_token_id.extend(self.additional_eos_token_id)
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(
|
||||
input_ids.device) if eos_token_id is not None else None
|
||||
has_default_max_length = (
|
||||
kwargs.get('max_length') is None
|
||||
and generation_config.max_length is not None)
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||
'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we'
|
||||
' recommend using `max_new_tokens` to control the maximum length of the generation.',
|
||||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length)
|
||||
if not has_default_max_length:
|
||||
logger.warn( # pylint: disable=W4902
|
||||
f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
|
||||
f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
|
||||
'Please refer to the documentation for more information. '
|
||||
'(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
decodeds = self.tokenizer.batch_decode(
|
||||
outputs, skip_special_tokens=True)
|
||||
if self.extract_pred_after_decode:
|
||||
decodeds = [
|
||||
token[len_:] for token, len_ in zip(decodeds, prompt_lens)
|
||||
]
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = 'input_ids'
|
||||
logger.warning(
|
||||
f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
|
||||
f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
|
||||
' increasing `max_new_tokens`.')
|
||||
|
||||
return decodeds[0]
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = self.logits_processor
|
||||
stopping_criteria = self.stopping_criteria
|
||||
|
||||
def generate_from_template(self, templates, max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
logits_processor = self.model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
"""
|
||||
inputs = self.parse_template(templates)
|
||||
response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
end_token = self.template_parser.meta_template[0]['end'].strip()
|
||||
# return response.replace(
|
||||
# self.template_parser.roles['assistant']['end'].strip(),
|
||||
# '').strip()
|
||||
return response.split(end_token.strip())[0]
|
||||
stopping_criteria = self.model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria)
|
||||
logits_warper = self.model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(batch_size).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = self.model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids,
|
||||
next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if do_sample:
|
||||
next_tokens = torch.multinomial(
|
||||
probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]],
|
||||
dim=-1)
|
||||
model_kwargs = self.model._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
|
||||
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
|
||||
# output_token_ids = input_ids.cpu()[:, input_length:].tolist()
|
||||
output_token_ids = input_ids.cpu().tolist()
|
||||
for i in range(len(output_token_ids)):
|
||||
output_token_ids[i] = output_token_ids[i][:][
|
||||
input_length[i]:]
|
||||
# Find the first occurrence of an EOS token in the sequence
|
||||
first_eos_idx = next(
|
||||
(idx
|
||||
for idx, token_id in enumerate(output_token_ids[i])
|
||||
if token_id in eos_token_id), None)
|
||||
# If an EOS token is found, only the previous part of it is retained
|
||||
if first_eos_idx is not None:
|
||||
output_token_ids[i] = output_token_ids[
|
||||
i][:first_eos_idx]
|
||||
|
||||
response = self.tokenizer.batch_decode(output_token_ids)
|
||||
# print(response)
|
||||
if not batched:
|
||||
yield response[0]
|
||||
else:
|
||||
yield response
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if (unfinished_sequences.max() == 0
|
||||
or stopping_criteria(input_ids, scores)):
|
||||
break
|
||||
|
||||
|
||||
class HFTransformerCasualLM(HFTransformer):
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
import dataclasses
|
||||
import os.path as osp
|
||||
import random
|
||||
|
||||
import lmdeploy.turbomind.chat as tm_chat
|
||||
from lmdeploy import turbomind as tm
|
||||
from lmdeploy.serve.turbomind.chatbot import Chatbot, Session, get_logger
|
||||
from lmdeploy.tokenizer import Tokenizer
|
||||
|
||||
from .base_llm import BaseModel
|
||||
|
||||
|
||||
class TritonClient(Chatbot, BaseModel):
|
||||
|
||||
def __init__(self, meta_template=None, **kwargs):
|
||||
"""TritonClient is a wrapper of TritonClient for LLM.
|
||||
|
||||
Args:
|
||||
model_name (str): the name of the model
|
||||
max_out_len (int): the expected generated token numbers
|
||||
log_level (str): log level
|
||||
"""
|
||||
BaseModel.__init__(self, meta_template=meta_template, path=None)
|
||||
Chatbot.__init__(self, **kwargs)
|
||||
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
session_id: int = 2967,
|
||||
request_id: str = '',
|
||||
max_out_len: int = None,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
session_id (int): the identical id of a session
|
||||
prompt (str): user's prompt in this round conversation
|
||||
request_id (str): the identical id of this round conversation
|
||||
max_out_len (int): the expected generated token numbers
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
"""
|
||||
assert isinstance(session_id, int), \
|
||||
f'INT session id is required, but got {type(session_id)}'
|
||||
|
||||
logger = get_logger(log_level=self.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_out_len}')
|
||||
|
||||
if self._session is None:
|
||||
sequence_start = True
|
||||
self._session = Session(session_id=session_id)
|
||||
elif self._session.status == 0:
|
||||
logger.error(f'session {session_id} has been ended. Please set '
|
||||
f'`sequence_start` be True if you want to restart it')
|
||||
return ''
|
||||
|
||||
self._session.status = 1
|
||||
self._session.request_id = request_id
|
||||
self._session.response = ''
|
||||
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self._stream_infer(self._session, prompt,
|
||||
max_out_len, sequence_start,
|
||||
sequence_end):
|
||||
if status.value < 0:
|
||||
break
|
||||
if status.value == 0:
|
||||
self._session.histories = \
|
||||
self._session.histories + self._session.prompt + \
|
||||
self._session.response
|
||||
return res
|
||||
else:
|
||||
return ''
|
||||
|
||||
def generate_from_template(self, templates, max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
"""
|
||||
inputs = self.parse_template(templates)
|
||||
response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
# The return of tuibomind contains <eoa>, here we hard code removes it.
|
||||
response = response.replace(
|
||||
self.template_parser.roles['assistant']['end'].strip(),
|
||||
'').strip()
|
||||
return response
|
||||
|
||||
|
||||
class TurboMind(BaseModel):
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 8192,
|
||||
tokenizer_only: bool = False,
|
||||
meta_template=None,
|
||||
tp=1,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
tokenizer_only=tokenizer_only,
|
||||
meta_template=meta_template)
|
||||
tokenizer_model_path = osp.join(path, 'triton_models', 'tokenizer')
|
||||
self.tokenizer = Tokenizer(tokenizer_model_path)
|
||||
self.tm_model = tm.TurboMind(
|
||||
path, eos_id=self.tokenizer.eos_token_id, tp=tp)
|
||||
self.generator = self.tm_model.create_instance()
|
||||
|
||||
model_name = self.tm_model.model_name
|
||||
self.model = tm_chat.MODELS.get(model_name)(
|
||||
capability='completion', **kwargs)
|
||||
self._session_id = 0
|
||||
|
||||
def generate(self, prompt, **kwargs):
|
||||
seed = random.getrandbits(64)
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
gen_param = tm_chat.get_gen_param(
|
||||
'completion', self.model.sampling_param, step=0, nth_round=1)
|
||||
response_size = 0
|
||||
self._session_id = (self._session_id + 1) % 100000
|
||||
for outputs in self.generator.stream_infer(
|
||||
session_id=self._session_id,
|
||||
input_ids=[input_ids],
|
||||
stream_output=False,
|
||||
**dataclasses.asdict(gen_param),
|
||||
ignore_eos=False,
|
||||
random_seed=seed):
|
||||
res, tokens = outputs[0]
|
||||
# decode res
|
||||
response = self.tokenizer.decode(
|
||||
res.tolist(), offset=response_size)
|
||||
response = tm_chat.valid_str(response)
|
||||
return response
|
||||
383
lagent/llms/lmdepoly_wrapper.py
Normal file
383
lagent/llms/lmdepoly_wrapper.py
Normal file
@@ -0,0 +1,383 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import AgentStatusCode
|
||||
from lagent.utils.util import filter_suffix
|
||||
|
||||
|
||||
class TritonClient(BaseModel):
|
||||
"""TritonClient is a wrapper of TritonClient for LLM.
|
||||
|
||||
Args:
|
||||
tritonserver_addr (str): the address in format "ip:port" of
|
||||
triton inference server
|
||||
model_name (str): the name of the model
|
||||
session_len (int): the context size
|
||||
max_tokens (int): the expected generated token numbers
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tritonserver_addr: str,
|
||||
model_name: str,
|
||||
session_len: int = 32768,
|
||||
log_level: str = 'WARNING',
|
||||
**kwargs):
|
||||
super().__init__(path=None, **kwargs)
|
||||
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
|
||||
self.state_map = {
|
||||
StatusCode.TRITON_STREAM_END: AgentStatusCode.END,
|
||||
StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR,
|
||||
StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED,
|
||||
StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING,
|
||||
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
|
||||
AgentStatusCode.SESSION_OUT_OF_LIMIT,
|
||||
StatusCode.TRITON_SESSION_INVALID_ARG:
|
||||
AgentStatusCode.SESSION_INVALID_ARG,
|
||||
StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY
|
||||
}
|
||||
self.chatbot = Chatbot(
|
||||
tritonserver_addr=tritonserver_addr,
|
||||
model_name=model_name,
|
||||
session_len=session_len,
|
||||
log_level=log_level,
|
||||
**kwargs)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
session_id: int = 2967,
|
||||
request_id: str = '',
|
||||
max_tokens: int = 512,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
inputs (str, List[str]): user's prompt(s) in this round
|
||||
session_id (int): the identical id of a session
|
||||
request_id (str): the identical id of this round conversation
|
||||
max_tokens (int): the expected generated token numbers
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
prompt = inputs
|
||||
|
||||
assert isinstance(session_id, int), \
|
||||
f'INT session id is required, but got {type(session_id)}'
|
||||
|
||||
logger = get_logger(log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_tokens}')
|
||||
|
||||
if self.chatbot._session is None:
|
||||
sequence_start = True
|
||||
self.chatbot._session = Session(session_id=session_id)
|
||||
elif self.chatbot._session.status == 0:
|
||||
logger.error(f'session {session_id} has been ended. Please set '
|
||||
f'`sequence_start` be True if you want to restart it')
|
||||
return ''
|
||||
|
||||
self.chatbot._session.status = 1
|
||||
self.chatbot._session.request_id = request_id
|
||||
self.chatbot._session.response = ''
|
||||
|
||||
self.chatbot.cfg = self._update_gen_params(
|
||||
max_tokens=max_tokens, **kwargs)
|
||||
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session, prompt, max_tokens, sequence_start,
|
||||
sequence_end):
|
||||
if status.value < 0:
|
||||
break
|
||||
if status.value == 0:
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt + self.chatbot._session.response)
|
||||
# remove stop_words
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
return res
|
||||
else:
|
||||
return ''
|
||||
|
||||
def stream_chat(self,
|
||||
inputs: List[dict],
|
||||
session_id: int = 2967,
|
||||
request_id: str = '',
|
||||
max_tokens: int = 512,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
session_id (int): the identical id of a session
|
||||
inputs (List[dict]): user's inputs in this round conversation
|
||||
request_id (str): the identical id of this round conversation
|
||||
max_tokens (int): the expected generated token numbers
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
"""
|
||||
from lmdeploy.serve.turbomind.chatbot import (Session, StatusCode,
|
||||
get_logger)
|
||||
assert isinstance(session_id, int), \
|
||||
f'INT session id is required, but got {type(session_id)}'
|
||||
|
||||
logger = get_logger(log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_tokens}')
|
||||
|
||||
if self.chatbot._session is None:
|
||||
sequence_start = True
|
||||
self.chatbot._session = Session(session_id=session_id)
|
||||
elif self.chatbot._session.status == 0:
|
||||
logger.error(f'session {session_id} has been ended. Please set '
|
||||
f'`sequence_start` be True if you want to restart it')
|
||||
return ''
|
||||
|
||||
self.chatbot._session.status = 1
|
||||
self.chatbot._session.request_id = request_id
|
||||
self.chatbot._session.response = ''
|
||||
|
||||
self.chatbot.cfg = self._update_gen_params(
|
||||
max_tokens=max_tokens, **kwargs)
|
||||
prompt = self.template_parser(inputs)
|
||||
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session, prompt, max_tokens, sequence_start,
|
||||
sequence_end):
|
||||
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
if status.value < 0:
|
||||
break
|
||||
else:
|
||||
yield self.state_map.get(status), res, _
|
||||
if status.value == 0:
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt + self.chatbot._session.response)
|
||||
yield self.state_map.get(status), res, _
|
||||
else:
|
||||
return ''
|
||||
|
||||
def _update_gen_params(self, **kwargs):
|
||||
import mmengine
|
||||
new_gen_params = self.update_gen_params(**kwargs)
|
||||
self.gen_params['stop_words'] = new_gen_params.pop('stop_words')
|
||||
stop_words = self.chatbot._stop_words(
|
||||
self.gen_params.get('stop_words'))
|
||||
cfg = mmengine.Config(
|
||||
dict(
|
||||
session_len=self.chatbot.model.session_len,
|
||||
stop_words=stop_words,
|
||||
bad_words=self.chatbot.cfg.bad_words,
|
||||
**new_gen_params))
|
||||
return cfg
|
||||
|
||||
|
||||
class LMDeployPipeline(BaseModel):
|
||||
"""
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
model_name (str): needed when model_path is a pytorch model on
|
||||
huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
|
||||
tp (int):
|
||||
pipeline_cfg (dict):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
model_name: Optional[str] = None,
|
||||
tp: int = 1,
|
||||
pipeline_cfg=dict(),
|
||||
**kwargs):
|
||||
|
||||
super().__init__(path=path, **kwargs)
|
||||
from lmdeploy import pipeline
|
||||
self.model = pipeline(
|
||||
model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
do_preprocess=None,
|
||||
**kwargs):
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
prompt = inputs
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
response = self.model.batch_infer(
|
||||
prompt, do_preprocess=do_preprocess, **gen_params)
|
||||
response = [resp.text for resp in response]
|
||||
# remove stop_words
|
||||
response = filter_suffix(response, self.gen_params.get('stop_words'))
|
||||
if batched:
|
||||
return response
|
||||
return response[0]
|
||||
|
||||
|
||||
class LMDeployServer(BaseModel):
|
||||
"""
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download from
|
||||
ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
model_name (str): needed when model_path is a pytorch model on
|
||||
huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
|
||||
server_name (str): host ip for serving
|
||||
server_port (int): server port
|
||||
tp (int):
|
||||
log_level (str): set log level whose value among
|
||||
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
model_name: Optional[str] = None,
|
||||
server_name: str = '0.0.0.0',
|
||||
server_port: int = 23333,
|
||||
tp: int = 1,
|
||||
log_level: str = 'WARNING',
|
||||
serve_cfg=dict(),
|
||||
**kwargs):
|
||||
super().__init__(path=path, **kwargs)
|
||||
# TODO get_logger issue in multi processing
|
||||
import lmdeploy
|
||||
self.client = lmdeploy.serve(
|
||||
model_path=self.path,
|
||||
model_name=model_name,
|
||||
server_name=server_name,
|
||||
server_port=server_port,
|
||||
tp=tp,
|
||||
log_level=log_level,
|
||||
**serve_cfg)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
session_id: int = 2967,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
timeout: int = 30,
|
||||
**kwargs) -> List[str]:
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
|
||||
resp = [''] * len(inputs)
|
||||
for text in self.client.completions_v1(
|
||||
self.path,
|
||||
inputs,
|
||||
session_id=session_id,
|
||||
sequence_start=sequence_start,
|
||||
sequence_end=sequence_end,
|
||||
stream=False,
|
||||
ignore_eos=ignore_eos,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp = [
|
||||
resp[i] + item['text']
|
||||
for i, item in enumerate(text['choices'])
|
||||
]
|
||||
# remove stop_words
|
||||
resp = filter_suffix(resp, self.gen_params.get('stop_words'))
|
||||
if not batched:
|
||||
return resp[0]
|
||||
return resp
|
||||
|
||||
def stream_chat(self,
|
||||
inputs: List[dict],
|
||||
session_id=0,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
stream: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
timeout: int = 30,
|
||||
**kwargs):
|
||||
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
prompt = self.template_parser(inputs)
|
||||
|
||||
resp = ''
|
||||
finished = False
|
||||
stop_words = self.gen_params.get('stop_words')
|
||||
for text in self.client.completions_v1(
|
||||
self.path,
|
||||
prompt,
|
||||
session_id=session_id,
|
||||
sequence_start=sequence_start,
|
||||
sequence_end=sequence_end,
|
||||
stream=stream,
|
||||
ignore_eos=ignore_eos,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp += text['choices'][0]['text']
|
||||
if not resp:
|
||||
continue
|
||||
# remove stop_words
|
||||
for sw in stop_words:
|
||||
if sw in resp:
|
||||
resp = filter_suffix(resp, stop_words)
|
||||
finished = True
|
||||
break
|
||||
yield AgentStatusCode.STREAM_ING, resp, None
|
||||
if finished:
|
||||
break
|
||||
yield AgentStatusCode.END, resp, None
|
||||
|
||||
|
||||
class LMDeployClient(LMDeployServer):
|
||||
"""
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
url (str):
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, url: str, **kwargs):
|
||||
BaseModel.__init__(self, path=path, **kwargs)
|
||||
from lmdeploy.serve.openai.api_client import APIClient
|
||||
self.client = APIClient(url)
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from logging import getLogger
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -38,9 +38,8 @@ class GPTAPI(BaseAPIModel):
|
||||
wrapping of any meta instructions.
|
||||
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
||||
'https://api.openai.com/v1/chat/completions'.
|
||||
temperature (float, optional): What sampling temperature to use.
|
||||
If not None, will override the temperature in the `generate()`
|
||||
call. Defaults to None.
|
||||
gen_params: Default generation configuration which could be overrided
|
||||
on the fly of generation.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
@@ -58,16 +57,15 @@ class GPTAPI(BaseAPIModel):
|
||||
dict(role='assistant', api_role='assistant')
|
||||
],
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
temperature: Optional[float] = None):
|
||||
|
||||
**gen_params):
|
||||
super().__init__(
|
||||
model_type=model_type,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
query_per_second=query_per_second,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
**gen_params)
|
||||
self.logger = getLogger(__name__)
|
||||
self.temperature = temperature
|
||||
|
||||
if isinstance(key, str):
|
||||
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
||||
@@ -97,67 +95,57 @@ class GPTAPI(BaseAPIModel):
|
||||
context_window = 8192
|
||||
self.context_window = context_window
|
||||
|
||||
def generate(
|
||||
def chat(
|
||||
self,
|
||||
inputs: Union[List, str],
|
||||
max_out_len: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
inputs: Union[List[dict], List[List[dict]]],
|
||||
**gen_params,
|
||||
) -> Union[str, List[str]]:
|
||||
"""Generate responses given the contexts.
|
||||
|
||||
Args:
|
||||
inputs (List[str or List]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic. Defaults to 0.7.
|
||||
inputs (Union[List[dict], List[List[dict]]]): a list of messages
|
||||
or list of lists of messages
|
||||
gen_params: additional generation configuration
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
Union[str, List[str]]: generated string(s)
|
||||
"""
|
||||
if self.temperature is not None:
|
||||
temperature = self.temperature
|
||||
return self._generate(inputs, max_out_len, temperature)
|
||||
assert isinstance(inputs, list)
|
||||
if isinstance(inputs[0], dict):
|
||||
inputs = [inputs]
|
||||
gen_params = {**self.gen_params, **gen_params}
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
tasks = [
|
||||
executor.submit(self._chat, messages, **gen_params)
|
||||
for messages in inputs
|
||||
]
|
||||
wait(tasks)
|
||||
ret = [task.result() for task in tasks]
|
||||
return ret[0] if isinstance(inputs[0], dict) else ret
|
||||
|
||||
def _generate(self, input: str or List, max_out_len: int,
|
||||
temperature: float) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
def _chat(self, messages: List[dict], **gen_params) -> str:
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
inputs (str or List): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic.
|
||||
messages (List[dict]): a list of prompt dictionaries
|
||||
gen_params: additional generation configuration
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, list, dict))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
elif isinstance(input, dict):
|
||||
messages = [input]
|
||||
else:
|
||||
messages = input
|
||||
assert isinstance(messages, list)
|
||||
gen_params = gen_params.copy()
|
||||
|
||||
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
||||
max_out_len = min(
|
||||
max_out_len,
|
||||
self.context_window - self.get_token_len(str(input)) - 100)
|
||||
gen_params.pop('max_out_len'),
|
||||
self.context_window - len(self.tokenize(str(input))) - 100)
|
||||
if max_out_len <= 0:
|
||||
return ''
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.wait()
|
||||
self._wait()
|
||||
|
||||
with Lock():
|
||||
if len(self.invalid_keys) == len(self.keys):
|
||||
@@ -192,8 +180,9 @@ class GPTAPI(BaseAPIModel):
|
||||
messages=messages,
|
||||
max_tokens=max_out_len,
|
||||
n=1,
|
||||
stop=None,
|
||||
temperature=temperature,
|
||||
stop=gen_params.pop('stop_words'),
|
||||
frequency_penalty=gen_params.pop('repetition_penalty'),
|
||||
**gen_params,
|
||||
)
|
||||
raw_response = requests.post(
|
||||
self.url, headers=header, data=json.dumps(data))
|
||||
@@ -225,18 +214,16 @@ class GPTAPI(BaseAPIModel):
|
||||
f'{max_num_retries} times. Check the logs for '
|
||||
'details.')
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
def tokenize(self, prompt: str) -> list:
|
||||
"""Tokenize the input prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
list: token ids
|
||||
"""
|
||||
import tiktoken
|
||||
self.tiktoken = tiktoken
|
||||
enc = self.tiktoken.encoding_for_model(self.model_type)
|
||||
return len(enc.encode(prompt))
|
||||
return enc.encode(prompt)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from lagent.utils import is_module_exist
|
||||
|
||||
@@ -33,47 +33,49 @@ class ActionValidCode(int, Enum):
|
||||
|
||||
@dataclass
|
||||
class ActionReturn:
|
||||
args: Dict
|
||||
args: Optional[dict] = None
|
||||
url: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
result: Optional[str] = None
|
||||
result: Optional[List[dict]] = None
|
||||
errmsg: Optional[str] = None
|
||||
state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS
|
||||
thought: Optional[str] = None
|
||||
valid: Optional[ActionValidCode] = ActionValidCode.OPEN
|
||||
|
||||
def format_result(self) -> str:
|
||||
"""Concatenate items in result"""
|
||||
result = []
|
||||
for item in self.result or []:
|
||||
if item['type'] == 'text':
|
||||
result.append(item['content'])
|
||||
else:
|
||||
result.append(f"[{item['type']}]({item['content']})")
|
||||
result = '\n'.join(result)
|
||||
return result
|
||||
|
||||
class AgentStatusCode(Enum):
|
||||
END = 0 # end of streaming
|
||||
|
||||
# 需要集成int,如此asdict可以把AgentStatusCode 转换成 int
|
||||
class AgentStatusCode(int, Enum):
|
||||
END = 0 # end of streaming 返回本次history
|
||||
STREAM_ING = 1 # response is in streaming
|
||||
SERVER_ERR = -1 # triton server's error
|
||||
SESSION_CLOSED = -2 # session has been closed
|
||||
SESSION_OUT_OF_LIMIT = -3 # request length out of limit
|
||||
CMD = 2 # return command
|
||||
PLUGIN_START = 3 # start tool
|
||||
PLUGIN_END = 4 # finish tool
|
||||
PLUGIN_RETURN = 5 # finish tool
|
||||
|
||||
CODING = 6 # start python
|
||||
CODE_END = 7 # end python
|
||||
CODE_RETURN = 8 # python return
|
||||
SESSION_INVALID_ARG = -4 # invalid argument
|
||||
SESSION_READY = 3 # session is ready for inference
|
||||
SESSION_READY = 2 # session is ready for inference
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentReturn:
|
||||
state: Union[AgentStatusCode, int] = AgentStatusCode.END
|
||||
actions: List[ActionReturn] = field(default_factory=list)
|
||||
response: str = ''
|
||||
inner_steps: List = field(default_factory=list)
|
||||
errmsg: Optional[str] = None
|
||||
|
||||
|
||||
if is_module_exist('lmdeploy'):
|
||||
from lmdeploy.serve.turbomind.chatbot import StatusCode
|
||||
STATE_MAP = {
|
||||
StatusCode.TRITON_STREAM_END: AgentStatusCode.END,
|
||||
StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR,
|
||||
StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED,
|
||||
StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING,
|
||||
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
|
||||
AgentStatusCode.SESSION_OUT_OF_LIMIT,
|
||||
StatusCode.TRITON_SESSION_INVALID_ARG:
|
||||
AgentStatusCode.SESSION_INVALID_ARG,
|
||||
StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY
|
||||
}
|
||||
else:
|
||||
STATE_MAP = {}
|
||||
|
||||
30
lagent/utils/util.py
Normal file
30
lagent/utils/util.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] = None) -> str:
|
||||
"""Filter response with suffixes.
|
||||
|
||||
Args:
|
||||
response (Union[str, List[str]]): generated responses by LLMs.
|
||||
suffixes (str): a list of suffixes to be deleted.
|
||||
|
||||
Return:
|
||||
str: a clean response.
|
||||
"""
|
||||
if suffixes is None:
|
||||
return response
|
||||
batched = True
|
||||
if isinstance(response, str):
|
||||
response = [response]
|
||||
batched = False
|
||||
processed = []
|
||||
for resp in response:
|
||||
for item in suffixes:
|
||||
# if response.endswith(item):
|
||||
# response = response[:len(response) - len(item)]
|
||||
if item in resp:
|
||||
resp = resp.split(item)[0]
|
||||
processed.append(resp)
|
||||
if not batched:
|
||||
return processed[0]
|
||||
return processed
|
||||
@@ -1,9 +1,13 @@
|
||||
docutils==0.16.0
|
||||
docutils==0.18.1
|
||||
markdown>=3.4.0
|
||||
myst-parser
|
||||
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
sphinx==4.0.2
|
||||
myst-nb
|
||||
# -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
# sphinx==4.0.2
|
||||
sphinx==6.1.0
|
||||
sphinx-tabs
|
||||
sphinx_copybutton
|
||||
sphinx_markdown_tables>=0.0.16
|
||||
sphinx-rtd-theme==1.3.0
|
||||
tabulate
|
||||
astroid<3.0.0
|
||||
sphinx-autoapi
|
||||
|
||||
@@ -3,3 +3,9 @@ func_timeout
|
||||
jsonschema
|
||||
requests
|
||||
tiktoken
|
||||
griffe
|
||||
phx-class-registry
|
||||
jupyter
|
||||
jupyter_client
|
||||
json5
|
||||
pillow
|
||||
|
||||
Reference in New Issue
Block a user