[Init] Initial commit

* Initial commit
This commit is contained in:
liukuikun
2023-08-21 13:16:50 +08:00
committed by GitHub
parent f671e8c20c
commit 931ec2ab6e
44 changed files with 3608 additions and 0 deletions

1
.gitignore vendored
View File

@@ -158,3 +158,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.vscode/

50
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,50 @@
exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
repos:
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: check-merge-conflict
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.9
hooks:
- id: mdformat
args: ["--number"]
additional_dependencies:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/asottile/pyupgrade
rev: v3.0.0
hooks:
- id: pyupgrade
args: ["--py36-plus"]

428
.pylintrc Normal file
View File

@@ -0,0 +1,428 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=third_party,storage
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-function-docstring,
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=colorized
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=builtins.BaseException,
builtins.Exception

81
README.md Normal file
View File

@@ -0,0 +1,81 @@
# LAgent: Large Language Model as Agent
English | [简体中文](README_zh-CN.md)
## Introduction
LAgent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM. The overview of our framework is shown below:
![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6)
### Major Features
- **Support multiple agent frameworks out of box.** We implement ReAct, AutoGPT and ReWOO, which enables the agents to drive LLMs for multiple trails of reasoning and tool utilization.
- **Extremely simple and easy to extend.** The framework is quite simple with clear project structure. With only 20 lines of code, you are able to construct your own agent. It also supports three typical tools: Python interpreter, API call, and google search.
- **Support various large language models.** We support different LLMs, including API-based (GPT3.5/4) and HuggingFace-based (LLaMa2, InternLM) models.
## Getting Started
Please see [Overview](docs/overview.md) for the general introduction of LAgent. Meanwhile, we provide extremely simple code for quick start. You may refer to [examples](examples/) for more details.
### Installation
```
git clone https://github.com/InternLM/lagent.git
cd lagent
pip install -e .
```
### Run a ReAct agent with GPT3.5 backend
```python
from lagent.agents import ReAct
from lagent.actions.action_executor import ActionExecutor
from lagent.llms import GPTAPI
from lagent.tools import SerperSearch, PythonInterpreter
llm = GPTAPI(model_type='gpt-3.5-turbo')
search_tool = SerperSearch()
python_interpreter = PythonInterpreter()
chatbot = ReAct(
llm=model,
action_executor=ActionExecutor(
actions=[search_tool, python_interpreter]),
)
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
print(response['response'])
>>> They are both film directors.
```
### Run a ReAct model with HuggingFace backend
NOTE: If you want to run a HuggingFace model, please run `pip install -e . [all]` first.
```python
from lagent.agents import ReAct
from lagent.actions.action_executor import ActionExecutor
from lagent.llms import HFTransformer
from lagent.tools import SerperSearch, PythonInterpreter
llm = HFTransformer('internlm/internlm-7b-chat')
search_tool = SerperSearch()
python_interpreter = PythonInterpreter()
chatbot = ReAct(
llm=model,
action_executor=ActionExecutor(
actions=[search_tool, python_interpreter]),
)
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$')
print(response['response'])
>>> 根据已有的信息可以求得$z=-1+\\sqrt{3}i$然后代入计算得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此答案是C)。
```
## License
This project is released under the [Apache 2.0 license](LICENSE).

80
README_zh-CN.md Normal file
View File

@@ -0,0 +1,80 @@
# LAgent: Large Language Model as Agent
[English](README.md) | 简体中文
## 简介
LAgent是一个开源的LLM代理框架支持用户快速地将一个大语言模型转变为多种类型的智能体并提供了一些典型工具为大语言模型赋能。它的整个框架图如下:
![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6)
### 主要特点
- **实现了多种类型的智能体,** 我们支持了经典的 ReActAutoGPT 和 ReWoo 等智能体,这些智能体能够调用大语言模型进行多轮的推理和工具调用。
- **框架简单易拓展.** 框架的代码结构清晰且简单只需要不到20行代码你就能够创造出一个你自己的agent。同时我们支持了Python解释器、API 调用和搜索三类常用典型工具。
- **灵活支持多个大语言模型.** 我们提供了多种大语言模型支持,包括 InternLM、Llama-2 等开源模型和 GPT-4/3.5 等基于 API 的闭源模型。
## 教程
请阅读[概述](docs/overview.md)对LAgent进行初步的了解。同时, 我们提供了两个非常简单的code帮助你快速入门。 你也可以阅读[examples](examples/)获得更多的例子参考。
### 安装
```
git clone https://github.com/InternLM/lagent.git
cd lagent
pip install -e .
```
### 用GPT3.5构建一个ReAct代理
```python
from lagent.agents import ReAct
from lagent.llms import GPTAPI
from lagent.tools import SerperSearch, PythonInterpreter
llm = GPTAPI(model_type='gpt-3.5-turbo')
search_tool = SerperSearch()
python_interpreter = PythonInterpreter()
chatbot = ReAct(
llm=model,
action_executor=ActionExecutor(
actions=[search_tool, python_interpreter]),
)
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
print(response['response'])
>>> They are both film directors.
```
### 用HuggingFace构建一个ReAct代理
注意如果你想要启动一个HuggingFace的模型请先运行`pip install -e . [all]`
```python
from lagent.agents import ReAct
from lagent.actions.action_executor import ActionExecutor
from lagent.llms import HFTransformer
from lagent.tools import SerperSearch, PythonInterpreter
llm = HFTransformer('internlm/internlm-7b-chat')
search_tool = SerperSearch()
python_interpreter = PythonInterpreter()
chatbot = ReAct(
llm=model,
action_executor=ActionExecutor(
actions=[search_tool, python_interpreter]),
)
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$')
print(response['response'])
>>> 根据已有的信息可以求得$z=-1+\\sqrt{3}i$然后代入计算得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此答案是C)。
```
## 开源许可证
该项目采用[Apache 2.0 开源许可证](LICENSE)。

0
docs/.gitkeep Normal file
View File

23
docs/overview.md Normal file
View File

@@ -0,0 +1,23 @@
# OVERVIEW
This chapter introduces you to the framework of LAgent, and provides links to detailed tutorials about LAgent.
## What is LAgent
LAgent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below:
![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6)
LAgent consists of 3 main parts, agents, llms, and actions.
- **agents** provides agent implementation, such as ReAct, AutoGPT.
- **llms** supports various large language models, including open-sourced models (Llama-2, InterLM) through HuggingFace models or closed-source models like GPT3.5/4.
- **actions** contains a series of actions, as well as an action executor to manage all actions.
## How to Use
Here is a detailed step-by-step guide to learn more about LAgent:
1. For installation instructions, please see [get_started](get_started.md).
2. We provide several examples to build agents with LAgent in [examples](examples/) by simply run `python examples/react_example.py`.

View File

@@ -0,0 +1,46 @@
from lagent.actions.action_executor import ActionExecutor
from lagent.actions.builtin_actions import FinishAction
from lagent.actions.python_interpreter import PythonInterpreter
from lagent.agents.autogpt import AutoGPT
from lagent.llms.openai import GPTAPI
def input_prompt():
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
def main():
# set OPEN_API_KEY in your environment
model = GPTAPI(model_type='gpt-3.5-turbo')
chatbot = AutoGPT(
llm=model,
action_executor=ActionExecutor(
actions=[
PythonInterpreter(),
],
finish_action=FinishAction(
description=(
'Goals are accomplished and there is nothing left '
'to do. Parameter: {"response: "final response '
'for the goal"}')),
finish_in_action=True),
)
while True:
try:
prompt = input_prompt()
except UnicodeDecodeError:
print('UnicodeDecodeError')
continue
if prompt == 'exit':
exit(0)
agent_return = chatbot.chat(prompt)
print(agent_return.response)
if __name__ == '__main__':
main()

35
examples/chat.py Normal file
View File

@@ -0,0 +1,35 @@
from argparse import ArgumentParser
from lagent.llms.openai import GPTAPI
def parse_args():
parser = ArgumentParser(description='chatbot')
parser.add_argument('--mode', default='chat')
args = parser.parse_args()
return args
def main():
args = parse_args()
model = GPTAPI(model_type='gpt-3.5-turbo', )
history = []
while True:
try:
prompt = input('>>> ')
except UnicodeDecodeError:
print('UnicodeDecodeError')
continue
if prompt == 'exit':
exit(0)
if args.mode == 'chat':
history.append(dict(role='user', content=prompt))
response = model.generate_from_template(history, max_out_len=512)
history.append(dict(role='assistant', content=response))
elif args.mode == 'generate':
response = model.generate(prompt, max_out_len=512)
print('Assistant:', response)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,37 @@
from lagent.actions.action_executor import ActionExecutor
from lagent.actions.python_interpreter import PythonInterpreter
from lagent.agents.react import ReACT
from lagent.llms.huggingface import HFTransformer
model = HFTransformer(
path='internlm/internlm-chat-7b',
meta_template=[
dict(role='system', begin='<|System|>:', end='<TOKENS_UNUSED_2>\n'),
dict(role='user', begin='<|User|>:', end='<eoh>\n'),
dict(role='assistant', begin='<|Bot|>:', end='<eoa>\n', generate=True)
],
)
chatbot = ReACT(
llm=model,
action_executor=ActionExecutor(actions=[PythonInterpreter()]),
)
def input_prompt():
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
while True:
try:
prompt = input_prompt()
except UnicodeDecodeError:
print('UnicodeDecodeError')
continue
if prompt == 'exit':
exit(0)
agent_return = chatbot.chat(prompt)
print(agent_return.response)

36
examples/react_example.py Normal file
View File

@@ -0,0 +1,36 @@
from lagent.actions.action_executor import ActionExecutor
from lagent.actions.python_interpreter import PythonInterpreter
from lagent.agents.react import ReACT
from lagent.llms.openai import GPTAPI
def input_prompt():
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
def main():
# set OPEN_API_KEY in your environment
model = GPTAPI(model_type='gpt-3.5-turbo', )
chatbot = ReACT(
llm=model,
action_executor=ActionExecutor(actions=[PythonInterpreter()]),
)
while True:
try:
prompt = input_prompt()
except UnicodeDecodeError:
print('UnicodeDecodeError')
continue
if prompt == 'exit':
exit(0)
agent_return = chatbot.chat(prompt)
print(agent_return.response)
if __name__ == '__main__':
main()

20
examples/rewoo_example.py Normal file
View File

@@ -0,0 +1,20 @@
from lagent.actions.action_executor import ActionExecutor
from lagent.actions.llm_qa import LLMQA
from lagent.actions.serper_search import SerperSearch
from lagent.agents.rewoo import ReWOO
from lagent.llms.openai import GPTAPI
model = GPTAPI(model_type='gpt-3.5-turbo')
# please set the serper search API key
search_tool = SerperSearch(api_key=None)
llmqa_tool = LLMQA(model)
chatbot = ReWOO(
llm=model,
action_executor=ActionExecutor(actions=[llmqa_tool, search_tool]),
)
prompt = 'What profession does Nicholas Ray and Elia Kazan have in common'
agent_return = chatbot.chat(prompt)
print(agent_return.response)

View File

@@ -0,0 +1,10 @@
from .action_executor import ActionExecutor
from .base_action import BaseAction
from .builtin_actions import FinishAction, InvalidAction, NoAction
from .python_interpreter import PythonInterpreter
from .serper_search import SerperSearch
__all__ = [
'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction',
'FinishAction', 'SerperSearch', 'PythonInterpreter'
]

View File

@@ -0,0 +1,84 @@
from typing import Any, Dict, List, Union
from lagent.schema import ActionReturn, ActionValidCode
from .base_action import BaseAction
from .builtin_actions import FinishAction, InvalidAction, NoAction
class ActionExecutor:
"""The action executor class.
Args:
actions (Union[BaseAction, List[BaseAction]]): The action or actions.
invalid_action (BaseAction, optional): The invalid action. Defaults to
InvalidAction().
no_action (BaseAction, optional): The no action.
Defaults to NoAction().
finish_action (BaseAction, optional): The finish action. Defaults to
FinishAction().
finish_in_action (bool, optional): Whether the finish action is in the
action list. Defaults to False.
"""
def __init__(self,
actions: Union[BaseAction, List[BaseAction]],
invalid_action: BaseAction = InvalidAction(),
no_action: BaseAction = NoAction(),
finish_action: BaseAction = FinishAction(),
finish_in_action: bool = False):
if isinstance(actions, BaseAction):
actions = [actions]
for action in actions:
assert isinstance(action, BaseAction), \
f'action must be BaseAction, but got {type(action)}'
if finish_in_action:
actions.append(finish_action)
self.actions = {action.name: action for action in actions}
self.invalid_action = invalid_action
self.no_action = no_action
self.finish_action = finish_action
def get_actions_info(self, only_enable: bool = True) -> Dict:
if only_enable:
return {
k: v.description
for k, v in self.actions.items() if v.enable
}
else:
return {k: v.description for k, v in self.actions.items()}
def is_valid(self, name: str):
return name in self.actions and self.actions[name].enable
def action_names(self, only_enable: bool = True):
if only_enable:
return [k for k, v in self.actions.items() if v.enable]
else:
return list(self.actions.keys())
def add_action(self, action: BaseAction):
assert isinstance(action, BaseAction), \
f'action must be BaseAction, but got {type(action)}'
self.actions[action.name] = action
def del_action(self, name: str):
if name in self.actions:
del self.actions[name]
def __call__(self, name: str, command: Any) -> ActionReturn:
if isinstance(command, str):
args, kwargs = (command, ), {}
else:
args, kwargs = (), command
if not self.is_valid(name):
if name == self.no_action.name:
action_return = self.no_action.run(*args, **kwargs)
elif name == self.finish_action.name:
action_return = self.finish_action.run(*args, **kwargs)
else:
action_return = self.invalid_action(*args, **kwargs)
else:
action_return = self.actions[name].run(*args, **kwargs)
action_return.valid = ActionValidCode.OPEN
return action_return

View File

@@ -0,0 +1,57 @@
from typing import Optional
from lagent.schema import ActionReturn
class BaseAction:
"""Base class for all actions.
Args:
description (str, optional): The description of the action. Defaults to
None.
name (str, optional): The name of the action. If None, the name will
be class nameDefaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
"""
def __init__(self,
description: Optional[str] = None,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None) -> None:
if name is None:
name = self.__class__.__name__
self._name = name
self._description = description
self._disable_description = disable_description
self._enable = enable
def __call__(self, *args, **kwargs) -> ActionReturn:
raise NotImplementedError
def __repr__(self):
return f'{self.name}:{self.description}'
def __str__(self):
return self.__repr__()
def run(self, *args, **kwargs) -> ActionReturn:
return self.__call__(*args, **kwargs)
@property
def enable(self):
return self._enable
@property
def name(self):
return self._name
@property
def description(self):
if self.enable:
return self._description
else:
return self._disable_description

View File

@@ -0,0 +1,100 @@
from typing import Optional
from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
class InvalidAction(BaseAction):
"""This is a invalid action class, which is used to return error message
when the action is invalid.
Args:
err_msg (str): The error message. Defaults to 'The action is invalid,
please check the action name'.
Returns:
ActionReturn: The action return.
"""
def __init__(self,
err_msg:
str = 'The action is invalid, please check the action name.',
**kwargs) -> None:
super().__init__(enable=False, **kwargs)
self._err_msg = err_msg
def __call__(self, err_msg: Optional[str] = None):
"""Return the error message.
Args:
err_msg (str, optional): The error message. If err_msg is not None,
it will be returned, otherwise the default error message will
be returned. Defaults to None.
"""
action_return = ActionReturn(
url=None,
args=dict(text=err_msg),
errmsg=err_msg if err_msg else self._err_msg,
type=self.name,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
return action_return
class NoAction(BaseAction):
"""This is a no action class, which is used to return error message when
the response does not follow the format.
Args:
err_msg (str): The error message. Defaults to
'Please follow the format'.
"""
def __init__(self, err_msg: str = 'Please follow the format', **kwargs):
super().__init__(enable=False, **kwargs)
self._err_msg = err_msg
def __call__(self, err_msg: Optional[str] = None):
"""Return the error message.
Args:
err_msg (str, optional): The error message. If err_msg is not None,
it will be returned, otherwise the default error message will
be returned. Defaults to None.
Returns:
ActionReturn: The action return.
"""
action_return = ActionReturn(
url=None,
args=dict(text=err_msg),
type=self.name,
errmsg=err_msg if err_msg else self._err_msg,
valid=ActionValidCode.INVALID,
state=ActionStatusCode.API_ERROR)
return action_return
class FinishAction(BaseAction):
"""This is a finish action class, which is used to return the final
result."""
def __call__(self, response: str) -> ActionReturn:
"""Return the final result.
Args:
response (str): The final result.
Returns:
ActionReturn: The action return.
"""
action_return = ActionReturn(
url=None,
args=dict(text=response),
result=dict(text=response),
type=self.name,
valid=ActionValidCode.FINISH,
state=ActionStatusCode.SUCCESS)
return action_return

56
lagent/actions/llm_qa.py Normal file
View File

@@ -0,0 +1,56 @@
from typing import Optional, Union
from lagent.llms.base_api import BaseAPIModel
from lagent.llms.base_llm import BaseModel
from lagent.schema import ActionReturn, ActionStatusCode
from .base_action import BaseAction
DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。
当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。
"""
class LLMQA(BaseAction):
"""An LLM Wrapper as BaseAction type.
Args:
llm (BaseModel or BaseAPIModel): a LLM service which can chat.
description (str): The description of the action. Defaults to
None.
name (str, optional): The name of the action. If None, the name will
be class nameDefaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
"""
def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
description: str = DEFAULT_DESCRIPTION,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None) -> None:
super().__init__(description, name, enable, disable_description)
self._llm = llm
def __call__(self, query: str) -> ActionReturn:
"""Return the QA response.
Args:
query (str): The query content.
Returns:
ActionReturn: The action return.
"""
tool_return = ActionReturn(url=None, args=None)
try:
response = self._llm.generate_from_template(query, 512)
tool_return.result = dict(text=str(response))
tool_return.state = ActionStatusCode.SUCCESS
except Exception as e:
tool_return.result = dict(text=str(e))
tool_return.state = ActionStatusCode.API_ERROR
return tool_return

View File

@@ -0,0 +1,133 @@
import copy
import io
from contextlib import redirect_stdout
from typing import Any, Optional
from func_timeout import FunctionTimedOut, func_set_timeout
from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode
class GenericRuntime:
GLOBAL_DICT = {}
LOCAL_DICT = None
HEADERS = []
def __init__(self):
self._global_vars = copy.copy(self.GLOBAL_DICT)
self._local_vars = copy.copy(
self.LOCAL_DICT) if self.LOCAL_DICT else None
for c in self.HEADERS:
self.exec_code(c)
def exec_code(self, code_piece: str) -> None:
exec(code_piece, self._global_vars)
def eval_code(self, expr: str) -> Any:
return eval(expr, self._global_vars)
DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
```python
# import 依赖包
import xxx
def solution():
# 初始化一些变量
variable_names_with_real_meaning = xxx
# 步骤一
mid_variable = func(variable_names_with_real_meaning)
# 步骤 x
mid_variable = func(mid_variable)
# 最后结果
final_answer = func(mid_variable)
return final_answer
```"""
class PythonInterpreter(BaseAction):
"""A Python executor that can execute Python scripts.
Args:
description (str): The description of the action. Defaults to
DEFAULT_DESCRIPTION.
answer_symbol (str, Optional): the answer symbol from LLM
answer_expr (str, Optional): the answer function name of the Python
script. Default to 'solution()'.
answer_from_stdout (boolean): whether the execution results is from
stdout.
name (str, optional): The name of the action. If None, the name will
be class nameDefaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
timeout (int): Upper bound of waiting time for Python script execution.
"""
def __init__(self,
description: str = DEFAULT_DESCRIPTION,
answer_symbol: Optional[str] = None,
answer_expr: Optional[str] = 'solution()',
answer_from_stdout: bool = False,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None,
timeout: int = 20) -> None:
super().__init__(description, name, enable, disable_description)
self.answer_symbol = answer_symbol
self.answer_expr = answer_expr
self.answer_from_stdout = answer_from_stdout
self.timeout = timeout
def __call__(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime()
try:
tool_return = func_set_timeout(self.timeout)(self._call)(command)
except FunctionTimedOut as e:
tool_return = ActionReturn(url=None, args=None, type=self.name)
tool_return.errmsg = repr(e)
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
def _call(self, command: str) -> ActionReturn:
tool_return = ActionReturn(url=None, args=None, type=self.name)
try:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
tool_return.args = dict(text='```python\n' + command + '\n```')
command = command.split('\n')
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr)
else:
self.runtime.exec_code('\n'.join(command[:-1]))
res = self.runtime.eval_code(command[-1])
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
try:
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return

View File

@@ -0,0 +1,175 @@
import os
from typing import List, Optional, Tuple, Union
import requests
from lagent.schema import ActionReturn, ActionStatusCode
from .base_action import BaseAction
DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。
当你需要对于一个特定问题找到简短明了的回答时,可以使用它。
输入应该是一个搜索查询。
"""
class SerperSearch(BaseAction):
"""Wrapper around the Serper.dev Google Search API.
To use, you should pass your serper API key to the constructor.
Code is modified from lang-chain GoogleSerperAPIWrapper
(https://github.com/langchain-ai/langchain/blob/ba5f
baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
langchain/utilities/google_serper.py)
Args:
api_key (str): API KEY to use serper google search API,
You can create a free API key at https://serper.dev.
timeout (int): Upper bound of waiting time for an serper request.
search_type (str): Serper API support ['search', 'images', 'news',
'places'] types of search, currently we only support 'search'.
k (int): select first k results in the search results as response.
description (str): The description of the action. Defaults to
None.
name (str, optional): The name of the action. If None, the name will
be class nameDefaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
"""
result_key_for_type = {
'news': 'news',
'places': 'places',
'images': 'images',
'search': 'organic',
}
def __init__(self,
api_key: Optional[str] = None,
timeout: int = 5,
search_type: str = 'search',
k: int = 10,
description: str = DEFAULT_DESCRIPTION,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None) -> None:
super().__init__(description, name, enable, disable_description)
api_key = os.environ.get('SERPER_API_KEY', api_key)
if api_key is None:
raise ValueError(
'Please Set Serper API key either in the environment '
' as SERPER_API_KEY or pass it as `api_key` parameter.')
self.api_key = api_key
self.timeout = timeout
self.search_type = search_type
self.k = k
def __call__(self, query: str) -> ActionReturn:
"""Return the search response.
Args:
query (str): The search content.
Returns:
ActionReturn: The action return.
"""
tool_return = ActionReturn(url=None, args=None)
status_code, response = self._search(
query, search_type=self.search_type, k=self.k)
# convert search results to ToolReturn format
if status_code == -1:
tool_return.errmsg = response
tool_return.state = ActionStatusCode.HTTP_ERROR
elif status_code == 200:
parsed_res = self._parse_results(response)
tool_return.result = dict(text=str(parsed_res))
tool_return.state = ActionStatusCode.SUCCESS
else:
tool_return.errmsg = str(status_code)
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
def _parse_results(self, results: dict) -> Union[str, List[str]]:
"""Parse the search results from Serper API.
Args:
results (dict): The search content from Serper API
in json format.
Returns:
List[str]: The parsed search results.
"""
snippets = []
if results.get('answerBox'):
answer_box = results.get('answerBox', {})
if answer_box.get('answer'):
return [answer_box.get('answer')]
elif answer_box.get('snippet'):
return [answer_box.get('snippet').replace('\n', ' ')]
elif answer_box.get('snippetHighlighted'):
return answer_box.get('snippetHighlighted')
if results.get('knowledgeGraph'):
kg = results.get('knowledgeGraph', {})
title = kg.get('title')
entity_type = kg.get('type')
if entity_type:
snippets.append(f'{title}: {entity_type}.')
description = kg.get('description')
if description:
snippets.append(description)
for attribute, value in kg.get('attributes', {}).items():
snippets.append(f'{title} {attribute}: {value}.')
for result in results[self.result_key_for_type[
self.search_type]][:self.k]:
if 'snippet' in result:
snippets.append(result['snippet'])
for attribute, value in result.get('attributes', {}).items():
snippets.append(f'{attribute}: {value}.')
if len(snippets) == 0:
return ['No good Google Search Result was found']
return snippets
def _search(self,
search_term: str,
search_type: str = 'search',
**kwargs) -> Tuple[int, Union[dict, str]]:
"""HTTP requests to Serper API.
Args:
search_term (str): The search query.
search_type (str): search type supported by Serper API,
default to 'search'.
Returns:
tuple: the return value is a tuple contains:
- status_code (int): HTTP status code from Serper API.
- response (dict): response context with json format.
"""
headers = {
'X-API-KEY': self.api_key or '',
'Content-Type': 'application/json',
}
params = {
'q': search_term,
**{
key: value
for key, value in kwargs.items() if value is not None
},
}
try:
response = requests.post(
f'https://google.serper.dev/{search_type}',
headers=headers,
params=params,
timeout=self.timeout)
except Exception as e:
return -1, str(e)
return response.status_code, response.json()

View File

@@ -0,0 +1,6 @@
from .autogpt import AutoGPT
from .base_agent import BaseAgent
from .react import ReACT
from .rewoo import ReWOO
__all__ = ['BaseAgent', 'ReACT', 'AutoGPT', 'ReWOO']

288
lagent/agents/autogpt.py Normal file
View File

@@ -0,0 +1,288 @@
# flake8: noqa
import ast
import platform
from typing import Dict, List, Optional, Tuple, Union
from jsonschema import Draft7Validator
from lagent.actions import ActionExecutor
from lagent.llms.base_api import BaseAPIModel
from lagent.llms.base_llm import BaseModel
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
from .base_agent import BaseAgent
DEFAULT_TRIGGERING_PROMPT = ('Determine exactly one command to use based on '
'the given goals and the progress you have made '
'so far, and respond using the JSON schema '
'specified previously:')
DEFAULT_PREFIX = """You are {ai_name}, {role_description}. Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.
The OS you are running on is: {os_info}
## Constraints
You operate within the following constraints:
1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files.
2. 'If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
3. No user assistance
4. Exclusively use the commands listed below e.g. command_name
## Commands
You have access to the following commands:
{tool_description}
## Resources
You can leverage access to the following resources:
1. Internet access for searches and information gathering.
2. Long Term memory management.', 'File output.', 'Command execution
## Best practices
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
2. Constructively self-criticize your big-picture behavior constantly.
3. Reflect on past decisions and strategies to refine your approach.
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
## Goals
For your task, you must fulfill the following goals:
{ai_goals}
"""
DEFAULT_CALL_PROTOCOL = """Respond strictly with JSON. The JSON should be compatible with the TypeScript type `Response` from the following:
```ts
interface Response {
thoughts: {
// Thoughts
text: string;
reasoning: string;
// Short markdown-style bullet list that conveys the long-term plan
plan: string;
// Constructive self-criticism
criticism: string;
// Summary of thoughts to say to the user
speak: string;
};
command: {
name: string;
args: Record<string, any>;
};
}
```
"""
DEFAULT_SCHEMA = {
'$schema': 'http://json-schema.org/draft-07/schema#',
'type': 'object',
'properties': {
'thoughts': {
'type': 'object',
'properties': {
'text': {
'type': 'string',
'description': 'thoughts'
},
'reasoning': {
'type': 'string'
},
'plan': {
'type':
'string',
'description':
'- short bulleted\n- list that conveys\n- long-term plan'
},
'criticism': {
'type': 'string',
'description': 'constructive self-criticism'
},
'speak': {
'type': 'string',
'description': 'thoughts summary to say to user'
}
},
'required': ['text', 'reasoning', 'plan', 'criticism', 'speak'],
'additionalProperties': False
},
'command': {
'type': 'object',
'properties': {
'name': {
'type': 'string'
},
'args': {
'type': 'object'
}
},
'required': ['name', 'args'],
'additionalProperties': False
}
},
'required': ['thoughts', 'command'],
'additionalProperties': False
}
class AutoGPTProtocol:
"""A wrapper of AutoGPT prompt which manages the response from LLM and
generate desired prompts in a AutoGPT format.
Args:
ai_name (str): the name of the agent, default to 'AutoGPT'
role_description (str): description of the role, e.g., System, User
prefix (str): the prefix prompt for AutoGPT
call_protocol (str): the request prompt which defines the protocol
of return format from LLM.
valid_schema (dict): defines the schema of the return format.
triggering_prompt (str): the predefined trigger prompt.
"""
def __init__(self,
ai_name: Optional[str] = 'AutoGPT',
role_description: Optional[str] = '',
prefix: str = DEFAULT_PREFIX,
call_protocol: str = DEFAULT_CALL_PROTOCOL,
valid_schema: str = DEFAULT_SCHEMA,
triggering_prompt: str = DEFAULT_TRIGGERING_PROMPT) -> None:
self.ai_name = ai_name
self.role_description = role_description
self.prefix = prefix
self.call_protocol = call_protocol
self.valid_schema = valid_schema
self.triggering_prompt = triggering_prompt
def parse(self, response: str,
action_executor: ActionExecutor) -> Tuple[str, str]:
"""Parse the action returns in a AutoGPT format.
Args:
response (str): The response from LLM with AutoGPT format.
action_executor (ActionExecutor): Action executor to
provide no_action/finish_action name.
Returns:
tuple: the return value is a tuple contains:
- action (str): the extracted action name.
- action_input (str): the corresponding action input.
"""
try:
if response.startswith('```') and response.endswith('```'):
# Discard the first and last ```, then re-join in case the response naturally included ```
response = '```'.join(response.split('```')[1:-1])
response = ast.literal_eval(response)
validator = Draft7Validator(self.valid_schema)
valid = True
if errors := sorted(
validator.iter_errors(response), key=lambda e: e.path):
valid = False
if not valid:
return action_executor.no_action, 'Validation of response failed:\n ' + ';\n '.join(
[str(e) for e in errors])
try:
if 'command' not in response:
return action_executor.no_action, "Missing 'command' object in JSON"
if not isinstance(response, dict):
return action_executor.no_action, f'The previous message sent was not a dictionary {response}'
command = response['command']
if not isinstance(command, dict):
return action_executor.no_action, "'command' object is not a dictionary"
if 'name' not in command:
return action_executor.no_action, "Missing 'name' field in 'command' object"
command_name = command['name']
# Use an empty dictionary if 'args' field is not present in 'command' object
arguments = command.get('args', {})
return command_name, arguments
except Exception as e:
return action_executor.no_action, repr(e)
except SyntaxError as e:
return action_executor.no_action, f'Your response could not be parsed: {repr(e)} \nRemember to only respond using the specified format above!'
def format(self, goal: str, inner_history: List[Dict],
action_executor: ActionExecutor) -> List[Dict]:
"""Generate the AutoGPT format prompt.
Args:
goal (str): The user request.
inner_history (List[Dict]): The log in the current run.
action_executor (ActionExecutor): the action manager to
execute actions.
Returns:
List[Dict]: AutoGPT format prompt.
"""
import distro
formatted_data = []
os_name = platform.system()
os_info = (
platform.platform(terse=True)
if os_name != 'Linux' else distro.name(pretty=True))
prefix = self.prefix.format(
ai_name=self.ai_name,
role_description=self.role_description,
tool_description=action_executor.get_actions_info(),
ai_goals=goal,
os_info=os_info,
)
formatted_data.append(dict(role='system', content=prefix))
formatted_data.append(dict(role='system', content=self.call_protocol))
formatted_data += inner_history
formatted_data.append(
dict(role='user', content=self.triggering_prompt))
return formatted_data
def format_response(self, action_return):
"""format the final response at current step.
Args:
action_return (ActionReturn): return value of the current action.
Returns:
str: the final response at current step.
"""
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.result['text']
response = f'Command {action_return.type} returned: {response}'
else:
response = action_return.errmsg
return response
class AutoGPT(BaseAgent):
"""An implementation of AutoGPT (https://github.com/Significant-
Gravitas/Auto-GPT)
Args:
llm (BaseModel or BaseAPIModel): a LLM service which can chat
and act as backend.
action_executor (ActionExecutor): an action executor to manage
all actions and their response.
protocol (ReActProtocol): a wrapper to generate prompt and
parse the response from LLM / actions.
max_turn (int): the maximum number of trails for LLM to generate
plans that can be successfully parsed by ReWOO protocol.
"""
def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
action_executor: ActionExecutor,
protocol: AutoGPTProtocol = AutoGPTProtocol(),
max_turn: int = 2):
self.max_turn = max_turn
super().__init__(
llm=llm, action_executor=action_executor, protocol=protocol)
def chat(self, goal: str) -> AgentReturn:
self._inner_history = []
agent_return = AgentReturn()
default_response = '对不起,我无法回答你的问题'
for _ in range(self.max_turn):
prompt = self._protocol.format(
goal=goal,
inner_history=self._inner_history,
action_executor=self._action_executor)
response = self._llm.generate_from_template(prompt, 512)
self._inner_history.append(
dict(role='assistant', content=response))
action, action_input = self._protocol.parse(
response, self._action_executor)
action_return: ActionReturn = self._action_executor(
action, action_input)
agent_return.actions.append(action_return)
if action_return.type == self._action_executor.finish_action.name:
agent_return.response = action_return.result['text']
return agent_return
self._inner_history.append(
dict(
role='system',
content=self._protocol.format_response(action_return)))
agent_return.response = default_response
return agent_return

View File

@@ -0,0 +1,49 @@
from typing import List
from lagent.actions import ActionExecutor
from lagent.actions.base_action import BaseAction
from lagent.llms.base_llm import BaseModel
from lagent.schema import AgentReturn
class BaseAgent:
"""BaseAgent is the base class of all agents.
Args:
llm (BaseModel): the language model.
action_executor (ActionExecutor): the action executor.
protocol (object): the protocol of the agent, which is used to
generate the prompt of the agent and parse the response from
the llm.
"""
def __init__(self, llm: BaseModel, action_executor: ActionExecutor,
protocol: object) -> None:
self._session_history = []
self._llm = llm
self._action_executor = action_executor
self._protocol = protocol
def add_action(self, action: BaseAction) -> None:
"""Add an action to the action executor.
Args:
action (BaseAction): the action to be added.
"""
self._action_executor.add_action(action)
def del_action(self, name: str) -> None:
"""Delete an action from the action executor.
Args:
name (str): the name of the action to be deleted.
"""
self._action_executor.del_action(name)
def chat(self, message: str) -> AgentReturn:
raise NotImplementedError
@property
def session_history(self) -> List:
return self._session_history

217
lagent/agents/react.py Normal file
View File

@@ -0,0 +1,217 @@
from typing import Dict, List, Tuple, Union
from lagent.actions import ActionExecutor
from lagent.llms.base_api import BaseAPIModel
from lagent.llms.base_llm import BaseModel
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
from .base_agent import BaseAgent
CALL_PROTOCOL = """你是一个可以调用外部工具的助手,可以使用的工具包括:
{tool_description}
如果使用工具请遵循以下格式回复:
```
{thought}思考你当前步骤需要解决什么问题,是否需要使用工具
{action}工具名称,你的工具必须从 [{action_names}] 选择
{action_input}工具输入参数
```
工具返回按照以下格式回复:
```
{response}调用工具后的结果
```
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
```
{thought}给出最终答案的思考过程
{finish}最终答案
```
开始!"""
class ReACTProtocol:
"""A wrapper of ReACT prompt which manages the response from LLM and
generate desired prompts in a ReACT format.
Args:
thought (dict): the information of thought pattern
action (dict): the information of action pattern
action_input (dict): the information of action_input pattern
response (dict): the information of response pattern
finish (dict): the information of finish pattern
call_protocol (str): the format of ReACT
force_stop (str): the prompt to force LLM to generate response
"""
def __init__(self,
thought: dict = dict(
role='THOUGHT',
begin='Thought:',
end='\n',
belong='assistant'),
action: dict = dict(role='ACTION', begin='Action:', end='\n'),
action_input: dict = dict(
role='ARGS', begin='ActionInput:', end='\n'),
response: dict = dict(
role='RESPONSE', begin='Response:', end='\n'),
finish: dict = dict(
role='FINISH', begin='FinalAnswer:', end='\n'),
call_protocol: str = CALL_PROTOCOL,
force_stop: str = '你需要基于历史消息返回一个最终结果') -> None:
self.call_protocol = call_protocol
self.force_stop = force_stop
self.thought = thought
self.action = action
self.action_input = action_input
self.response = response
self.finish = finish
def format(self,
chat_history: List[Dict],
inner_step: List[Dict],
action_executor: ActionExecutor,
force_stop: bool = False) -> list:
"""Generate the ReACT format prompt.
Args:
chat_history (List[Dict]): The history log in previous runs.
inner_step (List[Dict]): The log in the current run.
action_executor (ActionExecutor): the action manager to
execute actions.
force_stop (boolean): whether force the agent to give responses
under pre-defined turns.
Returns:
List[Dict]: ReACT format prompt.
"""
call_protocol = self.call_protocol.format(
tool_description=action_executor.get_actions_info(),
action_names=action_executor.action_names(),
thought=self.thought['begin'],
action=self.action['begin'],
action_input=self.action_input['begin'],
response=self.response['begin'],
finish=self.finish['begin'],
)
formatted = []
formatted.append(dict(role='system', content=call_protocol))
formatted += chat_history
formatted += inner_step
if force_stop:
formatted.append(dict(role='system', content=self.force_stop))
return formatted
def parse(
self,
message: str,
action_executor: ActionExecutor,
) -> Tuple[str, str, str]:
"""Parse the action returns in a ReACT format.
Args:
message (str): The response from LLM with ReACT format.
action_executor (ActionExecutor): Action executor to
provide no_action/finish_action name.
Returns:
tuple: the return value is a tuple contains:
- thought (str): contain LLM thought of the current step.
- action (str): contain action scheduled by LLM.
- action_input (str): contain the required action input
for current action.
"""
import re
thought = message.split(self.action['begin'])[0]
thought = thought.split(self.thought['begin'])[-1]
thought = thought.split(self.finish['begin'])[0]
if self.finish['begin'] in message:
final_answer = message.split(self.finish['begin'])[-1]
return thought, action_executor.finish_action.name, final_answer
action_regex = f"{self.action['begin']}(.*?)\n"
args_regex = f"{self.action_input['begin']}(.*)"
action_match = re.findall(action_regex, message)
if not action_match:
return thought, action_executor.no_action.name, ''
action = action_match[-1]
arg_match = re.findall(args_regex, message, re.DOTALL)
if not arg_match:
return thought, action_executor.no_action.name, ''
action_input = arg_match[-1]
return thought, action.strip(), action_input.strip().strip('"')
def format_response(self, action_return: ActionReturn) -> str:
"""format the final response at current step.
Args:
action_return (ActionReturn): return value of the current action.
Returns:
str: the final response at current step.
"""
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.result['text']
else:
response = action_return.errmsg
return self.response['begin'] + response + self.response['end']
class ReACT(BaseAgent):
"""An implementation of ReACT (https://arxiv.org/abs/2210.03629)
Args:
llm (BaseModel or BaseAPIModel): a LLM service which can chat
and act as backend.
action_executor (ActionExecutor): an action executor to manage
all actions and their response.
protocol (ReActProtocol): a wrapper to generate prompt and
parse the response from LLM / actions.
max_turn (int): the maximum number of trails for LLM to generate
plans that can be successfully parsed by ReWOO protocol.
"""
def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
action_executor: ActionExecutor,
protocol: ReACTProtocol = ReACTProtocol(),
max_turn: int = 2) -> None:
self.max_turn = max_turn
super().__init__(
llm=llm, action_executor=action_executor, protocol=protocol)
def chat(self, message: str) -> AgentReturn:
self._inner_history = []
self._inner_history.append(dict(role='user', content=message))
agent_return = AgentReturn()
force_stop = False
default_response = '对不起,我无法回答你的问题'
for turn in range(self.max_turn):
prompt = self._protocol.format(
chat_history=self.session_history,
inner_step=self._inner_history,
action_executor=self._action_executor,
force_stop=force_stop)
response = self._llm.generate_from_template(prompt, 512)
self._inner_history.append(
dict(role='assistant', content=response))
thought, action, action_input = self._protocol.parse(
response, self._action_executor)
action_return: ActionReturn = self._action_executor(
action, action_input)
action_return.thought = thought
agent_return.actions.append(action_return)
if action_return.type == self._action_executor.finish_action.name:
agent_return.response = action_return.result['text']
return agent_return
self._inner_history.append(
dict(
role='system',
content=self._protocol.format_response(action_return)))
if turn == self.max_turn - 1:
force_stop = True
agent_return.response = default_response
# only append the user and final response
self._session_history.append(dict(role='user', content=message))
self._session_history.append(
dict(role='assistant', content=agent_return.response))
return agent_return

232
lagent/agents/rewoo.py Normal file
View File

@@ -0,0 +1,232 @@
import re
import warnings
from typing import Dict, List, Optional, Tuple, Union
from lagent.actions import ActionExecutor
from lagent.llms.base_api import BaseAPIModel
from lagent.llms.base_llm import BaseModel
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
from .base_agent import BaseAgent
PLANER_PROMPT = """你是一个任务分解器, 你需要将用户的问题拆分成多个简单的子任务。
请拆分出多个子任务项,从而能够得到充分的信息以解决问题, 返回格式如下:
```
Plan: 当前子任务要解决的问题
#E[id] = 工具名称[工具参数]
Plan: 当前子任务要解决的问题
#E[id] = 工具名称[工具参数]
```
其中
1. #E[id] 用于存储Plan id的执行结果, 可被用作占位符。
2. 每个 #E[id] 所执行的内容应与当前Plan解决的问题严格对应。
3. 工具参数可以是正常输入text, 或是 #E[依赖的索引], 或是两者都可以。
4. 工具名称必须从一下工具中选择:
{tool_description}
注意: 每个Plan后有且仅能跟随一个#E。
开始!"""
WORKER_PROMPT = """
Thought: {thought}\nResponse: {action_resp}\n
"""
SOLVER_PROMPT = """解决接下来的任务或者问题。为了帮助你,我们提供了一些相关的计划
和相应的解答。注意其中一些信息可能存在噪声,因此你需要谨慎的使用它们。\n
{question}\n{worker_log}\n现在开始回答这个任务或者问题。请直接回答这个问题,
不要包含其他不需要的文字。{question}\n
"""
class ReWOOProtocol:
"""A wrapper of ReWOO prompt which manages the response from LLM and
generate desired prompts in a ReWOO format.
Args:
planner_prompt (str): prompt template for planner
solver_prompt (str): prompt template for solver
"""
def __init__(
self,
planner_prompt: str = PLANER_PROMPT,
worker_prompt: str = WORKER_PROMPT,
solver_prompt: str = SOLVER_PROMPT,
) -> None:
self.planner_prompt = planner_prompt
self.worker_prompt = worker_prompt
self.solver_prompt = solver_prompt
def format_planner(self,
chat_history: List[Dict],
inner_step: List[Dict],
action_executor: ActionExecutor,
reformat_request: Optional[str] = '') -> List[Dict]:
"""Generate the planner prompt required by ReWOO.
Args:
chat_history (List[Dict]): The history log in previous runs.
inner_step (List[Dict]): The log in the current run.
action_executor (ActionExecutor): the action manager to execute
actions.
reformat_request (str): the error feedback if the LLM fails to
generate required format for planner.
Returns:
List[Dict]: ReWOO format prompt for planner.
"""
planner_prompt = self.planner_prompt.format(
tool_description=action_executor.get_actions_info(), )
formatted = []
formatted.append(dict(role='system', content=planner_prompt))
formatted += chat_history
formatted += inner_step
if reformat_request != '':
formatted.append(
dict(
role='system',
content='回答格式错误: %s. 请重新重新回答: ' % reformat_request))
return formatted
def parse_worker(
self,
message: str,
) -> Tuple[List[str], List[str], List[str]]:
"""Parse the LLM generated planner response and convert it into the
worker format.
Args:
message (str): The response from LLM with ReWOO planner format.
Returns:
tuple: the return value is a tuple contains:
- thought_list (List(str)): contain LLM thoughts of the user
request.
- action_list (List(str)): contain actions scheduled by LLM.
- action_input_list (List(str)): contain the required action
input for above actions.
"""
action_list = []
action_input_list = []
thought_list = []
thoughts = re.findall('Plan: (.+)', message)
action_units = re.findall('#E[0-9]* = (.+)', message)
assert len(thoughts) == len(action_units), \
'Each Plan should only correspond to only ONE action'
for thought, action_unit in zip(thoughts, action_units):
action_name, action_input = re.findall(r'(.*?)\[(.*?)\]',
action_unit.strip())[0]
action_list.append(action_name.strip())
action_input_list.append(action_input.strip())
thought_list.append(thought.strip())
return thought_list, action_list, action_input_list
def format_solver(
self, question: str, thought_list: List[str],
action_return_list: List[ActionReturn]) -> Tuple[str, str]:
"""Generate the prompt for solver in a ReWOO format.
Args:
question (str): The user request in the current run.
thought_list (List[str]): thoughts generated from LLM for
each action.
action_return_list (List[ActionReturn]): action returns
from workers.
Returns:
tuple: the return value is a tuple contains:
- solver_prompt (str): the generated prompt for solver
in a ReWOO format.
- worker_log (str): contain action responses from workers.
Used for inner log.
"""
worker_log = ''
for thought, action_return in zip(thought_list, action_return_list):
if action_return.state == ActionStatusCode.SUCCESS:
action_resp = action_return.result['text']
else:
action_resp = action_return.errmsg
worker_response = self.worker_prompt.format(
thought=thought, action_resp=action_resp)
worker_log += worker_response
solver_prompt = self.solver_prompt.format(
question=question, worker_log=worker_log)
return solver_prompt, worker_log
class ReWOO(BaseAgent):
"""An implementation of ReWOO (https://arxiv.org/abs/2305.18323)
Args:
llm (BaseModel or BaseAPIModel): a LLM service which can chat
and act as planner / solver.
action_executor (ActionExecutor): an action executor to manage
all actions and their response.
protocol (ReWOOProtocol): a wrapper to generate prompt and
parse the response from LLM / actions.
max_turn (int): the maximum number of trails for LLM to generate
plans that can be successfully parsed by ReWOO protocol.
"""
def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
action_executor: ActionExecutor,
protocol: ReWOOProtocol = ReWOOProtocol(),
max_turn: int = 2) -> None:
super().__init__(
llm=llm, action_executor=action_executor, protocol=protocol)
self.max_turn = max_turn
def chat(self, message: str) -> AgentReturn:
self._inner_history = []
self._inner_history.append(dict(role='user', content=message))
agent_return = AgentReturn()
# planner
turn_id = 0
reformat_request = ''
while turn_id < self.max_turn:
planner_prompt = self._protocol.format_planner(
chat_history=self.session_history,
inner_step=self._inner_history,
action_executor=self._action_executor,
reformat_request=reformat_request)
response = self._llm.generate_from_template(planner_prompt, 512)
self._inner_history.append(
dict(role='assistant', content=response))
try:
thoughts, actions, actions_input = self._protocol.parse_worker(
response)
break
except Exception as e:
turn_id += 1
reformat_request = str(e)
if turn_id >= self.max_turn:
warnings.warn('\nUnable to parse LLM responses in %d turns, '
'directly request solver for question answer...' %
self.max_turn)
actions = []
thoughts = []
action_responses = []
# workers
action_responses = []
for action_id in range(len(actions)):
# we need to change actions_input inplace
prev_ptrs = re.findall(r'#E\d+', actions_input[action_id])
for prev_ptr in prev_ptrs:
ptr_num = int(prev_ptr.strip('#E')) - 1 # start from 0
actions_input[action_id] = actions_input[action_id].replace(
prev_ptr, action_responses[ptr_num].result['text'])
action_return: ActionReturn = self._action_executor(
actions[action_id], actions_input[action_id])
action_responses.append(action_return)
solver_prompt, worker_log = self._protocol.format_solver(
message, thoughts, action_responses)
self._inner_history.append(dict(role='system', content=worker_log))
final_response = self._llm.generate_from_template(solver_prompt, 512)
self._inner_history.append(
dict(role='assistant', content=final_response))
agent_return.response = final_response
return agent_return

10
lagent/llms/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
from lagent.utils import is_module_exist
from .base_api import BaseAPIModel
from .base_llm import BaseModel
from .openai import GPTAPI
__all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI']
if is_module_exist('transformers'):
from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401
__all__.extend(['HFTransformer', 'HFTransformerCasualLM'])

240
lagent/llms/base_api.py Normal file
View File

@@ -0,0 +1,240 @@
import re
import threading
import warnings
from abc import abstractclassmethod
from time import sleep
from typing import Dict, List, Optional, Tuple, Union
from .base_llm import BaseModel
class BaseAPIModel(BaseModel):
"""Base class for API model wrapper.
Args:
model_type (str): The type of model.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""
is_api: bool = True
def __init__(self,
model_type: str,
query_per_second: int = 1,
retry: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None):
self.model_type = model_type
self.max_seq_len = max_seq_len
self.meta_template = meta_template
self.retry = retry
self.query_per_second = query_per_second
self.token_bucket = TokenBucket(query_per_second)
self.template_parser = APITemplateParser(meta_template)
@abstractclassmethod
def generate(self, inputs, max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or list]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
# Count English words
english_count = sum(len(part.split()) for part in english_parts)
# Count Chinese words
chinese_count = sum(len(part) for part in chinese_parts)
return english_count + chinese_count
def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
def to(self, device):
pass
class APITemplateParser:
"""Intermidate prompt template parser, specifically for API models.
Args:
meta_template (Dict): The meta template for the model.
"""
def __init__(self, meta_template: Optional[Dict] = None):
self.meta_template = meta_template
# Check meta template
if meta_template:
assert isinstance(meta_template, list)
self.roles: Dict[str, dict] = dict() # maps role name to config
for item in meta_template:
assert isinstance(item, dict)
assert item['role'] not in self.roles, \
'role in meta prompt must be unique!'
self.roles[item['role']] = item.copy()
def parse_template(self, dialog: List[Union[str, List]]):
"""Parse the intermidate prompt template, and wrap it with meta
template if applicable. When the meta template is set and the input is
a list, the return value will be a list containing the full
conversation history. Each item looks like:
.. code-block:: python
{'role': 'user', 'content': '...'}).
Args:
dialog (List[str or list]): An intermidate prompt
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns:
List[str or list]: The finalized prompt or a conversation.
"""
assert isinstance(dialog, (str, list))
if isinstance(dialog, str):
return dialog
if self.meta_template:
prompt = list()
# Whether to keep generating the prompt
generate = True
for i, item in enumerate(dialog):
if not generate:
break
if isinstance(item, str):
if item.strip():
# TODO: logger
warnings.warn('Non-empty string in prompt template '
'will be ignored in API models.')
else:
api_prompts = self._prompt2api(item)
prompt.append(api_prompts)
# merge the consecutive prompts assigned to the same role
new_prompt = list([prompt[0]])
last_role = prompt[0]['role']
for item in prompt[1:]:
if item['role'] == last_role:
new_prompt[-1]['content'] += '\n' + item['content']
else:
last_role = item['role']
new_prompt.append(item)
prompt = new_prompt
else:
# in case the model does not have any meta template
prompt = ''
last_sep = ''
for item in dialog:
if isinstance(item, str):
if item:
prompt += last_sep + item
elif item.get('content', ''):
prompt += last_sep + item.get('content', '')
last_sep = '\n'
return prompt
def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
"""Convert the prompts to a API-style prompts, given an updated
role_dict.
Args:
prompts (Union[List, str]): The prompts to be converted.
role_dict (Dict[str, Dict]): The updated role dict.
for_gen (bool): If True, the prompts will be converted for
generation tasks. The conversion stops before the first
role whose "generate" is set to True.
Returns:
Tuple[str, bool]: The converted string, and whether the follow-up
conversion should be proceeded.
"""
if isinstance(prompts, str):
return prompts
elif isinstance(prompts, dict):
api_role = self._role2api_role(prompts)
return api_role
res = []
for prompt in prompts:
if isinstance(prompt, str):
raise TypeError('Mixing str without explicit role is not '
'allowed in API models!')
else:
api_role = self._role2api_role(prompt)
res.append(api_role)
return res
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
merged_prompt = self.roles.get(
role_prompt['role'],
self.roles.get(
self.roles[role_prompt['role']].get('fallback_role')))
res = {}
res['role'] = merged_prompt['api_role']
res['content'] = merged_prompt.get('begin', '')
res['content'] += role_prompt.get('content', '')
res['content'] += merged_prompt.get('end', '')
return res
class TokenBucket:
"""A token bucket for rate limiting.
Args:
rate (float): The rate of the token bucket.
"""
def __init__(self, rate: float) -> None:
self._rate = rate
self._tokens = threading.Semaphore(0)
self.started = False
def _add_tokens(self):
"""Add tokens to the bucket."""
while True:
if self._tokens._value < self._rate:
self._tokens.release()
sleep(1 / self._rate)
def get_token(self):
"""Get a token from the bucket."""
if not self.started:
self.started = True
threading.Thread(target=self._add_tokens, daemon=True).start()
self._tokens.acquire()

146
lagent/llms/base_llm.py Normal file
View File

@@ -0,0 +1,146 @@
from abc import abstractclassmethod
from typing import Dict, List, Optional, Tuple, Union
class BaseModel:
"""Base class for model wrapper.
Args:
path (str): The path to the model.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""
is_api: bool = False
def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None):
self.path = path
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
# meta template
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
@abstractclassmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
def parse_template(self, dialog) -> str:
"""Parse a prompt template, and wrap it with meta template if
applicable.
Args:
dialog (List[str or PromptList]): A prompt
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns:
str: The final string.
"""
return self.template_parser.parse_template(dialog)
def generate_from_template(self, templates, max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates)
return self.generate(inputs, max_out_len=max_out_len, **kwargs)
def to(self, device):
self.model.to(device)
class LMTemplateParser:
"""Intermidate prompt template parser, specifically for language models.
Args:
meta_template (Dict): The meta template for the model.
"""
def __init__(self, meta_template: Optional[Dict] = None):
self.meta_template = meta_template
if meta_template:
assert isinstance(meta_template, list)
self.roles: Dict[str, dict] = dict() # maps role name to config
for item in meta_template:
assert isinstance(item, dict)
assert item['role'] not in self.roles, \
'role in meta prompt must be unique!'
self.roles[item['role']] = item.copy()
def parse_template(self, dialog) -> str:
"""Parse a prompt template, and wrap it with meta template if
applicable.
Args:
dialog (List[str or PromptList]): A prompt
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns:
str: The final string.
"""
assert isinstance(dialog, (str, list))
if isinstance(dialog, str):
return dialog
if self.meta_template:
prompt = ''
for index, item in enumerate(dialog):
if isinstance(item, str):
prompt += item
else:
new_str = self._prompt2str(item, index == len(dialog) - 1)
prompt += new_str
else:
# in case the model does not have any meta template
prompt = ''
last_sep = ''
for item in dialog:
if isinstance(item, str):
if item:
prompt += last_sep + item
elif item.get('content', ''):
prompt += last_sep + item.get('prompt', '')
last_sep = '\n'
return prompt
def _prompt2str(self,
prompt: Union[List, str, Dict],
last: bool = False) -> Tuple[str, bool]:
if isinstance(prompt, str):
return prompt
merged_prompt = self.roles.get(
prompt['role'],
self.roles.get(self.roles[prompt['role']].get('fallback_role')))
res = merged_prompt.get('begin', '')
if last and merged_prompt.get('generate', False):
res += prompt.get('content', '')
return res
res += prompt.get('content', '') + merged_prompt.get('end', '')
if last and merged_prompt['role'] != 'assistant':
res += self.roles['assistant']['begin']
return res
return res

128
lagent/llms/huggingface.py Normal file
View File

@@ -0,0 +1,128 @@
from typing import Dict, List, Optional
import torch
from .base_llm import BaseModel
class HFTransformer(BaseModel):
"""Model wrapper around HuggingFace general models.
Adapted from OpenCompass (https://github.com/InternLM/opencompass
/blob/main/opencompass/models/huggingface.py)
Args:
path (str): The name or path to HuggingFace's model.
max_seq_len (int): The maximum length of the input sequence. Defaults
to 2048.
tokenizer_path (str): The path to the tokenizer. Defaults to None.
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
Defaults to {}.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
model_kwargs (dict): Keyword arguments for the model, used in loader.
Defaults to dict(device_map='auto').
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
extract_pred_after_decode (bool): Whether to extract the prediction
string from the decoded output string, instead of extract the
prediction tokens before decoding. Defaults to False.
batch_padding (bool): If False, inference with be performed in for-loop
without batch padding.
Note:
About ``extract_pred_after_decode``: Commonly, we should extract the
the prediction tokens before decoding. But for some tokenizers using
``sentencepiece``, like LLaMA, this behavior may change the number of
whitespaces, which is harmful for Python programming tasks.
"""
def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = None,
extract_pred_after_decode: bool = False,
batch_padding: bool = False):
super().__init__(
path=path,
max_seq_len=max_seq_len,
tokenizer_only=tokenizer_only,
meta_template=meta_template)
self._load_tokenizer(
path=path,
tokenizer_path=tokenizer_path,
tokenizer_kwargs=tokenizer_kwargs)
self.batch_padding = batch_padding
self.extract_pred_after_decode = extract_pred_after_decode
if not tokenizer_only:
self._load_model(path=path, model_kwargs=model_kwargs)
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
tokenizer_kwargs: dict):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path if tokenizer_path else path,
trust_remote_code=True,
**tokenizer_kwargs)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def _load_model(self, path: str, model_kwargs: dict):
from transformers import AutoModel
model_kwargs.setdefault('torch_dtype', torch.float16)
self.model = AutoModel.from_pretrained(
path, trust_remote_code=True, **model_kwargs)
self.model.eval()
def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
if isinstance(inputs, str):
inputs = [inputs]
if self.extract_pred_after_decode:
prompt_lens = [len(input_) for input_ in inputs]
input_ids = self.tokenizer(
inputs, truncation=True,
max_length=self.max_seq_len - max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
outputs = self.model.generate(
input_ids, max_new_tokens=max_out_len, **kwargs)
if not self.extract_pred_after_decode:
outputs = outputs[:, input_ids.shape[1]:]
decodeds = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True)
if self.extract_pred_after_decode:
decodeds = [
token[len_:] for token, len_ in zip(decodeds, prompt_lens)
]
return decodeds[0]
def generate_from_template(self, templates, max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates)
response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
return response.replace(
self.template_parser.roles['assistant']['end'].strip(),
'').strip()
class HFTransformerCasualLM(HFTransformer):
def _load_model(self, path: str, model_kwargs: dict):
from transformers import AutoModelForCausalLM
model_kwargs.setdefault('torch_dtype', torch.float16)
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
self.model.eval()

242
lagent/llms/openai.py Normal file
View File

@@ -0,0 +1,242 @@
import json
import os
import time
# from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
import requests
from .base_api import BaseAPIModel
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
class GPTAPI(BaseAPIModel):
"""Model wrapper around OpenAI's models.
Args:
model_type (str): The name of OpenAI's model.
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
key (str or List[str]): OpenAI key(s). In particular, when it
is set to "ENV", the key will be fetched from the environment
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
list, the keys will be used in round-robin manner. Defaults to
'ENV'.
org (str or List[str], optional): OpenAI organization(s). If not
specified, OpenAI uses the default organization bound to each API
key. If specified, the orgs will be posted with each request in
round-robin manner. Defaults to None.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
"""
is_api: bool = True
def __init__(self,
model_type: str = 'gpt-3.5-turbo',
max_seq_len: int = 4096,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant')
],
openai_api_base: str = OPENAI_API_BASE,
temperature: Optional[float] = None):
super().__init__(
model_type=model_type,
max_seq_len=max_seq_len,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry)
self.logger = getLogger(__name__)
self.temperature = temperature
if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
else:
self.keys = key
# record invalid keys and skip them when requesting API
# - keys have insufficient_quota
self.invalid_keys = set()
self.key_ctr = 0
if isinstance(org, str):
self.orgs = [org]
else:
self.orgs = org
self.org_ctr = 0
self.url = openai_api_base
self.model_type = model_type
# max num token for gpt-3.5-turbo is 4097
context_window = 4096
if '32k' in self.model_type:
context_window = 32768
elif '16k' in self.model_type:
context_window = 16384
elif 'gpt-4' in self.model_type:
context_window = 8192
self.context_window = context_window
def generate(
self,
inputs: Union[List, str],
max_out_len: int = 512,
temperature: float = 0.7,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or List]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
Returns:
List[str]: A list of generated strings.
"""
if self.temperature is not None:
temperature = self.temperature
return self._generate(inputs, max_out_len, temperature)
def _generate(self, input: str or List, max_out_len: int,
temperature: float) -> str:
"""Generate results given a list of inputs.
Args:
inputs (str or List): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert isinstance(input, (str, list, dict))
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
elif isinstance(input, dict):
messages = [input]
else:
messages = input
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len = min(
max_out_len,
self.context_window - self.get_token_len(str(input)) - 100)
if max_out_len <= 0:
return ''
max_num_retries = 0
while max_num_retries < self.retry:
self.wait()
with Lock():
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')
# find the next valid key
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0
if self.keys[self.key_ctr] not in self.invalid_keys:
break
key = self.keys[self.key_ctr]
header = {
'Authorization': f'Bearer {key}',
'content-type': 'application/json',
}
if self.orgs:
with Lock():
self.org_ctr += 1
if self.org_ctr == len(self.orgs):
self.org_ctr = 0
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
try:
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_out_len,
n=1,
stop=None,
temperature=temperature,
)
raw_response = requests.post(
self.url, headers=header, data=json.dumps(data))
except requests.ConnectionError:
print('Got connection error, retrying...')
continue
try:
response = raw_response.json()
except requests.JSONDecodeError:
print('JsonDecode error, got', str(raw_response.content))
continue
try:
return response['choices'][0]['message']['content'].strip()
except KeyError:
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1)
continue
elif response['error']['code'] == 'insufficient_quota':
self.invalid_keys.add(key)
self.logger.warn(f'insufficient_quota key: {key}')
continue
print('Find error message in response: ',
str(response['error']))
max_num_retries += 1
raise RuntimeError('Calling OpenAI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
import tiktoken
self.tiktoken = tiktoken
enc = self.tiktoken.encoding_for_model(self.model_type)
return len(enc.encode(prompt))

79
lagent/schema.py Normal file
View File

@@ -0,0 +1,79 @@
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
from lagent.utils import is_module_exist
def enum_dict_factory(inputs):
inputs = [(i[0], i[-1].value) if isinstance(i[-1], Enum) else i
for i in inputs]
return dict(inputs)
def dataclass2dict(data):
return asdict(data, dict_factory=enum_dict_factory)
class ActionStatusCode(int, Enum):
ING = 1
SUCCESS = 0
HTTP_ERROR = -1000 # http error
ARGS_ERROR = -1001 # 参数错误
API_ERROR = -1002 # 不知道的API错误
class ActionValidCode(int, Enum):
FINISH = 1
OPEN = 0
CLOSED = -1
INVALID = -2
ABSENT = -3 # NO ACTION
@dataclass
class ActionReturn:
args: Dict
url: Optional[str] = None
type: Optional[str] = None
result: Optional[str] = None
errmsg: Optional[str] = None
state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS
thought: Optional[str] = None
valid: Optional[ActionValidCode] = ActionValidCode.OPEN
class AgentStatusCode(Enum):
END = 0 # end of streaming
STREAM_ING = 1 # response is in streaming
SERVER_ERR = -1 # triton server's error
SESSION_CLOSED = -2 # session has been closed
SESSION_OUT_OF_LIMIT = -3 # request length out of limit
CMD = 2 # return command
SESSION_INVALID_ARG = -4 # invalid argument
SESSION_READY = 3 # session is ready for inference
@dataclass
class AgentReturn:
actions: List[ActionReturn] = field(default_factory=list)
response: str = ''
inner_steps: List = field(default_factory=list)
errmsg: Optional[str] = None
if is_module_exist('lmdeploy'):
from lmdeploy.serve.turbomind.chatbot import StatusCode
STATE_MAP = {
StatusCode.TRITON_STREAM_END: AgentStatusCode.END,
StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR,
StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED,
StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING,
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
AgentStatusCode.SESSION_OUT_OF_LIMIT,
StatusCode.TRITON_SESSION_INVALID_ARG:
AgentStatusCode.SESSION_INVALID_ARG,
StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY
}
else:
STATE_MAP = {}

3
lagent/utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .package import is_module_exist
__all__ = ['is_module_exist']

9
lagent/utils/package.py Normal file
View File

@@ -0,0 +1,9 @@
import importlib
def is_module_exist(module_name):
spec = importlib.util.find_spec(module_name)
if spec is None:
return False
else:
return True

35
lagent/version.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
__version__ = '0.1.0'
def parse_version_info(version_str: str, length: int = 4) -> tuple:
"""Parse a version string into a tuple.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
(1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
(2, 0, 0, 0, 'rc', 1) (when length is set to 4).
"""
from packaging.version import parse
version = parse(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
release.extend(list(version.pre)) # type: ignore
elif version.is_postrelease:
release.extend(list(version.post)) # type: ignore
else:
release.extend([0, 0])
return tuple(release)
version_info = tuple(int(x) for x in __version__.split('.')[:3])
__all__ = ['__version__', 'version_info', 'parse_version_info']

0
projects/.gitkeep Normal file
View File

2
requirements.txt Normal file
View File

@@ -0,0 +1,2 @@
-r requirements/optional.txt
-r requirements/runtime.txt

View File

@@ -0,0 +1,2 @@
torch
transformers

5
requirements/runtime.txt Normal file
View File

@@ -0,0 +1,5 @@
distro
func_timeout
jsonschema
requests
tiktoken

24
setup.cfg Normal file
View File

@@ -0,0 +1,24 @@
[isort]
line_length = 79
multi_line_output = 0
extra_standard_library = setuptools
known_first_party = mmdet
known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,mmengine,numpy,onnx,onnxruntime,pycocotools,parameterized,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
[yapf]
BASED_ON_STYLE = pep8
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
# ignore-words-list needs to be lowercase format. For example, if we want to
# ignore word "BA", then we need to append "ba" to ignore-words-list rather
# than "BA"
[codespell]
skip = *.ipynb
quiet-level = 3
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer
[flake8]
per-file-ignores = ftdp/configs/*: F401,F403,F405

107
setup.py Normal file
View File

@@ -0,0 +1,107 @@
from pathlib import Path
from setuptools import find_packages, setup
def get_version():
version_file = 'lagent/version.py'
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def parse_requirements(fname='requirements.txt', with_version=True):
"""Parse the package dependencies listed in a requirements file but strip
specific version information.
Args:
fname (str): Path to requirements file.
with_version (bool, default=False): If True, include version specs.
Returns:
info (list[str]): List of requirements items.
CommandLine:
python -c "import setup; print(setup.parse_requirements())"
"""
import re
import sys
from os.path import exists
require_fpath = fname
def parse_line(line):
"""Parse information from a line in a requirements text file."""
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
for info in parse_require_file(target):
yield info
else:
info = {'line': line}
if line.startswith('-e '):
info['package'] = line.split('#egg=')[1]
else:
# Remove versioning from the package
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]
info['package'] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
if ';' in rest:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version, platform_deps = map(str.strip,
rest.split(';'))
info['platform_deps'] = platform_deps
else:
version = rest # NOQA
info['version'] = (op, version)
yield info
def parse_require_file(fpath):
with open(fpath) as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
yield from parse_line(line)
def gen_packages_items():
if exists(require_fpath):
for info in parse_require_file(require_fpath):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
# apparently package_deps are broken in 3.4
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(';' + platform_deps)
item = ''.join(parts)
yield item
packages = list(gen_packages_items())
return packages
if __name__ == '__main__':
with Path(Path(__file__).parent,
'README.md').open(encoding='utf-8') as file:
long_description = file.read()
setup(
name='lagent',
packages=find_packages(),
include_package_data=True,
version=get_version(),
license='Apache 2.0',
description='An open-source framework for language agent research.',
long_description=long_description,
long_description_content_type='text/markdown',
data_files=[('.', ['README.md'])],
keywords=['artificial general intelligence', 'agent', 'agi', 'llm'],
install_requires=parse_requirements('requirements/runtime.txt'),
extras_require={
'all': parse_requirements('requirements.txt'),
'optional': parse_requirements('requirements/optional.txt'),
},
)

144
tests/data/search.json Normal file
View File

@@ -0,0 +1,144 @@
{
"searchParameters": {
"q": "What is the capital of China?",
"gl": "us",
"hl": "en",
"num": 10,
"type": "search"
},
"answerBox": {
"title": "China / Capital",
"answer": "Beijing"
},
"organic": [
{
"title": "Beijing - Wikipedia",
"link": "https://en.wikipedia.org/wiki/Beijing",
"snippet": "With over 21 million residents, Beijing is the world's most populous national capital city as well as China's second largest city after Shanghai.",
"sitelinks": [
{
"title": "Etymology",
"link": "https://en.wikipedia.org/wiki/Beijing#Etymology"
},
{
"title": "History",
"link": "https://en.wikipedia.org/wiki/Beijing#History"
},
{
"title": "Geography",
"link": "https://en.wikipedia.org/wiki/Beijing#Geography"
}
],
"position": 1
},
{
"title": "What is the Capital of China? - Mappr",
"link": "https://www.mappr.co/capital-cities/china/",
"snippet": "Beijing, also known as Peking, is the capital of China and one of the most populous cities in the world. It is the country's political, educational, ...",
"position": 2
},
{
"title": "Google Map of the City of Beijing, capital of P.R. China - Nations Online Project",
"link": "https://www.nationsonline.org/oneworld/map/google_map_Beijing.htm",
"snippet": "Google Earth: Searchable map/satellite view of Beijing, capital city of P.R. China. City Coordinates: 39°5450″N 116°2330″E, Bookmark/share this page ...",
"position": 3
},
{
"title": "Capital City of China - CountryReports.org",
"link": "https://www.countryreports.org/country/china/capital-city.htm",
"snippet": "Capital City, Beijing ; Capital location, 39 55 N, 116 23 E ; Capital - history, (Peking) Founded about 3,000 years ago on the site of a former Chinese capital, ...",
"position": 4
},
{
"title": "Capital of China - Beijing",
"link": "https://www.chinahighlights.com/beijing/capital-of-china.htm",
"snippet": "In Chinese, Beijing means 'Northern Capital'. Bei means 'north' and jing means 'capital'. In the history of China, Beijing's name has changed ...",
"date": "Dec 15, 2022",
"position": 5
},
{
"title": "Beijing is the capital of the People's Republic of China. It is the world's most populous capital city, with over 21 million residents within an... | By DL&D Consult | Facebook",
"link": "https://facebook.com/dldconsult/videos/beijing-capital-city-of-china/373001500555301/",
"snippet": "Beijing is an important world capital and global power ...",
"date": "Oct 19, 2020",
"attributes": {
"Duration": "2:58",
"Posted": "Oct 19, 2020"
},
"imageUrl": "https://encrypted-tbn1.gstatic.com/images?q=tbn:ANd9GcQExx68yUr7xP_1wcRapEKlT5bxe4ptMa6WaLnwXdVAtAdloa7WeTIvCoJp",
"position": 6
},
{
"title": "What is the Capital of China? - WorldAtlas",
"link": "https://www.worldatlas.com/articles/what-is-the-capital-of-china.html",
"snippet": "The capital of China is Beijing.",
"date": "Jul 3, 2018",
"position": 7
},
{
"title": "A Chinese capital that's not Beijing - BBC Travel",
"link": "https://www.bbc.com/travel/article/20151008-a-chinese-capital-thats-not-beijing",
"snippet": "Beijing may be the capital of China today, but for many centuries the country was ruled from Nanjing, a historic city located on the shores ...",
"date": "Oct 13, 2015",
"position": 8
},
{
"title": "Beijing | Province, City, History, Map, & Facts - Britannica",
"link": "https://www.britannica.com/place/Beijing",
"snippet": "Beijing, city, province-level shi (municipality), and capital of the People's Republic of China. The city has been an integral part of China's history over ...",
"position": 9
}
],
"peopleAlsoAsk": [
{
"question": "Does China have 2 capitals?",
"snippet": "There are traditionally four major historical capitals of China referred to as\nthe 'Four Great Ancient Capitals of China' (simplified Chinese: 中国四大古都;\ntraditional Chinese: 中國四大古都; pinyin: Zhōngguó Sì Dà Gǔ Dū). The four are\nBeijing, Nanjing, Luoyang and Xi\"an (Chang\"an).",
"title": "Historical capitals of China - Wikipedia",
"link": "https://en.wikipedia.org/wiki/Historical_capitals_of_China"
},
{
"question": "What is the capital city of China USA?",
"snippet": "Capital City\nBeijing\nCapital - time difference\nUTC+8 (13 hours ahead of Washington, DC, during Standard Time) note; despite its\nsize, all of China falls within one time zone",
"title": "Capital City of China - CountryReports.org",
"link": "https://www.countryreports.org/country/china/capital-city.htm"
},
{
"question": "Is Hong Kong is a part of China?",
"snippet": "Hong Kong (US: /ˈhɒŋkɒŋ/ or UK: /hɒŋˈkɒŋ/; Chinese: 香港, Cantonese: [hœ́ːŋ.kɔ̌ːŋ]\n( listen)), officially the Hong Kong Special Administrative Region of the\nPeople's Republic of China (abbr. Hong Kong SAR or HKSAR), is a city and a\nspecial administrative region in China.",
"title": "Hong Kong - Wikipedia",
"link": "https://en.wikipedia.org/wiki/Hong_Kong"
},
{
"question": "Why China changed its capital?",
"snippet": "Once in charge, it wasn't uncommon for a new emperor to shift the imperial\ncapital in order to: Rebuild after a great loss, as in the Han era when Liu Bang\nmoved the capital from Xianyang to nearby Chang'an (now Xi'an), after the former\nwas destroyed during a rebellion.",
"title": "Why do they keep moving the Capital? - China Simplified",
"link": "https://www.chinasimplified.com/2014/09/29/why-do-they-keep-moving-the-capital/"
}
],
"relatedSearches": [
{
"query": "China map"
},
{
"query": "Where is the capital of china on a map"
},
{
"query": "Beijing population"
},
{
"query": "Capital of Korea"
},
{
"query": "Beijing pronunciation"
},
{
"query": "What is the capital of India"
},
{
"query": "What is the capital of Japan"
},
{
"query": "What is the capital of Pakistan"
}
]
}

View File

@@ -0,0 +1,45 @@
from unittest import TestCase
from lagent.actions.builtin_actions import (FinishAction, InvalidAction,
NoAction)
from lagent.schema import ActionStatusCode
class TestFinishAction(TestCase):
def test_call(self):
action = FinishAction()
response = 'finish'
action_return = action(response)
self.assertEqual(action_return.state, ActionStatusCode.SUCCESS)
self.assertDictEqual(action_return.result, dict(text='finish'))
class TestInvalidAction(TestCase):
def test_call(self):
action = InvalidAction()
response = 'invalid'
action_return = action(response)
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
self.assertEqual(action_return.errmsg, response)
action = InvalidAction(err_msg='error')
action_return = action()
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
self.assertEqual(action_return.errmsg, 'error')
class TestNoAction(TestCase):
def test_call(self):
action = NoAction()
response = 'no'
action_return = action(response)
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
self.assertEqual(action_return.errmsg, response)
action = NoAction(err_msg='error')
action_return = action()
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
self.assertEqual(action_return.errmsg, 'error')

View File

@@ -0,0 +1,21 @@
from unittest import TestCase
from lagent.actions.python_interpreter import PythonInterpreter
from lagent.schema import ActionStatusCode
class TestPythonInterpreter(TestCase):
def test_python_executor(self):
python_executor = PythonInterpreter()
tool_return = python_executor(
'```python\ndef solution():\n return 1\n```')
self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS)
self.assertDictEqual(tool_return.result, dict(text='1'))
def test_timeout(self):
python_executor = PythonInterpreter(timeout=2)
tool_return = python_executor(
'```python\ndef solution():\n while True:\n pass\n```')
self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR)
self.assertIn('FunctionTimedOut', tool_return.errmsg)

View File

@@ -0,0 +1,35 @@
import json
from unittest import TestCase, mock
from lagent.actions import SerperSearch
from lagent.schema import ActionStatusCode
class TestSerperSearch(TestCase):
@mock.patch.object(SerperSearch, '_search')
def test_search_tool(self, mock_search_func):
mock_response = (200, json.load('tests/data/search.json'))
mock_search_func.return_value = mock_response
search_tool = SerperSearch(api_key='abc')
tool_return = search_tool.run("What's the capital of China?")
self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS)
self.assertDictEqual(tool_return.result, dict(text="['Beijing']"))
@mock.patch.object(SerperSearch, '_search')
def test_api_error(self, mock_search_func):
mock_response = (403, {'message': 'bad requests'})
mock_search_func.return_value = mock_response
search_tool = SerperSearch(api_key='abc')
tool_return = search_tool.run("What's the capital of China?")
self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR)
self.assertEqual(tool_return.errmsg, str(403))
@mock.patch.object(SerperSearch, '_search')
def test_http_error(self, mock_search_func):
mock_response = (-1, 'HTTPSConnectionPool')
mock_search_func.return_value = mock_response
search_tool = SerperSearch(api_key='abc')
tool_return = search_tool.run("What's the capital of China?")
self.assertEqual(tool_return.state, ActionStatusCode.HTTP_ERROR)
self.assertEqual(tool_return.errmsg, 'HTTPSConnectionPool')

View File

@@ -0,0 +1,87 @@
from unittest import TestCase, mock
from lagent.actions import ActionExecutor
from lagent.actions.llm_qa import LLMQA
from lagent.actions.serper_search import SerperSearch
from lagent.agents.rewoo import ReWOO, ReWOOProtocol
from lagent.schema import ActionReturn, ActionStatusCode
class TestReWOO(TestCase):
@mock.patch.object(SerperSearch, 'run')
@mock.patch.object(LLMQA, 'run')
@mock.patch.object(ReWOOProtocol, 'parse_worker')
def test_normal_chat(self, mock_parse_worker_func, mock_qa_func,
mock_search_func):
mock_model = mock.Mock()
mock_model.generate_from_template.return_value = 'LLM response'
mock_parse_worker_func.return_value = (['Thought1', 'Thought2'
], ['LLMQA', 'SerperSearch'],
['abc', 'abc'])
search_return = ActionReturn(args=None)
search_return.state = ActionStatusCode.SUCCESS
search_return.result = dict(text='search_return')
mock_search_func.return_value = search_return
qa_return = ActionReturn(args=None)
qa_return.state = ActionStatusCode.SUCCESS
qa_return.result = dict(text='qa_return')
mock_qa_func.return_value = qa_return
chatbot = ReWOO(
llm=mock_model,
action_executor=ActionExecutor(actions=[
LLMQA(mock_model),
SerperSearch(api_key=''),
]))
agent_return = chatbot.chat('abc')
self.assertEqual(agent_return.response, 'LLM response')
def test_parse_worker(self):
prompt = ReWOOProtocol()
message = """
Plan: a.
#E1 = tool1["a"]
#E2 = tool2["b"]
"""
try:
thoughts, actions, actions_input = prompt.parse_worker(message)
except Exception as e:
self.assertEqual(
'Each Plan should only correspond to only ONE action', str(e))
else:
self.assertFalse(
True, 'it should raise exception when the format is incorrect')
message = """
Plan: a.
#E1 = tool1("a")
Plan: b.
#E2 = tool2["b"]
"""
try:
thoughts, actions, actions_input = prompt.parse_worker(message)
except Exception as e:
self.assertIsInstance(e, BaseException)
else:
self.assertFalse(
True, 'it should raise exception when the format is incorrect')
message = """
Plan: a.
#E1 = tool1["a"]
Plan: b.
#E2 = tool2["b"]
"""
try:
thoughts, actions, actions_input = prompt.parse_worker(message)
except Exception:
self.assertFalse(
True,
'it should not raise exception when the format is correct')
self.assertEqual(thoughts, ['a.', 'b.'])
self.assertEqual(actions, ['tool1', 'tool2'])
self.assertEqual(actions_input, ['"a"', '"b"'])