[Refactor] Refactor Lagent (#97)

* Feature: redesign BaseModel (#80)

* redesign BaseModel

* update docstring

* update baseModel

* [Refactor] improve `Action` and `ActionExecutor` (#83)

* [Fix]: fix turbomind (#81)

fix turbomind

* add parsers

* skip ActionReturn in postprocessing

* check existence of API name

* add exception catch in action executing

* validate input arguments

* modify returned structure of `get_actions_info`

* adapt tools to the new protocol

* remove `LLMQA` action

---------

Co-authored-by: RangiLyu <lyuchqi@gmail.com>
Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* [Feature] add tools (#89)

* add new tools

* update PPT

* chores

* update action module init

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* rename func 'completion' to 'generate' (#90)

* [Feature] support batch inference in API models (#91)

* implement `chat`

* update agent interfaces

* redundancy reduction

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* Feature: lmdeploy_wrapper implemented BaseMode (#86)

* [Fix]: fix turbomind (#81)

fix turbomind

* Feature: lmdeploy_wrapper implemented BaseMode

* remove comments of 'llms.__init__'

* update of 'llms.__init__'

* update lmdepoly_wrapper with 'gen_params'

* add property 'state_map' in __init_ and use APIClient to stream infer_

* func 'generate' in LMDeployClient with 'APIClient'

* fix bug of TritonClient

* add docstr for LMDeployPipeline & LMDeployServer

* class LMDeployClient inherits class LMDeployServer

* LMDeployClient with BaseModel.__init__ and use field 'max_tokens' control model output

* add TODO

* move 'import mmengine' to func '_update_gen_params'

---------

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* Fix APITemplateParser object is not callable (#95)

fix APITemplateParser object is not callable

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* [Feat] support StreamAgent (#82)

* [Feat] support StreamAgent

* update `StreamAgent`

* truncate inner history

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* [Feat] hf llm implemented BaseMode (#92)

* Feature: huggingface implemented BaseMode

* hf llm implemented BaseMode

* fix bug of hf llm

* inject attention_mask during inference

* remove unnecessary

* [Feature] support building tool descriptions automatically (#96)

* redundancy reduction

* add `tool_api` to annotate a tool method

* improve json parsing

* enhance parsers

* update README.md

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* Enhance tool annotation (#98)

improve `tool_api`

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* [Docs] initialize the documentation (#99)

init the docs

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* Modify the structure of `ActionReturn`'s result (#102)

* modify the struction of action results

* fix docstrings

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* Fix .readthedocs.yml (#104)

fix rtd config

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* [Feature] support IPython interpreter action (#103)

* add ipython interpreter

* update requirements

* remove `return_list` argument

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* Fix BINGMap key (#105)

fix the fallback value

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* StreamAgent  infer demo (#106)

* update cfg & fix bug of StreamAgent

* fix bug of func 'stream_chat'

* streamlit demo with full response

* enchance stream chat

* fix bug of stream chat

* fix and file rename

* add exception catch for func 'chat'

---------

Co-authored-by: liujiangning <liujiangning@pjlab.org.cn>

* [Docs] Add action tutorials (#107)

* add `get_started` chapter

* fix docstrings

* add action.md

* add zh docs

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* Fix returns of OpenAI interface (#108)

fix `BaseAPIModel` chat returns

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>

* Feat: add warn for func 'generate_from_template' (#109)

* add warn for func 'generate_from_template'

* clearer alerts for deprecation

* clearer alerts for deprecation

---------

Co-authored-by: liujiangning30 <147385819+liujiangning30@users.noreply.github.com>
Co-authored-by: BraisedPork <46232992+braisedpork1964@users.noreply.github.com>
Co-authored-by: RangiLyu <lyuchqi@gmail.com>
Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
Co-authored-by: liujiangning <liujiangning@pjlab.org.cn>
This commit is contained in:
liukuikun
2024-01-30 12:48:21 +08:00
committed by GitHub
parent 85b91cc652
commit 95b68c821a
48 changed files with 4534 additions and 797 deletions

View File

@@ -2,7 +2,11 @@ version: 2
formats: all
build:
os: ubuntu-22.04
tools:
python: "3.10"
python:
version: 3.7
install:
- requirements: requirements/docs.txt
install:
- requirements: requirements/docs.txt

View File

@@ -83,7 +83,7 @@ Below is an example of running ReWOO with GPT-3.5
```python
# Import necessary modules and classes from the "lagent" library.
from lagent.agents import ReWOO
from lagent.actions import ActionExecutor, GoogleSearch, LLMQA
from lagent.actions import ActionExecutor, GoogleSearch
from lagent.llms import GPTAPI
# Initialize the Language Model (llm) and provide your API key.
@@ -92,14 +92,11 @@ llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
# Initialize the Google Search tool and provide your API key.
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
# Initialize the LLMQA tool using the Language Model (llm).
llmqa_tool = LLMQA(llm)
# Create a chatbot by configuring the ReWOO agent.
chatbot = ReWOO(
llm=llm, # Provide the Language Model instance.
action_executor=ActionExecutor(
actions=[search_tool, llmqa_tool] # Specify the actions the chatbot can perform.
actions=[search_tool] # Specify the actions the chatbot can perform.
),
)
@@ -154,6 +151,7 @@ response = chatbot.chat(
print(response.response) # Output the response generated by the chatbot.
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
```
### All Thanks To Our Contributors:
<a href="https://github.com/InternLM/lagent/graphs/contributors">
<img src="https://contrib.rocks/image?repo=InternLM/lagent" />

View File

@@ -0,0 +1,14 @@
API Reference
=============
This page contains auto-generated API reference documentation.
.. toctree::
:titlesonly:
:maxdepth: 3
{% for page in pages %}
{% if page.top_level_object and page.display %}
{{ page.include_path }}
{% endif %}
{% endfor %}

View File

@@ -0,0 +1,112 @@
{% if not obj.display %}
:orphan:
{% endif %}
:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}`
=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }}
.. py:module:: {{ obj.name }}
{% if obj.docstring %}
.. autoapi-nested-parse::
{{ obj.docstring|indent(3) }}
{% endif %}
{% block subpackages %}
{% set visible_subpackages = obj.subpackages|selectattr("display")|list %}
{% if visible_subpackages %}
Subpackages
-----------
.. toctree::
:titlesonly:
:maxdepth: 3
{% for subpackage in visible_subpackages %}
{{ subpackage.short_name }}/index.rst
{% endfor %}
{% endif %}
{% endblock %}
{% block submodules %}
{% set visible_submodules = obj.submodules|selectattr("display")|list %}
{% if visible_submodules %}
Submodules
----------
.. toctree::
:titlesonly:
:maxdepth: 1
{% for submodule in visible_submodules %}
{{ submodule.short_name }}/index.rst
{% endfor %}
{% endif %}
{% endblock %}
{% block content %}
{% if obj.type is equalto("package") %}
{% set visible_children = obj.children|selectattr("display")|list %}
{% else %}
{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %}
{% endif %}
{% if visible_children %}
{{ obj.type|title }} Contents
{{ "-" * obj.type|length }}---------
{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %}
{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %}
{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %}
{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %}
{% block classes scoped %}
{% if visible_classes %}
Classes
~~~~~~~
.. autoapisummary::
{% for klass in visible_classes %}
{{ klass.id }}
{% endfor %}
{% endif %}
{% endblock %}
{% block functions scoped %}
{% if visible_functions %}
Functions
~~~~~~~~~
.. autoapisummary::
{% for function in visible_functions %}
{{ function.id }}
{% endfor %}
{% endif %}
{% endblock %}
{% block attributes scoped %}
{% if visible_attributes %}
Attributes
~~~~~~~~~~
.. autoapisummary::
{% for attribute in visible_attributes %}
{{ attribute.id }}
{% endfor %}
{% endif %}
{% endblock %}
{% endif %}
{% for obj_item in visible_children %}
{{ obj_item.render()|indent(0) }}
{% endfor %}
{% endif %}
{% endblock %}

View File

@@ -11,17 +11,16 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
import os
import re
import sys
import pytorch_sphinx_theme
sys.path.insert(0, os.path.abspath('../../'))
sys.path.insert(0, os.path.abspath('../..'))
# -- Project information -----------------------------------------------------
project = 'Lagent'
copyright = '2020-2030, InternLM'
author = 'InternLM'
language = 'en'
# The full version, including alpha/beta/rc tags
version_file = '../../lagent/version.py'
@@ -36,97 +35,74 @@ release = __version__
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx_rtd_theme',
'myst_nb',
'autoapi.extension',
'sphinx_markdown_tables',
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_markdown_tables',
'sphinx_copybutton',
'myst_parser',
'sphinx.ext.intersphinx',
'sphinx.ext.autodoc.typehints',
'sphinx.ext.autosummary',
'sphinx.ext.autosectionlabel',
'sphinx_tabs.tabs',
]
nb_output_stderr = 'remove-warn'
autodoc_typehints = 'description'
autosummary_generate = True # Turn on sphinx.ext.autosummary
# Ignore >>> when copying code
copybutton_prompt_text = r'>>> |\.\.\. '
copybutton_prompt_is_regexp = True
myst_enable_extensions = ['colon_fence']
# sphinx-autoapi configuration
autoapi_dirs = ['../../lagent']
autoapi_options = [
'members',
'undoc-members',
'show-inheritance',
'show-module-summary',
]
autoapi_ignore = ['*migrations*', '*command.py', '*cli.py']
autoapi_template_dir = '_templates/autoapi'
autoapi_add_toctree_entry = False
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
# The master toctree document.
master_doc = 'index'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
# html_theme = 'sphinx_rtd_theme'
html_theme = 'pytorch_sphinx_theme'
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
html_theme = 'sphinx_rtd_theme'
html_theme_options = {
'menu': [
{
'name': 'GitHub',
'url': 'https://github.com/InternLM/lagent'
},
],
# Specify the language of shared menu
'menu_lang': 'en'
'navigation_depth': 3,
'titles_only': False,
'style_nav_header_background': '#4fabab',
}
language = 'en'
html_context = {
'display_github': True,
'github_host': 'github.com',
'github_user': 'InternLM',
'github_repo': 'lagent',
'github_version': 'main',
'conf_py_path': '/docs/en/',
}
html_title = 'Lagent'
html_logo = '../imgs/lagent_logo.png'
html_favicon = '../imgs/lagent_icon.png'
master_doc = 'index'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# so a file named 'default.css' will overwrite the builtin 'default.css'.
html_static_path = ['_static']
html_css_files = [
'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css',
'css/readthedocs.css'
]
html_js_files = [
'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js',
'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js',
'js/collapsed.js',
'js/table.js',
]
myst_heading_anchors = 4
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'numpy': ('https://numpy.org/doc/stable', None),
'torch': ('https://pytorch.org/docs/stable/', None),
}
def custom_skip(app, what, name, obj, skip, options):
if what in ['data', 'function', 'class'] and re.search('logger', name):
skip = True
return skip
def builder_inited_handler(app):
pass
def setup(app):
app.connect('builder-inited', builder_inited_handler)
def setup(sphinx):
sphinx.connect('autoapi-skip-member', custom_skip)

View File

@@ -0,0 +1,19 @@
# Installation
## With pip
Install with pip (Recommended).
```bash
pip install lagent
```
## From source
Optionally, you could also build Lagent from source in case you want to modify the code:
```bash
git clone https://github.com/InternLM/lagent.git
cd lagent
pip install -e .
```

View File

@@ -1,4 +1,4 @@
# OVERVIEW
# Overview
This chapter introduces you to the framework of Lagent, and provides links to detailed tutorials about Lagent.
@@ -18,6 +18,6 @@ Lagent consists of 3 main parts, agents, llms, and actions.
Here is a detailed step-by-step guide to learn more about Lagent:
1. For installation instructions, please see [README](../README.md).
1. For installation instructions, please see [README](https://github.com/InternLM/lagent/blob/main/README.md).
2. We provide several examples to build agents with Lagent in [examples](examples/) by simply run `python examples/react_example.py`.
2. We provide several examples to build agents with Lagent in [examples](https://github.com/InternLM/lagent/tree/main/examples) by simply run `python examples/react_example.py`.

View File

@@ -0,0 +1,89 @@
# Quickstart
Using Lagent, you can easily build agents with just a few lines of code.
## Run a ReWOO agent with GPT-3.5
Below is an example of running ReWOO with GPT-3.5
```python
# Import necessary modules and classes from the "lagent" library.
from lagent.agents import ReWOO
from lagent.actions import ActionExecutor, GoogleSearch
from lagent.llms import GPTAPI
# Initialize the Language Model (llm) and provide your API key.
llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
# Initialize the Google Search tool and provide your API key.
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
# Create a chatbot by configuring the ReWOO agent.
chatbot = ReWOO(
llm=llm, # Provide the Language Model instance.
action_executor=ActionExecutor(
actions=[search_tool] # Specify the actions the chatbot can perform.
),
)
# Ask the chatbot a question and store the response.
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
# Print the chatbot's response.
print(response.response) # Output the response generated by the chatbot.
```
```python
>>> Film director.
```
## Run a ReAct agent with InternLM
NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first.
```python
# Import necessary modules and classes from the "lagent" library.
from lagent.agents import ReAct
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
from lagent.llms import HFTransformer
from lagent.llms.meta_template import INTERNLM2_META as META
# Initialize the HFTransformer-based Language Model (llm) and
# provide the model name.
llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META)
# Initialize the Google Search tool and provide your API key.
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
# Initialize the Python Interpreter tool.
python_interpreter = PythonInterpreter()
# Create a chatbot by configuring the ReAct agent.
# Specify the actions the chatbot can perform.
chatbot = ReAct(
llm=llm, # Provide the Language Model instance.
action_executor=ActionExecutor(
actions=[search_tool, python_interpreter]),
)
# Ask the chatbot a mathematical question in LaTeX format.
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$')
# Print the chatbot's response.
print(response.response) # Output the response generated by the chatbot.
```
```python
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
```
## Run ReAct Web Demo
```python
# You need to install streamlit first
# pip install streamlit
streamlit run examples/react_web_demo.py
```
Then you can chat through the UI shown as below
![image](https://github.com/InternLM/lagent/assets/24622904/3aebb8b4-07d1-42a2-9da3-46080c556f68)

View File

@@ -8,13 +8,31 @@ You can switch between English and Chinese in the lower-left corner of the layou
:caption: Get Started
get_started/overview.md
get_started/install.md
get_started/quickstart.md
.. toctree::
:maxdepth: 2
:caption: Tutorials
tutorials/action.md
.. toctree::
:caption: Switch Language
switch_language.md
.. toctree::
:maxdepth: 1
:caption: API Reference
autoapi/lagent/actions/index
autoapi/lagent/agents/index
autoapi/lagent/llms/index
autoapi/lagent/utils/index
autoapi/lagent/schema/index
autoapi/lagent/version/index
Indices and tables
==================

396
docs/en/tutorials/action.md Normal file
View File

@@ -0,0 +1,396 @@
# Action
Actions, also called **tools**, provide a suite of functions LLM-driven agents can use to interact with the real world and perform complex tasks.
## Basic Concepts
### Tool & Toolkit
There are two categories of tools:
* tool: provide only one API to call.
* toolkit: implement multiple APIs that undertake different sub-tasks.
### Tool Description
In Lagent, the tool description is a dictionary containing the action's core information of usage, observed by LLMs for decision-making.
For simple tools, the description can be created as follows
```python
TOOL_DESCRIPTION = {
'name': 'bold', # name of the tool
'description': 'a function used to make text bold', # introduce the tool's function
'parameters': [ # a list of parameters the tool take.
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text'], # specify names of parameters required
}
```
In some situations there may be optional `return_data`, `parameter_description` keys describing the returns and argument passing format respectively.
```{attention}
`parameter_description` is usually inserted into the tool description automatically by the action's parser. It will be introduced in [Interface Design](#interface-design) .
```
For toolkits, the description is very similar but nest submethods
```python
TOOL_DESCRIPTION = {
'name': 'PhraseEmphasis', # name of the toolkit
'description': 'a toolkit which provides different styles of text emphasis', # introduce the tool's function
'api_list': [
{
'name': 'bold',
'description': 'make text bold',
'parameters': [
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text']
},
{
'name': 'italic',
'description': 'make text italic',
'parameters': [
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text']
}
]
}
```
## Make Functions Tools
It's not necessary to prepare an extra description for a defined function. In Lagent we provide a decorator `tool_api` which can conveniently turn a function into a tool by automatically parsing the function's typehints and dosctrings to generate the description dictionary and binding it to an attribute `api_description`.
```python
from lagent import tool_api
@tool_api
def bold(text: str) -> str:
"""make text bold
Args:
text (str): input text
Returns:
str: bold text
"""
return '**' + text + '**'
bold.api_description
```
```python
{'name': 'bold',
'description': 'make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text']}
```
Once `returns_named_value` is enabled you should declare the name of the return data, which will be processed to form a new field `return_data`:
```python
@tool_api(returns_named_value=True)
def bold(text: str) -> str:
"""make text bold
Args:
text (str): input text
Returns:
bold_text (str): bold text
"""
return '**' + text + '**'
bold.api_description
```
```python
{'name': 'bold',
'description': 'make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text'],
'return_data': [{'name': 'bold_text',
'description': 'bold text',
'type': 'STRING'}]}
```
Sometimes the tool may return a `dict` or `tuple`, and you want to elaborate each member in `return_data` rather than take them as a whole. Set `explode_return=True` and list them in the return part of docstrings.
```python
@tool_api(explode_return=True)
def list_args(a: str, b: int, c: float = 0.0) -> dict:
"""Return arguments in dict format
Args:
a (str): a
b (int): b
c (float): c
Returns:
dict: input arguments
- a (str): a
- b (int): b
- c: c
"""
return {'a': a, 'b': b, 'c': c}
```
```python
{'name': 'list_args',
'description': 'Return arguments in dict format',
'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'},
{'name': 'b', 'type': 'NUMBER', 'description': 'b'},
{'name': 'c', 'type': 'FLOAT', 'description': 'c'}],
'required': ['a', 'b'],
'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'},
{'name': 'b', 'description': 'b', 'type': 'NUMBER'},
{'name': 'c', 'description': 'c'}]}
```
```{warning}
Only Google style Python docstrings is currently supported.
```
## Interface Design
`BaseAction(description=None, parser=JsonParser, enable=True)` is the base class all actions should inherit from. It takes three initialization arguments
* **description**: a tool description dictionary, used set instance attribute `description`. Mostly you don't need explicitly pass this argument since the meta class of `BaseAction` will search methods decorated by `tool_api` and assemble their `api_description` as a class attribute `__tool_description__`, and if the initial `description` is left null, then `__tool_description__` will be copied as `description`.
* **parser**: `BaseParser` class. It will instantialize a parser used to validate the arguments of APIs in `description`.
For example, `JsonParser` requires arguments passed in the format of JSON or `dict`. To make LLMs aware of this, It inserts a field `parameter_description` into the `description`.
```python
from lagent import BaseAction
action = BaseAction(
{
'name': 'bold',
'description': 'a function used to make text bold',
'parameters': [
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text']
}
)
action.description
```
```python
{'name': 'bold',
'description': 'a function used to make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input content'}],
'required': ['text'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'}
```
* **enable**: specify whether the tool is available.
### Custom Action
A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word.
```python
class Bold(BaseAction):
@tool_api
def run(self, text: str):
"""make text bold
Args:
text (str): input text
Returns:
str: bold text
"""
return '**' + text + '**'
class PhraseEmphasis(BaseAction):
"""a toolkit which provides different styles of text emphasis"""
@tool_api
def bold(self, text):
"""make text bold
Args:
text (str): input text
Returns:
str: bold text
"""
return '**' + text + '**'
@tool_api
def italic(self, text):
"""make text italic
Args:
text (str): input text
Returns:
str: italic text
"""
return '*' + text + '*'
# Inspect the default description
# Bold.__tool_description__, PhraseEmphasis.__tool_description__
```
### Auto-registration
Any subclass of `BaseAction` will be registered automatically. You can use `list_tools()` and `get_tool()` to view all tools and initialize by name.
```python
from lagent import list_tools, get_tool
list_tools()
```
```python
['BaseAction',
'InvalidAction',
'NoAction',
'FinishAction',
'ArxivSearch',
'BINGMap',
'GoogleScholar',
'GoogleSearch',
'IPythonInterpreter',
'PPT',
'PythonInterpreter',
'Bold',
'PhraseEmphasis']
```
Create a `PhraseEmphasis` object
```python
action = get_tool('PhraseEmphasis')
action.description
```
```python
{'name': 'PhraseEmphasis',
'description': 'a toolkit which provides different styles of text emphasis',
'api_list': [{'name': 'bold',
'description': 'make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'},
{'name': 'italic',
'description': 'make text italic',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'}]}
```
## Tool Calling
### Run a Tool
`__call__` method of `Action` takes two arguments
* `inputs`: It depends on the action's parser. Often a string in specific formats generated by LLMs.
+ `JsonParser`: Allow passing arguements in the format of JSON string or Python `dict`.
+ `TupleParser`: Allow passing arguments in the format of tuple string format or Python `tuple`.
* `name`: Which API to call. Default is `run`.
It returns an `ActionReturn` object which encapsulates calling details
* `args`: Dictionary of action inputs.
* `type`: Action name.
* `result`: List of dicts. Each contains two keys: 'type' and 'content'. when errors occur, it is `None`.
* `errmsg`: Error message. Default is `None`.
Below is an example
```python
from lagent import IPythonInterpreter, TupleParser
action1 = IPythonInterpreter()
ret = action1('{"command": "import math;math.sqrt(100)"}')
print(ret.result)
ret = action1({'command': 'import math;math.sqrt(100)'})
print(ret.result)
action2 = IPythonInterpreter(parser=TupleParser)
ret = action2('("import math;math.sqrt(100)", )')
print(ret.result)
ret = action2(('import math;math.sqrt(100)',))
print(ret.result)
```
```python
[{'type': 'text', 'content': '10.0'}]
[{'type': 'text', 'content': '10.0'}]
[{'type': 'text', 'content': '10.0'}]
[{'type': 'text', 'content': '10.0'}]
```
### Dynamic Invocation
Lagent provides an `ActionExecutor` to manage multiple tools. It will flatten `api_list` of toolkits and rename each `{tool_name}.{api_name}`.
```python
from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter
executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()])
executor.get_actions_info() # This information is fed to LLMs as the tool meta prompt
```
```python
[{'name': 'ArxivSearch.get_arxiv_article_information',
'description': 'Run Arxiv search and get the article meta information.',
'parameters': [{'name': 'query',
'type': 'STRING',
'description': 'the content of search query'}],
'required': ['query'],
'return_data': [{'name': 'content',
'description': 'a list of 3 arxiv search papers',
'type': 'STRING'}],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'},
{'name': 'IPythonInterpreter',
'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.",
'parameters': [{'name': 'command',
'type': 'STRING',
'description': 'Python code'},
{'name': 'timeout',
'type': 'NUMBER',
'description': 'Upper bound of waiting time for Python script execution.'}],
'required': ['command'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'}]
```
Trigger an action through the executor
```python
ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}')
ret.result
```
```python
[{'type': 'text', 'content': '10.0'}]
```

View File

@@ -0,0 +1,14 @@
API Reference
=============
This page contains auto-generated API reference documentation.
.. toctree::
:titlesonly:
:maxdepth: 3
{% for page in pages %}
{% if page.top_level_object and page.display %}
{{ page.include_path }}
{% endif %}
{% endfor %}

View File

@@ -0,0 +1,112 @@
{% if not obj.display %}
:orphan:
{% endif %}
:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}`
=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }}
.. py:module:: {{ obj.name }}
{% if obj.docstring %}
.. autoapi-nested-parse::
{{ obj.docstring|indent(3) }}
{% endif %}
{% block subpackages %}
{% set visible_subpackages = obj.subpackages|selectattr("display")|list %}
{% if visible_subpackages %}
Subpackages
-----------
.. toctree::
:titlesonly:
:maxdepth: 3
{% for subpackage in visible_subpackages %}
{{ subpackage.short_name }}/index.rst
{% endfor %}
{% endif %}
{% endblock %}
{% block submodules %}
{% set visible_submodules = obj.submodules|selectattr("display")|list %}
{% if visible_submodules %}
Submodules
----------
.. toctree::
:titlesonly:
:maxdepth: 1
{% for submodule in visible_submodules %}
{{ submodule.short_name }}/index.rst
{% endfor %}
{% endif %}
{% endblock %}
{% block content %}
{% if obj.type is equalto("package") %}
{% set visible_children = obj.children|selectattr("display")|list %}
{% else %}
{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %}
{% endif %}
{% if visible_children %}
{{ obj.type|title }} Contents
{{ "-" * obj.type|length }}---------
{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %}
{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %}
{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %}
{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %}
{% block classes scoped %}
{% if visible_classes %}
Classes
~~~~~~~
.. autoapisummary::
{% for klass in visible_classes %}
{{ klass.id }}
{% endfor %}
{% endif %}
{% endblock %}
{% block functions scoped %}
{% if visible_functions %}
Functions
~~~~~~~~~
.. autoapisummary::
{% for function in visible_functions %}
{{ function.id }}
{% endfor %}
{% endif %}
{% endblock %}
{% block attributes scoped %}
{% if visible_attributes %}
Attributes
~~~~~~~~~~
.. autoapisummary::
{% for attribute in visible_attributes %}
{{ attribute.id }}
{% endfor %}
{% endif %}
{% endblock %}
{% endif %}
{% for obj_item in visible_children %}
{{ obj_item.render()|indent(0) }}
{% endfor %}
{% endif %}
{% endblock %}

View File

@@ -11,18 +11,16 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
import os
import subprocess
import re
import sys
import pytorch_sphinx_theme
sys.path.insert(0, os.path.abspath('../../'))
sys.path.insert(0, os.path.abspath('../..'))
# -- Project information -----------------------------------------------------
project = 'Lagent'
copyright = '2020-2030, InternLM'
author = 'InternLM'
language = 'zh_CN'
# The full version, including alpha/beta/rc tags
version_file = '../../lagent/version.py'
@@ -37,97 +35,74 @@ release = __version__
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx_rtd_theme',
'myst_nb',
'autoapi.extension',
'sphinx_markdown_tables',
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_markdown_tables',
'sphinx_copybutton',
'myst_parser',
'sphinx.ext.intersphinx',
'sphinx.ext.autodoc.typehints',
'sphinx.ext.autosummary',
'sphinx.ext.autosectionlabel',
'sphinx_tabs.tabs',
]
nb_output_stderr = 'remove-warn'
autodoc_typehints = 'description'
autosummary_generate = True # Turn on sphinx.ext.autosummary
# Ignore >>> when copying code
copybutton_prompt_text = r'>>> |\.\.\. '
copybutton_prompt_is_regexp = True
myst_enable_extensions = ['colon_fence']
# sphinx-autoapi configuration
autoapi_dirs = ['../../lagent']
autoapi_options = [
'members',
'undoc-members',
'show-inheritance',
'show-module-summary',
]
autoapi_ignore = ['*migrations*', '*command.py', '*cli.py']
autoapi_template_dir = '_templates/autoapi'
autoapi_add_toctree_entry = False
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
# The master toctree document.
master_doc = 'index'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
# html_theme = 'sphinx_rtd_theme'
html_theme = 'pytorch_sphinx_theme'
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
html_theme = 'sphinx_rtd_theme'
html_theme_options = {
'menu': [
{
'name': 'GitHub',
'url': 'https://github.com/InternLM/lagent'
},
],
# Specify the language of shared menu
'menu_lang': 'cn',
'navigation_depth': 3,
'titles_only': False,
'style_nav_header_background': '#4fabab',
}
language = 'zh_CN'
html_context = {
'display_github': True,
'github_host': 'github.com',
'github_user': 'InternLM',
'github_repo': 'lagent',
'github_version': 'main',
'conf_py_path': '/docs/en/',
}
html_title = 'Lagent'
html_logo = '../imgs/lagent_logo.png'
html_favicon = '../imgs/lagent_icon.png'
master_doc = 'index'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# so a file named 'default.css' will overwrite the builtin 'default.css'.
html_static_path = ['_static']
html_css_files = [
'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css',
'css/readthedocs.css'
]
html_js_files = [
'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js',
'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js',
'js/collapsed.js',
'js/table.js',
]
myst_heading_anchors = 4
# Configuration for intersphinx
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'numpy': ('https://numpy.org/doc/stable', None),
'torch': ('https://pytorch.org/docs/stable/', None),
}
def builder_inited_handler(app):
subprocess.run(['./cp_origin_docs.sh'])
def custom_skip(app, what, name, obj, skip, options):
if what in ['data', 'function', 'class'] and re.search('logger', name):
skip = True
return skip
def setup(app):
app.connect('builder-inited', builder_inited_handler)
def setup(sphinx):
sphinx.connect('autoapi-skip-member', custom_skip)

View File

@@ -0,0 +1,19 @@
# 安装方式
## pip安装
推荐使用 pip 安装
```bash
pip install lagent
```
## 源码安装
如需修改部分功能,可以从源码构建 Lagent
```bash
git clone https://github.com/InternLM/lagent.git
cd lagent
pip install -e .
```

View File

@@ -0,0 +1,23 @@
# 总览
本章节将介绍 Lagent 的架构,并提供 Lagent 详细教程的链接。
## Lagent 是什么
Lagent 是一个开源的 LLM 智能体框架允许使用者快速将一个大语言模型转换成智能体并提供一些典型工具来激发大语言模型的潜能。Lagent 框架图如下:
![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6)
Lagent 包含三个主要模块agentsllms 和 actions。
- **agents** 实现了多种智能体,如 ReActAutoGPT。
- **llms** 支持多种大语言模型,包括在 HuggingFace 上托管的开源模型Llama-2, InterLM及 GPT3.5/4 等闭源模型。
- **actions** 包含一系列工具,并提供工具执行器来统一管理。
## 如何使用
以下是帮助您了解关于 Lagent 更多信息的详细教程:
1. 安装请参考 [README](https://github.com/InternLM/lagent/blob/main/README.md).
2. 一些构建智能体的实例 [examples](https://github.com/InternLM/lagent/tree/main/examples),直接运行脚本即可,如 `python examples/react_example.py`.

View File

@@ -0,0 +1,87 @@
# 快速上手
借助 Lagent 仅需几行代码就能构建大语言模型智能体。
## GPT-3.5 驱动的 ReWOO 智能体
下面是使用 GPT-3.5 运行 ReWOO 的示例
```python
# 从 Lagent 导入必要的模块和类
from lagent.agents import ReWOO
from lagent.actions import ActionExecutor, GoogleSearch
from lagent.llms import GPTAPI
# 初始化 LLM你可能需要提供 API 密钥
llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
# 配置 ReWOO 智能体,创建聊天机器人
chatbot = ReWOO(
llm=llm, # 大语言模型实例
action_executor=ActionExecutor(
actions=[search_tool] # 指定智能体可以调用的工具
),
)
# 询问问题并获取回复
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
# 打印回复
print(response.response)
```
```python
>>> Film director.
```
## InterLM 驱动的 ReAct 智能体
注意,如果你想使用 HuggingFace 模型,请先运行 `pip install -e .[all]`
```python
# 从 Lagent 导入必要的模块和类
from lagent.agents import ReAct
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
from lagent.llms import HFTransformer
from lagent.llms.meta_template import INTERNLM2_META as META
# 初始化 HFTransformer 模型
llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META)
# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
# 初始化 Python 代码解释其
python_interpreter = PythonInterpreter()
# 配置 ReAct 智能体,创建聊天机器人
chatbot = ReAct(
llm=llm, # 大语言模型实例
action_executor=ActionExecutor(
actions=[search_tool, python_interpreter]), # 指定智能体可以调用的工具
)
# 询问LaTeX格式的数学问题
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$')
# 打印回复
print(response.response)
```
```python
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
```
## 启动 ReAct 网页 App
```python
# 你需要先安装 streamlit
# pip install streamlit
streamlit run examples/react_web_demo.py
```
然后你可以通过下图所示UI界面进行对话
![image](https://github.com/InternLM/lagent/assets/24622904/3aebb8b4-07d1-42a2-9da3-46080c556f68)

View File

@@ -8,12 +8,32 @@
:caption: 新手入门
get_started/overview.md
get_started/install.md
get_started/quickstart.md
.. toctree::
:maxdepth: 2
:caption: 教程
tutorials/action.md
.. toctree::
:caption: 切换语言
switch_language.md
.. toctree::
:maxdepth: 1
:caption: API 参考
autoapi/lagent/actions/index
autoapi/lagent/agents/index
autoapi/lagent/llms/index
autoapi/lagent/utils/index
autoapi/lagent/schema/index
autoapi/lagent/version/index
导引
==================

View File

@@ -0,0 +1,394 @@
# 动作
动作也被称为工具提供了一套LLM驱动的智能体用来与真实世界交互并执行复杂任务的函数。
## 基本概念
### 工具 & 工具包
有两种类型的工具:
* 简单工具: 只提供一个API接口供调用。
* 工具包: 实现多个API接口承担不同的子任务。
### 工具描述
在Lagent中工具描述是一个刻画工具调用方式的字典能够被LLM观察并用于决策。
对于简单工具,描述可按如下格式声明:
```python
TOOL_DESCRIPTION = {
'name': 'bold', # 工具名称
'description': 'a function used to make text bold', # 介绍工具的功能
'parameters': [ # 这个工具所需要的参数列表
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text'], # 指定必需的参数名
}
```
在某些情况下,可能还包含 `return_data``parameter_description` 字段,分别描述返回内容及参数传递格式。
```{attention}
`parameter_description` 通常被动作的解析器自动插入到工具描述中,这部分将在[接口设计](#id6)中进行介绍。
```
对于工具包,描述非常相似,但嵌套了子方法
```python
TOOL_DESCRIPTION = {
'name': 'PhraseEmphasis', # 工具包的名字
'description': 'a toolkit which provides different styles of text emphasis', # 介绍工具包的功能
'api_list': [
{
'name': 'bold',
'description': 'make text bold',
'parameters': [
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text']
},
{
'name': 'italic',
'description': 'make text italic',
'parameters': [
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text']
}
]
}
```
## 将函数转换为工具
对于已定义好的函数,无需人工添加额外的描述。在 Lagent 中,我们提供了一个修饰器 `tool_api`,它可以通过自动解析函数的类型提示和文档字符串来生成描述字典,并将其绑定到属性 `api_description`。
```python
from lagent import tool_api
@tool_api
def bold(text: str) -> str:
"""make text bold
Args:
text (str): input text
Returns:
str: bold text
"""
return '**' + text + '**'
bold.api_description
```
```python
{'name': 'bold',
'description': 'make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text']}
```
一旦启用 `returns_named_value`,您应当声明返回值的名称,这将被处理成一个新的字段 `return_data`
```python
@tool_api(returns_named_value=True)
def bold(text: str) -> str:
"""make text bold
Args:
text (str): input text
Returns:
bold_text (str): bold text
"""
return '**' + text + '**'
bold.api_description
```
```python
{'name': 'bold',
'description': 'make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text'],
'return_data': [{'name': 'bold_text',
'description': 'bold text',
'type': 'STRING'}]}
```
有时工具可能返回一个 `dict` 或 `tuple`,如果你想在 `return_data` 中详细说明每个成员的含义而不是把它们当作一个整体,设置 `explode_return=True` 并在文档字符串的 Returns 部分中罗列它们。
```python
@tool_api(explode_return=True)
def list_args(a: str, b: int, c: float = 0.0) -> dict:
"""Return arguments in dict format
Args:
a (str): a
b (int): b
c (float): c
Returns:
dict: input arguments
- a (str): a
- b (int): b
- c: c
"""
return {'a': a, 'b': b, 'c': c}
```
```python
{'name': 'list_args',
'description': 'Return arguments in dict format',
'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'},
{'name': 'b', 'type': 'NUMBER', 'description': 'b'},
{'name': 'c', 'type': 'FLOAT', 'description': 'c'}],
'required': ['a', 'b'],
'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'},
{'name': 'b', 'description': 'b', 'type': 'NUMBER'},
{'name': 'c', 'description': 'c'}]}
```
```{warning}
目前仅支持 Google 格式的 Python 文档字符串。
```
## 接口设计
`BaseAction(description=None, parser=JsonParser, enable=True)` 是所有动作应该继承的基类,它接收三个初始化参数:
* **description**:一个工具描述的字典,用于设置实例属性 `description`。通常不需要显式地传递这个参数,因为 `BaseAction` 的元类将查找被 `tool_api` 装饰的方法,并组装它们的 `api_description` 构造一个类属性 `__tool_description__`,如果实例化时 `description` 为空,那么该实例属性将置为 `__tool_description__`。
* **parser**`BaseParser` 类,用于实例化一个动作解析器校验 `description` 所描述的工具的参数。例如,`JsonParser` 会要求模型在调用工具时传入一个 JSON 格式字符串或者 Python 字典,为了让 LLM 感知到该指令,它会在 `description` 中插入一个 `parameter_description` 字段。
```python
from lagent import BaseAction
action = BaseAction(
{
'name': 'bold',
'description': 'a function used to make text bold',
'parameters': [
{
'name': 'text', 'type': 'STRING', 'description': 'input content'
}
],
'required': ['text']
}
)
action.description
```
```python
{'name': 'bold',
'description': 'a function used to make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input content'}],
'required': ['text'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'}
```
* **enable**: 指明该动作是否生效。
### 自定义动作
一个简单工具必须实现 `run` 方法而工具包则应当避免将各子API名称定义为该保留字段。
```python
class Bold(BaseAction):
@tool_api
def run(self, text: str):
"""make text bold
Args:
text (str): input text
Returns:
str: bold text
"""
return '**' + text + '**'
class PhraseEmphasis(BaseAction):
"""a toolkit which provides different styles of text emphasis"""
@tool_api
def bold(self, text):
"""make text bold
Args:
text (str): input text
Returns:
str: bold text
"""
return '**' + text + '**'
@tool_api
def italic(self, text):
"""make text italic
Args:
text (str): input text
Returns:
str: italic text
"""
return '*' + text + '*'
# 查看默认工具描述
# Bold.__tool_description__, PhraseEmphasis.__tool_description__
```
### 自动注册
任何 `BaseAction` 的子类都会自动被注册。你可以使用 `list_tools()` 和 `get_tool()` 来查看所有工具类并通过工具名进行初始化。
```python
from lagent import list_tools, get_tool
list_tools()
```
```python
['BaseAction',
'InvalidAction',
'NoAction',
'FinishAction',
'ArxivSearch',
'BINGMap',
'GoogleScholar',
'GoogleSearch',
'IPythonInterpreter',
'PPT',
'PythonInterpreter',
'Bold',
'PhraseEmphasis']
```
创建一个 `PhraseEmphasis` 对象。
```python
action = get_tool('PhraseEmphasis')
action.description
```
```python
{'name': 'PhraseEmphasis',
'description': 'a toolkit which provides different styles of text emphasis',
'api_list': [{'name': 'bold',
'description': 'make text bold',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'},
{'name': 'italic',
'description': 'make text italic',
'parameters': [{'name': 'text',
'type': 'STRING',
'description': 'input text'}],
'required': ['text'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'}]}
```
## 工具调用
### 执行工具
`Action` 的 `__call__` 方法需要传入两个参数
* `inputs`: 其类型与动作绑定的 `BaseParser` 相关,通常是由大语言模型生成的字符串。
+ `JsonParser`: 允许传入 JSON 格式字符串或 Python 字典。
+ `TupleParser`: 允许传入字面量为元组的字符串或 Python 元组。
* `name`: 调用哪个 API默认为 `run`。
工具会返回一个封装了调用细节的 `ActionReturn` 对象。
* `args`: 一个字典,表示该动作的入参。
* `type`: 动作名称。
* `result`: 以字典为成员的列表,每个字典包含两个键——'type' 和 'content',发生异常时该字段为 `None`。
* `errmsg`: 错误信息,默认为 `None`。
以下是一个例子:
```python
from lagent import IPythonInterpreter, TupleParser
action1 = IPythonInterpreter()
ret = action1('{"command": "import math;math.sqrt(100)"}')
print(ret.result)
ret = action1({'command': 'import math;math.sqrt(100)'})
print(ret.result)
action2 = IPythonInterpreter(parser=TupleParser)
ret = action2('("import math;math.sqrt(100)", )')
print(ret.result)
ret = action2(('import math;math.sqrt(100)',))
print(ret.result)
```
```python
[{'type': 'text', 'content': '10.0'}]
[{'type': 'text', 'content': '10.0'}]
[{'type': 'text', 'content': '10.0'}]
[{'type': 'text', 'content': '10.0'}]
```
### 动态触发
Lagent 提供 `ActionExecutor` 接口管理多个工具,它会将工具包的 `api_list` 平展并将各 API 更名为 `{tool_name}.{api_name}`。
```python
from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter
executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()])
executor.get_actions_info() # 该结果会作为LLM系统提示词的一部分
```
```python
[{'name': 'ArxivSearch.get_arxiv_article_information',
'description': 'Run Arxiv search and get the article meta information.',
'parameters': [{'name': 'query',
'type': 'STRING',
'description': 'the content of search query'}],
'required': ['query'],
'return_data': [{'name': 'content',
'description': 'a list of 3 arxiv search papers',
'type': 'STRING'}],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'},
{'name': 'IPythonInterpreter',
'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.",
'parameters': [{'name': 'command',
'type': 'STRING',
'description': 'Python code'},
{'name': 'timeout',
'type': 'NUMBER',
'description': 'Upper bound of waiting time for Python script execution.'}],
'required': ['command'],
'parameter_description': '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'}]
```
通过动作执行器来触发一个工具
```python
ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}')
ret.result
```
```python
[{'type': 'text', 'content': '10.0'}]
```

View File

@@ -0,0 +1,333 @@
import copy
import hashlib
import json
import os
import streamlit as st
from lagent.actions import ActionExecutor, ArxivSearch, GoogleScholar, IPythonInterpreter
from lagent.agents.internlm2_agent import (INTERPRETER_CN, META_INS, PLUGIN_CN,
Internlm2Agent, Interlm2Protocol)
from lagent.llms.lmdepoly_wrapper import LMDeployClient
from lagent.llms.meta_template import INTERNLM2_META as META
from lagent.schema import AgentStatusCode
# from streamlit.logger import get_logger
class SessionState:
def init_state(self):
"""Initialize session state variables."""
st.session_state['assistant'] = []
st.session_state['user'] = []
action_list = [
GoogleScholar(
api_key=('a558de7dee10146326ca86fbaa0736b'
'dd947c9e646cd3f14da5aff177d6b2ff0')),
ArxivSearch(),
]
st.session_state['plugin_map'] = {
action.name: action
for action in action_list
}
st.session_state['model_map'] = {}
st.session_state['model_selected'] = None
st.session_state['plugin_actions'] = set()
st.session_state['history'] = []
def clear_state(self):
"""Clear the existing session state."""
st.session_state['assistant'] = []
st.session_state['user'] = []
st.session_state['model_selected'] = None
st.session_state['file'] = set()
if 'chatbot' in st.session_state:
st.session_state['chatbot']._session_history = []
class StreamlitUI:
def __init__(self, session_state: SessionState):
self.init_streamlit()
self.session_state = session_state
def init_streamlit(self):
"""Initialize Streamlit's UI settings."""
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
st.sidebar.title('模型控制')
st.session_state['file'] = set()
st.session_state['ip'] = None
def setup_sidebar(self):
"""Setup the sidebar for model and plugin selection."""
model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
meta_prompt = st.sidebar.text_area('系统提示词', value=META_INS)
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
model_ip = st.sidebar.text_input('模型IP', value='10.140.0.220:23333')
if model_name != st.session_state[
'model_selected'] or st.session_state['ip'] != model_ip:
st.session_state['ip'] = model_ip
model = self.init_model(model_name, model_ip)
self.session_state.clear_state()
st.session_state['model_selected'] = model_name
if 'chatbot' in st.session_state:
del st.session_state['chatbot']
else:
model = st.session_state['model_map'][model_name]
plugin_name = st.sidebar.multiselect(
'插件选择',
options=list(st.session_state['plugin_map'].keys()),
default=[],
)
da_flag = st.sidebar.checkbox(
'数据分析',
value=False,
)
plugin_action = [
st.session_state['plugin_map'][name] for name in plugin_name
]
if 'chatbot' in st.session_state:
if len(plugin_action) > 0:
st.session_state['chatbot']._action_executor = ActionExecutor(
actions=plugin_action)
else:
st.session_state['chatbot']._action_executor = None
if da_flag:
st.session_state[
'chatbot']._interpreter_executor = ActionExecutor(
actions=[IPythonInterpreter()])
else:
st.session_state['chatbot']._interpreter_executor = None
st.session_state['chatbot']._protocol._meta_template = meta_prompt
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
st.session_state[
'chatbot']._protocol.interpreter_prompt = da_prompt
if st.sidebar.button('清空对话', key='clear'):
self.session_state.clear_state()
uploaded_file = st.sidebar.file_uploader('上传文件')
return model_name, model, plugin_action, uploaded_file, model_ip
def init_model(self, option, ip=None):
"""Initialize the model based on the selected option."""
model_url = f'http://{ip}'
st.session_state['model_map'][option] = LMDeployClient(
path='internlm2-chat-20b',
url=model_url,
meta_template=META,
top_p=0.8,
top_k=100,
temperature=0,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
return st.session_state['model_map'][option]
def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""
return Internlm2Agent(
llm=model,
protocol=Interlm2Protocol(
tool=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(
plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
), ),
)
def render_user(self, prompt: str):
with st.chat_message('user'):
st.markdown(prompt)
def render_assistant(self, agent_return):
with st.chat_message('assistant'):
for action in agent_return.actions:
if (action) and (action.type != 'FinishAction'):
self.render_action(action)
st.markdown(agent_return.response)
def render_plugin_args(self, action):
action_name = action.type
args = action.args
import json
parameter_dict = dict(name=action_name, parameters=args)
parameter_str = '```json\n' + json.dumps(
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
st.markdown(parameter_str)
def render_interpreter_args(self, action):
st.info(action.type)
st.markdown(action.args['text'])
def render_action(self, action):
st.markdown(action.thought)
if action.type == 'IPythonInterpreter':
self.render_interpreter_args(action)
elif action.type == 'FinishAction':
pass
else:
self.render_plugin_args(action)
self.render_action_results(action)
def render_action_results(self, action):
"""Render the results of action, including text, images, videos, and
audios."""
if (isinstance(action.result, dict)):
if 'text' in action.result:
st.markdown('```\n' + action.result['text'] + '\n```')
if 'image' in action.result:
# image_path = action.result['image']
for image_path in action.result['image']:
image_data = open(image_path, 'rb').read()
st.image(image_data, caption='Generated Image')
if 'video' in action.result:
video_data = action.result['video']
video_data = open(video_data, 'rb').read()
st.video(video_data)
if 'audio' in action.result:
audio_data = action.result['audio']
audio_data = open(audio_data, 'rb').read()
st.audio(audio_data)
elif isinstance(action.result, list):
for item in action.result:
if item['type'] == 'text':
st.markdown('```\n' + item['content'] + '\n```')
elif item['type'] == 'image':
image_data = open(item['content'], 'rb').read()
st.image(image_data, caption='Generated Image')
elif item['type'] == 'video':
video_data = open(item['content'], 'rb').read()
st.video(video_data)
elif item['type'] == 'audio':
audio_data = open(item['content'], 'rb').read()
st.audio(audio_data)
if action.errmsg:
st.error(action.errmsg)
def main():
# logger = get_logger(__name__)
# Initialize Streamlit UI and setup sidebar
if 'ui' not in st.session_state:
session_state = SessionState()
session_state.init_state()
st.session_state['ui'] = StreamlitUI(session_state)
else:
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
_, model, plugin_action, uploaded_file, _ = st.session_state[
'ui'].setup_sidebar()
# Initialize chatbot if it is not already initialized
# or if the model has changed
if 'chatbot' not in st.session_state or model != st.session_state[
'chatbot']._llm:
st.session_state['chatbot'] = st.session_state[
'ui'].initialize_chatbot(model, plugin_action)
st.session_state['session_history'] = []
for prompt, agent_return in zip(st.session_state['user'],
st.session_state['assistant']):
st.session_state['ui'].render_user(prompt)
st.session_state['ui'].render_assistant(agent_return)
if user_input := st.chat_input(''):
with st.container():
st.session_state['ui'].render_user(user_input)
st.session_state['user'].append(user_input)
# Add file uploader to sidebar
if (uploaded_file
and uploaded_file.name not in st.session_state['file']):
st.session_state['file'].add(uploaded_file.name)
file_bytes = uploaded_file.read()
file_type = uploaded_file.type
if 'image' in file_type:
st.image(file_bytes, caption='Uploaded Image')
elif 'video' in file_type:
st.video(file_bytes, caption='Uploaded Video')
elif 'audio' in file_type:
st.audio(file_bytes, caption='Uploaded Audio')
# Save the file to a temporary location and get the path
postfix = uploaded_file.name.split('.')[-1]
# prefix = str(uuid.uuid4())
prefix = hashlib.md5(file_bytes).hexdigest()
filename = f'{prefix}.{postfix}'
file_path = os.path.join(root_dir, filename)
with open(file_path, 'wb') as tmpfile:
tmpfile.write(file_bytes)
file_size = os.stat(file_path).st_size / 1024 / 1024
file_size = f'{round(file_size, 2)} MB'
# st.write(f'File saved at: {file_path}')
user_input = [
dict(role='user', content=user_input),
dict(
role='user',
content=json.dumps(dict(path=file_path, size=file_size)),
name='file')
]
if isinstance(user_input, str):
user_input = [dict(role='user', content=user_input)]
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
for agent_return in st.session_state['chatbot'].stream_chat(
st.session_state['session_history'] + user_input):
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
with st.container():
st.session_state['ui'].render_plugin_args(
agent_return.actions[-1])
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif agent_return.state == AgentStatusCode.CODE_RETURN:
with st.container():
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif (agent_return.state == AgentStatusCode.STREAM_ING
or agent_return.state == AgentStatusCode.CODING):
# st.markdown(agent_return.response)
# 清除占位符的当前内容,并显示新内容
with st.container():
if agent_return.state != st.session_state['last_status']:
st.session_state['temp'] = ''
placeholder = st.empty()
st.session_state['placeholder'] = placeholder
if isinstance(agent_return.response, dict):
action = f"\n\n {agent_return.response['name']}: \n\n"
action_input = agent_return.response['parameters']
if agent_return.response['name'] == 'IPythonInterpreter':
action_input = action_input['command']
response = action + action_input
else:
response = agent_return.response
st.session_state['temp'] = response
st.session_state['placeholder'].markdown(
st.session_state['temp'])
elif agent_return.state == AgentStatusCode.END:
st.session_state['session_history'] += (user_input + agent_return.inner_steps)
agent_return = copy.deepcopy(agent_return)
agent_return.response = st.session_state['temp']
st.session_state['assistant'].append(
copy.deepcopy(agent_return))
st.session_state['last_status'] = agent_return.state
if __name__ == '__main__':
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.path.join(root_dir, 'tmp_dir')
os.makedirs(root_dir, exist_ok=True)
main()

View File

@@ -1,11 +1,61 @@
from typing import Type
from .action_executor import ActionExecutor
from .base_action import BaseAction
from .arxiv_search import ArxivSearch
from .base_action import TOOL_REGISTRY, BaseAction, tool_api
from .bing_map import BINGMap
from .builtin_actions import FinishAction, InvalidAction, NoAction
from .google_scholar_search import GoogleScholar
from .google_search import GoogleSearch
from .llm_qa import LLMQA
from .ipython_interpreter import IPythonInterpreter
from .parser import BaseParser, JsonParser, TupleParser
from .ppt import PPT
from .python_interpreter import PythonInterpreter
__all__ = [
'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction',
'FinishAction', 'GoogleSearch', 'PythonInterpreter', 'LLMQA'
'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction',
'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch',
'GoogleScholar', 'IPythonInterpreter', 'PythonInterpreter', 'PPT',
'BaseParser', 'JsonParser', 'TupleParser', 'tool_api', 'list_tools',
'get_tool_cls', 'get_tool'
]
def list_tools(with_class: bool = False):
"""List available tools
Args:
with_class (bool): whether to return the action class along
with its name. Defaults to ``False``.
Returns:
list: all action names
"""
return list(TOOL_REGISTRY.items()) if with_class else list(
TOOL_REGISTRY.keys())
def get_tool_cls(specifier: str) -> Type[BaseAction]:
"""Get the action class
Args:
specifier (:class:`str`): tool name
Returns:
Type[BaseAction]: action class
"""
return TOOL_REGISTRY.get_class(specifier)
def get_tool(specifier: str, *args, **kwargs) -> BaseAction:
"""Instantiate an action
Args:
specifier (str): tool name
args: positional arguments passed to the action's ``__init__`` method
kwargs: keyword arguments passed to the action's ``__init__`` method
Returns:
:class:`BaseAction`: action object
"""
return TOOL_REGISTRY.get(specifier, *args, **kwargs)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Dict, List, Union
from lagent.schema import ActionReturn, ActionValidCode
from .base_action import BaseAction
@@ -39,14 +39,20 @@ class ActionExecutor:
self.no_action = no_action
self.finish_action = finish_action
def get_actions_info(self, only_enable: bool = True) -> Dict:
if only_enable:
return {
k: v.description
for k, v in self.actions.items() if v.enable
}
else:
return {k: v.description for k, v in self.actions.items()}
def get_actions_info(self) -> List[Dict]:
actions = []
for action_name, action in self.actions.items():
if not action.enable:
continue
if action.is_toolkit:
for api in action.description['api_list']:
api_desc = api.copy()
api_desc['name'] = f"{action_name}.{api_desc['name']}"
actions.append(api_desc)
else:
action_desc = action.description.copy()
actions.append(action_desc)
return actions
def is_valid(self, name: str):
return name in self.actions and self.actions[name].enable
@@ -66,19 +72,17 @@ class ActionExecutor:
if name in self.actions:
del self.actions[name]
def __call__(self, name: str, command: Any) -> ActionReturn:
if isinstance(command, str):
args, kwargs = (command, ), {}
else:
args, kwargs = (), command
if not self.is_valid(name):
def __call__(self, name: str, command: str) -> ActionReturn:
action_name, api_name = (
name.split('.') if '.' in name else (name, 'run'))
if not self.is_valid(action_name):
if name == self.no_action.name:
action_return = self.no_action.run(*args, **kwargs)
action_return = self.no_action(command)
elif name == self.finish_action.name:
action_return = self.finish_action.run(*args, **kwargs)
action_return = self.finish_action(command)
else:
action_return = self.invalid_action(*args, **kwargs)
action_return = self.invalid_action(command)
else:
action_return = self.actions[name].run(*args, **kwargs)
action_return = self.actions[action_name](command, api_name)
action_return.valid = ActionValidCode.OPEN
return action_return

View File

@@ -0,0 +1,56 @@
from typing import Optional, Type
import arxiv
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
from lagent.schema import ActionReturn, ActionStatusCode
class ArxivSearch(BaseAction):
"""Search information from Arxiv.org. \
Useful for when you need to answer questions about Physics, Mathematics, \
Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
Electrical Engineering, and Economics from scientific articles on arxiv.org.
"""
def __init__(self,
top_k_results: int = 3,
max_query_len: int = 300,
doc_content_chars_max: int = 1500,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True):
super().__init__(description, parser, enable)
self.top_k_results = top_k_results
self.max_query_len = max_query_len
self.doc_content_chars_max = doc_content_chars_max
@tool_api(explode_return=True)
def get_arxiv_article_information(self, query: str) -> dict:
"""Run Arxiv search and get the article meta information.
Args:
query (:class:`str`): the content of search query
Returns:
:class:`dict`: article information
* content (str): a list of 3 arxiv search papers
"""
try:
results = arxiv.Search( # type: ignore
query[:self.max_query_len],
max_results=self.top_k_results).results()
except Exception as exc:
return ActionReturn(
errmsg=f'Arxiv exception: {exc}',
state=ActionStatusCode.HTTP_ERROR)
docs = [
f'Published: {result.updated.date()}\nTitle: {result.title}\n'
f'Authors: {", ".join(a.name for a in result.authors)}\n'
f'Summary: {result.summary[:self.doc_content_chars_max]}'
for result in results
]
if docs:
return {'content': '\n\n'.join(docs)}
return {'content': 'No good Arxiv Result was found'}

View File

@@ -1,57 +1,364 @@
from typing import Optional
import inspect
import logging
import re
from abc import ABCMeta
from copy import deepcopy
from functools import wraps
from typing import Annotated, Callable, Optional, Type, get_args, get_origin
from lagent.schema import ActionReturn
from class_registry import AutoRegister, ClassRegistry
from griffe import Docstring
from griffe.enumerations import DocstringSectionKind
from ..schema import ActionReturn, ActionStatusCode
from .parser import BaseParser, JsonParser, ParseError
logging.getLogger('griffe').setLevel(logging.ERROR)
TOOL_REGISTRY = ClassRegistry('__tool_name__', unique=True)
class BaseAction:
def tool_api(func: Optional[Callable] = None,
*,
explode_return: bool = False,
returns_named_value: bool = False,
**kwargs):
"""Turn functions into tools. It will parse typehints as well as docstrings
to build the tool description and attach it to functions via an attribute
``api_description``.
Examples:
.. code-block:: python
# typehints has higher priority than docstrings
from typing import Annotated
@tool_api
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
'''Add operation
Args:
x (int): a
y (int): b
'''
return a + b
print(add.api_description)
Args:
func (Optional[Callable]): function to decorate. Defaults to ``None``.
explode_return (bool): whether to flatten the dictionary or tuple return
as the ``return_data`` field. When enabled, it is recommended to
annotate the member in docstrings. Defaults to ``False``.
.. code-block:: python
@tool_api(explode_return=True)
def foo(a, b):
'''A simple function
Args:
a (int): a
b (int): b
Returns:
dict: information of inputs
* x: value of a
* y: value of b
'''
return {'x': a, 'y': b}
print(foo.api_description)
returns_named_value (bool): whether to parse ``thing: Description`` in
returns sections as a name and description, rather than a type and
description. When true, type must be wrapped in parentheses:
``(int): Description``. When false, parentheses are optional but
the items cannot be named: ``int: Description``. Defaults to ``False``.
Returns:
Callable: wrapped function or partial decorator
Important:
``return_data`` field will be added to ``api_description`` only
when ``explode_return`` or ``returns_named_value`` is enabled.
"""
def _detect_type(string):
field_type = 'STRING'
if 'list' in string:
field_type = 'Array'
elif 'str' not in string:
if 'float' in string:
field_type = 'FLOAT'
elif 'int' in string:
field_type = 'NUMBER'
elif 'bool' in string:
field_type = 'BOOLEAN'
return field_type
def _explode(desc):
kvs = []
desc = '\nArgs:\n' + '\n'.join([
' ' + item.lstrip(' -+*#.')
for item in desc.split('\n')[1:] if item.strip()
])
docs = Docstring(desc).parse('google')
if not docs:
return kvs
if docs[0].kind is DocstringSectionKind.parameters:
for d in docs[0].value:
d = d.as_dict()
if not d['annotation']:
d.pop('annotation')
else:
d['type'] = _detect_type(d.pop('annotation').lower())
kvs.append(d)
return kvs
def _parse_tool(function):
# remove rst syntax
docs = Docstring(
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
'google', returns_named_value=returns_named_value, **kwargs)
desc = dict(
name=function.__name__,
description=docs[0].value
if docs[0].kind is DocstringSectionKind.text else '',
parameters=[],
required=[],
)
args_doc, returns_doc = {}, []
for doc in docs:
if doc.kind is DocstringSectionKind.parameters:
for d in doc.value:
d = d.as_dict()
d['type'] = _detect_type(d.pop('annotation').lower())
args_doc[d['name']] = d
if doc.kind is DocstringSectionKind.returns:
for d in doc.value:
d = d.as_dict()
if not d['name']:
d.pop('name')
if not d['annotation']:
d.pop('annotation')
else:
d['type'] = _detect_type(d.pop('annotation').lower())
returns_doc.append(d)
sig = inspect.signature(function)
for name, param in sig.parameters.items():
if name == 'self':
continue
parameter = dict(
name=param.name,
type='STRING',
description=args_doc.get(param.name,
{}).get('description', ''))
annotation = param.annotation
if annotation is inspect.Signature.empty:
parameter['type'] = args_doc.get(param.name,
{}).get('type', 'STRING')
else:
if get_origin(annotation) is Annotated:
annotation, info = get_args(annotation)
if info:
parameter['description'] = info
while get_origin(annotation):
annotation = get_args(annotation)
parameter['type'] = _detect_type(str(annotation))
desc['parameters'].append(parameter)
if param.default is inspect.Signature.empty:
desc['required'].append(param.name)
return_data = []
if explode_return:
return_data = _explode(returns_doc[0]['description'])
elif returns_named_value:
return_data = returns_doc
if return_data:
desc['return_data'] = return_data
return desc
if callable(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.api_description = _parse_tool(func)
return wrapper
def decorate(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.api_description = _parse_tool(func)
return wrapper
return decorate
class ToolMeta(ABCMeta):
"""Metaclass of tools"""
def __new__(mcs, name, base, attrs):
is_toolkit, tool_desc = True, dict(
name=attrs.setdefault('__tool_name__', name),
description=Docstring(attrs.get('__doc__',
'')).parse('google')[0].value)
for key, value in attrs.items():
if callable(value) and hasattr(value, 'api_description'):
api_desc = getattr(value, 'api_description')
if key == 'run':
tool_desc['parameters'] = api_desc['parameters']
tool_desc['required'] = api_desc['required']
if api_desc['description']:
tool_desc['description'] = api_desc['description']
if api_desc.get('return_data'):
tool_desc['return_data'] = api_desc['return_data']
is_toolkit = False
break
tool_desc.setdefault('api_list', []).append(api_desc)
attrs['_is_toolkit'] = is_toolkit
attrs['__tool_description__'] = tool_desc
return super().__new__(mcs, name, base, attrs)
class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)):
"""Base class for all actions.
Args:
description (str, optional): The description of the action. Defaults to
None.
name (str, optional): The name of the action. If None, the name will
be class name. Defaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
description (:class:`Optional[dict]`): The description of the action.
Defaults to ``None``.
parser (:class:`Type[BaseParser]`): The parser class to process the
action's inputs and outputs. Defaults to :class:`JsonParser`.
enable (:class:`bool`): Whether the action is enabled. Defaults to
``True``.
Examples:
* simple tool
.. code-block:: python
class Bold(BaseAction):
'''Make text bold'''
@tool_api
def run(self, text: str):
'''
Args:
text (str): input text
Returns:
str: bold text
'''
return '**' + text + '**'
action = Bold()
* toolkit with multiple APIs
.. code-block:: python
class Calculator(BaseAction):
'''Calculator'''
@tool_api
def add(self, a, b):
'''Add operation
Args:
a (int): augend
b (int): addend
Returns:
int: sum
'''
return a + b
@tool_api
def sub(self, a, b):
'''Subtraction operation
Args:
a (int): minuend
b (int): subtrahend
Returns:
int: difference
'''
return a - b
action = Calculator()
"""
def __init__(self,
description: Optional[str] = None,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None) -> None:
if name is None:
name = self.__class__.__name__
self._name = name
self._description = description
self._disable_description = disable_description
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True):
self._description = deepcopy(description or self.__tool_description__)
self._name = self._description['name']
self._parser = parser(self)
self._enable = enable
def __call__(self, *args, **kwargs) -> ActionReturn:
raise NotImplementedError
def __repr__(self):
return f'{self.name}:{self.description}'
def __str__(self):
return self.__repr__()
def run(self, *args, **kwargs) -> ActionReturn:
return self.__call__(*args, **kwargs)
@property
def enable(self):
return self._enable
def __call__(self, inputs: str, name='run') -> ActionReturn:
fallback_args = {'inputs': inputs, 'name': name}
if not hasattr(self, name):
return ActionReturn(
fallback_args,
type=self.name,
errmsg=f'invalid API: {name}',
state=ActionStatusCode.API_ERROR)
try:
inputs = self._parser.parse_inputs(inputs, name)
except ParseError as exc:
return ActionReturn(
fallback_args,
type=self.name,
errmsg=exc.err_msg,
state=ActionStatusCode.ARGS_ERROR)
try:
outputs = getattr(self, name)(**inputs)
except Exception as exc:
return ActionReturn(
inputs,
type=self.name,
errmsg=str(exc),
state=ActionStatusCode.API_ERROR)
if isinstance(outputs, ActionReturn):
action_return = outputs
if not action_return.args:
action_return.args = inputs
if not action_return.type:
action_return.type = self.name
else:
result = self._parser.parse_outputs(outputs)
action_return = ActionReturn(inputs, type=self.name, result=result)
return action_return
@property
def name(self):
return self._name
@property
def description(self):
if self.enable:
return self._description
else:
return self._disable_description
def enable(self):
return self._enable
@property
def is_toolkit(self):
return self._is_toolkit
@property
def description(self) -> dict:
"""Description of the tool"""
return self._description
def __repr__(self):
return f'{self.description}'
__str__ = __repr__

145
lagent/actions/bing_map.py Normal file
View File

@@ -0,0 +1,145 @@
import json
import os
from typing import Optional, Type
import requests
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
class BINGMap(BaseAction):
"""BING Map plugin for looking up map information"""
def __init__(self,
key: Optional[str] = None,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True) -> None:
super().__init__(description, parser, enable)
key = os.environ.get('BING_MAP_KEY', key)
if key is None:
raise ValueError(
'Please set BING Map API key either in the environment '
'as BING_MAP_KEY or pass it as `key` parameter.')
self.key = key
self.base_url = 'http://dev.virtualearth.net/REST/V1/'
@tool_api(explode_return=True)
def get_distance(self, start: str, end: str) -> dict:
"""Get the distance between two locations in km.
Args:
start (:class:`str`): The start location
end (:class:`str`): The end location
Returns:
:class:`dict`: distance information
* distance (str): the distance in km.
"""
# Request URL
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
# GET request
r = requests.get(url)
# TODO check request status?
data = json.loads(r.text)
# Extract route information
route = data['resourceSets'][0]['resources'][0]
# Extract distance in miles
distance = route['travelDistance']
return dict(distance=distance)
@tool_api(explode_return=True)
def get_route(self, start: str, end: str) -> dict:
"""Get the route between two locations in km.
Args:
start (:class:`str`): The start location
end (:class:`str`): The end location
Returns:
:class:`dict`: route information
* route (list): the route, a list of actions.
"""
# Request URL
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
# GET request
r = requests.get(url)
data = json.loads(r.text)
# Extract route information
route = data['resourceSets'][0]['resources'][0]
itinerary = route['routeLegs'][0]['itineraryItems']
# Extract route text information
route_text = []
for item in itinerary:
if 'instruction' in item:
route_text.append(item['instruction']['text'])
return dict(route=route_text)
@tool_api(explode_return=True)
def get_coordinates(self, location: str) -> dict:
"""Get the coordinates of a location.
Args:
location (:class:`str`): the location need to get coordinates.
Returns:
:class:`dict`: coordinates information
* latitude (float): the latitude of the location.
* longitude (float): the longitude of the location.
"""
url = self.base_url + 'Locations'
params = {'query': location, 'key': self.key}
response = requests.get(url, params=params)
json_data = response.json()
coordinates = json_data['resourceSets'][0]['resources'][0]['point'][
'coordinates']
return dict(latitude=coordinates[0], longitude=coordinates[1])
@tool_api(explode_return=True)
def search_nearby(self,
search_term: str,
places: str = 'unknown',
latitude: float = 0.0,
longitude: float = 0.0,
radius: int = 5000) -> dict: # radius in meters
"""Search for places nearby a location, within a given radius, and \
return the results into a list. You can use either the places name or the \
latitude and longitude.
Args:
search_term (:class:`str`): the place name.
places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
radius (:class:`int`): radius in meters. Defaults to ``5000``.
Returns:
:class:`dict`: places information
* places (list): the list of places, each place is a dict with name and address, at most 5 places.
"""
url = self.base_url + 'LocalSearch'
if places != 'unknown':
pos = self.get_coordinates(**{'location': places})
latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
# Build the request query string
params = {
'query': search_term,
'userLocation': f'{latitude},{longitude}',
'radius': radius,
'key': self.key
}
# Make the request
response = requests.get(url, params=params)
# Parse the response
response_data = json.loads(response.content)
# Get the results
results = response_data['resourceSets'][0]['resources']
addresses = []
for result in results:
name = result['name']
address = result['Address']['formattedAddress']
addresses.append(dict(name=name, address=address))
if len(addresses) == 5:
break
return dict(place=addresses)

View File

@@ -1,6 +1,7 @@
from typing import Optional
from lagent.actions.base_action import BaseAction
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
@@ -19,12 +20,13 @@ class InvalidAction(BaseAction):
def __init__(self,
err_msg:
str = 'The action is invalid, please check the action name.',
**kwargs) -> None:
super().__init__(enable=False, **kwargs)
description: Optional[dict] = None,
parser=BaseParser) -> None:
super().__init__(description, parser, enable=False)
self._err_msg = err_msg
def __call__(self, err_msg: Optional[str] = None):
@tool_api
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
"""Return the error message.
Args:
@@ -35,7 +37,7 @@ class InvalidAction(BaseAction):
action_return = ActionReturn(
url=None,
args=dict(text=err_msg),
errmsg=err_msg if err_msg else self._err_msg,
errmsg=err_msg or self._err_msg,
type=self.name,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
@@ -51,12 +53,15 @@ class NoAction(BaseAction):
'Please follow the format'.
"""
def __init__(self, err_msg: str = 'Please follow the format', **kwargs):
super().__init__(enable=False, **kwargs)
def __init__(self,
err_msg: str = 'Please follow the format',
description: Optional[dict] = None,
parser=BaseParser):
super().__init__(description, parser, enable=False)
self._err_msg = err_msg
def __call__(self, err_msg: Optional[str] = None):
@tool_api
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
"""Return the error message.
Args:
@@ -71,7 +76,7 @@ class NoAction(BaseAction):
url=None,
args=dict(text=err_msg),
type=self.name,
errmsg=err_msg if err_msg else self._err_msg,
errmsg=err_msg or self._err_msg,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
return action_return
@@ -81,7 +86,11 @@ class FinishAction(BaseAction):
"""This is a finish action class, which is used to return the final
result."""
def __call__(self, response: str) -> ActionReturn:
def __init__(self, description: Optional[dict] = None, parser=BaseParser):
super().__init__(description, parser, enable=True)
@tool_api
def run(self, response: str) -> ActionReturn:
"""Return the final result.
Args:
@@ -93,7 +102,7 @@ class FinishAction(BaseAction):
action_return = ActionReturn(
url=None,
args=dict(text=response),
result=dict(text=response),
result=[dict(type='text', content=response)],
type=self.name,
valid=ActionValidCode.FINISH,
state=ActionStatusCode.SUCCESS)

View File

@@ -0,0 +1,267 @@
import os
from typing import Optional, Type
from serpapi import GoogleSearch
from lagent.actions.base_action import BaseAction, tool_api
from lagent.schema import ActionReturn, ActionStatusCode
from .parser import BaseParser, JsonParser
class GoogleScholar(BaseAction):
"""Plugin for google scholar search
Args:
api_key (str): API KEY to use serper google search API,
You can create a free API key at https://serper.dev.
description (dict): The description of the action. Defaults to ``None``.
parser (Type[BaseParser]): The parser class to process the
action's inputs and outputs. Defaults to :class:`JsonParser`.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
"""
def __init__(self,
api_key: Optional[str] = None,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True):
super().__init__(description, parser, enable)
api_key = os.environ.get('SERPER_API_KEY', api_key)
if api_key is None:
raise ValueError(
'Please set Serper API key either in the environment '
'as SERPER_API_KEY or pass it as `api_key` parameter.')
self.api_key = api_key
@tool_api(explode_return=True)
def search_google_scholar(
self,
query: str,
cites: Optional[str] = None,
as_ylo: Optional[int] = None,
as_yhi: Optional[int] = None,
scisbd: Optional[int] = None,
cluster: Optional[str] = None,
hl: Optional[str] = None,
lr: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
as_sdt: Optional[str] = None,
safe: Optional[str] = None,
filter: Optional[str] = None,
as_vis: Optional[str] = None,
) -> dict:
"""Search for scholarly articles based on a query according to the google scholar
Args:
query (str): The query to search for.
cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
hl (Optional[str]): The language to use for the Google Scholar search.
lr (Optional[str]): One or multiple languages to limit the search to.
start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
num (Optional[int]): The maximum number of results to return, limited to 20.
as_sdt (Optional[str]): Can be used either as a search type or a filter.
safe (Optional[str]): The level of filtering for adult content.
filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
as_vis (Optional[str]): Defines whether to include citations or not.
Returns:
:class:`dict`: article information
- title: a list of the titles of the three selected papers
- cited_by: a list of the citation numbers of the three selected papers
- organic_id: a list of the organic results' ids of the three selected papers
- pub_info: publication information of selected papers
"""
params = {
'q': query,
'engine': 'google_scholar',
'api_key': self.api_key,
'cites': cites,
'as_ylo': as_ylo,
'as_yhi': as_yhi,
'scisbd': scisbd,
'cluster': cluster,
'hl': hl,
'lr': lr,
'start': start,
'num': num,
'as_sdt': as_sdt,
'safe': safe,
'filter': filter,
'as_vis': as_vis
}
search = GoogleSearch(params)
try:
r = search.get_dict()
results = r['organic_results']
title = []
snippets = []
cited_by = []
organic_id = []
pub_info = []
for item in results[:3]:
title.append(item['title'])
pub_info.append(item['publication_info']['summary'])
citation = item['inline_links'].get('cited_by', {'total': ''})
cited_by.append(citation['total'])
snippets.append(item['snippet'])
organic_id.append(item['result_id'])
return dict(
title=title,
cited_by=cited_by,
organic_id=organic_id,
snippets=snippets)
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
@tool_api(explode_return=True)
def get_author_information(self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None) -> dict:
"""Search for an author's information by author's id provided by get_author_id.
Args:
author_id (str): Required. The ID of an author.
hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
view_op (Optional[str]): Used for viewing specific parts of a page.
sort (Optional[str]): Used for sorting and refining articles.
citation_id (Optional[str]): Used for retrieving individual article citation.
start (Optional[int]): Defines the result offset. Default is 0.
num (Optional[int]): Defines the number of results to return. Default is 20.
no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
output (Optional[str]): Defines the final output you want. Default is 'json'.
Returns:
:class:`dict`: author information
* name: author's name
* affliation: the affliation of the author
* articles: at most 3 articles by the author
* website: the author's homepage url
"""
params = {
'engine': 'google_scholar_author',
'author_id': author_id,
'api_key': self.api_key,
'hl': hl,
'view_op': view_op,
'sort': sort,
'citation_id': citation_id,
'start': start,
'num': num,
'no_cache': no_cache,
'async': async_req,
'output': output
}
try:
search = GoogleSearch(params)
results = search.get_dict()
author = results['author']
articles = results.get('articles', [])
return dict(
name=author['name'],
affiliations=author.get('affiliations', ''),
website=author.get('website', ''),
articles=[
dict(title=article['title'], authors=article['authors'])
for article in articles[:3]
])
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
@tool_api(explode_return=True)
def get_citation_format(self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json') -> dict:
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
Args:
q (str): ID of an individual Google Scholar organic search result.
no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
Returns:
:class:`dict`: citation format
* authors: the authors of the article
* citation: the citation format of the article
"""
params = {
'q': q,
'engine': 'google_scholar_cite',
'api_key': self.api_key,
'no_cache': no_cache,
'async': async_,
'output': output
}
try:
search = GoogleSearch(params)
results = search.get_dict()
citation = results['citations']
citation_info = citation[0]['snippet']
return citation_info
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
@tool_api(explode_return=True)
def get_author_id(self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json') -> dict:
"""The getAuthorId function is used to get the author's id by his or her name.
Args:
mauthors (str): Defines the author you want to search for.
hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
_async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
Returns:
:class:`dict`: author id
* author_id: the author_id of the author
"""
params = {
'mauthors': mauthors,
'engine': 'google_scholar_profiles',
'api_key': self.api_key,
'hl': hl,
'after_author': after_author,
'before_author': before_author,
'no_cache': no_cache,
'async': _async,
'output': output
}
try:
search = GoogleSearch(params)
results = search.get_dict()
profile = results['profiles']
author_info = dict(author_id=profile[0]['author_id'])
return author_info
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

View File

@@ -1,15 +1,11 @@
import os
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Type, Union
import requests
from lagent.schema import ActionReturn, ActionStatusCode
from .base_action import BaseAction
DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。
当你需要对于一个特定问题找到简短明了的回答时,可以使用它。
输入应该是一个搜索查询。
"""
from .base_action import BaseAction, tool_api
from .parser import BaseParser, JsonParser
class GoogleSearch(BaseAction):
@@ -28,15 +24,10 @@ class GoogleSearch(BaseAction):
timeout (int): Upper bound of waiting time for a serper request.
search_type (str): Serper API support ['search', 'images', 'news',
'places'] types of search, currently we only support 'search'.
k (int): select first k results in the search results as response.
description (str): The description of the action. Defaults to
None.
name (str, optional): The name of the action. If None, the name will
be class name. Defaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
description (dict): The description of the action. Defaults to ``None``.
parser (Type[BaseParser]): The parser class to process the
action's inputs and outputs. Defaults to :class:`JsonParser`.
enable (bool): Whether the action is enabled. Defaults to ``True``.
"""
result_key_for_type = {
'news': 'news',
@@ -49,36 +40,29 @@ class GoogleSearch(BaseAction):
api_key: Optional[str] = None,
timeout: int = 5,
search_type: str = 'search',
k: int = 10,
description: str = DEFAULT_DESCRIPTION,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None) -> None:
super().__init__(description, name, enable, disable_description)
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True):
super().__init__(description, parser, enable)
api_key = os.environ.get('SERPER_API_KEY', api_key)
if api_key is None:
raise ValueError(
'Please set Serper API key either in the environment '
' as SERPER_API_KEY or pass it as `api_key` parameter.')
'as SERPER_API_KEY or pass it as `api_key` parameter.')
self.api_key = api_key
self.timeout = timeout
self.search_type = search_type
self.k = k
def __call__(self, query: str) -> ActionReturn:
"""Return the search response.
@tool_api
def run(self, query: str, k: int = 10) -> ActionReturn:
"""一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时可以使用它。输入应该是一个搜索查询。
Args:
query (str): The search content.
Returns:
ActionReturn: The action return.
query (str): the search content
k (int): select first k results in the search results as response
"""
tool_return = ActionReturn(url=None, args=None, type=self.name)
status_code, response = self._search(
query, search_type=self.search_type, k=self.k)
tool_return = ActionReturn(type=self.name)
status_code, response = self._search(query, k=k)
# convert search results to ToolReturn format
if status_code == -1:
tool_return.errmsg = response
@@ -139,7 +123,7 @@ class GoogleSearch(BaseAction):
def _search(self,
search_term: str,
search_type: str = 'search',
search_type: Optional[str] = None,
**kwargs) -> Tuple[int, Union[dict, str]]:
"""HTTP requests to Serper API.
@@ -166,7 +150,7 @@ class GoogleSearch(BaseAction):
}
try:
response = requests.post(
f'https://google.serper.dev/{search_type}',
f'https://google.serper.dev/{search_type or self.search_type}',
headers=headers,
params=params,
timeout=self.timeout)

View File

@@ -0,0 +1,296 @@
import base64
import io
import logging
import os
import queue
import re
import signal
import sys
import traceback
import uuid
from typing import Optional, Tuple, Type
import json5
import PIL.Image
from jupyter_client import KernelManager
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
from lagent.schema import ActionReturn, ActionStatusCode
START_CODE = """
def input(*args, **kwargs):
raise NotImplementedError('Python input() function is disabled.')
get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!')
{}
""" # noqa
class TimeoutError(Exception):
pass
class IPythonInterpreter(BaseAction):
"""A IPython executor that can execute Python scripts in a jupyter manner.
Args:
timeout (int): Upper bound of waiting time for Python script execution.
Defaults to 20.
user_data_dir (str, optional): Specified the user data directory for files
loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
Defaults to `ENV`.
work_dir (str, optional): Specify which directory to save output images to.
Defaults to ``'./work_dir/tmp_dir'``.
description (dict): The description of the action. Defaults to ``None``.
parser (Type[BaseParser]): The parser class to process the
action's inputs and outputs. Defaults to :class:`JsonParser`.
enable (bool, optional): Whether the action is enabled. Defaults to ``True``.
"""
_KERNEL_CLIENTS = {}
def __init__(self,
timeout: int = 20,
user_data_dir: str = 'ENV',
work_dir='./work_dir/tmp_dir',
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True):
super().__init__(description, parser, enable)
self.timeout = timeout
if user_data_dir == 'ENV':
user_data_dir = os.environ.get('USER_DATA_DIR', '')
if user_data_dir:
user_data_dir = os.path.dirname(user_data_dir)
user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
self.user_data_dir = user_data_dir
self._initialized = False
self.work_dir = work_dir
if not os.path.exists(self.work_dir):
os.makedirs(self.work_dir, exist_ok=True)
@staticmethod
def start_kernel():
# start the kernel and manager
km = KernelManager()
km.start_kernel()
kc = km.client()
return km, kc
def initialize(self):
if self._initialized:
return
pid = os.getpid()
if pid not in self._KERNEL_CLIENTS:
self._KERNEL_CLIENTS[pid] = self.start_kernel()
self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid]
self._initialized = True
self._call(START_CODE.format(self.user_data_dir), None)
def reset(self):
if not self._initialized:
self.initialize()
else:
code = "get_ipython().run_line_magic('reset', '-f')\n" + \
START_CODE.format(self.user_data_dir)
self._call(code, None)
def _call(self,
command: str,
timeout: Optional[int] = None) -> Tuple[str, bool]:
self.initialize()
command = extract_code(command)
# check previous remaining result
while True:
try:
msg = self.kernel_client.get_iopub_msg(timeout=5)
msg_type = msg['msg_type']
if msg_type == 'status':
if msg['content'].get('execution_state') == 'idle':
break
except queue.Empty:
# assume no result
break
self.kernel_client.execute(command)
def _inner_call():
result = ''
images = []
succeed = True
image_idx = 0
while True:
text = ''
image = ''
finished = False
msg_type = 'error'
try:
msg = self.kernel_client.get_iopub_msg(timeout=20)
msg_type = msg['msg_type']
if msg_type == 'status':
if msg['content'].get('execution_state') == 'idle':
finished = True
elif msg_type == 'execute_result':
text = msg['content']['data'].get('text/plain', '')
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = publish_image_to_local(
image_b64, self.work_dir)
image_idx += 1
image = '![fig-%03d](%s)' % (image_idx, image_url)
elif msg_type == 'display_data':
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = publish_image_to_local(
image_b64, self.work_dir)
image_idx += 1
image = '![fig-%03d](%s)' % (image_idx, image_url)
else:
text = msg['content']['data'].get('text/plain', '')
elif msg_type == 'stream':
msg_type = msg['content']['name'] # stdout, stderr
text = msg['content']['text']
elif msg_type == 'error':
succeed = False
text = escape_ansi('\n'.join(
msg['content']['traceback']))
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
text = f'Timeout. No response after {timeout} seconds.' # noqa
except queue.Empty:
# stop current task in case break next input.
self.kernel_manager.interrupt_kernel()
succeed = False
text = f'Timeout. No response after {timeout} seconds.'
finished = True
except Exception:
succeed = False
msg = ''.join(traceback.format_exception(*sys.exc_info()))
# text = 'The code interpreter encountered an unexpected error.' # noqa
text = msg
logging.warning(msg)
finished = True
if text:
# result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
result += f'{text}'
if image:
images.append(image_url)
if finished:
return succeed, dict(text=result, image=images)
try:
if timeout:
def handler(signum, frame):
raise TimeoutError()
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
succeed, result = _inner_call()
except TimeoutError:
succeed = False
text = 'The code interpreter encountered an unexpected error.'
result = f'\n\nerror:\n\n```\n{text}\n```'
finally:
if timeout:
signal.alarm(0)
# result = result.strip('\n')
return succeed, result
@tool_api
def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
"""When you send a message containing Python code to python, it will be \
executed in a stateful Jupyter notebook environment. python will respond with \
the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' \
can be used to save and persist user files. Internet access for this session is \
disabled. Do not make external web requests or API calls as they will fail.
Args:
command (:class:`str`): Python code
timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
"""
tool_return = ActionReturn(url=None, args=None, type=self.name)
tool_return.args = dict(text=command)
succeed, result = self._call(command, timeout)
if succeed:
text = result['text']
image = result.get('image', [])
resp = [dict(type="text", content=text)]
if image:
resp.extend([dict(type="image", content=im) for im in image])
tool_return.result = resp
# tool_return.result = dict(
# text=result['text'], image=result.get('image', [])[0])
tool_return.state = ActionStatusCode.SUCCESS
else:
tool_return.errmsg = result.get("text", "") if isinstance(
result, dict) else result
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
def extract_code(text):
# Match triple backtick blocks first
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
# Match single backtick blocks second
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
if triple_match:
text = triple_match.group(1)
elif single_match:
text = single_match.group(1)
else:
try:
text = json5.loads(text)['code']
except Exception:
pass
# If no code blocks found, return original text
return text
def escape_ansi(line):
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', line)
def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
image_file = str(uuid.uuid4()) + '.png'
local_image_file = os.path.join(work_dir, image_file)
png_bytes = base64.b64decode(image_base64)
assert isinstance(png_bytes, bytes)
bytes_io = io.BytesIO(png_bytes)
PIL.Image.open(bytes_io).save(local_image_file, 'png')
return local_image_file
# local test for code interpreter
def get_multiline_input(hint):
print(hint)
print('// Press ENTER to make a new line. Press CTRL-D to end input.')
lines = []
while True:
try:
line = input()
except EOFError: # CTRL-D
break
lines.append(line)
print('// Input received.')
if lines:
return '\n'.join(lines)
else:
return ''
if __name__ == '__main__':
code_interpreter = IPythonInterpreter()
while True:
print(code_interpreter(get_multiline_input('Enter python code:')))

View File

@@ -1,56 +0,0 @@
from typing import Optional, Union
from lagent.llms.base_api import BaseAPIModel
from lagent.llms.base_llm import BaseModel
from lagent.schema import ActionReturn, ActionStatusCode
from .base_action import BaseAction
DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。
当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。
"""
class LLMQA(BaseAction):
"""An LLM Wrapper as BaseAction type.
Args:
llm (BaseModel or BaseAPIModel): a LLM service which can chat.
description (str): The description of the action. Defaults to
None.
name (str, optional): The name of the action. If None, the name will
be class name. Defaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
"""
def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
description: str = DEFAULT_DESCRIPTION,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None) -> None:
super().__init__(description, name, enable, disable_description)
self._llm = llm
def __call__(self, query: str) -> ActionReturn:
"""Return the QA response.
Args:
query (str): The query content.
Returns:
ActionReturn: The action return.
"""
tool_return = ActionReturn(url=None, args=None)
try:
response = self._llm.generate_from_template(query, 512)
tool_return.result = dict(text=str(response))
tool_return.state = ActionStatusCode.SUCCESS
except Exception as e:
tool_return.result = dict(text=str(e))
tool_return.state = ActionStatusCode.API_ERROR
return tool_return

139
lagent/actions/parser.py Normal file
View File

@@ -0,0 +1,139 @@
import json
import re
from ast import literal_eval
from typing import Any, List, Union
class ParseError(Exception):
"""Parsing exception class"""
def __init__(self, err_msg: str):
self.err_msg = err_msg
class BaseParser:
"""Base parser to process inputs and outputs of actions.
Args:
action (:class:`BaseAction`): action to validate
Attributes:
PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
LLMs should follow when generating arguments for decided tools.
"""
PARAMETER_DESCRIPTION: str = ''
def __init__(self, action):
self.action = action
self._api2param = {}
self._api2required = {}
# perform basic argument validation
if action.description:
for api in action.description.get('api_list',
[action.description]):
name = (f'{action.name}.{api["name"]}'
if self.action.is_toolkit else api['name'])
required_parameters = set(api['required'])
all_parameters = {j['name'] for j in api['parameters']}
if not required_parameters.issubset(all_parameters):
raise ValueError(
f'unknown parameters for function "{name}": '
f'{required_parameters - all_parameters}')
if self.PARAMETER_DESCRIPTION:
api['parameter_description'] = self.PARAMETER_DESCRIPTION
api_name = api['name'] if self.action.is_toolkit else 'run'
self._api2param[api_name] = api['parameters']
self._api2required[api_name] = api['required']
def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
"""parse inputs LLMs generate for the action
Args:
inputs (:class:`str`): input string extracted from responses
Returns:
:class:`dict`: processed input
"""
inputs = {self._api2param[name][0]['name']: inputs}
return inputs
def parse_outputs(self, outputs: Any) -> List[dict]:
"""parser outputs returned by the action
Args:
outputs (:class:`Any`): raw output of the action
Returns:
:class:`List[dict]`: processed output of which each member is a
dictionary with two keys - 'type' and 'content'.
"""
if isinstance(outputs, dict):
outputs = json.dumps(outputs, ensure_ascii=False)
elif not isinstance(outputs, str):
outputs = str(outputs)
return [{'type': 'text', 'content': outputs}]
class JsonParser(BaseParser):
"""Json parser to convert input string into a dictionary.
Args:
action (:class:`BaseAction`): action to validate
"""
PARAMETER_DESCRIPTION = '如果调用该工具你必须使用Json格式 {key: value} 传参其中key为参数名称'
def parse_inputs(self,
inputs: Union[str, dict],
name: str = 'run') -> dict:
if not isinstance(inputs, dict):
try:
match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
re.S)
if match:
inputs = match.group(2).strip()
inputs = json.loads(inputs)
except json.JSONDecodeError as exc:
raise ParseError(f'invalid json format: {inputs}') from exc
input_keys = set(inputs)
all_keys = {param['name'] for param in self._api2param[name]}
if not input_keys.issubset(all_keys):
raise ParseError(f'unknown arguments: {input_keys - all_keys}')
required_keys = set(self._api2required[name])
if not input_keys.issuperset(required_keys):
raise ParseError(
f'missing required arguments: {required_keys - input_keys}')
return inputs
class TupleParser(BaseParser):
"""Tuple parser to convert input string into a tuple.
Args:
action (:class:`BaseAction`): action to validate
"""
PARAMETER_DESCRIPTION = '如果调用该工具你必须使用Tuple格式 (arg1, arg2, arg3) 传参,且参数是有序的'
def parse_inputs(self,
inputs: Union[str, tuple],
name: str = 'run') -> dict:
if not isinstance(inputs, tuple):
try:
inputs = literal_eval(inputs)
except Exception as exc:
raise ParseError(f'invalid tuple format: {inputs}') from exc
if len(inputs) < len(self._api2required[name]):
raise ParseError(
f'API takes {len(self._api2required[name])} required positional '
f'arguments but {len(inputs)} were given')
if len(inputs) > len(self._api2param[name]):
raise ParseError(
f'API takes {len(self._api2param[name])} positional arguments '
f'but {len(inputs)} were given')
inputs = {
self._api2param[name][i]['name']: item
for i, item in enumerate(inputs)
}
return inputs

157
lagent/actions/ppt.py Normal file
View File

@@ -0,0 +1,157 @@
from typing import Dict, Optional, Type
from pptx import Presentation
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
THEME_MAPPING = {
'Default': {
'template': None,
'title': 'Title Slide',
'single': 'Title and Content',
'two': 'Tow content',
}
}
class PPT(BaseAction):
"""Plugin to create ppt slides with text, paragraph, images in good looking styles"""
def __init__(self,
theme_mapping: Optional[Dict[str, dict]] = None,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True):
super().__init__(description, parser, enable)
self.theme_mapping = theme_mapping or THEME_MAPPING
self.pointer = None
self.location = None
@tool_api(explode_return=True)
def create_file(self, theme: str, abs_location: str) -> dict:
"""Create a pptx file with specific themes
Args:
theme (:class:`str`): the theme used
abs_location (:class:`str`): the ppt file's absolute location
Returns:
:class:`dict`: operation status
* status: the result of the execution
"""
self.location = abs_location
try:
self.pointer = Presentation(self.theme_mapping[theme]['template'])
self.pointer.slide_master.name = theme
# print('created')
except Exception as e:
print(e)
return dict(status='created a ppt file.')
@tool_api(explode_return=True)
def add_first_page(self, title: str, subtitle: str) -> dict:
"""Add the first page of ppt.
Args:
title (:class:`str`): the title of ppt
subtitle (:class:`str`): the subtitle of ppt
Returns:
:class:`dict`: operation status
* status: the result of the execution
"""
layout_name = self.theme_mapping[
self.pointer.slide_master.name]['title']
layout = next(i for i in self.pointer.slide_master.slide_layouts
if i.name == layout_name)
slide = self.pointer.slides.add_slide(layout)
ph_title, ph_subtitle = slide.placeholders
ph_title.text = title
if subtitle:
ph_subtitle.text = subtitle
return dict(status='added page')
@tool_api(explode_return=True)
def add_text_page(self, title: str, bullet_items: str) -> dict:
"""Add text page of ppt
Args:
title (:class:`str`): the title of the page
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
Returns:
:class:`dict`: operation status
* status: the result of the execution
"""
layout_name = self.theme_mapping[
self.pointer.slide_master.name]['single']
layout = next(i for i in self.pointer.slide_master.slide_layouts
if i.name == layout_name)
slide = self.pointer.slides.add_slide(layout)
ph_title, ph_body = slide.placeholders
ph_title.text = title
ph = ph_body
tf = ph.text_frame
for i, item in enumerate(bullet_items.split('[SPAN]')):
if i == 0:
p = tf.paragraphs[0]
else:
p = tf.add_paragraph()
p.text = item.strip()
p.level = 0
return dict(status='added page')
@tool_api(explode_return=True)
def add_text_image_page(self, title: str, bullet_items: str,
image: str) -> dict:
"""Add a text page with one image. Image should be a path
Args:
title (:class:`str`): the title of the page
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
image (:class:`str`): the path of the image
Returns:
:class:`dict`: operation status
* status: the result of the execution
"""
layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
layout = next(i for i in self.pointer.slide_master.slide_layouts
if i.name == layout_name)
slide = self.pointer.slides.add_slide(layout)
ph_title, ph_body1, ph_body2 = slide.placeholders
ph_title.text = title
ph = ph_body2
image_pil = image.to_pil()
left = ph.left
width = ph.width
height = int(width / image_pil.width * image_pil.height)
top = (ph.top + (ph.top + ph.height)) // 2 - height // 2
slide.shapes.add_picture(image.to_path(), left, top, width, height)
ph = ph_body1
tf = ph.text_frame
for i, item in enumerate(bullet_items.split('[SPAN]')):
if i == 0:
p = tf.paragraphs[0]
else:
p = tf.add_paragraph()
p.text = item.strip()
p.level = 0
return dict(status='added page')
@tool_api(explode_return=True)
def submit_file(self) -> dict:
"""When all steps done, YOU MUST use submit_file() to submit your work.
Returns:
:class:`dict`: operation status
* status: the result of the execution
"""
# file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx')
# self.pointer.save(file_path)
# retreival_url = upload_file(file_path)
self.pointer.save(self.location)
return dict(status=f'submitted. view ppt at {self.location}')

View File

@@ -1,11 +1,12 @@
import copy
import io
from contextlib import redirect_stdout
from typing import Any, Optional
from typing import Any, Optional, Type
from func_timeout import FunctionTimedOut, func_set_timeout
from lagent.actions.base_action import BaseAction
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
from lagent.schema import ActionReturn, ActionStatusCode
@@ -29,72 +30,70 @@ class GenericRuntime:
return eval(expr, self._global_vars)
DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
```python
# import 依赖包
import xxx
def solution():
# 初始化一些变量
variable_names_with_real_meaning = xxx
# 步骤一
mid_variable = func(variable_names_with_real_meaning)
# 步骤 x
mid_variable = func(mid_variable)
# 最后结果
final_answer = func(mid_variable)
return final_answer
```"""
class PythonInterpreter(BaseAction):
"""A Python executor that can execute Python scripts.
Args:
description (str): The description of the action. Defaults to
DEFAULT_DESCRIPTION.
answer_symbol (str, Optional): the answer symbol from LLM
answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
answer_expr (str, Optional): the answer function name of the Python
script. Default to 'solution()'.
answer_from_stdout (boolean): whether the execution results is from
stdout.
name (str, optional): The name of the action. If None, the name will
be class nameDefaults to None.
script. Defaults to ``'solution()'``.
answer_from_stdout (boolean, Optional): whether the execution results is from
stdout. Defaults to ``False``.
timeout (int, Optional): Upper bound of waiting time for Python script execution.
Defaults to ``20``.
description (dict, Optional): The description of the action. Defaults to ``None``.
parser (Type[BaseParser]): The parser class to process the
action's inputs and outputs. Defaults to :class:`JsonParser`.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
timeout (int): Upper bound of waiting time for Python script execution.
``True``.
"""
def __init__(self,
description: str = DEFAULT_DESCRIPTION,
answer_symbol: Optional[str] = None,
answer_expr: Optional[str] = 'solution()',
answer_from_stdout: bool = False,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None,
timeout: int = 20) -> None:
super().__init__(description, name, enable, disable_description)
timeout: int = 20,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True) -> None:
super().__init__(description, parser, enable)
self.answer_symbol = answer_symbol
self.answer_expr = answer_expr
self.answer_from_stdout = answer_from_stdout
self.timeout = timeout
def __call__(self, command: str) -> ActionReturn:
@tool_api
def run(self, command: str) -> ActionReturn:
"""用来执行Python代码。代码必须是一个函数函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
```python
# import 依赖包
import xxx
def solution():
# 初始化一些变量
variable_names_with_real_meaning = xxx
# 步骤一
mid_variable = func(variable_names_with_real_meaning)
# 步骤 x
mid_variable = func(mid_variable)
# 最后结果
final_answer = func(mid_variable)
return final_answer
```
Args:
command (:class:`str`): Python code snippet
"""
self.runtime = GenericRuntime()
try:
tool_return = func_set_timeout(self.timeout)(self._call)(command)
except FunctionTimedOut as e:
tool_return = ActionReturn(url=None, args=None, type=self.name)
tool_return = ActionReturn(type=self.name)
tool_return.errmsg = repr(e)
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
def _call(self, command: str) -> ActionReturn:
tool_return = ActionReturn(url=None, args=None, type=self.name)
tool_return = ActionReturn(type=self.name)
try:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]

View File

@@ -1,6 +1,5 @@
# flake8: noqa
import ast
import copy
import platform
from typing import Dict, List, Optional, Tuple, Union
@@ -220,21 +219,20 @@ class AutoGPTProtocol:
dict(role='user', content=self.triggering_prompt))
return formatted_data
def format_response(self, action_return):
def format_response(self, action_return) -> dict:
"""format the final response at current step.
Args:
action_return (ActionReturn): return value of the current action.
Returns:
str: the final response at current step.
dict: the final response at current step.
"""
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.result['text']
response = f'Command {action_return.type} returned: {response}'
response = f'Command {action_return.type} returned: {response.format_result()}'
else:
response = action_return.errmsg
return response
return dict(role='system', content=response)
class AutoGPT(BaseAgent):
@@ -261,30 +259,26 @@ class AutoGPT(BaseAgent):
super().__init__(
llm=llm, action_executor=action_executor, protocol=protocol)
def chat(self, goal: str) -> AgentReturn:
self._inner_history = []
def chat(self, goal: str, **kwargs) -> AgentReturn:
inner_history = []
agent_return = AgentReturn()
default_response = 'Sorry that I cannot answer your question.'
for _ in range(self.max_turn):
prompt = self._protocol.format(
goal=goal,
inner_history=self._inner_history,
inner_history=inner_history,
action_executor=self._action_executor)
response = self._llm.generate_from_template(prompt, 512)
self._inner_history.append(
dict(role='assistant', content=response))
response = self._llm.chat(prompt, **kwargs)
inner_history.append(dict(role='assistant', content=response))
action, action_input = self._protocol.parse(
response, self._action_executor)
action_return: ActionReturn = self._action_executor(
action, action_input)
agent_return.actions.append(action_return)
if action_return.type == self._action_executor.finish_action.name:
agent_return.response = action_return.result['text']
agent_return.response = action_return.format_result()
return agent_return
self._inner_history.append(
dict(
role='system',
content=self._protocol.format_response(action_return)))
agent_return.inner_steps = copy.deepcopy(self._inner_history)
inner_history.append(self._protocol.format_response(action_return))
agent_return.inner_steps = inner_history
agent_return.response = default_response
return agent_return

View File

@@ -1,5 +1,3 @@
from typing import List
from lagent.actions import ActionExecutor
from lagent.actions.base_action import BaseAction
from lagent.llms.base_llm import BaseModel
@@ -19,8 +17,6 @@ class BaseAgent:
def __init__(self, llm: BaseModel, action_executor: ActionExecutor,
protocol: object) -> None:
self._session_history = []
self._llm = llm
self._action_executor = action_executor
self._protocol = protocol
@@ -41,9 +37,5 @@ class BaseAgent:
"""
self._action_executor.del_action(name)
def chat(self, message: str) -> AgentReturn:
def chat(self, message: str, **kwargs) -> AgentReturn:
raise NotImplementedError
@property
def session_history(self) -> List:
return self._session_history

View File

@@ -0,0 +1,363 @@
import json
import logging
from copy import deepcopy
from typing import Dict, List, Union, Optional
from ilagent.schema import AgentReturn, AgentStatusCode
from lagent import BaseAgent
from lagent.actions import ActionExecutor
from lagent.llms import BaseAPIModel, BaseModel
from lagent.schema import ActionReturn, ActionStatusCode
API_PREFIX = (
"This is the subfunction for tool '{tool_name}', you can use this tool. "
'The description of this function is: \n{description}')
META_INS = ('You are InternLM, a large language model trained by PJLab. '
'Answer as concisely as possible. '
'当开启工具以及代码时,根据需求选择合适的工具进行调用')
INTERPRETER_CN = ('你现在可以通过如下格式向 Jupyter Notebook 发送并执行代码:'
'\n<|action_start|><|interpreter|>```python\n\n代码\n\n```\n'
'\n当遇到以下问题时,请使用上述格式调用 Jupyter Notebook 去解决,并根据执行结果做出友好的回复:\n'
'1. 文件操作和数据导入比如处理CSV、JSON等格式文件\n'
'2. 数据分析或处理,比如数据操作或图像绘制如折线图、柱状图等\n'
'3. 数学相关的问题。当遇到数学问题时,你需要分析题目,并给出代码去解决这个题目')
PLUGIN_CN = (
'你可以使用如下工具:'
'\n{prompt}\n'
'当你需要使用工具时,你可以使用如下格式:\n'
'<|action_start|><|plugin|>{{"name": "工具名称", "parameters": {{参数}}}}\n'
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
'同时注意你可以使用的工具,不要随意捏造!')
class Interlm2Protocol:
def __init__(
self,
meta_prompt: str=META_INS,
interpreter_prompt: str=INTERPRETER_CN,
plugin_prompt: str=PLUGIN_CN,
few_shot: Optional[List]=None,
language: Dict=dict(
begin='',
end='',
belong='assistant',
),
tool: Dict=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
),
execute: Dict = dict(
role='execute', begin='', end='', fallback_role='environment'),
) -> None:
self.meta_prompt = meta_prompt
self.interpreter_prompt = interpreter_prompt
self.plugin_prompt = plugin_prompt
self.roles_cfg = dict(tool=tool, language=language)
self.language = language
self.execute = execute
self.tool = tool
self.few_shot = few_shot
def format_sub_role(self, messages: List[Dict]) -> List[Dict]:
def format_interpreter(message):
if isinstance(message['content'], dict):
# assert message['content']['name'] == 'IPythonInterpreter'
return dict(
role=message['role'],
name=message['name'],
content=message['content']['parameters']['command'])
else:
return message
def format_plugin(message):
if isinstance(message['content'], dict):
return dict(
role=message['role'],
name=message['name'],
content=json.dumps(message['content']))
else:
return message
new_message = list()
for message in messages:
if message['role'] in [
'assistant', 'user', 'system', 'environment'
]:
new_message.append(message)
continue
role_cfg = self.roles_cfg[message['role']]
begin = role_cfg['begin']
if message['role'] == 'tool':
if message['name'] == 'interpreter':
message = format_interpreter(message)
elif message['name'] == 'plugin':
message = format_plugin(message)
else:
raise NotImplementedError
begin = role_cfg['begin'].format(
start_token=role_cfg.get('start_token', ''),
name=role_cfg.get('name_map', {}).get(message['name'], ''))
new_content = begin + message['content'] + role_cfg['end']
if role_cfg.get('fallback_role'):
new_message.append(
dict(role=role_cfg['fallback_role'], content=new_content))
elif role_cfg.get('belong'):
if new_message[-1]['role'] != role_cfg.get('belong'):
new_message.append(
dict(role=role_cfg.get('belong'), content=new_content))
else:
new_message[-1]['content'] += new_content
else:
new_message.append(
dict(role=message['role'], content=new_content))
return new_message
def format(self,
inner_step: List[Dict],
plugin_executor: ActionExecutor = None,
interpreter_executor: ActionExecutor = None,
**kwargs) -> list:
formatted = []
if self.meta_prompt:
formatted.append(dict(role='system', content=self.meta_prompt))
if interpreter_executor and self.interpreter_prompt:
interpreter_info = interpreter_executor.get_actions_info()[0]
interpreter_prompt = self.interpreter_prompt.format(
code_prompt=interpreter_info['description'])
formatted.append(
dict(
role='system',
content=interpreter_prompt,
name='interpreter'))
if plugin_executor and plugin_executor.actions and self.plugin_prompt:
plugin_descriptions = []
for api_info in plugin_executor.get_actions_info():
plugin = deepcopy(api_info)
if isinstance(api_info, dict):
tool_name = api_info['name'].split('.')[0]
plugin['description'] = API_PREFIX.format(
tool_name=tool_name, description=plugin['description'])
plugin_descriptions.append(plugin)
plugin_prompt = self.plugin_prompt.format(
prompt=json.dumps(
plugin_descriptions, ensure_ascii=False, indent=4))
formatted.append(
dict(role='system', content=plugin_prompt, name='plugin'))
if self.few_shot:
for few_shot in self.few_shot:
formatted += self.format_sub_role(few_shot)
formatted += self.format_sub_role(inner_step)
return formatted
def parse(self, message, plugin_executor: ActionExecutor,
interpreter_executor: ActionExecutor):
if self.language['begin']:
message = message.split(self.language['begin'])[-1]
if self.tool['name_map']['plugin'] in message:
message, action = message.split(
f"{self.tool['start_token']}{self.tool['name_map']['plugin']}")
action = action.split(self.tool['end'].strip())[0]
return 'plugin', message, action
if self.tool['name_map']['interpreter'] in message:
message, code = message.split(
f"{self.tool['start_token']}"
f"{self.tool['name_map']['interpreter']}")
code = code.split(self.tool['end'].strip())[0].strip()
return 'interpreter', message, dict(
name=interpreter_executor.action_names()[0],
parameters=dict(command=code))
return None, message, None
def format_response(self, action_return, name) -> dict:
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.format_result()
else:
response = action_return.errmsg
content = self.execute['begin'] + response + self.execute['end']
if self.execute.get('fallback_role'):
return dict(
role=self.execute['fallback_role'], content=content, name=name)
elif self.execute.get('belong'):
return dict(
role=self.execute['belong'], content=content, name=name)
return dict(role=self.execute['role'], content=response, name=name)
class Internlm2Agent(BaseAgent):
def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
plugin_executor: ActionExecutor = None,
interpreter_executor: ActionExecutor = None,
protocol=Interlm2Protocol(),
max_turn: int = 3) -> None:
self.max_turn = max_turn
self._interpreter_executor = interpreter_executor
super().__init__(
llm=llm, action_executor=plugin_executor, protocol=protocol)
def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn:
if isinstance(message, str):
message = dict(role='user', content=message)
if isinstance(message, dict):
message = [message]
inner_history = message[:]
offset = len(inner_history)
agent_return = AgentReturn()
for _ in range(self.max_turn):
# list of dict
prompt = self._protocol.format(
inner_step=inner_history,
plugin_executor=self._action_executor,
interpreter_executor=self._interpreter_executor,
)
response = self._llm.chat(prompt, **kwargs)
name, language, action = self._protocol.parse(
message=response,
plugin_executor=self._action_executor,
interpreter_executor=self._interpreter_executor,
)
if name:
if name == 'plugin':
if self._action_executor:
executor = self._action_executor
else:
logging.info(msg='No plugin is instantiated!')
continue
try:
action = json.loads(action)
except Exception as e:
logging.info(
msg=f'Invaild action {e}')
continue
elif name == 'interpreter':
if self._interpreter_executor:
executor = self._interpreter_executor
else:
logging.info(msg='No interpreter is instantiated!')
continue
else:
logging.info(
msg=(f"Invalid name '{name}'. Currently only 'plugin' "
"and 'interpreter' are supported."))
continue
action_return: ActionReturn = executor(action['name'],
action['parameters'])
action_return.thought = language
agent_return.actions.append(action_return)
inner_history.append(dict(role='language', content=language))
if not name or action_return.type == executor.finish_action.name:
agent_return.response = language
agent_return.state = AgentStatusCode.END
break
else:
inner_history.append(
dict(role='tool', content=action, name=name))
inner_history.append(
self._protocol.format_response(action_return, name=name))
yield agent_return
agent_return.inner_steps = inner_history[offset:]
yield agent_return
def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
if isinstance(message, str):
message = dict(role='user', content=message)
if isinstance(message, dict):
message = [message]
inner_history = message[:]
offset = len(inner_history)
agent_return = AgentReturn()
last_agent_state = AgentStatusCode.SESSION_READY
for _ in range(self.max_turn):
# list of dict
prompt = self._protocol.format(
inner_step=inner_history,
plugin_executor=self._action_executor,
interpreter_executor=self._interpreter_executor,
)
response = ''
for model_state, res, _ in self._llm.stream_chat(
prompt, **kwargs):
response = res
if model_state.value < 0:
agent_return.state = model_state
yield deepcopy(agent_return)
return
else:
name, language, action = self._protocol.parse(
message=response,
plugin_executor=self._action_executor,
interpreter_executor=self._interpreter_executor,
)
if name:
if model_state == AgentStatusCode.END:
agent_state = last_agent_state + 1
if name == 'plugin':
if self._action_executor:
executor = self._action_executor
else:
logging.info(
msg='No plugin is instantiated!')
continue
try:
action = json.loads(action)
except Exception as e:
logging.info(
msg=f'Invaild action {e}')
continue
elif name == 'interpreter':
if self._interpreter_executor:
executor = self._interpreter_executor
else:
logging.info(
msg='No interpreter is instantiated!')
continue
agent_return.state = agent_state
agent_return.response = action
else:
agent_state = (
AgentStatusCode.PLUGIN_START if name
== 'plugin' else AgentStatusCode.CODING)
if agent_state != last_agent_state:
# agent_return.state = agent_state
agent_return.response = language
yield deepcopy(agent_return)
agent_return.state = agent_state
agent_return.response = action
else:
agent_state = AgentStatusCode.STREAM_ING
agent_return.state = agent_state
agent_return.response = language
last_agent_state = agent_state
yield deepcopy(agent_return)
if name:
action_return: ActionReturn = executor(action['name'],
action['parameters'])
action_return.thought = language
agent_return.actions.append(action_return)
inner_history.append(dict(role='language', content=language))
if not name or action_return.type == executor.finish_action.name:
agent_return.response = language
agent_return.state = AgentStatusCode.END
break
else:
inner_history.append(
dict(role='tool', content=action, name=name))
inner_history.append(
self._protocol.format_response(action_return, name=name))
agent_state += 1
agent_return.state = agent_state
yield agent_return
agent_return.inner_steps = deepcopy(inner_history[offset:])
agent_return.state = AgentStatusCode.END
yield agent_return

View File

@@ -1,4 +1,3 @@
import copy
from typing import Dict, List, Tuple, Union
from lagent.actions import ActionExecutor
@@ -43,7 +42,7 @@ To use a tool, please use the following format:
The response after utilizing tools should using the following format:
```
{response}the results after call the tool.
``
```
If you already know the answer, or you do not need to use tools,
please using the following format to reply:
```
@@ -170,20 +169,22 @@ class ReActProtocol:
action_input = arg_match[-1]
return thought, action.strip(), action_input.strip().strip('"')
def format_response(self, action_return: ActionReturn) -> str:
def format_response(self, action_return: ActionReturn) -> dict:
"""format the final response at current step.
Args:
action_return (ActionReturn): return value of the current action.
Returns:
str: the final response at current step.
dict: the final response at current step.
"""
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.result['text']
response = action_return.format_result()
else:
response = action_return.errmsg
return self.response['begin'] + response + self.response['end']
return dict(
role='system',
content=self.response['begin'] + response + self.response['end'])
class ReAct(BaseAgent):
@@ -210,20 +211,27 @@ class ReAct(BaseAgent):
super().__init__(
llm=llm, action_executor=action_executor, protocol=protocol)
def chat(self, message: str) -> AgentReturn:
self._inner_history = []
self._inner_history.append(dict(role='user', content=message))
def chat(self, message: Union[str, dict, List[dict]],
**kwargs) -> AgentReturn:
if isinstance(message, str):
inner_history = [dict(role='user', content=message)]
elif isinstance(message, dict):
inner_history = [message]
elif isinstance(message, list):
inner_history = message[:]
else:
raise TypeError(f'unsupported type: {type(message)}')
offset = len(inner_history)
agent_return = AgentReturn()
default_response = 'Sorry that I cannot answer your question.'
for turn in range(self.max_turn):
prompt = self._protocol.format(
chat_history=self.session_history,
inner_step=self._inner_history,
chat_history=[],
inner_step=inner_history,
action_executor=self._action_executor,
force_stop=(turn == self.max_turn - 1))
response = self._llm.generate_from_template(prompt, 512)
self._inner_history.append(
dict(role='assistant', content=response))
response = self._llm.chat(prompt, **kwargs)
inner_history.append(dict(role='assistant', content=response))
thought, action, action_input = self._protocol.parse(
response, self._action_executor)
action_return: ActionReturn = self._action_executor(
@@ -231,17 +239,10 @@ class ReAct(BaseAgent):
action_return.thought = thought
agent_return.actions.append(action_return)
if action_return.type == self._action_executor.finish_action.name:
agent_return.response = action_return.result['text']
agent_return.response = action_return.format_result()
break
self._inner_history.append(
dict(
role='system',
content=self._protocol.format_response(action_return)))
inner_history.append(self._protocol.format_response(action_return))
else:
agent_return.response = default_response
agent_return.inner_steps = copy.deepcopy(self._inner_history)
# only append the user and final response
self._session_history.append(dict(role='user', content=message))
self._session_history.append(
dict(role='assistant', content=agent_return.response))
agent_return.inner_steps = inner_history[offset:]
return agent_return

View File

@@ -1,4 +1,3 @@
import copy
import re
import warnings
from typing import Dict, List, Optional, Tuple, Union
@@ -192,7 +191,7 @@ class ReWOOProtocol:
worker_log = ''
for thought, action_return in zip(thought_list, action_return_list):
if action_return.state == ActionStatusCode.SUCCESS:
action_resp = action_return.result['text']
action_resp = action_return.format_result()
else:
action_resp = action_return.errmsg
worker_response = self.worker_prompt.format(
@@ -227,9 +226,17 @@ class ReWOO(BaseAgent):
self.max_turn = max_turn
def chat(self, message: str) -> AgentReturn:
self._inner_history = []
self._inner_history.append(dict(role='user', content=message))
def chat(self, message: Union[str, dict, List[dict]],
**kwargs) -> AgentReturn:
if isinstance(message, str):
inner_history = [dict(role='user', content=message)]
elif isinstance(message, dict):
inner_history = [message]
elif isinstance(message, list):
inner_history = message[:]
else:
raise TypeError(f'unsupported type: {type(message)}')
offset = len(inner_history)
agent_return = AgentReturn()
# planner
@@ -237,13 +244,12 @@ class ReWOO(BaseAgent):
reformat_request = ''
while turn_id < self.max_turn:
planner_prompt = self._protocol.format_planner(
chat_history=self.session_history,
inner_step=self._inner_history,
chat_history=[],
inner_step=inner_history,
action_executor=self._action_executor,
reformat_request=reformat_request)
response = self._llm.generate_from_template(planner_prompt, 512)
self._inner_history.append(
dict(role='assistant', content=response))
response = self._llm.chat(planner_prompt, **kwargs)
inner_history.append(dict(role='assistant', content=response))
try:
thoughts, actions, actions_input = self._protocol.parse_worker(
response)
@@ -267,18 +273,17 @@ class ReWOO(BaseAgent):
for prev_ptr in prev_ptrs:
ptr_num = int(prev_ptr.strip('#E')) - 1 # start from 0
actions_input[action_id] = actions_input[action_id].replace(
prev_ptr, action_responses[ptr_num].result['text'])
prev_ptr, action_responses[ptr_num].format_result())
action_return: ActionReturn = self._action_executor(
actions[action_id], actions_input[action_id])
action_responses.append(action_return)
solver_prompt, worker_log = self._protocol.format_solver(
message, thoughts, action_responses)
self._inner_history.append(dict(role='system', content=worker_log))
inner_history.append(dict(role='system', content=worker_log))
final_response = self._llm.generate_from_template(solver_prompt, 512)
self._inner_history.append(
dict(role='assistant', content=final_response))
agent_return.inner_steps = copy.deepcopy(self._inner_history)
final_response = self._llm.chat(solver_prompt, **kwargs)
inner_history.append(dict(role='assistant', content=final_response))
agent_return.inner_steps = inner_history[offset:]
agent_return.response = final_response
return agent_return

View File

@@ -8,7 +8,3 @@ __all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI']
if is_module_exist('transformers'):
from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401
__all__.extend(['HFTransformer', 'HFTransformerCasualLM'])
if is_module_exist('lmdeploy'):
from .lmdeploy import TritonClient, TurboMind # noqa: F401
__all__.extend(['TritonClient', 'TurboMind'])

View File

@@ -27,7 +27,7 @@ class APITemplateParser:
'role in meta prompt must be unique!'
self.roles[item['role']] = item.copy()
def parse_template(self, dialog: List[Union[str, List]]):
def __call__(self, dialog: List[Union[str, List]]):
"""Parse the intermidate prompt template, and wrap it with meta
template if applicable. When the meta template is set and the input is
a list, the return value will be a list containing the full
@@ -155,7 +155,14 @@ class BaseAPIModel(BaseModel):
retry: int = 2,
max_seq_len: int = 2048,
template_parser: 'APITemplateParser' = APITemplateParser,
meta_template: Optional[Dict] = None):
meta_template: Optional[Dict] = None,
*,
max_out_len: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None):
self.model_type = model_type
self.max_seq_len = max_seq_len
self.meta_template = meta_template
@@ -165,53 +172,21 @@ class BaseAPIModel(BaseModel):
if template_parser:
self.template_parser = template_parser(meta_template)
@abstractclassmethod
def generate(self, inputs, max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
self.gen_params = dict(
max_out_len=max_out_len,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words)
Args:
inputs (List[str or list]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
# Count English words
english_count = sum(len(part.split()) for part in english_parts)
# Count Chinese words
chinese_count = sum(len(part) for part in chinese_parts)
return english_count + chinese_count
def wait(self):
def _wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
def to(self, device):
pass
class TokenBucket:
"""A token bucket for rate limiting.

View File

@@ -1,5 +1,7 @@
from abc import abstractclassmethod
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
from warnings import warn
class LMTemplateParser:
@@ -21,7 +23,7 @@ class LMTemplateParser:
'role in meta prompt must be unique!'
self.roles[item['role']] = item.copy()
def parse_template(self, dialog) -> str:
def __call__(self, dialog) -> str:
"""Parse a prompt template, and wrap it with meta template if
applicable.
@@ -111,12 +113,17 @@ class BaseModel:
def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
template_parser: 'LMTemplateParser' = LMTemplateParser,
meta_template: Optional[List[Dict]] = None):
meta_template: Optional[List[Dict]] = None,
*,
max_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
stop_words: Union[List[str], str] = None):
self.path = path
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
# meta template
self.template_parser = template_parser(meta_template)
@@ -124,41 +131,99 @@ class BaseModel:
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
@abstractclassmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
self.gen_params = dict(
max_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words)
def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
"""Generate results given a str (or list of) inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
inputs (Union[str, List[str]]):
gen_params (dict): The input params for generation.
Returns:
List[str]: A list of generated strings.
"""
Union[str, List[str]]: A (list of) generated strings.
def parse_template(self, dialog) -> str:
"""Parse a prompt template, and wrap it with meta template if
applicable.
eg.
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
response = ['']
if batched:
return response
return response[0]
"""
raise NotImplementedError
def stream_generate(self, inputs: str, **gen_params) -> List[str]:
"""Generate results as streaming given a str inputs.
Args:
dialog (List[str or PromptList]): A prompt
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
inputs (str):
gen_params (dict): The input params for generation.
Returns:
str: The final string.
str: A generated string.
"""
return self.template_parser.parse_template(dialog)
raise NotImplementedError
def generate_from_template(self, templates, max_out_len: int, **kwargs):
def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params):
"""Generate completion from a list of templates.
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
inputs (Union[List[dict], List[List[dict]]]):
gen_params (dict): The input params for generation.
Returns:
"""
inputs = self.parse_template(templates)
return self.generate(inputs, max_out_len=max_out_len, **kwargs)
if isinstance(inputs[0], list):
inputs = list()
for msg in inputs:
inputs.append(self.template_parser(msg))
else:
inputs = self.template_parser(inputs)
return self.generate(inputs, **gen_params)
def to(self, device):
self.model.to(device)
def generate_from_template(
self,
inputs: Union[List[dict], List[List[dict]]],
**gen_params
):
warn(
"This function will be deprecated after three months and will be replaced."
"Please use `.chat()`",
DeprecationWarning, 2)
return self.chat(inputs, **gen_params)
def stream_chat(self, inputs: List[dict], **gen_params):
"""Generate results as streaming given a list of templates.
Args:
inputs (Union[List[dict]):
gen_params (dict): The input params for generation.
Returns:
"""
raise NotImplementedError
def tokenize(self, prompts: Union[str, List[str], List[dict],
List[List[dict]]]):
"""Tokenize the input prompts.
Args:
prompts(str | List[str]): user's prompt, or a batch prompts
Returns:
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
ids, ids' length and requested output length
"""
raise NotImplementedError
def update_gen_params(self, **kwargs):
gen_params = copy(self.gen_params)
gen_params.update(kwargs)
return gen_params

View File

@@ -1,15 +1,19 @@
from typing import Dict, List, Optional
import torch
import copy
import warnings
import logging
from typing import Dict, List, Optional, Union
from dataclasses import asdict
from .base_llm import BaseModel
logger = logging.getLogger(__name__)
class HFTransformer(BaseModel):
"""Model wrapper around HuggingFace general models.
Adapted from OpenCompass (https://github.com/InternLM/opencompass
/blob/main/opencompass/models/huggingface.py)
Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/
chat/web_demo.py)
Args:
path (str): The name or path to HuggingFace's model.
@@ -25,52 +29,40 @@ class HFTransformer(BaseModel):
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
extract_pred_after_decode (bool): Whether to extract the prediction
string from the decoded output string, instead of extract the
prediction tokens before decoding. Defaults to False.
batch_padding (bool): If False, inference with be performed in for-loop
without batch padding.
Note:
About ``extract_pred_after_decode``: Commonly, we should extract the
the prediction tokens before decoding. But for some tokenizers using
``sentencepiece``, like LLaMA, this behavior may change the number of
whitespaces, which is harmful for Python programming tasks.
"""
def __init__(
self,
path: str,
max_seq_len: int = 2048,
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = [
dict(role='system', begin='<|System|>:', end='\n'),
dict(role='user', begin='<|User|>:', end='\n'),
dict(
role='assistant',
begin='<|Bot|>:',
end='<eoa>\n',
generate=True)
], # default meta template for InternLM-7b
extract_pred_after_decode: bool = False,
batch_padding: bool = False):
def __init__(self,
path: str,
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = None,
**kwargs):
super().__init__(
path=path,
max_seq_len=max_seq_len,
tokenizer_only=tokenizer_only,
meta_template=meta_template)
meta_template=meta_template,
**kwargs)
self._load_tokenizer(
path=path,
tokenizer_path=tokenizer_path,
tokenizer_kwargs=tokenizer_kwargs)
self.batch_padding = batch_padding
self.extract_pred_after_decode = extract_pred_after_decode
if not tokenizer_only:
self._load_model(path=path, model_kwargs=model_kwargs)
from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList)
self.logits_processor = LogitsProcessorList()
self.stopping_criteria = StoppingCriteriaList()
self.prefix_allowed_tokens_fn = None
stop_words_id = []
for sw in self.gen_params.get('stop_words', []):
stop_words_id.append(self.tokenizer(sw)['input_ids'][1])
self.additional_eos_token_id = stop_words_id
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
tokenizer_kwargs: dict):
from transformers import AutoTokenizer
@@ -82,57 +74,182 @@ class HFTransformer(BaseModel):
self.tokenizer.pad_token = self.tokenizer.eos_token
def _load_model(self, path: str, model_kwargs: dict):
import torch
from transformers import AutoModel
model_kwargs.setdefault('torch_dtype', torch.float16)
self.model = AutoModel.from_pretrained(
path, trust_remote_code=True, **model_kwargs)
self.model.eval()
def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
if isinstance(inputs, str):
inputs = [inputs]
if self.extract_pred_after_decode:
prompt_lens = [len(input_) for input_ in inputs]
def tokenize(self, inputs: str):
assert isinstance(inputs, str)
inputs = self.tokenizer(
inputs, return_tensors='pt', return_length=True)
return inputs['input_ids'].tolist()
input_ids = self.tokenizer(
inputs, truncation=True,
max_length=self.max_seq_len - max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
outputs = self.model.generate(
input_ids=input_ids, max_new_tokens=max_out_len, **kwargs)
def generate(
self,
inputs: List[str],
do_sample=True,
**kwargs,
):
for chunk in self.stream_generate(inputs, do_sample, **kwargs):
response = chunk
return response
if not self.extract_pred_after_decode:
outputs = outputs[:, input_ids.shape[1]:]
def stream_generate(
self,
inputs: List[str],
do_sample=True,
**kwargs,
):
import torch
from torch import nn
with torch.no_grad():
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs, padding=True, return_tensors='pt', return_length=True)
input_length = inputs['length']
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
batch_size, input_ids_seq_length = input_ids.shape[
0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
generation_config = self.model.generation_config
generation_config = copy.deepcopy(generation_config)
new_gen_params = self.update_gen_params(**kwargs)
generation_config.update(**new_gen_params)
generation_config.update(**kwargs)
model_kwargs = generation_config.to_dict()
model_kwargs['attention_mask'] = attention_mask
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the maximum length of the generation.',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
'Please refer to the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
UserWarning,
)
decodeds = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True)
if self.extract_pred_after_decode:
decodeds = [
token[len_:] for token, len_ in zip(decodeds, prompt_lens)
]
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
' increasing `max_new_tokens`.')
return decodeds[0]
# 2. Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria
def generate_from_template(self, templates, max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates)
response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
end_token = self.template_parser.meta_template[0]['end'].strip()
# return response.replace(
# self.template_parser.roles['assistant']['end'].strip(),
# '').strip()
return response.split(end_token.strip())[0]
stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = self.model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(batch_size).fill_(1)
scores = None
while True:
model_inputs = self.model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = self.model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids,
next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if do_sample:
next_tokens = torch.multinomial(
probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]],
dim=-1)
model_kwargs = self.model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
# output_token_ids = input_ids.cpu()[:, input_length:].tolist()
output_token_ids = input_ids.cpu().tolist()
for i in range(len(output_token_ids)):
output_token_ids[i] = output_token_ids[i][:][
input_length[i]:]
# Find the first occurrence of an EOS token in the sequence
first_eos_idx = next(
(idx
for idx, token_id in enumerate(output_token_ids[i])
if token_id in eos_token_id), None)
# If an EOS token is found, only the previous part of it is retained
if first_eos_idx is not None:
output_token_ids[i] = output_token_ids[
i][:first_eos_idx]
response = self.tokenizer.batch_decode(output_token_ids)
# print(response)
if not batched:
yield response[0]
else:
yield response
# stop when each sentence is finished, or if we exceed the maximum length
if (unfinished_sequences.max() == 0
or stopping_criteria(input_ids, scores)):
break
class HFTransformerCasualLM(HFTransformer):
def _load_model(self, path: str, model_kwargs: dict):
import torch
from transformers import AutoModelForCausalLM
model_kwargs.setdefault('torch_dtype', torch.float16)
self.model = AutoModelForCausalLM.from_pretrained(

View File

@@ -1,143 +0,0 @@
import dataclasses
import os.path as osp
import random
import lmdeploy.turbomind.chat as tm_chat
from lmdeploy import turbomind as tm
from lmdeploy.serve.turbomind.chatbot import Chatbot, Session, get_logger
from lmdeploy.tokenizer import Tokenizer
from .base_llm import BaseModel
class TritonClient(Chatbot, BaseModel):
def __init__(self, meta_template=None, **kwargs):
"""TritonClient is a wrapper of TritonClient for LLM.
Args:
model_name (str): the name of the model
max_out_len (int): the expected generated token numbers
log_level (str): log level
"""
BaseModel.__init__(self, meta_template=meta_template, path=None)
Chatbot.__init__(self, **kwargs)
def generate(self,
prompt: str,
session_id: int = 2967,
request_id: str = '',
max_out_len: int = None,
sequence_start: bool = True,
sequence_end: bool = True,
*args,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
session_id (int): the identical id of a session
prompt (str): user's prompt in this round conversation
request_id (str): the identical id of this round conversation
max_out_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_out_len}')
if self._session is None:
sequence_start = True
self._session = Session(session_id=session_id)
elif self._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return ''
self._session.status = 1
self._session.request_id = request_id
self._session.response = ''
status, res, _ = None, '', 0
for status, res, _ in self._stream_infer(self._session, prompt,
max_out_len, sequence_start,
sequence_end):
if status.value < 0:
break
if status.value == 0:
self._session.histories = \
self._session.histories + self._session.prompt + \
self._session.response
return res
else:
return ''
def generate_from_template(self, templates, max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates)
response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
# The return of tuibomind contains <eoa>, here we hard code removes it.
response = response.replace(
self.template_parser.roles['assistant']['end'].strip(),
'').strip()
return response
class TurboMind(BaseModel):
def __init__(self,
path: str,
max_seq_len: int = 8192,
tokenizer_only: bool = False,
meta_template=None,
tp=1,
**kwargs):
super().__init__(
path=path,
max_seq_len=max_seq_len,
tokenizer_only=tokenizer_only,
meta_template=meta_template)
tokenizer_model_path = osp.join(path, 'triton_models', 'tokenizer')
self.tokenizer = Tokenizer(tokenizer_model_path)
self.tm_model = tm.TurboMind(
path, eos_id=self.tokenizer.eos_token_id, tp=tp)
self.generator = self.tm_model.create_instance()
model_name = self.tm_model.model_name
self.model = tm_chat.MODELS.get(model_name)(
capability='completion', **kwargs)
self._session_id = 0
def generate(self, prompt, **kwargs):
seed = random.getrandbits(64)
input_ids = self.tokenizer.encode(prompt)
gen_param = tm_chat.get_gen_param(
'completion', self.model.sampling_param, step=0, nth_round=1)
response_size = 0
self._session_id = (self._session_id + 1) % 100000
for outputs in self.generator.stream_infer(
session_id=self._session_id,
input_ids=[input_ids],
stream_output=False,
**dataclasses.asdict(gen_param),
ignore_eos=False,
random_seed=seed):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(
res.tolist(), offset=response_size)
response = tm_chat.valid_str(response)
return response

View File

@@ -0,0 +1,383 @@
from typing import List, Optional, Union
from lagent.llms.base_llm import BaseModel
from lagent.schema import AgentStatusCode
from lagent.utils.util import filter_suffix
class TritonClient(BaseModel):
"""TritonClient is a wrapper of TritonClient for LLM.
Args:
tritonserver_addr (str): the address in format "ip:port" of
triton inference server
model_name (str): the name of the model
session_len (int): the context size
max_tokens (int): the expected generated token numbers
"""
def __init__(self,
tritonserver_addr: str,
model_name: str,
session_len: int = 32768,
log_level: str = 'WARNING',
**kwargs):
super().__init__(path=None, **kwargs)
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
self.state_map = {
StatusCode.TRITON_STREAM_END: AgentStatusCode.END,
StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR,
StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED,
StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING,
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
AgentStatusCode.SESSION_OUT_OF_LIMIT,
StatusCode.TRITON_SESSION_INVALID_ARG:
AgentStatusCode.SESSION_INVALID_ARG,
StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY
}
self.chatbot = Chatbot(
tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
log_level=log_level,
**kwargs)
def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
if isinstance(inputs, str):
inputs = [inputs]
prompt = inputs
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
if self.chatbot._session is None:
sequence_start = True
self.chatbot._session = Session(session_id=session_id)
elif self.chatbot._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return ''
self.chatbot._session.status = 1
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''
self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
sequence_end):
if status.value < 0:
break
if status.value == 0:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt + self.chatbot._session.response)
# remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
return res
else:
return ''
def stream_chat(self,
inputs: List[dict],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
from lmdeploy.serve.turbomind.chatbot import (Session, StatusCode,
get_logger)
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
if self.chatbot._session is None:
sequence_start = True
self.chatbot._session = Session(session_id=session_id)
elif self.chatbot._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return ''
self.chatbot._session.status = 1
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''
self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)
prompt = self.template_parser(inputs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
sequence_end):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
if status.value < 0:
break
else:
yield self.state_map.get(status), res, _
if status.value == 0:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt + self.chatbot._session.response)
yield self.state_map.get(status), res, _
else:
return ''
def _update_gen_params(self, **kwargs):
import mmengine
new_gen_params = self.update_gen_params(**kwargs)
self.gen_params['stop_words'] = new_gen_params.pop('stop_words')
stop_words = self.chatbot._stop_words(
self.gen_params.get('stop_words'))
cfg = mmengine.Config(
dict(
session_len=self.chatbot.model.session_len,
stop_words=stop_words,
bad_words=self.chatbot.cfg.bad_words,
**new_gen_params))
return cfg
class LMDeployPipeline(BaseModel):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
tp (int):
pipeline_cfg (dict):
"""
def __init__(self,
path: str,
model_name: Optional[str] = None,
tp: int = 1,
pipeline_cfg=dict(),
**kwargs):
super().__init__(path=path, **kwargs)
from lmdeploy import pipeline
self.model = pipeline(
model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg)
def generate(self,
inputs: Union[str, List[str]],
do_preprocess=None,
**kwargs):
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
response = self.model.batch_infer(
prompt, do_preprocess=do_preprocess, **gen_params)
response = [resp.text for resp in response]
# remove stop_words
response = filter_suffix(response, self.gen_params.get('stop_words'))
if batched:
return response
return response[0]
class LMDeployServer(BaseModel):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
server_name (str): host ip for serving
server_port (int): server port
tp (int):
log_level (str): set log level whose value among
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
"""
def __init__(self,
path: str,
model_name: Optional[str] = None,
server_name: str = '0.0.0.0',
server_port: int = 23333,
tp: int = 1,
log_level: str = 'WARNING',
serve_cfg=dict(),
**kwargs):
super().__init__(path=path, **kwargs)
# TODO get_logger issue in multi processing
import lmdeploy
self.client = lmdeploy.serve(
model_path=self.path,
model_name=model_name,
server_name=server_name,
server_port=server_port,
tp=tp,
log_level=log_level,
**serve_cfg)
def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
timeout: int = 30,
**kwargs) -> List[str]:
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
gen_params = self.update_gen_params(**kwargs)
resp = [''] * len(inputs)
for text in self.client.completions_v1(
self.path,
inputs,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
timeout=timeout,
**gen_params):
resp = [
resp[i] + item['text']
for i, item in enumerate(text['choices'])
]
# remove stop_words
resp = filter_suffix(resp, self.gen_params.get('stop_words'))
if not batched:
return resp[0]
return resp
def stream_chat(self,
inputs: List[dict],
session_id=0,
sequence_start: bool = True,
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
timeout: int = 30,
**kwargs):
gen_params = self.update_gen_params(**kwargs)
prompt = self.template_parser(inputs)
resp = ''
finished = False
stop_words = self.gen_params.get('stop_words')
for text in self.client.completions_v1(
self.path,
prompt,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']
if not resp:
continue
# remove stop_words
for sw in stop_words:
if sw in resp:
resp = filter_suffix(resp, stop_words)
finished = True
break
yield AgentStatusCode.STREAM_ING, resp, None
if finished:
break
yield AgentStatusCode.END, resp, None
class LMDeployClient(LMDeployServer):
"""
Args:
path (str): The path to the model.
url (str):
"""
def __init__(self, path: str, url: str, **kwargs):
BaseModel.__init__(self, path=path, **kwargs)
from lmdeploy.serve.openai.api_client import APIClient
self.client = APIClient(url)

View File

@@ -1,7 +1,7 @@
import json
import os
import time
# from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, wait
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
@@ -38,9 +38,8 @@ class GPTAPI(BaseAPIModel):
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
gen_params: Default generation configuration which could be overrided
on the fly of generation.
"""
is_api: bool = True
@@ -58,16 +57,15 @@ class GPTAPI(BaseAPIModel):
dict(role='assistant', api_role='assistant')
],
openai_api_base: str = OPENAI_API_BASE,
temperature: Optional[float] = None):
**gen_params):
super().__init__(
model_type=model_type,
max_seq_len=max_seq_len,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry)
retry=retry,
**gen_params)
self.logger = getLogger(__name__)
self.temperature = temperature
if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
@@ -97,67 +95,57 @@ class GPTAPI(BaseAPIModel):
context_window = 8192
self.context_window = context_window
def generate(
def chat(
self,
inputs: Union[List, str],
max_out_len: int = 512,
temperature: float = 0.7,
) -> List[str]:
"""Generate results given a list of inputs.
inputs: Union[List[dict], List[List[dict]]],
**gen_params,
) -> Union[str, List[str]]:
"""Generate responses given the contexts.
Args:
inputs (List[str or List]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
inputs (Union[List[dict], List[List[dict]]]): a list of messages
or list of lists of messages
gen_params: additional generation configuration
Returns:
List[str]: A list of generated strings.
Union[str, List[str]]: generated string(s)
"""
if self.temperature is not None:
temperature = self.temperature
return self._generate(inputs, max_out_len, temperature)
assert isinstance(inputs, list)
if isinstance(inputs[0], dict):
inputs = [inputs]
gen_params = {**self.gen_params, **gen_params}
with ThreadPoolExecutor(max_workers=20) as executor:
tasks = [
executor.submit(self._chat, messages, **gen_params)
for messages in inputs
]
wait(tasks)
ret = [task.result() for task in tasks]
return ret[0] if isinstance(inputs[0], dict) else ret
def _generate(self, input: str or List, max_out_len: int,
temperature: float) -> str:
"""Generate results given a list of inputs.
def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
inputs (str or List): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Returns:
str: The generated string.
"""
assert isinstance(input, (str, list, dict))
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
elif isinstance(input, dict):
messages = [input]
else:
messages = input
assert isinstance(messages, list)
gen_params = gen_params.copy()
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len = min(
max_out_len,
self.context_window - self.get_token_len(str(input)) - 100)
gen_params.pop('max_out_len'),
self.context_window - len(self.tokenize(str(input))) - 100)
if max_out_len <= 0:
return ''
max_num_retries = 0
while max_num_retries < self.retry:
self.wait()
self._wait()
with Lock():
if len(self.invalid_keys) == len(self.keys):
@@ -192,8 +180,9 @@ class GPTAPI(BaseAPIModel):
messages=messages,
max_tokens=max_out_len,
n=1,
stop=None,
temperature=temperature,
stop=gen_params.pop('stop_words'),
frequency_penalty=gen_params.pop('repetition_penalty'),
**gen_params,
)
raw_response = requests.post(
self.url, headers=header, data=json.dumps(data))
@@ -225,18 +214,16 @@ class GPTAPI(BaseAPIModel):
f'{max_num_retries} times. Check the logs for '
'details.')
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
def tokenize(self, prompt: str) -> list:
"""Tokenize the input prompt.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
list: token ids
"""
import tiktoken
self.tiktoken = tiktoken
enc = self.tiktoken.encoding_for_model(self.model_type)
return len(enc.encode(prompt))
return enc.encode(prompt)

View File

@@ -1,6 +1,6 @@
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union
from lagent.utils import is_module_exist
@@ -33,47 +33,49 @@ class ActionValidCode(int, Enum):
@dataclass
class ActionReturn:
args: Dict
args: Optional[dict] = None
url: Optional[str] = None
type: Optional[str] = None
result: Optional[str] = None
result: Optional[List[dict]] = None
errmsg: Optional[str] = None
state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS
thought: Optional[str] = None
valid: Optional[ActionValidCode] = ActionValidCode.OPEN
def format_result(self) -> str:
"""Concatenate items in result"""
result = []
for item in self.result or []:
if item['type'] == 'text':
result.append(item['content'])
else:
result.append(f"[{item['type']}]({item['content']})")
result = '\n'.join(result)
return result
class AgentStatusCode(Enum):
END = 0 # end of streaming
# 需要集成int如此asdict可以把AgentStatusCode 转换成 int
class AgentStatusCode(int, Enum):
END = 0 # end of streaming 返回本次history
STREAM_ING = 1 # response is in streaming
SERVER_ERR = -1 # triton server's error
SESSION_CLOSED = -2 # session has been closed
SESSION_OUT_OF_LIMIT = -3 # request length out of limit
CMD = 2 # return command
PLUGIN_START = 3 # start tool
PLUGIN_END = 4 # finish tool
PLUGIN_RETURN = 5 # finish tool
CODING = 6 # start python
CODE_END = 7 # end python
CODE_RETURN = 8 # python return
SESSION_INVALID_ARG = -4 # invalid argument
SESSION_READY = 3 # session is ready for inference
SESSION_READY = 2 # session is ready for inference
@dataclass
class AgentReturn:
state: Union[AgentStatusCode, int] = AgentStatusCode.END
actions: List[ActionReturn] = field(default_factory=list)
response: str = ''
inner_steps: List = field(default_factory=list)
errmsg: Optional[str] = None
if is_module_exist('lmdeploy'):
from lmdeploy.serve.turbomind.chatbot import StatusCode
STATE_MAP = {
StatusCode.TRITON_STREAM_END: AgentStatusCode.END,
StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR,
StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED,
StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING,
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
AgentStatusCode.SESSION_OUT_OF_LIMIT,
StatusCode.TRITON_SESSION_INVALID_ARG:
AgentStatusCode.SESSION_INVALID_ARG,
StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY
}
else:
STATE_MAP = {}

30
lagent/utils/util.py Normal file
View File

@@ -0,0 +1,30 @@
from typing import List, Optional, Union
def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] = None) -> str:
"""Filter response with suffixes.
Args:
response (Union[str, List[str]]): generated responses by LLMs.
suffixes (str): a list of suffixes to be deleted.
Return:
str: a clean response.
"""
if suffixes is None:
return response
batched = True
if isinstance(response, str):
response = [response]
batched = False
processed = []
for resp in response:
for item in suffixes:
# if response.endswith(item):
# response = response[:len(response) - len(item)]
if item in resp:
resp = resp.split(item)[0]
processed.append(resp)
if not batched:
return processed[0]
return processed

View File

@@ -1,9 +1,13 @@
docutils==0.16.0
docutils==0.18.1
markdown>=3.4.0
myst-parser
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx==4.0.2
myst-nb
# -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
# sphinx==4.0.2
sphinx==6.1.0
sphinx-tabs
sphinx_copybutton
sphinx_markdown_tables>=0.0.16
sphinx-rtd-theme==1.3.0
tabulate
astroid<3.0.0
sphinx-autoapi

View File

@@ -3,3 +3,9 @@ func_timeout
jsonschema
requests
tiktoken
griffe
phx-class-registry
jupyter
jupyter_client
json5
pillow