[Init] Initial commit
* Initial commit
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -158,3 +158,4 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.vscode/
|
||||
|
||||
50
.pre-commit-config.yaml
Normal file
50
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 5.0.4
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.32.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
- id: double-quote-string-fixer
|
||||
- id: check-merge-conflict
|
||||
- id: fix-encoding-pragma
|
||||
args: ["--remove"]
|
||||
- id: mixed-line-ending
|
||||
args: ["--fix=lf"]
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.9
|
||||
hooks:
|
||||
- id: mdformat
|
||||
args: ["--number"]
|
||||
additional_dependencies:
|
||||
- mdformat-openmmlab
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.1
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.0.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: ["--py36-plus"]
|
||||
428
.pylintrc
Normal file
428
.pylintrc
Normal file
@@ -0,0 +1,428 @@
|
||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=third_party,storage
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
duplicate-code,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
inconsistent-return-statements,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-function-docstring,
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
too-few-public-methods,
|
||||
too-many-ancestors,
|
||||
too-many-arguments,
|
||||
too-many-boolean-expressions,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
useless-else-on-loop,
|
||||
useless-object-inheritance,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=colorized
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=120
|
||||
|
||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=builtins.BaseException,
|
||||
builtins.Exception
|
||||
81
README.md
Normal file
81
README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# LAgent: Large Language Model as Agent
|
||||
|
||||
English | [简体中文](README_zh-CN.md)
|
||||
|
||||
## Introduction
|
||||
|
||||
LAgent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM. The overview of our framework is shown below:
|
||||
|
||||

|
||||
|
||||
### Major Features
|
||||
|
||||
- **Support multiple agent frameworks out of box.** We implement ReAct, AutoGPT and ReWOO, which enables the agents to drive LLMs for multiple trails of reasoning and tool utilization.
|
||||
|
||||
- **Extremely simple and easy to extend.** The framework is quite simple with clear project structure. With only 20 lines of code, you are able to construct your own agent. It also supports three typical tools: Python interpreter, API call, and google search.
|
||||
|
||||
- **Support various large language models.** We support different LLMs, including API-based (GPT3.5/4) and HuggingFace-based (LLaMa2, InternLM) models.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Please see [Overview](docs/overview.md) for the general introduction of LAgent. Meanwhile, we provide extremely simple code for quick start. You may refer to [examples](examples/) for more details.
|
||||
|
||||
### Installation
|
||||
|
||||
```
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Run a ReAct agent with GPT3.5 backend
|
||||
|
||||
```python
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.llms import GPTAPI
|
||||
from lagent.tools import SerperSearch, PythonInterpreter
|
||||
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
search_tool = SerperSearch()
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
print(response['response'])
|
||||
>>> They are both film directors.
|
||||
```
|
||||
|
||||
### Run a ReAct model with HuggingFace backend
|
||||
|
||||
NOTE: If you want to run a HuggingFace model, please run `pip install -e . [all]` first.
|
||||
|
||||
```python
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.llms import HFTransformer
|
||||
from lagent.tools import SerperSearch, PythonInterpreter
|
||||
|
||||
llm = HFTransformer('internlm/internlm-7b-chat')
|
||||
search_tool = SerperSearch()
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$')
|
||||
print(response['response'])
|
||||
>>> 根据已有的信息,可以求得$z=-1+\\sqrt{3}i$,然后代入计算,得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此,答案是(C)。
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is released under the [Apache 2.0 license](LICENSE).
|
||||
80
README_zh-CN.md
Normal file
80
README_zh-CN.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# LAgent: Large Language Model as Agent
|
||||
|
||||
[English](README.md) | 简体中文
|
||||
|
||||
## 简介
|
||||
|
||||
LAgent是一个开源的LLM代理框架,支持用户快速地将一个大语言模型转变为多种类型的智能体,并提供了一些典型工具为大语言模型赋能。它的整个框架图如下:
|
||||
|
||||

|
||||
|
||||
### 主要特点
|
||||
|
||||
- **实现了多种类型的智能体,** 我们支持了经典的 ReAct,AutoGPT 和 ReWoo 等智能体,这些智能体能够调用大语言模型进行多轮的推理和工具调用。
|
||||
|
||||
- **框架简单易拓展.** 框架的代码结构清晰且简单,只需要不到20行代码你就能够创造出一个你自己的agent。同时我们支持了Python解释器、API 调用和搜索三类常用典型工具。
|
||||
|
||||
- **灵活支持多个大语言模型.** 我们提供了多种大语言模型支持,包括 InternLM、Llama-2 等开源模型和 GPT-4/3.5 等基于 API 的闭源模型。
|
||||
|
||||
## 教程
|
||||
|
||||
请阅读[概述](docs/overview.md)对LAgent进行初步的了解。同时, 我们提供了两个非常简单的code帮助你快速入门。 你也可以阅读[examples](examples/)获得更多的例子参考。
|
||||
|
||||
### 安装
|
||||
|
||||
```
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 用GPT3.5构建一个ReAct代理
|
||||
|
||||
```python
|
||||
from lagent.agents import ReAct
|
||||
from lagent.llms import GPTAPI
|
||||
from lagent.tools import SerperSearch, PythonInterpreter
|
||||
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
search_tool = SerperSearch()
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
print(response['response'])
|
||||
>>> They are both film directors.
|
||||
```
|
||||
|
||||
### 用HuggingFace构建一个ReAct代理
|
||||
|
||||
注意:如果你想要启动一个HuggingFace的模型,请先运行`pip install -e . [all]`。
|
||||
|
||||
```python
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.llms import HFTransformer
|
||||
from lagent.tools import SerperSearch, PythonInterpreter
|
||||
|
||||
llm = HFTransformer('internlm/internlm-7b-chat')
|
||||
search_tool = SerperSearch()
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$')
|
||||
print(response['response'])
|
||||
>>> 根据已有的信息,可以求得$z=-1+\\sqrt{3}i$,然后代入计算,得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此,答案是(C)。
|
||||
```
|
||||
|
||||
## 开源许可证
|
||||
|
||||
该项目采用[Apache 2.0 开源许可证](LICENSE)。
|
||||
0
docs/.gitkeep
Normal file
0
docs/.gitkeep
Normal file
23
docs/overview.md
Normal file
23
docs/overview.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# OVERVIEW
|
||||
|
||||
This chapter introduces you to the framework of LAgent, and provides links to detailed tutorials about LAgent.
|
||||
|
||||
## What is LAgent
|
||||
|
||||
LAgent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below:
|
||||
|
||||

|
||||
|
||||
LAgent consists of 3 main parts, agents, llms, and actions.
|
||||
|
||||
- **agents** provides agent implementation, such as ReAct, AutoGPT.
|
||||
- **llms** supports various large language models, including open-sourced models (Llama-2, InterLM) through HuggingFace models or closed-source models like GPT3.5/4.
|
||||
- **actions** contains a series of actions, as well as an action executor to manage all actions.
|
||||
|
||||
## How to Use
|
||||
|
||||
Here is a detailed step-by-step guide to learn more about LAgent:
|
||||
|
||||
1. For installation instructions, please see [get_started](get_started.md).
|
||||
|
||||
2. We provide several examples to build agents with LAgent in [examples](examples/) by simply run `python examples/react_example.py`.
|
||||
46
examples/autogpt_example.py
Normal file
46
examples/autogpt_example.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.actions.builtin_actions import FinishAction
|
||||
from lagent.actions.python_interpreter import PythonInterpreter
|
||||
from lagent.agents.autogpt import AutoGPT
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='')
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
|
||||
def main():
|
||||
# set OPEN_API_KEY in your environment
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
|
||||
chatbot = AutoGPT(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[
|
||||
PythonInterpreter(),
|
||||
],
|
||||
finish_action=FinishAction(
|
||||
description=(
|
||||
'Goals are accomplished and there is nothing left '
|
||||
'to do. Parameter: {"response: "final response '
|
||||
'for the goal"}')),
|
||||
finish_in_action=True),
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
35
examples/chat.py
Normal file
35
examples/chat.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='chatbot')
|
||||
parser.add_argument('--mode', default='chat')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo', )
|
||||
history = []
|
||||
while True:
|
||||
try:
|
||||
prompt = input('>>> ')
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
if args.mode == 'chat':
|
||||
history.append(dict(role='user', content=prompt))
|
||||
response = model.generate_from_template(history, max_out_len=512)
|
||||
history.append(dict(role='assistant', content=response))
|
||||
elif args.mode == 'generate':
|
||||
response = model.generate(prompt, max_out_len=512)
|
||||
print('Assistant:', response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
37
examples/hf_react_example.py
Normal file
37
examples/hf_react_example.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.actions.python_interpreter import PythonInterpreter
|
||||
from lagent.agents.react import ReACT
|
||||
from lagent.llms.huggingface import HFTransformer
|
||||
|
||||
model = HFTransformer(
|
||||
path='internlm/internlm-chat-7b',
|
||||
meta_template=[
|
||||
dict(role='system', begin='<|System|>:', end='<TOKENS_UNUSED_2>\n'),
|
||||
dict(role='user', begin='<|User|>:', end='<eoh>\n'),
|
||||
dict(role='assistant', begin='<|Bot|>:', end='<eoa>\n', generate=True)
|
||||
],
|
||||
)
|
||||
|
||||
chatbot = ReACT(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(actions=[PythonInterpreter()]),
|
||||
)
|
||||
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='')
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
36
examples/react_example.py
Normal file
36
examples/react_example.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.actions.python_interpreter import PythonInterpreter
|
||||
from lagent.agents.react import ReACT
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='')
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
|
||||
def main():
|
||||
# set OPEN_API_KEY in your environment
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo', )
|
||||
|
||||
chatbot = ReACT(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(actions=[PythonInterpreter()]),
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
20
examples/rewoo_example.py
Normal file
20
examples/rewoo_example.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.actions.llm_qa import LLMQA
|
||||
from lagent.actions.serper_search import SerperSearch
|
||||
from lagent.agents.rewoo import ReWOO
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
# please set the serper search API key
|
||||
search_tool = SerperSearch(api_key=None)
|
||||
llmqa_tool = LLMQA(model)
|
||||
|
||||
chatbot = ReWOO(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(actions=[llmqa_tool, search_tool]),
|
||||
)
|
||||
|
||||
prompt = 'What profession does Nicholas Ray and Elia Kazan have in common'
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
10
lagent/actions/__init__.py
Normal file
10
lagent/actions/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .action_executor import ActionExecutor
|
||||
from .base_action import BaseAction
|
||||
from .builtin_actions import FinishAction, InvalidAction, NoAction
|
||||
from .python_interpreter import PythonInterpreter
|
||||
from .serper_search import SerperSearch
|
||||
|
||||
__all__ = [
|
||||
'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction',
|
||||
'FinishAction', 'SerperSearch', 'PythonInterpreter'
|
||||
]
|
||||
84
lagent/actions/action_executor.py
Normal file
84
lagent/actions/action_executor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from lagent.schema import ActionReturn, ActionValidCode
|
||||
from .base_action import BaseAction
|
||||
from .builtin_actions import FinishAction, InvalidAction, NoAction
|
||||
|
||||
|
||||
class ActionExecutor:
|
||||
"""The action executor class.
|
||||
|
||||
Args:
|
||||
actions (Union[BaseAction, List[BaseAction]]): The action or actions.
|
||||
invalid_action (BaseAction, optional): The invalid action. Defaults to
|
||||
InvalidAction().
|
||||
no_action (BaseAction, optional): The no action.
|
||||
Defaults to NoAction().
|
||||
finish_action (BaseAction, optional): The finish action. Defaults to
|
||||
FinishAction().
|
||||
finish_in_action (bool, optional): Whether the finish action is in the
|
||||
action list. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actions: Union[BaseAction, List[BaseAction]],
|
||||
invalid_action: BaseAction = InvalidAction(),
|
||||
no_action: BaseAction = NoAction(),
|
||||
finish_action: BaseAction = FinishAction(),
|
||||
finish_in_action: bool = False):
|
||||
if isinstance(actions, BaseAction):
|
||||
actions = [actions]
|
||||
|
||||
for action in actions:
|
||||
assert isinstance(action, BaseAction), \
|
||||
f'action must be BaseAction, but got {type(action)}'
|
||||
if finish_in_action:
|
||||
actions.append(finish_action)
|
||||
self.actions = {action.name: action for action in actions}
|
||||
self.invalid_action = invalid_action
|
||||
self.no_action = no_action
|
||||
self.finish_action = finish_action
|
||||
|
||||
def get_actions_info(self, only_enable: bool = True) -> Dict:
|
||||
if only_enable:
|
||||
return {
|
||||
k: v.description
|
||||
for k, v in self.actions.items() if v.enable
|
||||
}
|
||||
else:
|
||||
return {k: v.description for k, v in self.actions.items()}
|
||||
|
||||
def is_valid(self, name: str):
|
||||
return name in self.actions and self.actions[name].enable
|
||||
|
||||
def action_names(self, only_enable: bool = True):
|
||||
if only_enable:
|
||||
return [k for k, v in self.actions.items() if v.enable]
|
||||
else:
|
||||
return list(self.actions.keys())
|
||||
|
||||
def add_action(self, action: BaseAction):
|
||||
assert isinstance(action, BaseAction), \
|
||||
f'action must be BaseAction, but got {type(action)}'
|
||||
self.actions[action.name] = action
|
||||
|
||||
def del_action(self, name: str):
|
||||
if name in self.actions:
|
||||
del self.actions[name]
|
||||
|
||||
def __call__(self, name: str, command: Any) -> ActionReturn:
|
||||
if isinstance(command, str):
|
||||
args, kwargs = (command, ), {}
|
||||
else:
|
||||
args, kwargs = (), command
|
||||
if not self.is_valid(name):
|
||||
if name == self.no_action.name:
|
||||
action_return = self.no_action.run(*args, **kwargs)
|
||||
elif name == self.finish_action.name:
|
||||
action_return = self.finish_action.run(*args, **kwargs)
|
||||
else:
|
||||
action_return = self.invalid_action(*args, **kwargs)
|
||||
else:
|
||||
action_return = self.actions[name].run(*args, **kwargs)
|
||||
action_return.valid = ActionValidCode.OPEN
|
||||
return action_return
|
||||
57
lagent/actions/base_action.py
Normal file
57
lagent/actions/base_action.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Optional
|
||||
|
||||
from lagent.schema import ActionReturn
|
||||
|
||||
|
||||
class BaseAction:
|
||||
"""Base class for all actions.
|
||||
|
||||
Args:
|
||||
description (str, optional): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
description: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
if name is None:
|
||||
name = self.__class__.__name__
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._disable_description = disable_description
|
||||
self._enable = enable
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ActionReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}:{self.description}'
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def run(self, *args, **kwargs) -> ActionReturn:
|
||||
return self.__call__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if self.enable:
|
||||
return self._description
|
||||
else:
|
||||
return self._disable_description
|
||||
100
lagent/actions/builtin_actions.py
Normal file
100
lagent/actions/builtin_actions.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import Optional
|
||||
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
|
||||
|
||||
|
||||
class InvalidAction(BaseAction):
|
||||
"""This is a invalid action class, which is used to return error message
|
||||
when the action is invalid.
|
||||
|
||||
Args:
|
||||
err_msg (str): The error message. Defaults to 'The action is invalid,
|
||||
please check the action name'.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
err_msg:
|
||||
str = 'The action is invalid, please check the action name.',
|
||||
**kwargs) -> None:
|
||||
|
||||
super().__init__(enable=False, **kwargs)
|
||||
self._err_msg = err_msg
|
||||
|
||||
def __call__(self, err_msg: Optional[str] = None):
|
||||
"""Return the error message.
|
||||
|
||||
Args:
|
||||
err_msg (str, optional): The error message. If err_msg is not None,
|
||||
it will be returned, otherwise the default error message will
|
||||
be returned. Defaults to None.
|
||||
"""
|
||||
action_return = ActionReturn(
|
||||
url=None,
|
||||
args=dict(text=err_msg),
|
||||
errmsg=err_msg if err_msg else self._err_msg,
|
||||
type=self.name,
|
||||
valid=ActionValidCode.INVALID,
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
return action_return
|
||||
|
||||
|
||||
class NoAction(BaseAction):
|
||||
"""This is a no action class, which is used to return error message when
|
||||
the response does not follow the format.
|
||||
|
||||
Args:
|
||||
err_msg (str): The error message. Defaults to
|
||||
'Please follow the format'.
|
||||
"""
|
||||
|
||||
def __init__(self, err_msg: str = 'Please follow the format', **kwargs):
|
||||
|
||||
super().__init__(enable=False, **kwargs)
|
||||
self._err_msg = err_msg
|
||||
|
||||
def __call__(self, err_msg: Optional[str] = None):
|
||||
"""Return the error message.
|
||||
|
||||
Args:
|
||||
err_msg (str, optional): The error message. If err_msg is not None,
|
||||
it will be returned, otherwise the default error message will
|
||||
be returned. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
"""
|
||||
action_return = ActionReturn(
|
||||
url=None,
|
||||
args=dict(text=err_msg),
|
||||
type=self.name,
|
||||
errmsg=err_msg if err_msg else self._err_msg,
|
||||
valid=ActionValidCode.INVALID,
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
return action_return
|
||||
|
||||
|
||||
class FinishAction(BaseAction):
|
||||
"""This is a finish action class, which is used to return the final
|
||||
result."""
|
||||
|
||||
def __call__(self, response: str) -> ActionReturn:
|
||||
"""Return the final result.
|
||||
|
||||
Args:
|
||||
response (str): The final result.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
"""
|
||||
action_return = ActionReturn(
|
||||
url=None,
|
||||
args=dict(text=response),
|
||||
result=dict(text=response),
|
||||
type=self.name,
|
||||
valid=ActionValidCode.FINISH,
|
||||
state=ActionStatusCode.SUCCESS)
|
||||
return action_return
|
||||
56
lagent/actions/llm_qa.py
Normal file
56
lagent/actions/llm_qa.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from lagent.llms.base_api import BaseAPIModel
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction
|
||||
|
||||
DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。
|
||||
当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。
|
||||
"""
|
||||
|
||||
|
||||
class LLMQA(BaseAction):
|
||||
"""An LLM Wrapper as BaseAction type.
|
||||
|
||||
Args:
|
||||
llm (BaseModel or BaseAPIModel): a LLM service which can chat.
|
||||
description (str): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
self._llm = llm
|
||||
|
||||
def __call__(self, query: str) -> ActionReturn:
|
||||
"""Return the QA response.
|
||||
|
||||
Args:
|
||||
query (str): The query content.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
"""
|
||||
|
||||
tool_return = ActionReturn(url=None, args=None)
|
||||
try:
|
||||
response = self._llm.generate_from_template(query, 512)
|
||||
tool_return.result = dict(text=str(response))
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
except Exception as e:
|
||||
tool_return.result = dict(text=str(e))
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
133
lagent/actions/python_interpreter.py
Normal file
133
lagent/actions/python_interpreter.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import copy
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Optional
|
||||
|
||||
from func_timeout import FunctionTimedOut, func_set_timeout
|
||||
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
|
||||
class GenericRuntime:
|
||||
GLOBAL_DICT = {}
|
||||
LOCAL_DICT = None
|
||||
HEADERS = []
|
||||
|
||||
def __init__(self):
|
||||
self._global_vars = copy.copy(self.GLOBAL_DICT)
|
||||
self._local_vars = copy.copy(
|
||||
self.LOCAL_DICT) if self.LOCAL_DICT else None
|
||||
|
||||
for c in self.HEADERS:
|
||||
self.exec_code(c)
|
||||
|
||||
def exec_code(self, code_piece: str) -> None:
|
||||
exec(code_piece, self._global_vars)
|
||||
|
||||
def eval_code(self, expr: str) -> Any:
|
||||
return eval(expr, self._global_vars)
|
||||
|
||||
|
||||
DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数,
|
||||
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
||||
```python
|
||||
# import 依赖包
|
||||
import xxx
|
||||
def solution():
|
||||
# 初始化一些变量
|
||||
variable_names_with_real_meaning = xxx
|
||||
# 步骤一
|
||||
mid_variable = func(variable_names_with_real_meaning)
|
||||
# 步骤 x
|
||||
mid_variable = func(mid_variable)
|
||||
# 最后结果
|
||||
final_answer = func(mid_variable)
|
||||
return final_answer
|
||||
```"""
|
||||
|
||||
|
||||
class PythonInterpreter(BaseAction):
|
||||
"""A Python executor that can execute Python scripts.
|
||||
|
||||
Args:
|
||||
description (str): The description of the action. Defaults to
|
||||
DEFAULT_DESCRIPTION.
|
||||
answer_symbol (str, Optional): the answer symbol from LLM
|
||||
answer_expr (str, Optional): the answer function name of the Python
|
||||
script. Default to 'solution()'.
|
||||
answer_from_stdout (boolean): whether the execution results is from
|
||||
stdout.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
timeout (int): Upper bound of waiting time for Python script execution.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
answer_symbol: Optional[str] = None,
|
||||
answer_expr: Optional[str] = 'solution()',
|
||||
answer_from_stdout: bool = False,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None,
|
||||
timeout: int = 20) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
self.answer_symbol = answer_symbol
|
||||
self.answer_expr = answer_expr
|
||||
self.answer_from_stdout = answer_from_stdout
|
||||
self.timeout = timeout
|
||||
|
||||
def __call__(self, command: str) -> ActionReturn:
|
||||
self.runtime = GenericRuntime()
|
||||
try:
|
||||
tool_return = func_set_timeout(self.timeout)(self._call)(command)
|
||||
except FunctionTimedOut as e:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
|
||||
def _call(self, command: str) -> ActionReturn:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
try:
|
||||
if '```python' in command:
|
||||
command = command.split('```python')[1].split('```')[0]
|
||||
elif '```' in command:
|
||||
command = command.split('```')[1].split('```')[0]
|
||||
tool_return.args = dict(text='```python\n' + command + '\n```')
|
||||
command = command.split('\n')
|
||||
|
||||
if self.answer_from_stdout:
|
||||
program_io = io.StringIO()
|
||||
with redirect_stdout(program_io):
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
program_io.seek(0)
|
||||
res = program_io.readlines()[-1]
|
||||
elif self.answer_symbol:
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
res = self.runtime._global_vars[self.answer_symbol]
|
||||
elif self.answer_expr:
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
res = self.runtime.eval_code(self.answer_expr)
|
||||
else:
|
||||
self.runtime.exec_code('\n'.join(command[:-1]))
|
||||
res = self.runtime.eval_code(command[-1])
|
||||
except Exception as e:
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.type = self.name
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
try:
|
||||
tool_return.result = dict(text=str(res))
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
except Exception as e:
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.type = self.name
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
175
lagent/actions/serper_search.py
Normal file
175
lagent/actions/serper_search.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction
|
||||
|
||||
DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。
|
||||
当你需要对于一个特定问题找到简短明了的回答时,可以使用它。
|
||||
输入应该是一个搜索查询。
|
||||
"""
|
||||
|
||||
|
||||
class SerperSearch(BaseAction):
|
||||
"""Wrapper around the Serper.dev Google Search API.
|
||||
|
||||
To use, you should pass your serper API key to the constructor.
|
||||
|
||||
Code is modified from lang-chain GoogleSerperAPIWrapper
|
||||
(https://github.com/langchain-ai/langchain/blob/ba5f
|
||||
baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
|
||||
langchain/utilities/google_serper.py)
|
||||
|
||||
Args:
|
||||
api_key (str): API KEY to use serper google search API,
|
||||
You can create a free API key at https://serper.dev.
|
||||
timeout (int): Upper bound of waiting time for an serper request.
|
||||
search_type (str): Serper API support ['search', 'images', 'news',
|
||||
'places'] types of search, currently we only support 'search'.
|
||||
k (int): select first k results in the search results as response.
|
||||
description (str): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
"""
|
||||
result_key_for_type = {
|
||||
'news': 'news',
|
||||
'places': 'places',
|
||||
'images': 'images',
|
||||
'search': 'organic',
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 5,
|
||||
search_type: str = 'search',
|
||||
k: int = 10,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
api_key = os.environ.get('SERPER_API_KEY', api_key)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
'Please Set Serper API key either in the environment '
|
||||
' as SERPER_API_KEY or pass it as `api_key` parameter.')
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.search_type = search_type
|
||||
self.k = k
|
||||
|
||||
def __call__(self, query: str) -> ActionReturn:
|
||||
"""Return the search response.
|
||||
|
||||
Args:
|
||||
query (str): The search content.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
"""
|
||||
|
||||
tool_return = ActionReturn(url=None, args=None)
|
||||
status_code, response = self._search(
|
||||
query, search_type=self.search_type, k=self.k)
|
||||
# convert search results to ToolReturn format
|
||||
if status_code == -1:
|
||||
tool_return.errmsg = response
|
||||
tool_return.state = ActionStatusCode.HTTP_ERROR
|
||||
elif status_code == 200:
|
||||
parsed_res = self._parse_results(response)
|
||||
tool_return.result = dict(text=str(parsed_res))
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
else:
|
||||
tool_return.errmsg = str(status_code)
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
|
||||
def _parse_results(self, results: dict) -> Union[str, List[str]]:
|
||||
"""Parse the search results from Serper API.
|
||||
|
||||
Args:
|
||||
results (dict): The search content from Serper API
|
||||
in json format.
|
||||
|
||||
Returns:
|
||||
List[str]: The parsed search results.
|
||||
"""
|
||||
|
||||
snippets = []
|
||||
|
||||
if results.get('answerBox'):
|
||||
answer_box = results.get('answerBox', {})
|
||||
if answer_box.get('answer'):
|
||||
return [answer_box.get('answer')]
|
||||
elif answer_box.get('snippet'):
|
||||
return [answer_box.get('snippet').replace('\n', ' ')]
|
||||
elif answer_box.get('snippetHighlighted'):
|
||||
return answer_box.get('snippetHighlighted')
|
||||
|
||||
if results.get('knowledgeGraph'):
|
||||
kg = results.get('knowledgeGraph', {})
|
||||
title = kg.get('title')
|
||||
entity_type = kg.get('type')
|
||||
if entity_type:
|
||||
snippets.append(f'{title}: {entity_type}.')
|
||||
description = kg.get('description')
|
||||
if description:
|
||||
snippets.append(description)
|
||||
for attribute, value in kg.get('attributes', {}).items():
|
||||
snippets.append(f'{title} {attribute}: {value}.')
|
||||
|
||||
for result in results[self.result_key_for_type[
|
||||
self.search_type]][:self.k]:
|
||||
if 'snippet' in result:
|
||||
snippets.append(result['snippet'])
|
||||
for attribute, value in result.get('attributes', {}).items():
|
||||
snippets.append(f'{attribute}: {value}.')
|
||||
|
||||
if len(snippets) == 0:
|
||||
return ['No good Google Search Result was found']
|
||||
return snippets
|
||||
|
||||
def _search(self,
|
||||
search_term: str,
|
||||
search_type: str = 'search',
|
||||
**kwargs) -> Tuple[int, Union[dict, str]]:
|
||||
"""HTTP requests to Serper API.
|
||||
|
||||
Args:
|
||||
search_term (str): The search query.
|
||||
search_type (str): search type supported by Serper API,
|
||||
default to 'search'.
|
||||
|
||||
Returns:
|
||||
tuple: the return value is a tuple contains:
|
||||
- status_code (int): HTTP status code from Serper API.
|
||||
- response (dict): response context with json format.
|
||||
"""
|
||||
headers = {
|
||||
'X-API-KEY': self.api_key or '',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
params = {
|
||||
'q': search_term,
|
||||
**{
|
||||
key: value
|
||||
for key, value in kwargs.items() if value is not None
|
||||
},
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
f'https://google.serper.dev/{search_type}',
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=self.timeout)
|
||||
except Exception as e:
|
||||
return -1, str(e)
|
||||
return response.status_code, response.json()
|
||||
6
lagent/agents/__init__.py
Normal file
6
lagent/agents/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .autogpt import AutoGPT
|
||||
from .base_agent import BaseAgent
|
||||
from .react import ReACT
|
||||
from .rewoo import ReWOO
|
||||
|
||||
__all__ = ['BaseAgent', 'ReACT', 'AutoGPT', 'ReWOO']
|
||||
288
lagent/agents/autogpt.py
Normal file
288
lagent/agents/autogpt.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# flake8: noqa
|
||||
import ast
|
||||
import platform
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from jsonschema import Draft7Validator
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.llms.base_api import BaseAPIModel
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
DEFAULT_TRIGGERING_PROMPT = ('Determine exactly one command to use based on '
|
||||
'the given goals and the progress you have made '
|
||||
'so far, and respond using the JSON schema '
|
||||
'specified previously:')
|
||||
|
||||
DEFAULT_PREFIX = """You are {ai_name}, {role_description}. Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.
|
||||
The OS you are running on is: {os_info}
|
||||
## Constraints
|
||||
You operate within the following constraints:
|
||||
1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files.
|
||||
2. 'If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
|
||||
3. No user assistance
|
||||
4. Exclusively use the commands listed below e.g. command_name
|
||||
## Commands
|
||||
You have access to the following commands:
|
||||
{tool_description}
|
||||
## Resources
|
||||
You can leverage access to the following resources:
|
||||
1. Internet access for searches and information gathering.
|
||||
2. Long Term memory management.', 'File output.', 'Command execution
|
||||
## Best practices
|
||||
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
|
||||
2. Constructively self-criticize your big-picture behavior constantly.
|
||||
3. Reflect on past decisions and strategies to refine your approach.
|
||||
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
|
||||
## Goals
|
||||
For your task, you must fulfill the following goals:
|
||||
{ai_goals}
|
||||
"""
|
||||
|
||||
DEFAULT_CALL_PROTOCOL = """Respond strictly with JSON. The JSON should be compatible with the TypeScript type `Response` from the following:
|
||||
```ts
|
||||
interface Response {
|
||||
thoughts: {
|
||||
// Thoughts
|
||||
text: string;
|
||||
reasoning: string;
|
||||
// Short markdown-style bullet list that conveys the long-term plan
|
||||
plan: string;
|
||||
// Constructive self-criticism
|
||||
criticism: string;
|
||||
// Summary of thoughts to say to the user
|
||||
speak: string;
|
||||
};
|
||||
command: {
|
||||
name: string;
|
||||
args: Record<string, any>;
|
||||
};
|
||||
}
|
||||
```
|
||||
"""
|
||||
DEFAULT_SCHEMA = {
|
||||
'$schema': 'http://json-schema.org/draft-07/schema#',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'thoughts': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'text': {
|
||||
'type': 'string',
|
||||
'description': 'thoughts'
|
||||
},
|
||||
'reasoning': {
|
||||
'type': 'string'
|
||||
},
|
||||
'plan': {
|
||||
'type':
|
||||
'string',
|
||||
'description':
|
||||
'- short bulleted\n- list that conveys\n- long-term plan'
|
||||
},
|
||||
'criticism': {
|
||||
'type': 'string',
|
||||
'description': 'constructive self-criticism'
|
||||
},
|
||||
'speak': {
|
||||
'type': 'string',
|
||||
'description': 'thoughts summary to say to user'
|
||||
}
|
||||
},
|
||||
'required': ['text', 'reasoning', 'plan', 'criticism', 'speak'],
|
||||
'additionalProperties': False
|
||||
},
|
||||
'command': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'name': {
|
||||
'type': 'string'
|
||||
},
|
||||
'args': {
|
||||
'type': 'object'
|
||||
}
|
||||
},
|
||||
'required': ['name', 'args'],
|
||||
'additionalProperties': False
|
||||
}
|
||||
},
|
||||
'required': ['thoughts', 'command'],
|
||||
'additionalProperties': False
|
||||
}
|
||||
|
||||
|
||||
class AutoGPTProtocol:
|
||||
"""A wrapper of AutoGPT prompt which manages the response from LLM and
|
||||
generate desired prompts in a AutoGPT format.
|
||||
|
||||
Args:
|
||||
ai_name (str): the name of the agent, default to 'AutoGPT'
|
||||
role_description (str): description of the role, e.g., System, User
|
||||
prefix (str): the prefix prompt for AutoGPT
|
||||
call_protocol (str): the request prompt which defines the protocol
|
||||
of return format from LLM.
|
||||
valid_schema (dict): defines the schema of the return format.
|
||||
triggering_prompt (str): the predefined trigger prompt.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ai_name: Optional[str] = 'AutoGPT',
|
||||
role_description: Optional[str] = '',
|
||||
prefix: str = DEFAULT_PREFIX,
|
||||
call_protocol: str = DEFAULT_CALL_PROTOCOL,
|
||||
valid_schema: str = DEFAULT_SCHEMA,
|
||||
triggering_prompt: str = DEFAULT_TRIGGERING_PROMPT) -> None:
|
||||
self.ai_name = ai_name
|
||||
self.role_description = role_description
|
||||
self.prefix = prefix
|
||||
self.call_protocol = call_protocol
|
||||
self.valid_schema = valid_schema
|
||||
self.triggering_prompt = triggering_prompt
|
||||
|
||||
def parse(self, response: str,
|
||||
action_executor: ActionExecutor) -> Tuple[str, str]:
|
||||
"""Parse the action returns in a AutoGPT format.
|
||||
|
||||
Args:
|
||||
response (str): The response from LLM with AutoGPT format.
|
||||
action_executor (ActionExecutor): Action executor to
|
||||
provide no_action/finish_action name.
|
||||
|
||||
Returns:
|
||||
tuple: the return value is a tuple contains:
|
||||
- action (str): the extracted action name.
|
||||
- action_input (str): the corresponding action input.
|
||||
"""
|
||||
try:
|
||||
if response.startswith('```') and response.endswith('```'):
|
||||
# Discard the first and last ```, then re-join in case the response naturally included ```
|
||||
response = '```'.join(response.split('```')[1:-1])
|
||||
response = ast.literal_eval(response)
|
||||
validator = Draft7Validator(self.valid_schema)
|
||||
valid = True
|
||||
if errors := sorted(
|
||||
validator.iter_errors(response), key=lambda e: e.path):
|
||||
valid = False
|
||||
if not valid:
|
||||
return action_executor.no_action, 'Validation of response failed:\n ' + ';\n '.join(
|
||||
[str(e) for e in errors])
|
||||
try:
|
||||
if 'command' not in response:
|
||||
return action_executor.no_action, "Missing 'command' object in JSON"
|
||||
if not isinstance(response, dict):
|
||||
return action_executor.no_action, f'The previous message sent was not a dictionary {response}'
|
||||
command = response['command']
|
||||
if not isinstance(command, dict):
|
||||
return action_executor.no_action, "'command' object is not a dictionary"
|
||||
if 'name' not in command:
|
||||
return action_executor.no_action, "Missing 'name' field in 'command' object"
|
||||
command_name = command['name']
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get('args', {})
|
||||
return command_name, arguments
|
||||
except Exception as e:
|
||||
return action_executor.no_action, repr(e)
|
||||
except SyntaxError as e:
|
||||
return action_executor.no_action, f'Your response could not be parsed: {repr(e)} \nRemember to only respond using the specified format above!'
|
||||
|
||||
def format(self, goal: str, inner_history: List[Dict],
|
||||
action_executor: ActionExecutor) -> List[Dict]:
|
||||
"""Generate the AutoGPT format prompt.
|
||||
|
||||
Args:
|
||||
goal (str): The user request.
|
||||
inner_history (List[Dict]): The log in the current run.
|
||||
action_executor (ActionExecutor): the action manager to
|
||||
execute actions.
|
||||
Returns:
|
||||
List[Dict]: AutoGPT format prompt.
|
||||
"""
|
||||
import distro
|
||||
formatted_data = []
|
||||
os_name = platform.system()
|
||||
os_info = (
|
||||
platform.platform(terse=True)
|
||||
if os_name != 'Linux' else distro.name(pretty=True))
|
||||
prefix = self.prefix.format(
|
||||
ai_name=self.ai_name,
|
||||
role_description=self.role_description,
|
||||
tool_description=action_executor.get_actions_info(),
|
||||
ai_goals=goal,
|
||||
os_info=os_info,
|
||||
)
|
||||
formatted_data.append(dict(role='system', content=prefix))
|
||||
formatted_data.append(dict(role='system', content=self.call_protocol))
|
||||
formatted_data += inner_history
|
||||
formatted_data.append(
|
||||
dict(role='user', content=self.triggering_prompt))
|
||||
return formatted_data
|
||||
|
||||
def format_response(self, action_return):
|
||||
"""format the final response at current step.
|
||||
|
||||
Args:
|
||||
action_return (ActionReturn): return value of the current action.
|
||||
|
||||
Returns:
|
||||
str: the final response at current step.
|
||||
"""
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.result['text']
|
||||
response = f'Command {action_return.type} returned: {response}'
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
return response
|
||||
|
||||
|
||||
class AutoGPT(BaseAgent):
|
||||
"""An implementation of AutoGPT (https://github.com/Significant-
|
||||
Gravitas/Auto-GPT)
|
||||
|
||||
Args:
|
||||
llm (BaseModel or BaseAPIModel): a LLM service which can chat
|
||||
and act as backend.
|
||||
action_executor (ActionExecutor): an action executor to manage
|
||||
all actions and their response.
|
||||
protocol (ReActProtocol): a wrapper to generate prompt and
|
||||
parse the response from LLM / actions.
|
||||
max_turn (int): the maximum number of trails for LLM to generate
|
||||
plans that can be successfully parsed by ReWOO protocol.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
action_executor: ActionExecutor,
|
||||
protocol: AutoGPTProtocol = AutoGPTProtocol(),
|
||||
max_turn: int = 2):
|
||||
self.max_turn = max_turn
|
||||
super().__init__(
|
||||
llm=llm, action_executor=action_executor, protocol=protocol)
|
||||
|
||||
def chat(self, goal: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
agent_return = AgentReturn()
|
||||
default_response = '对不起,我无法回答你的问题'
|
||||
for _ in range(self.max_turn):
|
||||
prompt = self._protocol.format(
|
||||
goal=goal,
|
||||
inner_history=self._inner_history,
|
||||
action_executor=self._action_executor)
|
||||
response = self._llm.generate_from_template(prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
action, action_input = self._protocol.parse(
|
||||
response, self._action_executor)
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
action, action_input)
|
||||
agent_return.actions.append(action_return)
|
||||
if action_return.type == self._action_executor.finish_action.name:
|
||||
agent_return.response = action_return.result['text']
|
||||
return agent_return
|
||||
self._inner_history.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=self._protocol.format_response(action_return)))
|
||||
agent_return.response = default_response
|
||||
return agent_return
|
||||
49
lagent/agents/base_agent.py
Normal file
49
lagent/agents/base_agent.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import List
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import AgentReturn
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
"""BaseAgent is the base class of all agents.
|
||||
|
||||
Args:
|
||||
llm (BaseModel): the language model.
|
||||
action_executor (ActionExecutor): the action executor.
|
||||
protocol (object): the protocol of the agent, which is used to
|
||||
generate the prompt of the agent and parse the response from
|
||||
the llm.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: BaseModel, action_executor: ActionExecutor,
|
||||
protocol: object) -> None:
|
||||
|
||||
self._session_history = []
|
||||
self._llm = llm
|
||||
self._action_executor = action_executor
|
||||
self._protocol = protocol
|
||||
|
||||
def add_action(self, action: BaseAction) -> None:
|
||||
"""Add an action to the action executor.
|
||||
|
||||
Args:
|
||||
action (BaseAction): the action to be added.
|
||||
"""
|
||||
self._action_executor.add_action(action)
|
||||
|
||||
def del_action(self, name: str) -> None:
|
||||
"""Delete an action from the action executor.
|
||||
|
||||
Args:
|
||||
name (str): the name of the action to be deleted.
|
||||
"""
|
||||
self._action_executor.del_action(name)
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def session_history(self) -> List:
|
||||
return self._session_history
|
||||
217
lagent/agents/react.py
Normal file
217
lagent/agents/react.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.llms.base_api import BaseAPIModel
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
CALL_PROTOCOL = """你是一个可以调用外部工具的助手,可以使用的工具包括:
|
||||
{tool_description}
|
||||
如果使用工具请遵循以下格式回复:
|
||||
```
|
||||
{thought}思考你当前步骤需要解决什么问题,是否需要使用工具
|
||||
{action}工具名称,你的工具必须从 [{action_names}] 选择
|
||||
{action_input}工具输入参数
|
||||
```
|
||||
工具返回按照以下格式回复:
|
||||
```
|
||||
{response}调用工具后的结果
|
||||
```
|
||||
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
|
||||
```
|
||||
{thought}给出最终答案的思考过程
|
||||
{finish}最终答案
|
||||
```
|
||||
开始!"""
|
||||
|
||||
|
||||
class ReACTProtocol:
|
||||
"""A wrapper of ReACT prompt which manages the response from LLM and
|
||||
generate desired prompts in a ReACT format.
|
||||
|
||||
Args:
|
||||
thought (dict): the information of thought pattern
|
||||
action (dict): the information of action pattern
|
||||
action_input (dict): the information of action_input pattern
|
||||
response (dict): the information of response pattern
|
||||
finish (dict): the information of finish pattern
|
||||
call_protocol (str): the format of ReACT
|
||||
force_stop (str): the prompt to force LLM to generate response
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thought: dict = dict(
|
||||
role='THOUGHT',
|
||||
begin='Thought:',
|
||||
end='\n',
|
||||
belong='assistant'),
|
||||
action: dict = dict(role='ACTION', begin='Action:', end='\n'),
|
||||
action_input: dict = dict(
|
||||
role='ARGS', begin='ActionInput:', end='\n'),
|
||||
response: dict = dict(
|
||||
role='RESPONSE', begin='Response:', end='\n'),
|
||||
finish: dict = dict(
|
||||
role='FINISH', begin='FinalAnswer:', end='\n'),
|
||||
call_protocol: str = CALL_PROTOCOL,
|
||||
force_stop: str = '你需要基于历史消息返回一个最终结果') -> None:
|
||||
self.call_protocol = call_protocol
|
||||
self.force_stop = force_stop
|
||||
self.thought = thought
|
||||
self.action = action
|
||||
self.action_input = action_input
|
||||
self.response = response
|
||||
self.finish = finish
|
||||
|
||||
def format(self,
|
||||
chat_history: List[Dict],
|
||||
inner_step: List[Dict],
|
||||
action_executor: ActionExecutor,
|
||||
force_stop: bool = False) -> list:
|
||||
"""Generate the ReACT format prompt.
|
||||
|
||||
Args:
|
||||
chat_history (List[Dict]): The history log in previous runs.
|
||||
inner_step (List[Dict]): The log in the current run.
|
||||
action_executor (ActionExecutor): the action manager to
|
||||
execute actions.
|
||||
force_stop (boolean): whether force the agent to give responses
|
||||
under pre-defined turns.
|
||||
|
||||
Returns:
|
||||
List[Dict]: ReACT format prompt.
|
||||
"""
|
||||
|
||||
call_protocol = self.call_protocol.format(
|
||||
tool_description=action_executor.get_actions_info(),
|
||||
action_names=action_executor.action_names(),
|
||||
thought=self.thought['begin'],
|
||||
action=self.action['begin'],
|
||||
action_input=self.action_input['begin'],
|
||||
response=self.response['begin'],
|
||||
finish=self.finish['begin'],
|
||||
)
|
||||
formatted = []
|
||||
formatted.append(dict(role='system', content=call_protocol))
|
||||
formatted += chat_history
|
||||
formatted += inner_step
|
||||
if force_stop:
|
||||
formatted.append(dict(role='system', content=self.force_stop))
|
||||
return formatted
|
||||
|
||||
def parse(
|
||||
self,
|
||||
message: str,
|
||||
action_executor: ActionExecutor,
|
||||
) -> Tuple[str, str, str]:
|
||||
"""Parse the action returns in a ReACT format.
|
||||
|
||||
Args:
|
||||
message (str): The response from LLM with ReACT format.
|
||||
action_executor (ActionExecutor): Action executor to
|
||||
provide no_action/finish_action name.
|
||||
|
||||
Returns:
|
||||
tuple: the return value is a tuple contains:
|
||||
- thought (str): contain LLM thought of the current step.
|
||||
- action (str): contain action scheduled by LLM.
|
||||
- action_input (str): contain the required action input
|
||||
for current action.
|
||||
"""
|
||||
|
||||
import re
|
||||
thought = message.split(self.action['begin'])[0]
|
||||
thought = thought.split(self.thought['begin'])[-1]
|
||||
thought = thought.split(self.finish['begin'])[0]
|
||||
if self.finish['begin'] in message:
|
||||
final_answer = message.split(self.finish['begin'])[-1]
|
||||
return thought, action_executor.finish_action.name, final_answer
|
||||
|
||||
action_regex = f"{self.action['begin']}(.*?)\n"
|
||||
args_regex = f"{self.action_input['begin']}(.*)"
|
||||
action_match = re.findall(action_regex, message)
|
||||
if not action_match:
|
||||
return thought, action_executor.no_action.name, ''
|
||||
action = action_match[-1]
|
||||
arg_match = re.findall(args_regex, message, re.DOTALL)
|
||||
|
||||
if not arg_match:
|
||||
return thought, action_executor.no_action.name, ''
|
||||
action_input = arg_match[-1]
|
||||
return thought, action.strip(), action_input.strip().strip('"')
|
||||
|
||||
def format_response(self, action_return: ActionReturn) -> str:
|
||||
"""format the final response at current step.
|
||||
|
||||
Args:
|
||||
action_return (ActionReturn): return value of the current action.
|
||||
|
||||
Returns:
|
||||
str: the final response at current step.
|
||||
"""
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.result['text']
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
return self.response['begin'] + response + self.response['end']
|
||||
|
||||
|
||||
class ReACT(BaseAgent):
|
||||
"""An implementation of ReACT (https://arxiv.org/abs/2210.03629)
|
||||
|
||||
Args:
|
||||
llm (BaseModel or BaseAPIModel): a LLM service which can chat
|
||||
and act as backend.
|
||||
action_executor (ActionExecutor): an action executor to manage
|
||||
all actions and their response.
|
||||
protocol (ReActProtocol): a wrapper to generate prompt and
|
||||
parse the response from LLM / actions.
|
||||
max_turn (int): the maximum number of trails for LLM to generate
|
||||
plans that can be successfully parsed by ReWOO protocol.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
action_executor: ActionExecutor,
|
||||
protocol: ReACTProtocol = ReACTProtocol(),
|
||||
max_turn: int = 2) -> None:
|
||||
self.max_turn = max_turn
|
||||
super().__init__(
|
||||
llm=llm, action_executor=action_executor, protocol=protocol)
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
self._inner_history.append(dict(role='user', content=message))
|
||||
agent_return = AgentReturn()
|
||||
force_stop = False
|
||||
default_response = '对不起,我无法回答你的问题'
|
||||
for turn in range(self.max_turn):
|
||||
prompt = self._protocol.format(
|
||||
chat_history=self.session_history,
|
||||
inner_step=self._inner_history,
|
||||
action_executor=self._action_executor,
|
||||
force_stop=force_stop)
|
||||
response = self._llm.generate_from_template(prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
thought, action, action_input = self._protocol.parse(
|
||||
response, self._action_executor)
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
action, action_input)
|
||||
action_return.thought = thought
|
||||
agent_return.actions.append(action_return)
|
||||
if action_return.type == self._action_executor.finish_action.name:
|
||||
agent_return.response = action_return.result['text']
|
||||
return agent_return
|
||||
self._inner_history.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=self._protocol.format_response(action_return)))
|
||||
if turn == self.max_turn - 1:
|
||||
force_stop = True
|
||||
agent_return.response = default_response
|
||||
# only append the user and final response
|
||||
self._session_history.append(dict(role='user', content=message))
|
||||
self._session_history.append(
|
||||
dict(role='assistant', content=agent_return.response))
|
||||
return agent_return
|
||||
232
lagent/agents/rewoo.py
Normal file
232
lagent/agents/rewoo.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.llms.base_api import BaseAPIModel
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
PLANER_PROMPT = """你是一个任务分解器, 你需要将用户的问题拆分成多个简单的子任务。
|
||||
请拆分出多个子任务项,从而能够得到充分的信息以解决问题, 返回格式如下:
|
||||
```
|
||||
Plan: 当前子任务要解决的问题
|
||||
#E[id] = 工具名称[工具参数]
|
||||
Plan: 当前子任务要解决的问题
|
||||
#E[id] = 工具名称[工具参数]
|
||||
```
|
||||
其中
|
||||
1. #E[id] 用于存储Plan id的执行结果, 可被用作占位符。
|
||||
2. 每个 #E[id] 所执行的内容应与当前Plan解决的问题严格对应。
|
||||
3. 工具参数可以是正常输入text, 或是 #E[依赖的索引], 或是两者都可以。
|
||||
4. 工具名称必须从一下工具中选择:
|
||||
{tool_description}
|
||||
注意: 每个Plan后有且仅能跟随一个#E。
|
||||
开始!"""
|
||||
|
||||
WORKER_PROMPT = """
|
||||
Thought: {thought}\nResponse: {action_resp}\n
|
||||
"""
|
||||
|
||||
SOLVER_PROMPT = """解决接下来的任务或者问题。为了帮助你,我们提供了一些相关的计划
|
||||
和相应的解答。注意其中一些信息可能存在噪声,因此你需要谨慎的使用它们。\n
|
||||
{question}\n{worker_log}\n现在开始回答这个任务或者问题。请直接回答这个问题,
|
||||
不要包含其他不需要的文字。{question}\n
|
||||
"""
|
||||
|
||||
|
||||
class ReWOOProtocol:
|
||||
"""A wrapper of ReWOO prompt which manages the response from LLM and
|
||||
generate desired prompts in a ReWOO format.
|
||||
|
||||
Args:
|
||||
planner_prompt (str): prompt template for planner
|
||||
solver_prompt (str): prompt template for solver
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
planner_prompt: str = PLANER_PROMPT,
|
||||
worker_prompt: str = WORKER_PROMPT,
|
||||
solver_prompt: str = SOLVER_PROMPT,
|
||||
) -> None:
|
||||
self.planner_prompt = planner_prompt
|
||||
self.worker_prompt = worker_prompt
|
||||
self.solver_prompt = solver_prompt
|
||||
|
||||
def format_planner(self,
|
||||
chat_history: List[Dict],
|
||||
inner_step: List[Dict],
|
||||
action_executor: ActionExecutor,
|
||||
reformat_request: Optional[str] = '') -> List[Dict]:
|
||||
"""Generate the planner prompt required by ReWOO.
|
||||
|
||||
Args:
|
||||
chat_history (List[Dict]): The history log in previous runs.
|
||||
inner_step (List[Dict]): The log in the current run.
|
||||
action_executor (ActionExecutor): the action manager to execute
|
||||
actions.
|
||||
reformat_request (str): the error feedback if the LLM fails to
|
||||
generate required format for planner.
|
||||
|
||||
Returns:
|
||||
List[Dict]: ReWOO format prompt for planner.
|
||||
"""
|
||||
planner_prompt = self.planner_prompt.format(
|
||||
tool_description=action_executor.get_actions_info(), )
|
||||
formatted = []
|
||||
formatted.append(dict(role='system', content=planner_prompt))
|
||||
formatted += chat_history
|
||||
formatted += inner_step
|
||||
if reformat_request != '':
|
||||
formatted.append(
|
||||
dict(
|
||||
role='system',
|
||||
content='回答格式错误: %s. 请重新重新回答: ' % reformat_request))
|
||||
return formatted
|
||||
|
||||
def parse_worker(
|
||||
self,
|
||||
message: str,
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""Parse the LLM generated planner response and convert it into the
|
||||
worker format.
|
||||
|
||||
Args:
|
||||
message (str): The response from LLM with ReWOO planner format.
|
||||
|
||||
Returns:
|
||||
tuple: the return value is a tuple contains:
|
||||
- thought_list (List(str)): contain LLM thoughts of the user
|
||||
request.
|
||||
- action_list (List(str)): contain actions scheduled by LLM.
|
||||
- action_input_list (List(str)): contain the required action
|
||||
input for above actions.
|
||||
"""
|
||||
action_list = []
|
||||
action_input_list = []
|
||||
thought_list = []
|
||||
thoughts = re.findall('Plan: (.+)', message)
|
||||
action_units = re.findall('#E[0-9]* = (.+)', message)
|
||||
assert len(thoughts) == len(action_units), \
|
||||
'Each Plan should only correspond to only ONE action'
|
||||
for thought, action_unit in zip(thoughts, action_units):
|
||||
action_name, action_input = re.findall(r'(.*?)\[(.*?)\]',
|
||||
action_unit.strip())[0]
|
||||
action_list.append(action_name.strip())
|
||||
action_input_list.append(action_input.strip())
|
||||
thought_list.append(thought.strip())
|
||||
return thought_list, action_list, action_input_list
|
||||
|
||||
def format_solver(
|
||||
self, question: str, thought_list: List[str],
|
||||
action_return_list: List[ActionReturn]) -> Tuple[str, str]:
|
||||
"""Generate the prompt for solver in a ReWOO format.
|
||||
|
||||
Args:
|
||||
question (str): The user request in the current run.
|
||||
thought_list (List[str]): thoughts generated from LLM for
|
||||
each action.
|
||||
action_return_list (List[ActionReturn]): action returns
|
||||
from workers.
|
||||
|
||||
Returns:
|
||||
tuple: the return value is a tuple contains:
|
||||
- solver_prompt (str): the generated prompt for solver
|
||||
in a ReWOO format.
|
||||
- worker_log (str): contain action responses from workers.
|
||||
Used for inner log.
|
||||
"""
|
||||
worker_log = ''
|
||||
for thought, action_return in zip(thought_list, action_return_list):
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
action_resp = action_return.result['text']
|
||||
else:
|
||||
action_resp = action_return.errmsg
|
||||
worker_response = self.worker_prompt.format(
|
||||
thought=thought, action_resp=action_resp)
|
||||
worker_log += worker_response
|
||||
solver_prompt = self.solver_prompt.format(
|
||||
question=question, worker_log=worker_log)
|
||||
return solver_prompt, worker_log
|
||||
|
||||
|
||||
class ReWOO(BaseAgent):
|
||||
"""An implementation of ReWOO (https://arxiv.org/abs/2305.18323)
|
||||
|
||||
Args:
|
||||
llm (BaseModel or BaseAPIModel): a LLM service which can chat
|
||||
and act as planner / solver.
|
||||
action_executor (ActionExecutor): an action executor to manage
|
||||
all actions and their response.
|
||||
protocol (ReWOOProtocol): a wrapper to generate prompt and
|
||||
parse the response from LLM / actions.
|
||||
max_turn (int): the maximum number of trails for LLM to generate
|
||||
plans that can be successfully parsed by ReWOO protocol.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
action_executor: ActionExecutor,
|
||||
protocol: ReWOOProtocol = ReWOOProtocol(),
|
||||
max_turn: int = 2) -> None:
|
||||
super().__init__(
|
||||
llm=llm, action_executor=action_executor, protocol=protocol)
|
||||
|
||||
self.max_turn = max_turn
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
self._inner_history.append(dict(role='user', content=message))
|
||||
agent_return = AgentReturn()
|
||||
|
||||
# planner
|
||||
turn_id = 0
|
||||
reformat_request = ''
|
||||
while turn_id < self.max_turn:
|
||||
planner_prompt = self._protocol.format_planner(
|
||||
chat_history=self.session_history,
|
||||
inner_step=self._inner_history,
|
||||
action_executor=self._action_executor,
|
||||
reformat_request=reformat_request)
|
||||
response = self._llm.generate_from_template(planner_prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
try:
|
||||
thoughts, actions, actions_input = self._protocol.parse_worker(
|
||||
response)
|
||||
break
|
||||
except Exception as e:
|
||||
turn_id += 1
|
||||
reformat_request = str(e)
|
||||
|
||||
if turn_id >= self.max_turn:
|
||||
warnings.warn('\nUnable to parse LLM responses in %d turns, '
|
||||
'directly request solver for question answer...' %
|
||||
self.max_turn)
|
||||
actions = []
|
||||
thoughts = []
|
||||
action_responses = []
|
||||
# workers
|
||||
action_responses = []
|
||||
for action_id in range(len(actions)):
|
||||
# we need to change actions_input inplace
|
||||
prev_ptrs = re.findall(r'#E\d+', actions_input[action_id])
|
||||
for prev_ptr in prev_ptrs:
|
||||
ptr_num = int(prev_ptr.strip('#E')) - 1 # start from 0
|
||||
actions_input[action_id] = actions_input[action_id].replace(
|
||||
prev_ptr, action_responses[ptr_num].result['text'])
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
actions[action_id], actions_input[action_id])
|
||||
action_responses.append(action_return)
|
||||
|
||||
solver_prompt, worker_log = self._protocol.format_solver(
|
||||
message, thoughts, action_responses)
|
||||
self._inner_history.append(dict(role='system', content=worker_log))
|
||||
|
||||
final_response = self._llm.generate_from_template(solver_prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=final_response))
|
||||
agent_return.response = final_response
|
||||
return agent_return
|
||||
10
lagent/llms/__init__.py
Normal file
10
lagent/llms/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from lagent.utils import is_module_exist
|
||||
from .base_api import BaseAPIModel
|
||||
from .base_llm import BaseModel
|
||||
from .openai import GPTAPI
|
||||
|
||||
__all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI']
|
||||
|
||||
if is_module_exist('transformers'):
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401
|
||||
__all__.extend(['HFTransformer', 'HFTransformerCasualLM'])
|
||||
240
lagent/llms/base_api.py
Normal file
240
lagent/llms/base_api.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
from abc import abstractclassmethod
|
||||
from time import sleep
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .base_llm import BaseModel
|
||||
|
||||
|
||||
class BaseAPIModel(BaseModel):
|
||||
"""Base class for API model wrapper.
|
||||
|
||||
Args:
|
||||
model_type (str): The type of model.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
max_seq_len (int): The maximum sequence length of the model. Defaults
|
||||
to 2048.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
model_type: str,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None):
|
||||
self.model_type = model_type
|
||||
self.max_seq_len = max_seq_len
|
||||
self.meta_template = meta_template
|
||||
self.retry = retry
|
||||
self.query_per_second = query_per_second
|
||||
self.token_bucket = TokenBucket(query_per_second)
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs, max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or list]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
|
||||
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
|
||||
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
|
||||
|
||||
# Count English words
|
||||
english_count = sum(len(part.split()) for part in english_parts)
|
||||
|
||||
# Count Chinese words
|
||||
chinese_count = sum(len(part) for part in chinese_parts)
|
||||
|
||||
return english_count + chinese_count
|
||||
|
||||
def wait(self):
|
||||
"""Wait till the next query can be sent.
|
||||
|
||||
Applicable in both single-thread and multi-thread environments.
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
||||
def to(self, device):
|
||||
pass
|
||||
|
||||
|
||||
class APITemplateParser:
|
||||
"""Intermidate prompt template parser, specifically for API models.
|
||||
|
||||
Args:
|
||||
meta_template (Dict): The meta template for the model.
|
||||
"""
|
||||
|
||||
def __init__(self, meta_template: Optional[Dict] = None):
|
||||
self.meta_template = meta_template
|
||||
# Check meta template
|
||||
if meta_template:
|
||||
assert isinstance(meta_template, list)
|
||||
self.roles: Dict[str, dict] = dict() # maps role name to config
|
||||
for item in meta_template:
|
||||
assert isinstance(item, dict)
|
||||
assert item['role'] not in self.roles, \
|
||||
'role in meta prompt must be unique!'
|
||||
self.roles[item['role']] = item.copy()
|
||||
|
||||
def parse_template(self, dialog: List[Union[str, List]]):
|
||||
"""Parse the intermidate prompt template, and wrap it with meta
|
||||
template if applicable. When the meta template is set and the input is
|
||||
a list, the return value will be a list containing the full
|
||||
conversation history. Each item looks like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{'role': 'user', 'content': '...'}).
|
||||
|
||||
Args:
|
||||
dialog (List[str or list]): An intermidate prompt
|
||||
template (potentially before being wrapped by meta template).
|
||||
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
||||
|
||||
Returns:
|
||||
List[str or list]: The finalized prompt or a conversation.
|
||||
"""
|
||||
assert isinstance(dialog, (str, list))
|
||||
if isinstance(dialog, str):
|
||||
return dialog
|
||||
if self.meta_template:
|
||||
|
||||
prompt = list()
|
||||
# Whether to keep generating the prompt
|
||||
generate = True
|
||||
for i, item in enumerate(dialog):
|
||||
if not generate:
|
||||
break
|
||||
if isinstance(item, str):
|
||||
if item.strip():
|
||||
# TODO: logger
|
||||
warnings.warn('Non-empty string in prompt template '
|
||||
'will be ignored in API models.')
|
||||
else:
|
||||
api_prompts = self._prompt2api(item)
|
||||
prompt.append(api_prompts)
|
||||
|
||||
# merge the consecutive prompts assigned to the same role
|
||||
new_prompt = list([prompt[0]])
|
||||
last_role = prompt[0]['role']
|
||||
for item in prompt[1:]:
|
||||
if item['role'] == last_role:
|
||||
new_prompt[-1]['content'] += '\n' + item['content']
|
||||
else:
|
||||
last_role = item['role']
|
||||
new_prompt.append(item)
|
||||
prompt = new_prompt
|
||||
|
||||
else:
|
||||
# in case the model does not have any meta template
|
||||
prompt = ''
|
||||
last_sep = ''
|
||||
for item in dialog:
|
||||
if isinstance(item, str):
|
||||
if item:
|
||||
prompt += last_sep + item
|
||||
elif item.get('content', ''):
|
||||
prompt += last_sep + item.get('content', '')
|
||||
last_sep = '\n'
|
||||
return prompt
|
||||
|
||||
def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
|
||||
"""Convert the prompts to a API-style prompts, given an updated
|
||||
role_dict.
|
||||
|
||||
Args:
|
||||
prompts (Union[List, str]): The prompts to be converted.
|
||||
role_dict (Dict[str, Dict]): The updated role dict.
|
||||
for_gen (bool): If True, the prompts will be converted for
|
||||
generation tasks. The conversion stops before the first
|
||||
role whose "generate" is set to True.
|
||||
|
||||
Returns:
|
||||
Tuple[str, bool]: The converted string, and whether the follow-up
|
||||
conversion should be proceeded.
|
||||
"""
|
||||
if isinstance(prompts, str):
|
||||
return prompts
|
||||
elif isinstance(prompts, dict):
|
||||
api_role = self._role2api_role(prompts)
|
||||
return api_role
|
||||
|
||||
res = []
|
||||
for prompt in prompts:
|
||||
if isinstance(prompt, str):
|
||||
raise TypeError('Mixing str without explicit role is not '
|
||||
'allowed in API models!')
|
||||
else:
|
||||
api_role = self._role2api_role(prompt)
|
||||
res.append(api_role)
|
||||
return res
|
||||
|
||||
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
|
||||
|
||||
merged_prompt = self.roles.get(
|
||||
role_prompt['role'],
|
||||
self.roles.get(
|
||||
self.roles[role_prompt['role']].get('fallback_role')))
|
||||
res = {}
|
||||
res['role'] = merged_prompt['api_role']
|
||||
res['content'] = merged_prompt.get('begin', '')
|
||||
res['content'] += role_prompt.get('content', '')
|
||||
res['content'] += merged_prompt.get('end', '')
|
||||
return res
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""A token bucket for rate limiting.
|
||||
|
||||
Args:
|
||||
rate (float): The rate of the token bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, rate: float) -> None:
|
||||
self._rate = rate
|
||||
self._tokens = threading.Semaphore(0)
|
||||
self.started = False
|
||||
|
||||
def _add_tokens(self):
|
||||
"""Add tokens to the bucket."""
|
||||
while True:
|
||||
if self._tokens._value < self._rate:
|
||||
self._tokens.release()
|
||||
sleep(1 / self._rate)
|
||||
|
||||
def get_token(self):
|
||||
"""Get a token from the bucket."""
|
||||
if not self.started:
|
||||
self.started = True
|
||||
threading.Thread(target=self._add_tokens, daemon=True).start()
|
||||
self._tokens.acquire()
|
||||
146
lagent/llms/base_llm.py
Normal file
146
lagent/llms/base_llm.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from abc import abstractclassmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class BaseModel:
|
||||
"""Base class for model wrapper.
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
max_seq_len (int): The maximum sequence length of the model. Defaults
|
||||
to 2048.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
"""
|
||||
|
||||
is_api: bool = False
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_only: bool = False,
|
||||
meta_template: Optional[Dict] = None):
|
||||
self.path = path
|
||||
self.max_seq_len = max_seq_len
|
||||
self.tokenizer_only = tokenizer_only
|
||||
# meta template
|
||||
self.template_parser = LMTemplateParser(meta_template)
|
||||
self.eos_token_id = None
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of strings.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
def parse_template(self, dialog) -> str:
|
||||
"""Parse a prompt template, and wrap it with meta template if
|
||||
applicable.
|
||||
|
||||
Args:
|
||||
dialog (List[str or PromptList]): A prompt
|
||||
template (potentially before being wrapped by meta template).
|
||||
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
||||
|
||||
Returns:
|
||||
str: The final string.
|
||||
"""
|
||||
return self.template_parser.parse_template(dialog)
|
||||
|
||||
def generate_from_template(self, templates, max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
"""
|
||||
inputs = self.parse_template(templates)
|
||||
return self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
|
||||
class LMTemplateParser:
|
||||
"""Intermidate prompt template parser, specifically for language models.
|
||||
|
||||
Args:
|
||||
meta_template (Dict): The meta template for the model.
|
||||
"""
|
||||
|
||||
def __init__(self, meta_template: Optional[Dict] = None):
|
||||
self.meta_template = meta_template
|
||||
if meta_template:
|
||||
assert isinstance(meta_template, list)
|
||||
self.roles: Dict[str, dict] = dict() # maps role name to config
|
||||
for item in meta_template:
|
||||
assert isinstance(item, dict)
|
||||
assert item['role'] not in self.roles, \
|
||||
'role in meta prompt must be unique!'
|
||||
self.roles[item['role']] = item.copy()
|
||||
|
||||
def parse_template(self, dialog) -> str:
|
||||
"""Parse a prompt template, and wrap it with meta template if
|
||||
applicable.
|
||||
|
||||
Args:
|
||||
dialog (List[str or PromptList]): A prompt
|
||||
template (potentially before being wrapped by meta template).
|
||||
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
||||
|
||||
Returns:
|
||||
str: The final string.
|
||||
"""
|
||||
assert isinstance(dialog, (str, list))
|
||||
if isinstance(dialog, str):
|
||||
return dialog
|
||||
if self.meta_template:
|
||||
|
||||
prompt = ''
|
||||
for index, item in enumerate(dialog):
|
||||
if isinstance(item, str):
|
||||
prompt += item
|
||||
else:
|
||||
new_str = self._prompt2str(item, index == len(dialog) - 1)
|
||||
prompt += new_str
|
||||
else:
|
||||
# in case the model does not have any meta template
|
||||
prompt = ''
|
||||
last_sep = ''
|
||||
for item in dialog:
|
||||
if isinstance(item, str):
|
||||
if item:
|
||||
prompt += last_sep + item
|
||||
elif item.get('content', ''):
|
||||
prompt += last_sep + item.get('prompt', '')
|
||||
last_sep = '\n'
|
||||
return prompt
|
||||
|
||||
def _prompt2str(self,
|
||||
prompt: Union[List, str, Dict],
|
||||
last: bool = False) -> Tuple[str, bool]:
|
||||
if isinstance(prompt, str):
|
||||
return prompt
|
||||
merged_prompt = self.roles.get(
|
||||
prompt['role'],
|
||||
self.roles.get(self.roles[prompt['role']].get('fallback_role')))
|
||||
res = merged_prompt.get('begin', '')
|
||||
if last and merged_prompt.get('generate', False):
|
||||
res += prompt.get('content', '')
|
||||
return res
|
||||
res += prompt.get('content', '') + merged_prompt.get('end', '')
|
||||
if last and merged_prompt['role'] != 'assistant':
|
||||
res += self.roles['assistant']['begin']
|
||||
return res
|
||||
return res
|
||||
128
lagent/llms/huggingface.py
Normal file
128
lagent/llms/huggingface.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .base_llm import BaseModel
|
||||
|
||||
|
||||
class HFTransformer(BaseModel):
|
||||
"""Model wrapper around HuggingFace general models.
|
||||
|
||||
Adapted from OpenCompass (https://github.com/InternLM/opencompass
|
||||
/blob/main/opencompass/models/huggingface.py)
|
||||
|
||||
Args:
|
||||
path (str): The name or path to HuggingFace's model.
|
||||
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||
to 2048.
|
||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||
Defaults to {}.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
||||
Defaults to dict(device_map='auto').
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
extract_pred_after_decode (bool): Whether to extract the prediction
|
||||
string from the decoded output string, instead of extract the
|
||||
prediction tokens before decoding. Defaults to False.
|
||||
batch_padding (bool): If False, inference with be performed in for-loop
|
||||
without batch padding.
|
||||
|
||||
Note:
|
||||
About ``extract_pred_after_decode``: Commonly, we should extract the
|
||||
the prediction tokens before decoding. But for some tokenizers using
|
||||
``sentencepiece``, like LLaMA, this behavior may change the number of
|
||||
whitespaces, which is harmful for Python programming tasks.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = None,
|
||||
extract_pred_after_decode: bool = False,
|
||||
batch_padding: bool = False):
|
||||
super().__init__(
|
||||
path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
tokenizer_only=tokenizer_only,
|
||||
meta_template=meta_template)
|
||||
self._load_tokenizer(
|
||||
path=path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
tokenizer_kwargs=tokenizer_kwargs)
|
||||
self.batch_padding = batch_padding
|
||||
self.extract_pred_after_decode = extract_pred_after_decode
|
||||
if not tokenizer_only:
|
||||
self._load_model(path=path, model_kwargs=model_kwargs)
|
||||
|
||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
||||
tokenizer_kwargs: dict):
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path if tokenizer_path else path,
|
||||
trust_remote_code=True,
|
||||
**tokenizer_kwargs)
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
from transformers import AutoModel
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
path, trust_remote_code=True, **model_kwargs)
|
||||
self.model.eval()
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int,
|
||||
**kwargs) -> List[str]:
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
if self.extract_pred_after_decode:
|
||||
prompt_lens = [len(input_) for input_ in inputs]
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
inputs, truncation=True,
|
||||
max_length=self.max_seq_len - max_out_len)['input_ids']
|
||||
input_ids = torch.tensor(input_ids, device=self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids, max_new_tokens=max_out_len, **kwargs)
|
||||
|
||||
if not self.extract_pred_after_decode:
|
||||
outputs = outputs[:, input_ids.shape[1]:]
|
||||
|
||||
decodeds = self.tokenizer.batch_decode(
|
||||
outputs, skip_special_tokens=True)
|
||||
if self.extract_pred_after_decode:
|
||||
decodeds = [
|
||||
token[len_:] for token, len_ in zip(decodeds, prompt_lens)
|
||||
]
|
||||
|
||||
return decodeds[0]
|
||||
|
||||
def generate_from_template(self, templates, max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
"""
|
||||
inputs = self.parse_template(templates)
|
||||
response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
return response.replace(
|
||||
self.template_parser.roles['assistant']['end'].strip(),
|
||||
'').strip()
|
||||
|
||||
|
||||
class HFTransformerCasualLM(HFTransformer):
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
from transformers import AutoModelForCausalLM
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
self.model.eval()
|
||||
242
lagent/llms/openai.py
Normal file
242
lagent/llms/openai.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
from logging import getLogger
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
|
||||
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
|
||||
|
||||
|
||||
class GPTAPI(BaseAPIModel):
|
||||
"""Model wrapper around OpenAI's models.
|
||||
|
||||
Args:
|
||||
model_type (str): The name of OpenAI's model.
|
||||
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||
Note that the length of prompt + generated tokens shall not exceed
|
||||
this value. Defaults to 2048.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
key (str or List[str]): OpenAI key(s). In particular, when it
|
||||
is set to "ENV", the key will be fetched from the environment
|
||||
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
|
||||
list, the keys will be used in round-robin manner. Defaults to
|
||||
'ENV'.
|
||||
org (str or List[str], optional): OpenAI organization(s). If not
|
||||
specified, OpenAI uses the default organization bound to each API
|
||||
key. If specified, the orgs will be posted with each request in
|
||||
round-robin manner. Defaults to None.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
||||
'https://api.openai.com/v1/chat/completions'.
|
||||
temperature (float, optional): What sampling temperature to use.
|
||||
If not None, will override the temperature in the `generate()`
|
||||
call. Defaults to None.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
model_type: str = 'gpt-3.5-turbo',
|
||||
max_seq_len: int = 4096,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
key: Union[str, List[str]] = 'ENV',
|
||||
org: Optional[Union[str, List[str]]] = None,
|
||||
meta_template: Optional[Dict] = [
|
||||
dict(role='system', api_role='system'),
|
||||
dict(role='user', api_role='user'),
|
||||
dict(role='assistant', api_role='assistant')
|
||||
],
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
temperature: Optional[float] = None):
|
||||
|
||||
super().__init__(
|
||||
model_type=model_type,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
query_per_second=query_per_second,
|
||||
retry=retry)
|
||||
self.logger = getLogger(__name__)
|
||||
self.temperature = temperature
|
||||
|
||||
if isinstance(key, str):
|
||||
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
||||
else:
|
||||
self.keys = key
|
||||
|
||||
# record invalid keys and skip them when requesting API
|
||||
# - keys have insufficient_quota
|
||||
self.invalid_keys = set()
|
||||
|
||||
self.key_ctr = 0
|
||||
if isinstance(org, str):
|
||||
self.orgs = [org]
|
||||
else:
|
||||
self.orgs = org
|
||||
self.org_ctr = 0
|
||||
self.url = openai_api_base
|
||||
self.model_type = model_type
|
||||
|
||||
# max num token for gpt-3.5-turbo is 4097
|
||||
context_window = 4096
|
||||
if '32k' in self.model_type:
|
||||
context_window = 32768
|
||||
elif '16k' in self.model_type:
|
||||
context_window = 16384
|
||||
elif 'gpt-4' in self.model_type:
|
||||
context_window = 8192
|
||||
self.context_window = context_window
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Union[List, str],
|
||||
max_out_len: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or List]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic. Defaults to 0.7.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
if self.temperature is not None:
|
||||
temperature = self.temperature
|
||||
return self._generate(inputs, max_out_len, temperature)
|
||||
|
||||
def _generate(self, input: str or List, max_out_len: int,
|
||||
temperature: float) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (str or List): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic.
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, list, dict))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
elif isinstance(input, dict):
|
||||
messages = [input]
|
||||
else:
|
||||
messages = input
|
||||
|
||||
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
||||
max_out_len = min(
|
||||
max_out_len,
|
||||
self.context_window - self.get_token_len(str(input)) - 100)
|
||||
if max_out_len <= 0:
|
||||
return ''
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.wait()
|
||||
|
||||
with Lock():
|
||||
if len(self.invalid_keys) == len(self.keys):
|
||||
raise RuntimeError('All keys have insufficient quota.')
|
||||
|
||||
# find the next valid key
|
||||
while True:
|
||||
self.key_ctr += 1
|
||||
if self.key_ctr == len(self.keys):
|
||||
self.key_ctr = 0
|
||||
|
||||
if self.keys[self.key_ctr] not in self.invalid_keys:
|
||||
break
|
||||
|
||||
key = self.keys[self.key_ctr]
|
||||
|
||||
header = {
|
||||
'Authorization': f'Bearer {key}',
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
|
||||
if self.orgs:
|
||||
with Lock():
|
||||
self.org_ctr += 1
|
||||
if self.org_ctr == len(self.orgs):
|
||||
self.org_ctr = 0
|
||||
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
|
||||
|
||||
try:
|
||||
data = dict(
|
||||
model=self.model_type,
|
||||
messages=messages,
|
||||
max_tokens=max_out_len,
|
||||
n=1,
|
||||
stop=None,
|
||||
temperature=temperature,
|
||||
)
|
||||
raw_response = requests.post(
|
||||
self.url, headers=header, data=json.dumps(data))
|
||||
except requests.ConnectionError:
|
||||
print('Got connection error, retrying...')
|
||||
continue
|
||||
try:
|
||||
response = raw_response.json()
|
||||
except requests.JSONDecodeError:
|
||||
print('JsonDecode error, got', str(raw_response.content))
|
||||
continue
|
||||
try:
|
||||
return response['choices'][0]['message']['content'].strip()
|
||||
except KeyError:
|
||||
if 'error' in response:
|
||||
if response['error']['code'] == 'rate_limit_exceeded':
|
||||
time.sleep(1)
|
||||
continue
|
||||
elif response['error']['code'] == 'insufficient_quota':
|
||||
self.invalid_keys.add(key)
|
||||
self.logger.warn(f'insufficient_quota key: {key}')
|
||||
continue
|
||||
|
||||
print('Find error message in response: ',
|
||||
str(response['error']))
|
||||
max_num_retries += 1
|
||||
|
||||
raise RuntimeError('Calling OpenAI failed after retrying for '
|
||||
f'{max_num_retries} times. Check the logs for '
|
||||
'details.')
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
import tiktoken
|
||||
self.tiktoken = tiktoken
|
||||
enc = self.tiktoken.encoding_for_model(self.model_type)
|
||||
return len(enc.encode(prompt))
|
||||
79
lagent/schema.py
Normal file
79
lagent/schema.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from lagent.utils import is_module_exist
|
||||
|
||||
|
||||
def enum_dict_factory(inputs):
|
||||
inputs = [(i[0], i[-1].value) if isinstance(i[-1], Enum) else i
|
||||
for i in inputs]
|
||||
return dict(inputs)
|
||||
|
||||
|
||||
def dataclass2dict(data):
|
||||
return asdict(data, dict_factory=enum_dict_factory)
|
||||
|
||||
|
||||
class ActionStatusCode(int, Enum):
|
||||
ING = 1
|
||||
SUCCESS = 0
|
||||
HTTP_ERROR = -1000 # http error
|
||||
ARGS_ERROR = -1001 # 参数错误
|
||||
API_ERROR = -1002 # 不知道的API错误
|
||||
|
||||
|
||||
class ActionValidCode(int, Enum):
|
||||
FINISH = 1
|
||||
OPEN = 0
|
||||
CLOSED = -1
|
||||
INVALID = -2
|
||||
ABSENT = -3 # NO ACTION
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionReturn:
|
||||
args: Dict
|
||||
url: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
result: Optional[str] = None
|
||||
errmsg: Optional[str] = None
|
||||
state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS
|
||||
thought: Optional[str] = None
|
||||
valid: Optional[ActionValidCode] = ActionValidCode.OPEN
|
||||
|
||||
|
||||
class AgentStatusCode(Enum):
|
||||
END = 0 # end of streaming
|
||||
STREAM_ING = 1 # response is in streaming
|
||||
SERVER_ERR = -1 # triton server's error
|
||||
SESSION_CLOSED = -2 # session has been closed
|
||||
SESSION_OUT_OF_LIMIT = -3 # request length out of limit
|
||||
CMD = 2 # return command
|
||||
SESSION_INVALID_ARG = -4 # invalid argument
|
||||
SESSION_READY = 3 # session is ready for inference
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentReturn:
|
||||
actions: List[ActionReturn] = field(default_factory=list)
|
||||
response: str = ''
|
||||
inner_steps: List = field(default_factory=list)
|
||||
errmsg: Optional[str] = None
|
||||
|
||||
|
||||
if is_module_exist('lmdeploy'):
|
||||
from lmdeploy.serve.turbomind.chatbot import StatusCode
|
||||
STATE_MAP = {
|
||||
StatusCode.TRITON_STREAM_END: AgentStatusCode.END,
|
||||
StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR,
|
||||
StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED,
|
||||
StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING,
|
||||
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
|
||||
AgentStatusCode.SESSION_OUT_OF_LIMIT,
|
||||
StatusCode.TRITON_SESSION_INVALID_ARG:
|
||||
AgentStatusCode.SESSION_INVALID_ARG,
|
||||
StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY
|
||||
}
|
||||
else:
|
||||
STATE_MAP = {}
|
||||
3
lagent/utils/__init__.py
Normal file
3
lagent/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .package import is_module_exist
|
||||
|
||||
__all__ = ['is_module_exist']
|
||||
9
lagent/utils/package.py
Normal file
9
lagent/utils/package.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import importlib
|
||||
|
||||
|
||||
def is_module_exist(module_name):
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
35
lagent/version.py
Normal file
35
lagent/version.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
__version__ = '0.1.0'
|
||||
|
||||
|
||||
def parse_version_info(version_str: str, length: int = 4) -> tuple:
|
||||
"""Parse a version string into a tuple.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
length (int): The maximum number of version levels. Default: 4.
|
||||
|
||||
Returns:
|
||||
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
|
||||
(1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
|
||||
(2, 0, 0, 0, 'rc', 1) (when length is set to 4).
|
||||
"""
|
||||
from packaging.version import parse
|
||||
version = parse(version_str)
|
||||
assert version.release, f'failed to parse version {version_str}'
|
||||
release = list(version.release)
|
||||
release = release[:length]
|
||||
if len(release) < length:
|
||||
release = release + [0] * (length - len(release))
|
||||
if version.is_prerelease:
|
||||
release.extend(list(version.pre)) # type: ignore
|
||||
elif version.is_postrelease:
|
||||
release.extend(list(version.post)) # type: ignore
|
||||
else:
|
||||
release.extend([0, 0])
|
||||
return tuple(release)
|
||||
|
||||
|
||||
version_info = tuple(int(x) for x in __version__.split('.')[:3])
|
||||
|
||||
__all__ = ['__version__', 'version_info', 'parse_version_info']
|
||||
0
projects/.gitkeep
Normal file
0
projects/.gitkeep
Normal file
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
-r requirements/optional.txt
|
||||
-r requirements/runtime.txt
|
||||
2
requirements/optional.txt
Normal file
2
requirements/optional.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
torch
|
||||
transformers
|
||||
5
requirements/runtime.txt
Normal file
5
requirements/runtime.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
distro
|
||||
func_timeout
|
||||
jsonschema
|
||||
requests
|
||||
tiktoken
|
||||
24
setup.cfg
Normal file
24
setup.cfg
Normal file
@@ -0,0 +1,24 @@
|
||||
[isort]
|
||||
line_length = 79
|
||||
multi_line_output = 0
|
||||
extra_standard_library = setuptools
|
||||
known_first_party = mmdet
|
||||
known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,mmengine,numpy,onnx,onnxruntime,pycocotools,parameterized,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
[yapf]
|
||||
BASED_ON_STYLE = pep8
|
||||
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
|
||||
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
|
||||
|
||||
# ignore-words-list needs to be lowercase format. For example, if we want to
|
||||
# ignore word "BA", then we need to append "ba" to ignore-words-list rather
|
||||
# than "BA"
|
||||
[codespell]
|
||||
skip = *.ipynb
|
||||
quiet-level = 3
|
||||
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer
|
||||
|
||||
[flake8]
|
||||
per-file-ignores = ftdp/configs/*: F401,F403,F405
|
||||
107
setup.py
Normal file
107
setup.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from pathlib import Path
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_version():
|
||||
version_file = 'lagent/version.py'
|
||||
with open(version_file, encoding='utf-8') as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
return locals()['__version__']
|
||||
|
||||
|
||||
def parse_requirements(fname='requirements.txt', with_version=True):
|
||||
"""Parse the package dependencies listed in a requirements file but strip
|
||||
specific version information.
|
||||
|
||||
Args:
|
||||
fname (str): Path to requirements file.
|
||||
with_version (bool, default=False): If True, include version specs.
|
||||
Returns:
|
||||
info (list[str]): List of requirements items.
|
||||
CommandLine:
|
||||
python -c "import setup; print(setup.parse_requirements())"
|
||||
"""
|
||||
import re
|
||||
import sys
|
||||
from os.path import exists
|
||||
|
||||
require_fpath = fname
|
||||
|
||||
def parse_line(line):
|
||||
"""Parse information from a line in a requirements text file."""
|
||||
if line.startswith('-r '):
|
||||
# Allow specifying requirements in other files
|
||||
target = line.split(' ')[1]
|
||||
for info in parse_require_file(target):
|
||||
yield info
|
||||
else:
|
||||
info = {'line': line}
|
||||
if line.startswith('-e '):
|
||||
info['package'] = line.split('#egg=')[1]
|
||||
else:
|
||||
# Remove versioning from the package
|
||||
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
|
||||
parts = re.split(pat, line, maxsplit=1)
|
||||
parts = [p.strip() for p in parts]
|
||||
|
||||
info['package'] = parts[0]
|
||||
if len(parts) > 1:
|
||||
op, rest = parts[1:]
|
||||
if ';' in rest:
|
||||
# Handle platform specific dependencies
|
||||
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
|
||||
version, platform_deps = map(str.strip,
|
||||
rest.split(';'))
|
||||
info['platform_deps'] = platform_deps
|
||||
else:
|
||||
version = rest # NOQA
|
||||
info['version'] = (op, version)
|
||||
yield info
|
||||
|
||||
def parse_require_file(fpath):
|
||||
with open(fpath) as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
yield from parse_line(line)
|
||||
|
||||
def gen_packages_items():
|
||||
if exists(require_fpath):
|
||||
for info in parse_require_file(require_fpath):
|
||||
parts = [info['package']]
|
||||
if with_version and 'version' in info:
|
||||
parts.extend(info['version'])
|
||||
if not sys.version.startswith('3.4'):
|
||||
# apparently package_deps are broken in 3.4
|
||||
platform_deps = info.get('platform_deps')
|
||||
if platform_deps is not None:
|
||||
parts.append(';' + platform_deps)
|
||||
item = ''.join(parts)
|
||||
yield item
|
||||
|
||||
packages = list(gen_packages_items())
|
||||
return packages
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with Path(Path(__file__).parent,
|
||||
'README.md').open(encoding='utf-8') as file:
|
||||
long_description = file.read()
|
||||
|
||||
setup(
|
||||
name='lagent',
|
||||
packages=find_packages(),
|
||||
include_package_data=True,
|
||||
version=get_version(),
|
||||
license='Apache 2.0',
|
||||
description='An open-source framework for language agent research.',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
data_files=[('.', ['README.md'])],
|
||||
keywords=['artificial general intelligence', 'agent', 'agi', 'llm'],
|
||||
install_requires=parse_requirements('requirements/runtime.txt'),
|
||||
extras_require={
|
||||
'all': parse_requirements('requirements.txt'),
|
||||
'optional': parse_requirements('requirements/optional.txt'),
|
||||
},
|
||||
)
|
||||
144
tests/data/search.json
Normal file
144
tests/data/search.json
Normal file
@@ -0,0 +1,144 @@
|
||||
{
|
||||
"searchParameters": {
|
||||
"q": "What is the capital of China?",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
"num": 10,
|
||||
"type": "search"
|
||||
},
|
||||
"answerBox": {
|
||||
"title": "China / Capital",
|
||||
"answer": "Beijing"
|
||||
},
|
||||
"organic": [
|
||||
{
|
||||
"title": "Beijing - Wikipedia",
|
||||
"link": "https://en.wikipedia.org/wiki/Beijing",
|
||||
"snippet": "With over 21 million residents, Beijing is the world's most populous national capital city as well as China's second largest city after Shanghai.",
|
||||
"sitelinks": [
|
||||
{
|
||||
"title": "Etymology",
|
||||
"link": "https://en.wikipedia.org/wiki/Beijing#Etymology"
|
||||
},
|
||||
{
|
||||
"title": "History",
|
||||
"link": "https://en.wikipedia.org/wiki/Beijing#History"
|
||||
},
|
||||
{
|
||||
"title": "Geography",
|
||||
"link": "https://en.wikipedia.org/wiki/Beijing#Geography"
|
||||
}
|
||||
],
|
||||
"position": 1
|
||||
},
|
||||
{
|
||||
"title": "What is the Capital of China? - Mappr",
|
||||
"link": "https://www.mappr.co/capital-cities/china/",
|
||||
"snippet": "Beijing, also known as Peking, is the capital of China and one of the most populous cities in the world. It is the country's political, educational, ...",
|
||||
"position": 2
|
||||
},
|
||||
{
|
||||
"title": "Google Map of the City of Beijing, capital of P.R. China - Nations Online Project",
|
||||
"link": "https://www.nationsonline.org/oneworld/map/google_map_Beijing.htm",
|
||||
"snippet": "Google Earth: Searchable map/satellite view of Beijing, capital city of P.R. China. City Coordinates: 39°54′50″N 116°23′30″E, Bookmark/share this page ...",
|
||||
"position": 3
|
||||
},
|
||||
{
|
||||
"title": "Capital City of China - CountryReports.org",
|
||||
"link": "https://www.countryreports.org/country/china/capital-city.htm",
|
||||
"snippet": "Capital City, Beijing ; Capital location, 39 55 N, 116 23 E ; Capital - history, (Peking) Founded about 3,000 years ago on the site of a former Chinese capital, ...",
|
||||
"position": 4
|
||||
},
|
||||
{
|
||||
"title": "Capital of China - Beijing",
|
||||
"link": "https://www.chinahighlights.com/beijing/capital-of-china.htm",
|
||||
"snippet": "In Chinese, Beijing means 'Northern Capital'. Bei means 'north' and jing means 'capital'. In the history of China, Beijing's name has changed ...",
|
||||
"date": "Dec 15, 2022",
|
||||
"position": 5
|
||||
},
|
||||
{
|
||||
"title": "Beijing is the capital of the People's Republic of China. It is the world's most populous capital city, with over 21 million residents within an... | By DL&D Consult | Facebook",
|
||||
"link": "https://facebook.com/dldconsult/videos/beijing-capital-city-of-china/373001500555301/",
|
||||
"snippet": "Beijing is an important world capital and global power ...",
|
||||
"date": "Oct 19, 2020",
|
||||
"attributes": {
|
||||
"Duration": "2:58",
|
||||
"Posted": "Oct 19, 2020"
|
||||
},
|
||||
"imageUrl": "https://encrypted-tbn1.gstatic.com/images?q=tbn:ANd9GcQExx68yUr7xP_1wcRapEKlT5bxe4ptMa6WaLnwXdVAtAdloa7WeTIvCoJp",
|
||||
"position": 6
|
||||
},
|
||||
{
|
||||
"title": "What is the Capital of China? - WorldAtlas",
|
||||
"link": "https://www.worldatlas.com/articles/what-is-the-capital-of-china.html",
|
||||
"snippet": "The capital of China is Beijing.",
|
||||
"date": "Jul 3, 2018",
|
||||
"position": 7
|
||||
},
|
||||
{
|
||||
"title": "A Chinese capital that's not Beijing - BBC Travel",
|
||||
"link": "https://www.bbc.com/travel/article/20151008-a-chinese-capital-thats-not-beijing",
|
||||
"snippet": "Beijing may be the capital of China today, but for many centuries the country was ruled from Nanjing, a historic city located on the shores ...",
|
||||
"date": "Oct 13, 2015",
|
||||
"position": 8
|
||||
},
|
||||
{
|
||||
"title": "Beijing | Province, City, History, Map, & Facts - Britannica",
|
||||
"link": "https://www.britannica.com/place/Beijing",
|
||||
"snippet": "Beijing, city, province-level shi (municipality), and capital of the People's Republic of China. The city has been an integral part of China's history over ...",
|
||||
"position": 9
|
||||
}
|
||||
],
|
||||
"peopleAlsoAsk": [
|
||||
{
|
||||
"question": "Does China have 2 capitals?",
|
||||
"snippet": "There are traditionally four major historical capitals of China referred to as\nthe 'Four Great Ancient Capitals of China' (simplified Chinese: 中国四大古都;\ntraditional Chinese: 中國四大古都; pinyin: Zhōngguó Sì Dà Gǔ Dū). The four are\nBeijing, Nanjing, Luoyang and Xi\"an (Chang\"an).",
|
||||
"title": "Historical capitals of China - Wikipedia",
|
||||
"link": "https://en.wikipedia.org/wiki/Historical_capitals_of_China"
|
||||
},
|
||||
{
|
||||
"question": "What is the capital city of China USA?",
|
||||
"snippet": "Capital City\nBeijing\nCapital - time difference\nUTC+8 (13 hours ahead of Washington, DC, during Standard Time) note; despite its\nsize, all of China falls within one time zone",
|
||||
"title": "Capital City of China - CountryReports.org",
|
||||
"link": "https://www.countryreports.org/country/china/capital-city.htm"
|
||||
},
|
||||
{
|
||||
"question": "Is Hong Kong is a part of China?",
|
||||
"snippet": "Hong Kong (US: /ˈhɒŋkɒŋ/ or UK: /hɒŋˈkɒŋ/; Chinese: 香港, Cantonese: [hœ́ːŋ.kɔ̌ːŋ]\n( listen)), officially the Hong Kong Special Administrative Region of the\nPeople's Republic of China (abbr. Hong Kong SAR or HKSAR), is a city and a\nspecial administrative region in China.",
|
||||
"title": "Hong Kong - Wikipedia",
|
||||
"link": "https://en.wikipedia.org/wiki/Hong_Kong"
|
||||
},
|
||||
{
|
||||
"question": "Why China changed its capital?",
|
||||
"snippet": "Once in charge, it wasn't uncommon for a new emperor to shift the imperial\ncapital in order to: Rebuild after a great loss, as in the Han era when Liu Bang\nmoved the capital from Xianyang to nearby Chang'an (now Xi'an), after the former\nwas destroyed during a rebellion.",
|
||||
"title": "Why do they keep moving the Capital? - China Simplified",
|
||||
"link": "https://www.chinasimplified.com/2014/09/29/why-do-they-keep-moving-the-capital/"
|
||||
}
|
||||
],
|
||||
"relatedSearches": [
|
||||
{
|
||||
"query": "China map"
|
||||
},
|
||||
{
|
||||
"query": "Where is the capital of china on a map"
|
||||
},
|
||||
{
|
||||
"query": "Beijing population"
|
||||
},
|
||||
{
|
||||
"query": "Capital of Korea"
|
||||
},
|
||||
{
|
||||
"query": "Beijing pronunciation"
|
||||
},
|
||||
{
|
||||
"query": "What is the capital of India"
|
||||
},
|
||||
{
|
||||
"query": "What is the capital of Japan"
|
||||
},
|
||||
{
|
||||
"query": "What is the capital of Pakistan"
|
||||
}
|
||||
]
|
||||
}
|
||||
45
tests/test_actions/test_builtin_actions.py
Normal file
45
tests/test_actions/test_builtin_actions.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from lagent.actions.builtin_actions import (FinishAction, InvalidAction,
|
||||
NoAction)
|
||||
from lagent.schema import ActionStatusCode
|
||||
|
||||
|
||||
class TestFinishAction(TestCase):
|
||||
|
||||
def test_call(self):
|
||||
action = FinishAction()
|
||||
response = 'finish'
|
||||
action_return = action(response)
|
||||
self.assertEqual(action_return.state, ActionStatusCode.SUCCESS)
|
||||
self.assertDictEqual(action_return.result, dict(text='finish'))
|
||||
|
||||
|
||||
class TestInvalidAction(TestCase):
|
||||
|
||||
def test_call(self):
|
||||
action = InvalidAction()
|
||||
response = 'invalid'
|
||||
action_return = action(response)
|
||||
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
|
||||
self.assertEqual(action_return.errmsg, response)
|
||||
|
||||
action = InvalidAction(err_msg='error')
|
||||
action_return = action()
|
||||
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
|
||||
self.assertEqual(action_return.errmsg, 'error')
|
||||
|
||||
|
||||
class TestNoAction(TestCase):
|
||||
|
||||
def test_call(self):
|
||||
action = NoAction()
|
||||
response = 'no'
|
||||
action_return = action(response)
|
||||
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
|
||||
self.assertEqual(action_return.errmsg, response)
|
||||
|
||||
action = NoAction(err_msg='error')
|
||||
action_return = action()
|
||||
self.assertEqual(action_return.state, ActionStatusCode.API_ERROR)
|
||||
self.assertEqual(action_return.errmsg, 'error')
|
||||
21
tests/test_actions/test_python_interpreter.py
Normal file
21
tests/test_actions/test_python_interpreter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from lagent.actions.python_interpreter import PythonInterpreter
|
||||
from lagent.schema import ActionStatusCode
|
||||
|
||||
|
||||
class TestPythonInterpreter(TestCase):
|
||||
|
||||
def test_python_executor(self):
|
||||
python_executor = PythonInterpreter()
|
||||
tool_return = python_executor(
|
||||
'```python\ndef solution():\n return 1\n```')
|
||||
self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS)
|
||||
self.assertDictEqual(tool_return.result, dict(text='1'))
|
||||
|
||||
def test_timeout(self):
|
||||
python_executor = PythonInterpreter(timeout=2)
|
||||
tool_return = python_executor(
|
||||
'```python\ndef solution():\n while True:\n pass\n```')
|
||||
self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR)
|
||||
self.assertIn('FunctionTimedOut', tool_return.errmsg)
|
||||
35
tests/test_actions/test_serper_search.py
Normal file
35
tests/test_actions/test_serper_search.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
from unittest import TestCase, mock
|
||||
|
||||
from lagent.actions import SerperSearch
|
||||
from lagent.schema import ActionStatusCode
|
||||
|
||||
|
||||
class TestSerperSearch(TestCase):
|
||||
|
||||
@mock.patch.object(SerperSearch, '_search')
|
||||
def test_search_tool(self, mock_search_func):
|
||||
mock_response = (200, json.load('tests/data/search.json'))
|
||||
mock_search_func.return_value = mock_response
|
||||
search_tool = SerperSearch(api_key='abc')
|
||||
tool_return = search_tool.run("What's the capital of China?")
|
||||
self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS)
|
||||
self.assertDictEqual(tool_return.result, dict(text="['Beijing']"))
|
||||
|
||||
@mock.patch.object(SerperSearch, '_search')
|
||||
def test_api_error(self, mock_search_func):
|
||||
mock_response = (403, {'message': 'bad requests'})
|
||||
mock_search_func.return_value = mock_response
|
||||
search_tool = SerperSearch(api_key='abc')
|
||||
tool_return = search_tool.run("What's the capital of China?")
|
||||
self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR)
|
||||
self.assertEqual(tool_return.errmsg, str(403))
|
||||
|
||||
@mock.patch.object(SerperSearch, '_search')
|
||||
def test_http_error(self, mock_search_func):
|
||||
mock_response = (-1, 'HTTPSConnectionPool')
|
||||
mock_search_func.return_value = mock_response
|
||||
search_tool = SerperSearch(api_key='abc')
|
||||
tool_return = search_tool.run("What's the capital of China?")
|
||||
self.assertEqual(tool_return.state, ActionStatusCode.HTTP_ERROR)
|
||||
self.assertEqual(tool_return.errmsg, 'HTTPSConnectionPool')
|
||||
87
tests/test_agents/test_rewoo.py
Normal file
87
tests/test_agents/test_rewoo.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from unittest import TestCase, mock
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.actions.llm_qa import LLMQA
|
||||
from lagent.actions.serper_search import SerperSearch
|
||||
from lagent.agents.rewoo import ReWOO, ReWOOProtocol
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
|
||||
class TestReWOO(TestCase):
|
||||
|
||||
@mock.patch.object(SerperSearch, 'run')
|
||||
@mock.patch.object(LLMQA, 'run')
|
||||
@mock.patch.object(ReWOOProtocol, 'parse_worker')
|
||||
def test_normal_chat(self, mock_parse_worker_func, mock_qa_func,
|
||||
mock_search_func):
|
||||
mock_model = mock.Mock()
|
||||
mock_model.generate_from_template.return_value = 'LLM response'
|
||||
|
||||
mock_parse_worker_func.return_value = (['Thought1', 'Thought2'
|
||||
], ['LLMQA', 'SerperSearch'],
|
||||
['abc', 'abc'])
|
||||
|
||||
search_return = ActionReturn(args=None)
|
||||
search_return.state = ActionStatusCode.SUCCESS
|
||||
search_return.result = dict(text='search_return')
|
||||
mock_search_func.return_value = search_return
|
||||
|
||||
qa_return = ActionReturn(args=None)
|
||||
qa_return.state = ActionStatusCode.SUCCESS
|
||||
qa_return.result = dict(text='qa_return')
|
||||
mock_qa_func.return_value = qa_return
|
||||
|
||||
chatbot = ReWOO(
|
||||
llm=mock_model,
|
||||
action_executor=ActionExecutor(actions=[
|
||||
LLMQA(mock_model),
|
||||
SerperSearch(api_key=''),
|
||||
]))
|
||||
agent_return = chatbot.chat('abc')
|
||||
self.assertEqual(agent_return.response, 'LLM response')
|
||||
|
||||
def test_parse_worker(self):
|
||||
prompt = ReWOOProtocol()
|
||||
message = """
|
||||
Plan: a.
|
||||
#E1 = tool1["a"]
|
||||
#E2 = tool2["b"]
|
||||
"""
|
||||
try:
|
||||
thoughts, actions, actions_input = prompt.parse_worker(message)
|
||||
except Exception as e:
|
||||
self.assertEqual(
|
||||
'Each Plan should only correspond to only ONE action', str(e))
|
||||
else:
|
||||
self.assertFalse(
|
||||
True, 'it should raise exception when the format is incorrect')
|
||||
|
||||
message = """
|
||||
Plan: a.
|
||||
#E1 = tool1("a")
|
||||
Plan: b.
|
||||
#E2 = tool2["b"]
|
||||
"""
|
||||
try:
|
||||
thoughts, actions, actions_input = prompt.parse_worker(message)
|
||||
except Exception as e:
|
||||
self.assertIsInstance(e, BaseException)
|
||||
else:
|
||||
self.assertFalse(
|
||||
True, 'it should raise exception when the format is incorrect')
|
||||
|
||||
message = """
|
||||
Plan: a.
|
||||
#E1 = tool1["a"]
|
||||
Plan: b.
|
||||
#E2 = tool2["b"]
|
||||
"""
|
||||
try:
|
||||
thoughts, actions, actions_input = prompt.parse_worker(message)
|
||||
except Exception:
|
||||
self.assertFalse(
|
||||
True,
|
||||
'it should not raise exception when the format is correct')
|
||||
self.assertEqual(thoughts, ['a.', 'b.'])
|
||||
self.assertEqual(actions, ['tool1', 'tool2'])
|
||||
self.assertEqual(actions_input, ['"a"', '"b"'])
|
||||
Reference in New Issue
Block a user