From 931ec2ab6e18018a3b748e6db394a1c4ac414be4 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:16:50 +0800 Subject: [PATCH] [Init] Initial commit * Initial commit --- .gitignore | 1 + .pre-commit-config.yaml | 50 ++ .pylintrc | 428 ++++++++++++++++++ README.md | 81 ++++ README_zh-CN.md | 80 ++++ docs/.gitkeep | 0 docs/overview.md | 23 + examples/autogpt_example.py | 46 ++ examples/chat.py | 35 ++ examples/hf_react_example.py | 37 ++ examples/react_example.py | 36 ++ examples/rewoo_example.py | 20 + lagent/actions/__init__.py | 10 + lagent/actions/action_executor.py | 84 ++++ lagent/actions/base_action.py | 57 +++ lagent/actions/builtin_actions.py | 100 ++++ lagent/actions/llm_qa.py | 56 +++ lagent/actions/python_interpreter.py | 133 ++++++ lagent/actions/serper_search.py | 175 +++++++ lagent/agents/__init__.py | 6 + lagent/agents/autogpt.py | 288 ++++++++++++ lagent/agents/base_agent.py | 49 ++ lagent/agents/react.py | 217 +++++++++ lagent/agents/rewoo.py | 232 ++++++++++ lagent/llms/__init__.py | 10 + lagent/llms/base_api.py | 240 ++++++++++ lagent/llms/base_llm.py | 146 ++++++ lagent/llms/huggingface.py | 128 ++++++ lagent/llms/openai.py | 242 ++++++++++ lagent/schema.py | 79 ++++ lagent/utils/__init__.py | 3 + lagent/utils/package.py | 9 + lagent/version.py | 35 ++ projects/.gitkeep | 0 requirements.txt | 2 + requirements/optional.txt | 2 + requirements/runtime.txt | 5 + setup.cfg | 24 + setup.py | 107 +++++ tests/data/search.json | 144 ++++++ tests/test_actions/test_builtin_actions.py | 45 ++ tests/test_actions/test_python_interpreter.py | 21 + tests/test_actions/test_serper_search.py | 35 ++ tests/test_agents/test_rewoo.py | 87 ++++ 44 files changed, 3608 insertions(+) create mode 100644 .pre-commit-config.yaml create mode 100644 .pylintrc create mode 100644 README.md create mode 100644 README_zh-CN.md create mode 100644 docs/.gitkeep create mode 100644 docs/overview.md create mode 100644 examples/autogpt_example.py create mode 100644 examples/chat.py create mode 100644 examples/hf_react_example.py create mode 100644 examples/react_example.py create mode 100644 examples/rewoo_example.py create mode 100644 lagent/actions/__init__.py create mode 100644 lagent/actions/action_executor.py create mode 100644 lagent/actions/base_action.py create mode 100644 lagent/actions/builtin_actions.py create mode 100644 lagent/actions/llm_qa.py create mode 100644 lagent/actions/python_interpreter.py create mode 100644 lagent/actions/serper_search.py create mode 100644 lagent/agents/__init__.py create mode 100644 lagent/agents/autogpt.py create mode 100644 lagent/agents/base_agent.py create mode 100644 lagent/agents/react.py create mode 100644 lagent/agents/rewoo.py create mode 100644 lagent/llms/__init__.py create mode 100644 lagent/llms/base_api.py create mode 100644 lagent/llms/base_llm.py create mode 100644 lagent/llms/huggingface.py create mode 100644 lagent/llms/openai.py create mode 100644 lagent/schema.py create mode 100644 lagent/utils/__init__.py create mode 100644 lagent/utils/package.py create mode 100644 lagent/version.py create mode 100644 projects/.gitkeep create mode 100644 requirements.txt create mode 100644 requirements/optional.txt create mode 100644 requirements/runtime.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/data/search.json create mode 100644 tests/test_actions/test_builtin_actions.py create mode 100644 tests/test_actions/test_python_interpreter.py create mode 100644 tests/test_actions/test_serper_search.py create mode 100644 tests/test_agents/test_rewoo.py diff --git a/.gitignore b/.gitignore index 68bc17f..0561ce6 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +.vscode/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4a5c2af --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,50 @@ +exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/ +repos: + - repo: https://github.com/PyCQA/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + - repo: https://github.com/PyCQA/isort + rev: 5.11.5 + hooks: + - id: isort + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.32.0 + hooks: + - id: yapf + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: trailing-whitespace + - id: check-yaml + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: double-quote-string-fixer + - id: check-merge-conflict + - id: fix-encoding-pragma + args: ["--remove"] + - id: mixed-line-ending + args: ["--fix=lf"] + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.9 + hooks: + - id: mdformat + args: ["--number"] + additional_dependencies: + - mdformat-openmmlab + - mdformat_frontmatter + - linkify-it-py + - repo: https://github.com/codespell-project/codespell + rev: v2.2.1 + hooks: + - id: codespell + - repo: https://github.com/myint/docformatter + rev: v1.3.1 + hooks: + - id: docformatter + args: ["--in-place", "--wrap-descriptions", "79"] + - repo: https://github.com/asottile/pyupgrade + rev: v3.0.0 + hooks: + - id: pyupgrade + args: ["--py36-plus"] diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..73cb3ae --- /dev/null +++ b/.pylintrc @@ -0,0 +1,428 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MASTER] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party,storage + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat, + import-error, + import-self, + import-star-module-level, + inconsistent-return-statements, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-object-inheritance, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=colorized + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=120 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=builtins.BaseException, + builtins.Exception diff --git a/README.md b/README.md new file mode 100644 index 0000000..85177fe --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +# LAgent: Large Language Model as Agent + +English | [简体中文](README_zh-CN.md) + +## Introduction + +LAgent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM. The overview of our framework is shown below: + +![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6) + +### Major Features + +- **Support multiple agent frameworks out of box.** We implement ReAct, AutoGPT and ReWOO, which enables the agents to drive LLMs for multiple trails of reasoning and tool utilization. + +- **Extremely simple and easy to extend.** The framework is quite simple with clear project structure. With only 20 lines of code, you are able to construct your own agent. It also supports three typical tools: Python interpreter, API call, and google search. + +- **Support various large language models.** We support different LLMs, including API-based (GPT3.5/4) and HuggingFace-based (LLaMa2, InternLM) models. + +## Getting Started + +Please see [Overview](docs/overview.md) for the general introduction of LAgent. Meanwhile, we provide extremely simple code for quick start. You may refer to [examples](examples/) for more details. + +### Installation + +``` +git clone https://github.com/InternLM/lagent.git +cd lagent +pip install -e . +``` + +### Run a ReAct agent with GPT3.5 backend + +```python +from lagent.agents import ReAct +from lagent.actions.action_executor import ActionExecutor +from lagent.llms import GPTAPI +from lagent.tools import SerperSearch, PythonInterpreter + +llm = GPTAPI(model_type='gpt-3.5-turbo') +search_tool = SerperSearch() +python_interpreter = PythonInterpreter() + +chatbot = ReAct( + llm=model, + action_executor=ActionExecutor( + actions=[search_tool, python_interpreter]), +) + +response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') +print(response['response']) +>>> They are both film directors. +``` + +### Run a ReAct model with HuggingFace backend + +NOTE: If you want to run a HuggingFace model, please run `pip install -e . [all]` first. + +```python +from lagent.agents import ReAct +from lagent.actions.action_executor import ActionExecutor +from lagent.llms import HFTransformer +from lagent.tools import SerperSearch, PythonInterpreter + +llm = HFTransformer('internlm/internlm-7b-chat') +search_tool = SerperSearch() +python_interpreter = PythonInterpreter() + +chatbot = ReAct( + llm=model, + action_executor=ActionExecutor( + actions=[search_tool, python_interpreter]), +) + +response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$') +print(response['response']) +>>> 根据已有的信息,可以求得$z=-1+\\sqrt{3}i$,然后代入计算,得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此,答案是(C)。 +``` + +## License + +This project is released under the [Apache 2.0 license](LICENSE). diff --git a/README_zh-CN.md b/README_zh-CN.md new file mode 100644 index 0000000..6303405 --- /dev/null +++ b/README_zh-CN.md @@ -0,0 +1,80 @@ +# LAgent: Large Language Model as Agent + +[English](README.md) | 简体中文 + +## 简介 + +LAgent是一个开源的LLM代理框架,支持用户快速地将一个大语言模型转变为多种类型的智能体,并提供了一些典型工具为大语言模型赋能。它的整个框架图如下: + +![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6) + +### 主要特点 + +- **实现了多种类型的智能体,** 我们支持了经典的 ReAct,AutoGPT 和 ReWoo 等智能体,这些智能体能够调用大语言模型进行多轮的推理和工具调用。 + +- **框架简单易拓展.** 框架的代码结构清晰且简单,只需要不到20行代码你就能够创造出一个你自己的agent。同时我们支持了Python解释器、API 调用和搜索三类常用典型工具。 + +- **灵活支持多个大语言模型.** 我们提供了多种大语言模型支持,包括 InternLM、Llama-2 等开源模型和 GPT-4/3.5 等基于 API 的闭源模型。 + +## 教程 + +请阅读[概述](docs/overview.md)对LAgent进行初步的了解。同时, 我们提供了两个非常简单的code帮助你快速入门。 你也可以阅读[examples](examples/)获得更多的例子参考。 + +### 安装 + +``` +git clone https://github.com/InternLM/lagent.git +cd lagent +pip install -e . +``` + +### 用GPT3.5构建一个ReAct代理 + +```python +from lagent.agents import ReAct +from lagent.llms import GPTAPI +from lagent.tools import SerperSearch, PythonInterpreter + +llm = GPTAPI(model_type='gpt-3.5-turbo') +search_tool = SerperSearch() +python_interpreter = PythonInterpreter() + +chatbot = ReAct( + llm=model, + action_executor=ActionExecutor( + actions=[search_tool, python_interpreter]), +) + +response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') +print(response['response']) +>>> They are both film directors. +``` + +### 用HuggingFace构建一个ReAct代理 + +注意:如果你想要启动一个HuggingFace的模型,请先运行`pip install -e . [all]`。 + +```python +from lagent.agents import ReAct +from lagent.actions.action_executor import ActionExecutor +from lagent.llms import HFTransformer +from lagent.tools import SerperSearch, PythonInterpreter + +llm = HFTransformer('internlm/internlm-7b-chat') +search_tool = SerperSearch() +python_interpreter = PythonInterpreter() + +chatbot = ReAct( + llm=model, + action_executor=ActionExecutor( + actions=[search_tool, python_interpreter]), +) + +response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$') +print(response['response']) +>>> 根据已有的信息,可以求得$z=-1+\\sqrt{3}i$,然后代入计算,得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此,答案是(C)。 +``` + +## 开源许可证 + +该项目采用[Apache 2.0 开源许可证](LICENSE)。 diff --git a/docs/.gitkeep b/docs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/overview.md b/docs/overview.md new file mode 100644 index 0000000..c91cdf5 --- /dev/null +++ b/docs/overview.md @@ -0,0 +1,23 @@ +# OVERVIEW + +This chapter introduces you to the framework of LAgent, and provides links to detailed tutorials about LAgent. + +## What is LAgent + +LAgent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below: + +![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6) + +LAgent consists of 3 main parts, agents, llms, and actions. + +- **agents** provides agent implementation, such as ReAct, AutoGPT. +- **llms** supports various large language models, including open-sourced models (Llama-2, InterLM) through HuggingFace models or closed-source models like GPT3.5/4. +- **actions** contains a series of actions, as well as an action executor to manage all actions. + +## How to Use + +Here is a detailed step-by-step guide to learn more about LAgent: + +1. For installation instructions, please see [get_started](get_started.md). + +2. We provide several examples to build agents with LAgent in [examples](examples/) by simply run `python examples/react_example.py`. diff --git a/examples/autogpt_example.py b/examples/autogpt_example.py new file mode 100644 index 0000000..ac338a2 --- /dev/null +++ b/examples/autogpt_example.py @@ -0,0 +1,46 @@ +from lagent.actions.action_executor import ActionExecutor +from lagent.actions.builtin_actions import FinishAction +from lagent.actions.python_interpreter import PythonInterpreter +from lagent.agents.autogpt import AutoGPT +from lagent.llms.openai import GPTAPI + + +def input_prompt(): + print('\ndouble enter to end input >>> ', end='') + sentinel = '' # ends when this string is seen + return '\n'.join(iter(input, sentinel)) + + +def main(): + # set OPEN_API_KEY in your environment + model = GPTAPI(model_type='gpt-3.5-turbo') + + chatbot = AutoGPT( + llm=model, + action_executor=ActionExecutor( + actions=[ + PythonInterpreter(), + ], + finish_action=FinishAction( + description=( + 'Goals are accomplished and there is nothing left ' + 'to do. Parameter: {"response: "final response ' + 'for the goal"}')), + finish_in_action=True), + ) + + while True: + try: + prompt = input_prompt() + except UnicodeDecodeError: + print('UnicodeDecodeError') + continue + if prompt == 'exit': + exit(0) + + agent_return = chatbot.chat(prompt) + print(agent_return.response) + + +if __name__ == '__main__': + main() diff --git a/examples/chat.py b/examples/chat.py new file mode 100644 index 0000000..0520e3e --- /dev/null +++ b/examples/chat.py @@ -0,0 +1,35 @@ +from argparse import ArgumentParser + +from lagent.llms.openai import GPTAPI + + +def parse_args(): + parser = ArgumentParser(description='chatbot') + parser.add_argument('--mode', default='chat') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + model = GPTAPI(model_type='gpt-3.5-turbo', ) + history = [] + while True: + try: + prompt = input('>>> ') + except UnicodeDecodeError: + print('UnicodeDecodeError') + continue + if prompt == 'exit': + exit(0) + if args.mode == 'chat': + history.append(dict(role='user', content=prompt)) + response = model.generate_from_template(history, max_out_len=512) + history.append(dict(role='assistant', content=response)) + elif args.mode == 'generate': + response = model.generate(prompt, max_out_len=512) + print('Assistant:', response) + + +if __name__ == '__main__': + main() diff --git a/examples/hf_react_example.py b/examples/hf_react_example.py new file mode 100644 index 0000000..c9e18f7 --- /dev/null +++ b/examples/hf_react_example.py @@ -0,0 +1,37 @@ +from lagent.actions.action_executor import ActionExecutor +from lagent.actions.python_interpreter import PythonInterpreter +from lagent.agents.react import ReACT +from lagent.llms.huggingface import HFTransformer + +model = HFTransformer( + path='internlm/internlm-chat-7b', + meta_template=[ + dict(role='system', begin='<|System|>:', end='\n'), + dict(role='user', begin='<|User|>:', end='\n'), + dict(role='assistant', begin='<|Bot|>:', end='\n', generate=True) + ], +) + +chatbot = ReACT( + llm=model, + action_executor=ActionExecutor(actions=[PythonInterpreter()]), +) + + +def input_prompt(): + print('\ndouble enter to end input >>> ', end='') + sentinel = '' # ends when this string is seen + return '\n'.join(iter(input, sentinel)) + + +while True: + try: + prompt = input_prompt() + except UnicodeDecodeError: + print('UnicodeDecodeError') + continue + if prompt == 'exit': + exit(0) + + agent_return = chatbot.chat(prompt) + print(agent_return.response) diff --git a/examples/react_example.py b/examples/react_example.py new file mode 100644 index 0000000..c6a1269 --- /dev/null +++ b/examples/react_example.py @@ -0,0 +1,36 @@ +from lagent.actions.action_executor import ActionExecutor +from lagent.actions.python_interpreter import PythonInterpreter +from lagent.agents.react import ReACT +from lagent.llms.openai import GPTAPI + + +def input_prompt(): + print('\ndouble enter to end input >>> ', end='') + sentinel = '' # ends when this string is seen + return '\n'.join(iter(input, sentinel)) + + +def main(): + # set OPEN_API_KEY in your environment + model = GPTAPI(model_type='gpt-3.5-turbo', ) + + chatbot = ReACT( + llm=model, + action_executor=ActionExecutor(actions=[PythonInterpreter()]), + ) + + while True: + try: + prompt = input_prompt() + except UnicodeDecodeError: + print('UnicodeDecodeError') + continue + if prompt == 'exit': + exit(0) + + agent_return = chatbot.chat(prompt) + print(agent_return.response) + + +if __name__ == '__main__': + main() diff --git a/examples/rewoo_example.py b/examples/rewoo_example.py new file mode 100644 index 0000000..541bdb2 --- /dev/null +++ b/examples/rewoo_example.py @@ -0,0 +1,20 @@ +from lagent.actions.action_executor import ActionExecutor +from lagent.actions.llm_qa import LLMQA +from lagent.actions.serper_search import SerperSearch +from lagent.agents.rewoo import ReWOO +from lagent.llms.openai import GPTAPI + +model = GPTAPI(model_type='gpt-3.5-turbo') +# please set the serper search API key +search_tool = SerperSearch(api_key=None) +llmqa_tool = LLMQA(model) + +chatbot = ReWOO( + llm=model, + action_executor=ActionExecutor(actions=[llmqa_tool, search_tool]), +) + +prompt = 'What profession does Nicholas Ray and Elia Kazan have in common' + +agent_return = chatbot.chat(prompt) +print(agent_return.response) diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py new file mode 100644 index 0000000..66a12a4 --- /dev/null +++ b/lagent/actions/__init__.py @@ -0,0 +1,10 @@ +from .action_executor import ActionExecutor +from .base_action import BaseAction +from .builtin_actions import FinishAction, InvalidAction, NoAction +from .python_interpreter import PythonInterpreter +from .serper_search import SerperSearch + +__all__ = [ + 'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction', + 'FinishAction', 'SerperSearch', 'PythonInterpreter' +] diff --git a/lagent/actions/action_executor.py b/lagent/actions/action_executor.py new file mode 100644 index 0000000..c3186e6 --- /dev/null +++ b/lagent/actions/action_executor.py @@ -0,0 +1,84 @@ +from typing import Any, Dict, List, Union + +from lagent.schema import ActionReturn, ActionValidCode +from .base_action import BaseAction +from .builtin_actions import FinishAction, InvalidAction, NoAction + + +class ActionExecutor: + """The action executor class. + + Args: + actions (Union[BaseAction, List[BaseAction]]): The action or actions. + invalid_action (BaseAction, optional): The invalid action. Defaults to + InvalidAction(). + no_action (BaseAction, optional): The no action. + Defaults to NoAction(). + finish_action (BaseAction, optional): The finish action. Defaults to + FinishAction(). + finish_in_action (bool, optional): Whether the finish action is in the + action list. Defaults to False. + """ + + def __init__(self, + actions: Union[BaseAction, List[BaseAction]], + invalid_action: BaseAction = InvalidAction(), + no_action: BaseAction = NoAction(), + finish_action: BaseAction = FinishAction(), + finish_in_action: bool = False): + if isinstance(actions, BaseAction): + actions = [actions] + + for action in actions: + assert isinstance(action, BaseAction), \ + f'action must be BaseAction, but got {type(action)}' + if finish_in_action: + actions.append(finish_action) + self.actions = {action.name: action for action in actions} + self.invalid_action = invalid_action + self.no_action = no_action + self.finish_action = finish_action + + def get_actions_info(self, only_enable: bool = True) -> Dict: + if only_enable: + return { + k: v.description + for k, v in self.actions.items() if v.enable + } + else: + return {k: v.description for k, v in self.actions.items()} + + def is_valid(self, name: str): + return name in self.actions and self.actions[name].enable + + def action_names(self, only_enable: bool = True): + if only_enable: + return [k for k, v in self.actions.items() if v.enable] + else: + return list(self.actions.keys()) + + def add_action(self, action: BaseAction): + assert isinstance(action, BaseAction), \ + f'action must be BaseAction, but got {type(action)}' + self.actions[action.name] = action + + def del_action(self, name: str): + if name in self.actions: + del self.actions[name] + + def __call__(self, name: str, command: Any) -> ActionReturn: + if isinstance(command, str): + args, kwargs = (command, ), {} + else: + args, kwargs = (), command + if not self.is_valid(name): + if name == self.no_action.name: + action_return = self.no_action.run(*args, **kwargs) + elif name == self.finish_action.name: + action_return = self.finish_action.run(*args, **kwargs) + else: + action_return = self.invalid_action(*args, **kwargs) + else: + action_return = self.actions[name].run(*args, **kwargs) + action_return.valid = ActionValidCode.OPEN + return action_return diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py new file mode 100644 index 0000000..951b7f6 --- /dev/null +++ b/lagent/actions/base_action.py @@ -0,0 +1,57 @@ +from typing import Optional + +from lagent.schema import ActionReturn + + +class BaseAction: + """Base class for all actions. + + Args: + description (str, optional): The description of the action. Defaults to + None. + name (str, optional): The name of the action. If None, the name will + be class nameDefaults to None. + enable (bool, optional): Whether the action is enabled. Defaults to + True. + disable_description (str, optional): The description of the action when + it is disabled. Defaults to None. + """ + + def __init__(self, + description: Optional[str] = None, + name: Optional[str] = None, + enable: bool = True, + disable_description: Optional[str] = None) -> None: + if name is None: + name = self.__class__.__name__ + self._name = name + self._description = description + self._disable_description = disable_description + self._enable = enable + + def __call__(self, *args, **kwargs) -> ActionReturn: + raise NotImplementedError + + def __repr__(self): + return f'{self.name}:{self.description}' + + def __str__(self): + return self.__repr__() + + def run(self, *args, **kwargs) -> ActionReturn: + return self.__call__(*args, **kwargs) + + @property + def enable(self): + return self._enable + + @property + def name(self): + return self._name + + @property + def description(self): + if self.enable: + return self._description + else: + return self._disable_description diff --git a/lagent/actions/builtin_actions.py b/lagent/actions/builtin_actions.py new file mode 100644 index 0000000..20f5c2a --- /dev/null +++ b/lagent/actions/builtin_actions.py @@ -0,0 +1,100 @@ +from typing import Optional + +from lagent.actions.base_action import BaseAction +from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode + + +class InvalidAction(BaseAction): + """This is a invalid action class, which is used to return error message + when the action is invalid. + + Args: + err_msg (str): The error message. Defaults to 'The action is invalid, + please check the action name'. + + Returns: + ActionReturn: The action return. + """ + + def __init__(self, + err_msg: + str = 'The action is invalid, please check the action name.', + **kwargs) -> None: + + super().__init__(enable=False, **kwargs) + self._err_msg = err_msg + + def __call__(self, err_msg: Optional[str] = None): + """Return the error message. + + Args: + err_msg (str, optional): The error message. If err_msg is not None, + it will be returned, otherwise the default error message will + be returned. Defaults to None. + """ + action_return = ActionReturn( + url=None, + args=dict(text=err_msg), + errmsg=err_msg if err_msg else self._err_msg, + type=self.name, + valid=ActionValidCode.INVALID, + state=ActionStatusCode.API_ERROR) + return action_return + + +class NoAction(BaseAction): + """This is a no action class, which is used to return error message when + the response does not follow the format. + + Args: + err_msg (str): The error message. Defaults to + 'Please follow the format'. + """ + + def __init__(self, err_msg: str = 'Please follow the format', **kwargs): + + super().__init__(enable=False, **kwargs) + self._err_msg = err_msg + + def __call__(self, err_msg: Optional[str] = None): + """Return the error message. + + Args: + err_msg (str, optional): The error message. If err_msg is not None, + it will be returned, otherwise the default error message will + be returned. Defaults to None. + + Returns: + ActionReturn: The action return. + """ + action_return = ActionReturn( + url=None, + args=dict(text=err_msg), + type=self.name, + errmsg=err_msg if err_msg else self._err_msg, + valid=ActionValidCode.INVALID, + state=ActionStatusCode.API_ERROR) + return action_return + + +class FinishAction(BaseAction): + """This is a finish action class, which is used to return the final + result.""" + + def __call__(self, response: str) -> ActionReturn: + """Return the final result. + + Args: + response (str): The final result. + + Returns: + ActionReturn: The action return. + """ + action_return = ActionReturn( + url=None, + args=dict(text=response), + result=dict(text=response), + type=self.name, + valid=ActionValidCode.FINISH, + state=ActionStatusCode.SUCCESS) + return action_return diff --git a/lagent/actions/llm_qa.py b/lagent/actions/llm_qa.py new file mode 100644 index 0000000..89faad6 --- /dev/null +++ b/lagent/actions/llm_qa.py @@ -0,0 +1,56 @@ +from typing import Optional, Union + +from lagent.llms.base_api import BaseAPIModel +from lagent.llms.base_llm import BaseModel +from lagent.schema import ActionReturn, ActionStatusCode +from .base_action import BaseAction + +DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。 +当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。 +""" + + +class LLMQA(BaseAction): + """An LLM Wrapper as BaseAction type. + + Args: + llm (BaseModel or BaseAPIModel): a LLM service which can chat. + description (str): The description of the action. Defaults to + None. + name (str, optional): The name of the action. If None, the name will + be class nameDefaults to None. + enable (bool, optional): Whether the action is enabled. Defaults to + True. + disable_description (str, optional): The description of the action when + it is disabled. Defaults to None. + """ + + def __init__(self, + llm: Union[BaseModel, BaseAPIModel], + description: str = DEFAULT_DESCRIPTION, + name: Optional[str] = None, + enable: bool = True, + disable_description: Optional[str] = None) -> None: + super().__init__(description, name, enable, disable_description) + + self._llm = llm + + def __call__(self, query: str) -> ActionReturn: + """Return the QA response. + + Args: + query (str): The query content. + + Returns: + ActionReturn: The action return. + """ + + tool_return = ActionReturn(url=None, args=None) + try: + response = self._llm.generate_from_template(query, 512) + tool_return.result = dict(text=str(response)) + tool_return.state = ActionStatusCode.SUCCESS + except Exception as e: + tool_return.result = dict(text=str(e)) + tool_return.state = ActionStatusCode.API_ERROR + return tool_return diff --git a/lagent/actions/python_interpreter.py b/lagent/actions/python_interpreter.py new file mode 100644 index 0000000..e10102b --- /dev/null +++ b/lagent/actions/python_interpreter.py @@ -0,0 +1,133 @@ +import copy +import io +from contextlib import redirect_stdout +from typing import Any, Optional + +from func_timeout import FunctionTimedOut, func_set_timeout + +from lagent.actions.base_action import BaseAction +from lagent.schema import ActionReturn, ActionStatusCode + + +class GenericRuntime: + GLOBAL_DICT = {} + LOCAL_DICT = None + HEADERS = [] + + def __init__(self): + self._global_vars = copy.copy(self.GLOBAL_DICT) + self._local_vars = copy.copy( + self.LOCAL_DICT) if self.LOCAL_DICT else None + + for c in self.HEADERS: + self.exec_code(c) + + def exec_code(self, code_piece: str) -> None: + exec(code_piece, self._global_vars) + + def eval_code(self, expr: str) -> Any: + return eval(expr, self._global_vars) + + +DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数, +函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: +```python +# import 依赖包 +import xxx +def solution(): + # 初始化一些变量 + variable_names_with_real_meaning = xxx + # 步骤一 + mid_variable = func(variable_names_with_real_meaning) + # 步骤 x + mid_variable = func(mid_variable) + # 最后结果 + final_answer = func(mid_variable) + return final_answer +```""" + + +class PythonInterpreter(BaseAction): + """A Python executor that can execute Python scripts. + + Args: + description (str): The description of the action. Defaults to + DEFAULT_DESCRIPTION. + answer_symbol (str, Optional): the answer symbol from LLM + answer_expr (str, Optional): the answer function name of the Python + script. Default to 'solution()'. + answer_from_stdout (boolean): whether the execution results is from + stdout. + name (str, optional): The name of the action. If None, the name will + be class nameDefaults to None. + enable (bool, optional): Whether the action is enabled. Defaults to + True. + disable_description (str, optional): The description of the action when + it is disabled. Defaults to None. + timeout (int): Upper bound of waiting time for Python script execution. + """ + + def __init__(self, + description: str = DEFAULT_DESCRIPTION, + answer_symbol: Optional[str] = None, + answer_expr: Optional[str] = 'solution()', + answer_from_stdout: bool = False, + name: Optional[str] = None, + enable: bool = True, + disable_description: Optional[str] = None, + timeout: int = 20) -> None: + super().__init__(description, name, enable, disable_description) + + self.answer_symbol = answer_symbol + self.answer_expr = answer_expr + self.answer_from_stdout = answer_from_stdout + self.timeout = timeout + + def __call__(self, command: str) -> ActionReturn: + self.runtime = GenericRuntime() + try: + tool_return = func_set_timeout(self.timeout)(self._call)(command) + except FunctionTimedOut as e: + tool_return = ActionReturn(url=None, args=None, type=self.name) + tool_return.errmsg = repr(e) + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + def _call(self, command: str) -> ActionReturn: + tool_return = ActionReturn(url=None, args=None, type=self.name) + try: + if '```python' in command: + command = command.split('```python')[1].split('```')[0] + elif '```' in command: + command = command.split('```')[1].split('```')[0] + tool_return.args = dict(text='```python\n' + command + '\n```') + command = command.split('\n') + + if self.answer_from_stdout: + program_io = io.StringIO() + with redirect_stdout(program_io): + self.runtime.exec_code('\n'.join(command)) + program_io.seek(0) + res = program_io.readlines()[-1] + elif self.answer_symbol: + self.runtime.exec_code('\n'.join(command)) + res = self.runtime._global_vars[self.answer_symbol] + elif self.answer_expr: + self.runtime.exec_code('\n'.join(command)) + res = self.runtime.eval_code(self.answer_expr) + else: + self.runtime.exec_code('\n'.join(command[:-1])) + res = self.runtime.eval_code(command[-1]) + except Exception as e: + tool_return.errmsg = repr(e) + tool_return.type = self.name + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + try: + tool_return.result = dict(text=str(res)) + tool_return.state = ActionStatusCode.SUCCESS + except Exception as e: + tool_return.errmsg = repr(e) + tool_return.type = self.name + tool_return.state = ActionStatusCode.API_ERROR + return tool_return diff --git a/lagent/actions/serper_search.py b/lagent/actions/serper_search.py new file mode 100644 index 0000000..346f671 --- /dev/null +++ b/lagent/actions/serper_search.py @@ -0,0 +1,175 @@ +import os +from typing import List, Optional, Tuple, Union + +import requests + +from lagent.schema import ActionReturn, ActionStatusCode +from .base_action import BaseAction + +DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。 +当你需要对于一个特定问题找到简短明了的回答时,可以使用它。 +输入应该是一个搜索查询。 +""" + + +class SerperSearch(BaseAction): + """Wrapper around the Serper.dev Google Search API. + + To use, you should pass your serper API key to the constructor. + + Code is modified from lang-chain GoogleSerperAPIWrapper + (https://github.com/langchain-ai/langchain/blob/ba5f + baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/ + langchain/utilities/google_serper.py) + + Args: + api_key (str): API KEY to use serper google search API, + You can create a free API key at https://serper.dev. + timeout (int): Upper bound of waiting time for an serper request. + search_type (str): Serper API support ['search', 'images', 'news', + 'places'] types of search, currently we only support 'search'. + k (int): select first k results in the search results as response. + description (str): The description of the action. Defaults to + None. + name (str, optional): The name of the action. If None, the name will + be class nameDefaults to None. + enable (bool, optional): Whether the action is enabled. Defaults to + True. + disable_description (str, optional): The description of the action when + it is disabled. Defaults to None. + """ + result_key_for_type = { + 'news': 'news', + 'places': 'places', + 'images': 'images', + 'search': 'organic', + } + + def __init__(self, + api_key: Optional[str] = None, + timeout: int = 5, + search_type: str = 'search', + k: int = 10, + description: str = DEFAULT_DESCRIPTION, + name: Optional[str] = None, + enable: bool = True, + disable_description: Optional[str] = None) -> None: + super().__init__(description, name, enable, disable_description) + + api_key = os.environ.get('SERPER_API_KEY', api_key) + if api_key is None: + raise ValueError( + 'Please Set Serper API key either in the environment ' + ' as SERPER_API_KEY or pass it as `api_key` parameter.') + self.api_key = api_key + self.timeout = timeout + self.search_type = search_type + self.k = k + + def __call__(self, query: str) -> ActionReturn: + """Return the search response. + + Args: + query (str): The search content. + + Returns: + ActionReturn: The action return. + """ + + tool_return = ActionReturn(url=None, args=None) + status_code, response = self._search( + query, search_type=self.search_type, k=self.k) + # convert search results to ToolReturn format + if status_code == -1: + tool_return.errmsg = response + tool_return.state = ActionStatusCode.HTTP_ERROR + elif status_code == 200: + parsed_res = self._parse_results(response) + tool_return.result = dict(text=str(parsed_res)) + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = str(status_code) + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + def _parse_results(self, results: dict) -> Union[str, List[str]]: + """Parse the search results from Serper API. + + Args: + results (dict): The search content from Serper API + in json format. + + Returns: + List[str]: The parsed search results. + """ + + snippets = [] + + if results.get('answerBox'): + answer_box = results.get('answerBox', {}) + if answer_box.get('answer'): + return [answer_box.get('answer')] + elif answer_box.get('snippet'): + return [answer_box.get('snippet').replace('\n', ' ')] + elif answer_box.get('snippetHighlighted'): + return answer_box.get('snippetHighlighted') + + if results.get('knowledgeGraph'): + kg = results.get('knowledgeGraph', {}) + title = kg.get('title') + entity_type = kg.get('type') + if entity_type: + snippets.append(f'{title}: {entity_type}.') + description = kg.get('description') + if description: + snippets.append(description) + for attribute, value in kg.get('attributes', {}).items(): + snippets.append(f'{title} {attribute}: {value}.') + + for result in results[self.result_key_for_type[ + self.search_type]][:self.k]: + if 'snippet' in result: + snippets.append(result['snippet']) + for attribute, value in result.get('attributes', {}).items(): + snippets.append(f'{attribute}: {value}.') + + if len(snippets) == 0: + return ['No good Google Search Result was found'] + return snippets + + def _search(self, + search_term: str, + search_type: str = 'search', + **kwargs) -> Tuple[int, Union[dict, str]]: + """HTTP requests to Serper API. + + Args: + search_term (str): The search query. + search_type (str): search type supported by Serper API, + default to 'search'. + + Returns: + tuple: the return value is a tuple contains: + - status_code (int): HTTP status code from Serper API. + - response (dict): response context with json format. + """ + headers = { + 'X-API-KEY': self.api_key or '', + 'Content-Type': 'application/json', + } + params = { + 'q': search_term, + **{ + key: value + for key, value in kwargs.items() if value is not None + }, + } + try: + response = requests.post( + f'https://google.serper.dev/{search_type}', + headers=headers, + params=params, + timeout=self.timeout) + except Exception as e: + return -1, str(e) + return response.status_code, response.json() diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py new file mode 100644 index 0000000..0e6490d --- /dev/null +++ b/lagent/agents/__init__.py @@ -0,0 +1,6 @@ +from .autogpt import AutoGPT +from .base_agent import BaseAgent +from .react import ReACT +from .rewoo import ReWOO + +__all__ = ['BaseAgent', 'ReACT', 'AutoGPT', 'ReWOO'] diff --git a/lagent/agents/autogpt.py b/lagent/agents/autogpt.py new file mode 100644 index 0000000..cd07028 --- /dev/null +++ b/lagent/agents/autogpt.py @@ -0,0 +1,288 @@ +# flake8: noqa +import ast +import platform +from typing import Dict, List, Optional, Tuple, Union + +from jsonschema import Draft7Validator + +from lagent.actions import ActionExecutor +from lagent.llms.base_api import BaseAPIModel +from lagent.llms.base_llm import BaseModel +from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn +from .base_agent import BaseAgent + +DEFAULT_TRIGGERING_PROMPT = ('Determine exactly one command to use based on ' + 'the given goals and the progress you have made ' + 'so far, and respond using the JSON schema ' + 'specified previously:') + +DEFAULT_PREFIX = """You are {ai_name}, {role_description}. Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications. +The OS you are running on is: {os_info} +## Constraints +You operate within the following constraints: +1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files. +2. 'If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember. +3. No user assistance +4. Exclusively use the commands listed below e.g. command_name +## Commands +You have access to the following commands: +{tool_description} +## Resources +You can leverage access to the following resources: +1. Internet access for searches and information gathering. +2. Long Term memory management.', 'File output.', 'Command execution +## Best practices +1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities. +2. Constructively self-criticize your big-picture behavior constantly. +3. Reflect on past decisions and strategies to refine your approach. +4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps. +## Goals +For your task, you must fulfill the following goals: +{ai_goals} +""" + +DEFAULT_CALL_PROTOCOL = """Respond strictly with JSON. The JSON should be compatible with the TypeScript type `Response` from the following: +```ts +interface Response { + thoughts: { + // Thoughts + text: string; + reasoning: string; + // Short markdown-style bullet list that conveys the long-term plan + plan: string; + // Constructive self-criticism + criticism: string; + // Summary of thoughts to say to the user + speak: string; + }; + command: { + name: string; + args: Record; + }; +} +``` +""" +DEFAULT_SCHEMA = { + '$schema': 'http://json-schema.org/draft-07/schema#', + 'type': 'object', + 'properties': { + 'thoughts': { + 'type': 'object', + 'properties': { + 'text': { + 'type': 'string', + 'description': 'thoughts' + }, + 'reasoning': { + 'type': 'string' + }, + 'plan': { + 'type': + 'string', + 'description': + '- short bulleted\n- list that conveys\n- long-term plan' + }, + 'criticism': { + 'type': 'string', + 'description': 'constructive self-criticism' + }, + 'speak': { + 'type': 'string', + 'description': 'thoughts summary to say to user' + } + }, + 'required': ['text', 'reasoning', 'plan', 'criticism', 'speak'], + 'additionalProperties': False + }, + 'command': { + 'type': 'object', + 'properties': { + 'name': { + 'type': 'string' + }, + 'args': { + 'type': 'object' + } + }, + 'required': ['name', 'args'], + 'additionalProperties': False + } + }, + 'required': ['thoughts', 'command'], + 'additionalProperties': False +} + + +class AutoGPTProtocol: + """A wrapper of AutoGPT prompt which manages the response from LLM and + generate desired prompts in a AutoGPT format. + + Args: + ai_name (str): the name of the agent, default to 'AutoGPT' + role_description (str): description of the role, e.g., System, User + prefix (str): the prefix prompt for AutoGPT + call_protocol (str): the request prompt which defines the protocol + of return format from LLM. + valid_schema (dict): defines the schema of the return format. + triggering_prompt (str): the predefined trigger prompt. + """ + + def __init__(self, + ai_name: Optional[str] = 'AutoGPT', + role_description: Optional[str] = '', + prefix: str = DEFAULT_PREFIX, + call_protocol: str = DEFAULT_CALL_PROTOCOL, + valid_schema: str = DEFAULT_SCHEMA, + triggering_prompt: str = DEFAULT_TRIGGERING_PROMPT) -> None: + self.ai_name = ai_name + self.role_description = role_description + self.prefix = prefix + self.call_protocol = call_protocol + self.valid_schema = valid_schema + self.triggering_prompt = triggering_prompt + + def parse(self, response: str, + action_executor: ActionExecutor) -> Tuple[str, str]: + """Parse the action returns in a AutoGPT format. + + Args: + response (str): The response from LLM with AutoGPT format. + action_executor (ActionExecutor): Action executor to + provide no_action/finish_action name. + + Returns: + tuple: the return value is a tuple contains: + - action (str): the extracted action name. + - action_input (str): the corresponding action input. + """ + try: + if response.startswith('```') and response.endswith('```'): + # Discard the first and last ```, then re-join in case the response naturally included ``` + response = '```'.join(response.split('```')[1:-1]) + response = ast.literal_eval(response) + validator = Draft7Validator(self.valid_schema) + valid = True + if errors := sorted( + validator.iter_errors(response), key=lambda e: e.path): + valid = False + if not valid: + return action_executor.no_action, 'Validation of response failed:\n ' + ';\n '.join( + [str(e) for e in errors]) + try: + if 'command' not in response: + return action_executor.no_action, "Missing 'command' object in JSON" + if not isinstance(response, dict): + return action_executor.no_action, f'The previous message sent was not a dictionary {response}' + command = response['command'] + if not isinstance(command, dict): + return action_executor.no_action, "'command' object is not a dictionary" + if 'name' not in command: + return action_executor.no_action, "Missing 'name' field in 'command' object" + command_name = command['name'] + # Use an empty dictionary if 'args' field is not present in 'command' object + arguments = command.get('args', {}) + return command_name, arguments + except Exception as e: + return action_executor.no_action, repr(e) + except SyntaxError as e: + return action_executor.no_action, f'Your response could not be parsed: {repr(e)} \nRemember to only respond using the specified format above!' + + def format(self, goal: str, inner_history: List[Dict], + action_executor: ActionExecutor) -> List[Dict]: + """Generate the AutoGPT format prompt. + + Args: + goal (str): The user request. + inner_history (List[Dict]): The log in the current run. + action_executor (ActionExecutor): the action manager to + execute actions. + Returns: + List[Dict]: AutoGPT format prompt. + """ + import distro + formatted_data = [] + os_name = platform.system() + os_info = ( + platform.platform(terse=True) + if os_name != 'Linux' else distro.name(pretty=True)) + prefix = self.prefix.format( + ai_name=self.ai_name, + role_description=self.role_description, + tool_description=action_executor.get_actions_info(), + ai_goals=goal, + os_info=os_info, + ) + formatted_data.append(dict(role='system', content=prefix)) + formatted_data.append(dict(role='system', content=self.call_protocol)) + formatted_data += inner_history + formatted_data.append( + dict(role='user', content=self.triggering_prompt)) + return formatted_data + + def format_response(self, action_return): + """format the final response at current step. + + Args: + action_return (ActionReturn): return value of the current action. + + Returns: + str: the final response at current step. + """ + if action_return.state == ActionStatusCode.SUCCESS: + response = action_return.result['text'] + response = f'Command {action_return.type} returned: {response}' + else: + response = action_return.errmsg + return response + + +class AutoGPT(BaseAgent): + """An implementation of AutoGPT (https://github.com/Significant- + Gravitas/Auto-GPT) + + Args: + llm (BaseModel or BaseAPIModel): a LLM service which can chat + and act as backend. + action_executor (ActionExecutor): an action executor to manage + all actions and their response. + protocol (ReActProtocol): a wrapper to generate prompt and + parse the response from LLM / actions. + max_turn (int): the maximum number of trails for LLM to generate + plans that can be successfully parsed by ReWOO protocol. + """ + + def __init__(self, + llm: Union[BaseModel, BaseAPIModel], + action_executor: ActionExecutor, + protocol: AutoGPTProtocol = AutoGPTProtocol(), + max_turn: int = 2): + self.max_turn = max_turn + super().__init__( + llm=llm, action_executor=action_executor, protocol=protocol) + + def chat(self, goal: str) -> AgentReturn: + self._inner_history = [] + agent_return = AgentReturn() + default_response = '对不起,我无法回答你的问题' + for _ in range(self.max_turn): + prompt = self._protocol.format( + goal=goal, + inner_history=self._inner_history, + action_executor=self._action_executor) + response = self._llm.generate_from_template(prompt, 512) + self._inner_history.append( + dict(role='assistant', content=response)) + action, action_input = self._protocol.parse( + response, self._action_executor) + action_return: ActionReturn = self._action_executor( + action, action_input) + agent_return.actions.append(action_return) + if action_return.type == self._action_executor.finish_action.name: + agent_return.response = action_return.result['text'] + return agent_return + self._inner_history.append( + dict( + role='system', + content=self._protocol.format_response(action_return))) + agent_return.response = default_response + return agent_return diff --git a/lagent/agents/base_agent.py b/lagent/agents/base_agent.py new file mode 100644 index 0000000..d311860 --- /dev/null +++ b/lagent/agents/base_agent.py @@ -0,0 +1,49 @@ +from typing import List + +from lagent.actions import ActionExecutor +from lagent.actions.base_action import BaseAction +from lagent.llms.base_llm import BaseModel +from lagent.schema import AgentReturn + + +class BaseAgent: + """BaseAgent is the base class of all agents. + + Args: + llm (BaseModel): the language model. + action_executor (ActionExecutor): the action executor. + protocol (object): the protocol of the agent, which is used to + generate the prompt of the agent and parse the response from + the llm. + """ + + def __init__(self, llm: BaseModel, action_executor: ActionExecutor, + protocol: object) -> None: + + self._session_history = [] + self._llm = llm + self._action_executor = action_executor + self._protocol = protocol + + def add_action(self, action: BaseAction) -> None: + """Add an action to the action executor. + + Args: + action (BaseAction): the action to be added. + """ + self._action_executor.add_action(action) + + def del_action(self, name: str) -> None: + """Delete an action from the action executor. + + Args: + name (str): the name of the action to be deleted. + """ + self._action_executor.del_action(name) + + def chat(self, message: str) -> AgentReturn: + raise NotImplementedError + + @property + def session_history(self) -> List: + return self._session_history diff --git a/lagent/agents/react.py b/lagent/agents/react.py new file mode 100644 index 0000000..da1bf75 --- /dev/null +++ b/lagent/agents/react.py @@ -0,0 +1,217 @@ +from typing import Dict, List, Tuple, Union + +from lagent.actions import ActionExecutor +from lagent.llms.base_api import BaseAPIModel +from lagent.llms.base_llm import BaseModel +from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn +from .base_agent import BaseAgent + +CALL_PROTOCOL = """你是一个可以调用外部工具的助手,可以使用的工具包括: +{tool_description} +如果使用工具请遵循以下格式回复: +``` +{thought}思考你当前步骤需要解决什么问题,是否需要使用工具 +{action}工具名称,你的工具必须从 [{action_names}] 选择 +{action_input}工具输入参数 +``` +工具返回按照以下格式回复: +``` +{response}调用工具后的结果 +``` +如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复 +``` +{thought}给出最终答案的思考过程 +{finish}最终答案 +``` +开始!""" + + +class ReACTProtocol: + """A wrapper of ReACT prompt which manages the response from LLM and + generate desired prompts in a ReACT format. + + Args: + thought (dict): the information of thought pattern + action (dict): the information of action pattern + action_input (dict): the information of action_input pattern + response (dict): the information of response pattern + finish (dict): the information of finish pattern + call_protocol (str): the format of ReACT + force_stop (str): the prompt to force LLM to generate response + """ + + def __init__(self, + thought: dict = dict( + role='THOUGHT', + begin='Thought:', + end='\n', + belong='assistant'), + action: dict = dict(role='ACTION', begin='Action:', end='\n'), + action_input: dict = dict( + role='ARGS', begin='ActionInput:', end='\n'), + response: dict = dict( + role='RESPONSE', begin='Response:', end='\n'), + finish: dict = dict( + role='FINISH', begin='FinalAnswer:', end='\n'), + call_protocol: str = CALL_PROTOCOL, + force_stop: str = '你需要基于历史消息返回一个最终结果') -> None: + self.call_protocol = call_protocol + self.force_stop = force_stop + self.thought = thought + self.action = action + self.action_input = action_input + self.response = response + self.finish = finish + + def format(self, + chat_history: List[Dict], + inner_step: List[Dict], + action_executor: ActionExecutor, + force_stop: bool = False) -> list: + """Generate the ReACT format prompt. + + Args: + chat_history (List[Dict]): The history log in previous runs. + inner_step (List[Dict]): The log in the current run. + action_executor (ActionExecutor): the action manager to + execute actions. + force_stop (boolean): whether force the agent to give responses + under pre-defined turns. + + Returns: + List[Dict]: ReACT format prompt. + """ + + call_protocol = self.call_protocol.format( + tool_description=action_executor.get_actions_info(), + action_names=action_executor.action_names(), + thought=self.thought['begin'], + action=self.action['begin'], + action_input=self.action_input['begin'], + response=self.response['begin'], + finish=self.finish['begin'], + ) + formatted = [] + formatted.append(dict(role='system', content=call_protocol)) + formatted += chat_history + formatted += inner_step + if force_stop: + formatted.append(dict(role='system', content=self.force_stop)) + return formatted + + def parse( + self, + message: str, + action_executor: ActionExecutor, + ) -> Tuple[str, str, str]: + """Parse the action returns in a ReACT format. + + Args: + message (str): The response from LLM with ReACT format. + action_executor (ActionExecutor): Action executor to + provide no_action/finish_action name. + + Returns: + tuple: the return value is a tuple contains: + - thought (str): contain LLM thought of the current step. + - action (str): contain action scheduled by LLM. + - action_input (str): contain the required action input + for current action. + """ + + import re + thought = message.split(self.action['begin'])[0] + thought = thought.split(self.thought['begin'])[-1] + thought = thought.split(self.finish['begin'])[0] + if self.finish['begin'] in message: + final_answer = message.split(self.finish['begin'])[-1] + return thought, action_executor.finish_action.name, final_answer + + action_regex = f"{self.action['begin']}(.*?)\n" + args_regex = f"{self.action_input['begin']}(.*)" + action_match = re.findall(action_regex, message) + if not action_match: + return thought, action_executor.no_action.name, '' + action = action_match[-1] + arg_match = re.findall(args_regex, message, re.DOTALL) + + if not arg_match: + return thought, action_executor.no_action.name, '' + action_input = arg_match[-1] + return thought, action.strip(), action_input.strip().strip('"') + + def format_response(self, action_return: ActionReturn) -> str: + """format the final response at current step. + + Args: + action_return (ActionReturn): return value of the current action. + + Returns: + str: the final response at current step. + """ + if action_return.state == ActionStatusCode.SUCCESS: + response = action_return.result['text'] + else: + response = action_return.errmsg + return self.response['begin'] + response + self.response['end'] + + +class ReACT(BaseAgent): + """An implementation of ReACT (https://arxiv.org/abs/2210.03629) + + Args: + llm (BaseModel or BaseAPIModel): a LLM service which can chat + and act as backend. + action_executor (ActionExecutor): an action executor to manage + all actions and their response. + protocol (ReActProtocol): a wrapper to generate prompt and + parse the response from LLM / actions. + max_turn (int): the maximum number of trails for LLM to generate + plans that can be successfully parsed by ReWOO protocol. + """ + + def __init__(self, + llm: Union[BaseModel, BaseAPIModel], + action_executor: ActionExecutor, + protocol: ReACTProtocol = ReACTProtocol(), + max_turn: int = 2) -> None: + self.max_turn = max_turn + super().__init__( + llm=llm, action_executor=action_executor, protocol=protocol) + + def chat(self, message: str) -> AgentReturn: + self._inner_history = [] + self._inner_history.append(dict(role='user', content=message)) + agent_return = AgentReturn() + force_stop = False + default_response = '对不起,我无法回答你的问题' + for turn in range(self.max_turn): + prompt = self._protocol.format( + chat_history=self.session_history, + inner_step=self._inner_history, + action_executor=self._action_executor, + force_stop=force_stop) + response = self._llm.generate_from_template(prompt, 512) + self._inner_history.append( + dict(role='assistant', content=response)) + thought, action, action_input = self._protocol.parse( + response, self._action_executor) + action_return: ActionReturn = self._action_executor( + action, action_input) + action_return.thought = thought + agent_return.actions.append(action_return) + if action_return.type == self._action_executor.finish_action.name: + agent_return.response = action_return.result['text'] + return agent_return + self._inner_history.append( + dict( + role='system', + content=self._protocol.format_response(action_return))) + if turn == self.max_turn - 1: + force_stop = True + agent_return.response = default_response + # only append the user and final response + self._session_history.append(dict(role='user', content=message)) + self._session_history.append( + dict(role='assistant', content=agent_return.response)) + return agent_return diff --git a/lagent/agents/rewoo.py b/lagent/agents/rewoo.py new file mode 100644 index 0000000..c20d775 --- /dev/null +++ b/lagent/agents/rewoo.py @@ -0,0 +1,232 @@ +import re +import warnings +from typing import Dict, List, Optional, Tuple, Union + +from lagent.actions import ActionExecutor +from lagent.llms.base_api import BaseAPIModel +from lagent.llms.base_llm import BaseModel +from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn +from .base_agent import BaseAgent + +PLANER_PROMPT = """你是一个任务分解器, 你需要将用户的问题拆分成多个简单的子任务。 +请拆分出多个子任务项,从而能够得到充分的信息以解决问题, 返回格式如下: +``` +Plan: 当前子任务要解决的问题 +#E[id] = 工具名称[工具参数] +Plan: 当前子任务要解决的问题 +#E[id] = 工具名称[工具参数] +``` +其中 +1. #E[id] 用于存储Plan id的执行结果, 可被用作占位符。 +2. 每个 #E[id] 所执行的内容应与当前Plan解决的问题严格对应。 +3. 工具参数可以是正常输入text, 或是 #E[依赖的索引], 或是两者都可以。 +4. 工具名称必须从一下工具中选择: +{tool_description} +注意: 每个Plan后有且仅能跟随一个#E。 +开始!""" + +WORKER_PROMPT = """ +Thought: {thought}\nResponse: {action_resp}\n +""" + +SOLVER_PROMPT = """解决接下来的任务或者问题。为了帮助你,我们提供了一些相关的计划 +和相应的解答。注意其中一些信息可能存在噪声,因此你需要谨慎的使用它们。\n +{question}\n{worker_log}\n现在开始回答这个任务或者问题。请直接回答这个问题, +不要包含其他不需要的文字。{question}\n +""" + + +class ReWOOProtocol: + """A wrapper of ReWOO prompt which manages the response from LLM and + generate desired prompts in a ReWOO format. + + Args: + planner_prompt (str): prompt template for planner + solver_prompt (str): prompt template for solver + """ + + def __init__( + self, + planner_prompt: str = PLANER_PROMPT, + worker_prompt: str = WORKER_PROMPT, + solver_prompt: str = SOLVER_PROMPT, + ) -> None: + self.planner_prompt = planner_prompt + self.worker_prompt = worker_prompt + self.solver_prompt = solver_prompt + + def format_planner(self, + chat_history: List[Dict], + inner_step: List[Dict], + action_executor: ActionExecutor, + reformat_request: Optional[str] = '') -> List[Dict]: + """Generate the planner prompt required by ReWOO. + + Args: + chat_history (List[Dict]): The history log in previous runs. + inner_step (List[Dict]): The log in the current run. + action_executor (ActionExecutor): the action manager to execute + actions. + reformat_request (str): the error feedback if the LLM fails to + generate required format for planner. + + Returns: + List[Dict]: ReWOO format prompt for planner. + """ + planner_prompt = self.planner_prompt.format( + tool_description=action_executor.get_actions_info(), ) + formatted = [] + formatted.append(dict(role='system', content=planner_prompt)) + formatted += chat_history + formatted += inner_step + if reformat_request != '': + formatted.append( + dict( + role='system', + content='回答格式错误: %s. 请重新重新回答: ' % reformat_request)) + return formatted + + def parse_worker( + self, + message: str, + ) -> Tuple[List[str], List[str], List[str]]: + """Parse the LLM generated planner response and convert it into the + worker format. + + Args: + message (str): The response from LLM with ReWOO planner format. + + Returns: + tuple: the return value is a tuple contains: + - thought_list (List(str)): contain LLM thoughts of the user + request. + - action_list (List(str)): contain actions scheduled by LLM. + - action_input_list (List(str)): contain the required action + input for above actions. + """ + action_list = [] + action_input_list = [] + thought_list = [] + thoughts = re.findall('Plan: (.+)', message) + action_units = re.findall('#E[0-9]* = (.+)', message) + assert len(thoughts) == len(action_units), \ + 'Each Plan should only correspond to only ONE action' + for thought, action_unit in zip(thoughts, action_units): + action_name, action_input = re.findall(r'(.*?)\[(.*?)\]', + action_unit.strip())[0] + action_list.append(action_name.strip()) + action_input_list.append(action_input.strip()) + thought_list.append(thought.strip()) + return thought_list, action_list, action_input_list + + def format_solver( + self, question: str, thought_list: List[str], + action_return_list: List[ActionReturn]) -> Tuple[str, str]: + """Generate the prompt for solver in a ReWOO format. + + Args: + question (str): The user request in the current run. + thought_list (List[str]): thoughts generated from LLM for + each action. + action_return_list (List[ActionReturn]): action returns + from workers. + + Returns: + tuple: the return value is a tuple contains: + - solver_prompt (str): the generated prompt for solver + in a ReWOO format. + - worker_log (str): contain action responses from workers. + Used for inner log. + """ + worker_log = '' + for thought, action_return in zip(thought_list, action_return_list): + if action_return.state == ActionStatusCode.SUCCESS: + action_resp = action_return.result['text'] + else: + action_resp = action_return.errmsg + worker_response = self.worker_prompt.format( + thought=thought, action_resp=action_resp) + worker_log += worker_response + solver_prompt = self.solver_prompt.format( + question=question, worker_log=worker_log) + return solver_prompt, worker_log + + +class ReWOO(BaseAgent): + """An implementation of ReWOO (https://arxiv.org/abs/2305.18323) + + Args: + llm (BaseModel or BaseAPIModel): a LLM service which can chat + and act as planner / solver. + action_executor (ActionExecutor): an action executor to manage + all actions and their response. + protocol (ReWOOProtocol): a wrapper to generate prompt and + parse the response from LLM / actions. + max_turn (int): the maximum number of trails for LLM to generate + plans that can be successfully parsed by ReWOO protocol. + """ + + def __init__(self, + llm: Union[BaseModel, BaseAPIModel], + action_executor: ActionExecutor, + protocol: ReWOOProtocol = ReWOOProtocol(), + max_turn: int = 2) -> None: + super().__init__( + llm=llm, action_executor=action_executor, protocol=protocol) + + self.max_turn = max_turn + + def chat(self, message: str) -> AgentReturn: + self._inner_history = [] + self._inner_history.append(dict(role='user', content=message)) + agent_return = AgentReturn() + + # planner + turn_id = 0 + reformat_request = '' + while turn_id < self.max_turn: + planner_prompt = self._protocol.format_planner( + chat_history=self.session_history, + inner_step=self._inner_history, + action_executor=self._action_executor, + reformat_request=reformat_request) + response = self._llm.generate_from_template(planner_prompt, 512) + self._inner_history.append( + dict(role='assistant', content=response)) + try: + thoughts, actions, actions_input = self._protocol.parse_worker( + response) + break + except Exception as e: + turn_id += 1 + reformat_request = str(e) + + if turn_id >= self.max_turn: + warnings.warn('\nUnable to parse LLM responses in %d turns, ' + 'directly request solver for question answer...' % + self.max_turn) + actions = [] + thoughts = [] + action_responses = [] + # workers + action_responses = [] + for action_id in range(len(actions)): + # we need to change actions_input inplace + prev_ptrs = re.findall(r'#E\d+', actions_input[action_id]) + for prev_ptr in prev_ptrs: + ptr_num = int(prev_ptr.strip('#E')) - 1 # start from 0 + actions_input[action_id] = actions_input[action_id].replace( + prev_ptr, action_responses[ptr_num].result['text']) + action_return: ActionReturn = self._action_executor( + actions[action_id], actions_input[action_id]) + action_responses.append(action_return) + + solver_prompt, worker_log = self._protocol.format_solver( + message, thoughts, action_responses) + self._inner_history.append(dict(role='system', content=worker_log)) + + final_response = self._llm.generate_from_template(solver_prompt, 512) + self._inner_history.append( + dict(role='assistant', content=final_response)) + agent_return.response = final_response + return agent_return diff --git a/lagent/llms/__init__.py b/lagent/llms/__init__.py new file mode 100644 index 0000000..202910c --- /dev/null +++ b/lagent/llms/__init__.py @@ -0,0 +1,10 @@ +from lagent.utils import is_module_exist +from .base_api import BaseAPIModel +from .base_llm import BaseModel +from .openai import GPTAPI + +__all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI'] + +if is_module_exist('transformers'): + from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401 + __all__.extend(['HFTransformer', 'HFTransformerCasualLM']) diff --git a/lagent/llms/base_api.py b/lagent/llms/base_api.py new file mode 100644 index 0000000..ac84d82 --- /dev/null +++ b/lagent/llms/base_api.py @@ -0,0 +1,240 @@ +import re +import threading +import warnings +from abc import abstractclassmethod +from time import sleep +from typing import Dict, List, Optional, Tuple, Union + +from .base_llm import BaseModel + + +class BaseAPIModel(BaseModel): + """Base class for API model wrapper. + + Args: + model_type (str): The type of model. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + retry (int): Number of retires if the API call fails. Defaults to 2. + max_seq_len (int): The maximum sequence length of the model. Defaults + to 2048. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + is_api: bool = True + + def __init__(self, + model_type: str, + query_per_second: int = 1, + retry: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None): + self.model_type = model_type + self.max_seq_len = max_seq_len + self.meta_template = meta_template + self.retry = retry + self.query_per_second = query_per_second + self.token_bucket = TokenBucket(query_per_second) + self.template_parser = APITemplateParser(meta_template) + + @abstractclassmethod + def generate(self, inputs, max_out_len: int) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or list]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized string. Only English and Chinese + characters are counted for now. Users are encouraged to override this + method if more accurate length is needed. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + + english_parts = re.findall(r'[A-Za-z0-9]+', prompt) + chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) + + # Count English words + english_count = sum(len(part.split()) for part in english_parts) + + # Count Chinese words + chinese_count = sum(len(part) for part in chinese_parts) + + return english_count + chinese_count + + def wait(self): + """Wait till the next query can be sent. + + Applicable in both single-thread and multi-thread environments. + """ + return self.token_bucket.get_token() + + def to(self, device): + pass + + +class APITemplateParser: + """Intermidate prompt template parser, specifically for API models. + + Args: + meta_template (Dict): The meta template for the model. + """ + + def __init__(self, meta_template: Optional[Dict] = None): + self.meta_template = meta_template + # Check meta template + if meta_template: + assert isinstance(meta_template, list) + self.roles: Dict[str, dict] = dict() # maps role name to config + for item in meta_template: + assert isinstance(item, dict) + assert item['role'] not in self.roles, \ + 'role in meta prompt must be unique!' + self.roles[item['role']] = item.copy() + + def parse_template(self, dialog: List[Union[str, List]]): + """Parse the intermidate prompt template, and wrap it with meta + template if applicable. When the meta template is set and the input is + a list, the return value will be a list containing the full + conversation history. Each item looks like: + + .. code-block:: python + + {'role': 'user', 'content': '...'}). + + Args: + dialog (List[str or list]): An intermidate prompt + template (potentially before being wrapped by meta template). + mode (str): Parsing mode. Choices are 'ppl' and 'gen'. + + Returns: + List[str or list]: The finalized prompt or a conversation. + """ + assert isinstance(dialog, (str, list)) + if isinstance(dialog, str): + return dialog + if self.meta_template: + + prompt = list() + # Whether to keep generating the prompt + generate = True + for i, item in enumerate(dialog): + if not generate: + break + if isinstance(item, str): + if item.strip(): + # TODO: logger + warnings.warn('Non-empty string in prompt template ' + 'will be ignored in API models.') + else: + api_prompts = self._prompt2api(item) + prompt.append(api_prompts) + + # merge the consecutive prompts assigned to the same role + new_prompt = list([prompt[0]]) + last_role = prompt[0]['role'] + for item in prompt[1:]: + if item['role'] == last_role: + new_prompt[-1]['content'] += '\n' + item['content'] + else: + last_role = item['role'] + new_prompt.append(item) + prompt = new_prompt + + else: + # in case the model does not have any meta template + prompt = '' + last_sep = '' + for item in dialog: + if isinstance(item, str): + if item: + prompt += last_sep + item + elif item.get('content', ''): + prompt += last_sep + item.get('content', '') + last_sep = '\n' + return prompt + + def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]: + """Convert the prompts to a API-style prompts, given an updated + role_dict. + + Args: + prompts (Union[List, str]): The prompts to be converted. + role_dict (Dict[str, Dict]): The updated role dict. + for_gen (bool): If True, the prompts will be converted for + generation tasks. The conversion stops before the first + role whose "generate" is set to True. + + Returns: + Tuple[str, bool]: The converted string, and whether the follow-up + conversion should be proceeded. + """ + if isinstance(prompts, str): + return prompts + elif isinstance(prompts, dict): + api_role = self._role2api_role(prompts) + return api_role + + res = [] + for prompt in prompts: + if isinstance(prompt, str): + raise TypeError('Mixing str without explicit role is not ' + 'allowed in API models!') + else: + api_role = self._role2api_role(prompt) + res.append(api_role) + return res + + def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]: + + merged_prompt = self.roles.get( + role_prompt['role'], + self.roles.get( + self.roles[role_prompt['role']].get('fallback_role'))) + res = {} + res['role'] = merged_prompt['api_role'] + res['content'] = merged_prompt.get('begin', '') + res['content'] += role_prompt.get('content', '') + res['content'] += merged_prompt.get('end', '') + return res + + +class TokenBucket: + """A token bucket for rate limiting. + + Args: + rate (float): The rate of the token bucket. + """ + + def __init__(self, rate: float) -> None: + self._rate = rate + self._tokens = threading.Semaphore(0) + self.started = False + + def _add_tokens(self): + """Add tokens to the bucket.""" + while True: + if self._tokens._value < self._rate: + self._tokens.release() + sleep(1 / self._rate) + + def get_token(self): + """Get a token from the bucket.""" + if not self.started: + self.started = True + threading.Thread(target=self._add_tokens, daemon=True).start() + self._tokens.acquire() diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py new file mode 100644 index 0000000..890bada --- /dev/null +++ b/lagent/llms/base_llm.py @@ -0,0 +1,146 @@ +from abc import abstractclassmethod +from typing import Dict, List, Optional, Tuple, Union + + +class BaseModel: + """Base class for model wrapper. + + Args: + path (str): The path to the model. + max_seq_len (int): The maximum sequence length of the model. Defaults + to 2048. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + is_api: bool = False + + def __init__(self, + path: str, + max_seq_len: int = 2048, + tokenizer_only: bool = False, + meta_template: Optional[Dict] = None): + self.path = path + self.max_seq_len = max_seq_len + self.tokenizer_only = tokenizer_only + # meta template + self.template_parser = LMTemplateParser(meta_template) + self.eos_token_id = None + if meta_template and 'eos_token_id' in meta_template: + self.eos_token_id = meta_template['eos_token_id'] + + @abstractclassmethod + def generate(self, inputs: List[str], max_out_len: int) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + + def parse_template(self, dialog) -> str: + """Parse a prompt template, and wrap it with meta template if + applicable. + + Args: + dialog (List[str or PromptList]): A prompt + template (potentially before being wrapped by meta template). + mode (str): Parsing mode. Choices are 'ppl' and 'gen'. + + Returns: + str: The final string. + """ + return self.template_parser.parse_template(dialog) + + def generate_from_template(self, templates, max_out_len: int, **kwargs): + """Generate completion from a list of templates. + + Args: + templates (List[PromptType]): A list of templates. + max_out_len (int): The maximum length of the output. + """ + inputs = self.parse_template(templates) + return self.generate(inputs, max_out_len=max_out_len, **kwargs) + + def to(self, device): + self.model.to(device) + + +class LMTemplateParser: + """Intermidate prompt template parser, specifically for language models. + + Args: + meta_template (Dict): The meta template for the model. + """ + + def __init__(self, meta_template: Optional[Dict] = None): + self.meta_template = meta_template + if meta_template: + assert isinstance(meta_template, list) + self.roles: Dict[str, dict] = dict() # maps role name to config + for item in meta_template: + assert isinstance(item, dict) + assert item['role'] not in self.roles, \ + 'role in meta prompt must be unique!' + self.roles[item['role']] = item.copy() + + def parse_template(self, dialog) -> str: + """Parse a prompt template, and wrap it with meta template if + applicable. + + Args: + dialog (List[str or PromptList]): A prompt + template (potentially before being wrapped by meta template). + mode (str): Parsing mode. Choices are 'ppl' and 'gen'. + + Returns: + str: The final string. + """ + assert isinstance(dialog, (str, list)) + if isinstance(dialog, str): + return dialog + if self.meta_template: + + prompt = '' + for index, item in enumerate(dialog): + if isinstance(item, str): + prompt += item + else: + new_str = self._prompt2str(item, index == len(dialog) - 1) + prompt += new_str + else: + # in case the model does not have any meta template + prompt = '' + last_sep = '' + for item in dialog: + if isinstance(item, str): + if item: + prompt += last_sep + item + elif item.get('content', ''): + prompt += last_sep + item.get('prompt', '') + last_sep = '\n' + return prompt + + def _prompt2str(self, + prompt: Union[List, str, Dict], + last: bool = False) -> Tuple[str, bool]: + if isinstance(prompt, str): + return prompt + merged_prompt = self.roles.get( + prompt['role'], + self.roles.get(self.roles[prompt['role']].get('fallback_role'))) + res = merged_prompt.get('begin', '') + if last and merged_prompt.get('generate', False): + res += prompt.get('content', '') + return res + res += prompt.get('content', '') + merged_prompt.get('end', '') + if last and merged_prompt['role'] != 'assistant': + res += self.roles['assistant']['begin'] + return res + return res diff --git a/lagent/llms/huggingface.py b/lagent/llms/huggingface.py new file mode 100644 index 0000000..199df5c --- /dev/null +++ b/lagent/llms/huggingface.py @@ -0,0 +1,128 @@ +from typing import Dict, List, Optional + +import torch + +from .base_llm import BaseModel + + +class HFTransformer(BaseModel): + """Model wrapper around HuggingFace general models. + + Adapted from OpenCompass (https://github.com/InternLM/opencompass + /blob/main/opencompass/models/huggingface.py) + + Args: + path (str): The name or path to HuggingFace's model. + max_seq_len (int): The maximum length of the input sequence. Defaults + to 2048. + tokenizer_path (str): The path to the tokenizer. Defaults to None. + tokenizer_kwargs (dict): Keyword arguments for the tokenizer. + Defaults to {}. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + model_kwargs (dict): Keyword arguments for the model, used in loader. + Defaults to dict(device_map='auto'). + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + extract_pred_after_decode (bool): Whether to extract the prediction + string from the decoded output string, instead of extract the + prediction tokens before decoding. Defaults to False. + batch_padding (bool): If False, inference with be performed in for-loop + without batch padding. + + Note: + About ``extract_pred_after_decode``: Commonly, we should extract the + the prediction tokens before decoding. But for some tokenizers using + ``sentencepiece``, like LLaMA, this behavior may change the number of + whitespaces, which is harmful for Python programming tasks. + """ + + def __init__(self, + path: str, + max_seq_len: int = 2048, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + tokenizer_only: bool = False, + model_kwargs: dict = dict(device_map='auto'), + meta_template: Optional[Dict] = None, + extract_pred_after_decode: bool = False, + batch_padding: bool = False): + super().__init__( + path=path, + max_seq_len=max_seq_len, + tokenizer_only=tokenizer_only, + meta_template=meta_template) + self._load_tokenizer( + path=path, + tokenizer_path=tokenizer_path, + tokenizer_kwargs=tokenizer_kwargs) + self.batch_padding = batch_padding + self.extract_pred_after_decode = extract_pred_after_decode + if not tokenizer_only: + self._load_model(path=path, model_kwargs=model_kwargs) + + def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], + tokenizer_kwargs: dict): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path if tokenizer_path else path, + trust_remote_code=True, + **tokenizer_kwargs) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def _load_model(self, path: str, model_kwargs: dict): + from transformers import AutoModel + model_kwargs.setdefault('torch_dtype', torch.float16) + self.model = AutoModel.from_pretrained( + path, trust_remote_code=True, **model_kwargs) + self.model.eval() + + def generate(self, inputs: List[str], max_out_len: int, + **kwargs) -> List[str]: + if isinstance(inputs, str): + inputs = [inputs] + if self.extract_pred_after_decode: + prompt_lens = [len(input_) for input_ in inputs] + + input_ids = self.tokenizer( + inputs, truncation=True, + max_length=self.max_seq_len - max_out_len)['input_ids'] + input_ids = torch.tensor(input_ids, device=self.model.device) + outputs = self.model.generate( + input_ids, max_new_tokens=max_out_len, **kwargs) + + if not self.extract_pred_after_decode: + outputs = outputs[:, input_ids.shape[1]:] + + decodeds = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + if self.extract_pred_after_decode: + decodeds = [ + token[len_:] for token, len_ in zip(decodeds, prompt_lens) + ] + + return decodeds[0] + + def generate_from_template(self, templates, max_out_len: int, **kwargs): + """Generate completion from a list of templates. + + Args: + templates (List[PromptType]): A list of templates. + max_out_len (int): The maximum length of the output. + """ + inputs = self.parse_template(templates) + response = self.generate(inputs, max_out_len=max_out_len, **kwargs) + return response.replace( + self.template_parser.roles['assistant']['end'].strip(), + '').strip() + + +class HFTransformerCasualLM(HFTransformer): + + def _load_model(self, path: str, model_kwargs: dict): + from transformers import AutoModelForCausalLM + model_kwargs.setdefault('torch_dtype', torch.float16) + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) + self.model.eval() diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py new file mode 100644 index 0000000..0ec5966 --- /dev/null +++ b/lagent/llms/openai.py @@ -0,0 +1,242 @@ +import json +import os +import time +# from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from threading import Lock +from typing import Dict, List, Optional, Union + +import requests + +from .base_api import BaseAPIModel + +OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' + + +class GPTAPI(BaseAPIModel): + """Model wrapper around OpenAI's models. + + Args: + model_type (str): The name of OpenAI's model. + max_seq_len (int): The maximum allowed sequence length of a model. + Note that the length of prompt + generated tokens shall not exceed + this value. Defaults to 2048. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + retry (int): Number of retires if the API call fails. Defaults to 2. + key (str or List[str]): OpenAI key(s). In particular, when it + is set to "ENV", the key will be fetched from the environment + variable $OPENAI_API_KEY, as how openai defaults to be. If it's a + list, the keys will be used in round-robin manner. Defaults to + 'ENV'. + org (str or List[str], optional): OpenAI organization(s). If not + specified, OpenAI uses the default organization bound to each API + key. If specified, the orgs will be posted with each request in + round-robin manner. Defaults to None. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + openai_api_base (str): The base url of OpenAI's API. Defaults to + 'https://api.openai.com/v1/chat/completions'. + temperature (float, optional): What sampling temperature to use. + If not None, will override the temperature in the `generate()` + call. Defaults to None. + """ + + is_api: bool = True + + def __init__(self, + model_type: str = 'gpt-3.5-turbo', + max_seq_len: int = 4096, + query_per_second: int = 1, + retry: int = 2, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant') + ], + openai_api_base: str = OPENAI_API_BASE, + temperature: Optional[float] = None): + + super().__init__( + model_type=model_type, + max_seq_len=max_seq_len, + meta_template=meta_template, + query_per_second=query_per_second, + retry=retry) + self.logger = getLogger(__name__) + self.temperature = temperature + + if isinstance(key, str): + self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] + else: + self.keys = key + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + if isinstance(org, str): + self.orgs = [org] + else: + self.orgs = org + self.org_ctr = 0 + self.url = openai_api_base + self.model_type = model_type + + # max num token for gpt-3.5-turbo is 4097 + context_window = 4096 + if '32k' in self.model_type: + context_window = 32768 + elif '16k' in self.model_type: + context_window = 16384 + elif 'gpt-4' in self.model_type: + context_window = 8192 + self.context_window = context_window + + def generate( + self, + inputs: Union[List, str], + max_out_len: int = 512, + temperature: float = 0.7, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or List]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + temperature (float): What sampling temperature to use, + between 0 and 2. Higher values like 0.8 will make the output + more random, while lower values like 0.2 will make it more + focused and deterministic. Defaults to 0.7. + + Returns: + List[str]: A list of generated strings. + """ + if self.temperature is not None: + temperature = self.temperature + return self._generate(inputs, max_out_len, temperature) + + def _generate(self, input: str or List, max_out_len: int, + temperature: float) -> str: + """Generate results given a list of inputs. + + Args: + inputs (str or List): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + temperature (float): What sampling temperature to use, + between 0 and 2. Higher values like 0.8 will make the output + more random, while lower values like 0.2 will make it more + focused and deterministic. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, list, dict)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + elif isinstance(input, dict): + messages = [input] + else: + messages = input + + # Hold out 100 tokens due to potential errors in tiktoken calculation + max_out_len = min( + max_out_len, + self.context_window - self.get_token_len(str(input)) - 100) + if max_out_len <= 0: + return '' + + max_num_retries = 0 + while max_num_retries < self.retry: + self.wait() + + with Lock(): + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + + header = { + 'Authorization': f'Bearer {key}', + 'content-type': 'application/json', + } + + if self.orgs: + with Lock(): + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + try: + data = dict( + model=self.model_type, + messages=messages, + max_tokens=max_out_len, + n=1, + stop=None, + temperature=temperature, + ) + raw_response = requests.post( + self.url, headers=header, data=json.dumps(data)) + except requests.ConnectionError: + print('Got connection error, retrying...') + continue + try: + response = raw_response.json() + except requests.JSONDecodeError: + print('JsonDecode error, got', str(raw_response.content)) + continue + try: + return response['choices'][0]['message']['content'].strip() + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + print('Find error message in response: ', + str(response['error'])) + max_num_retries += 1 + + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + 'details.') + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized string. Only English and Chinese + characters are counted for now. Users are encouraged to override this + method if more accurate length is needed. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + import tiktoken + self.tiktoken = tiktoken + enc = self.tiktoken.encoding_for_model(self.model_type) + return len(enc.encode(prompt)) diff --git a/lagent/schema.py b/lagent/schema.py new file mode 100644 index 0000000..925df6c --- /dev/null +++ b/lagent/schema.py @@ -0,0 +1,79 @@ +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Union + +from lagent.utils import is_module_exist + + +def enum_dict_factory(inputs): + inputs = [(i[0], i[-1].value) if isinstance(i[-1], Enum) else i + for i in inputs] + return dict(inputs) + + +def dataclass2dict(data): + return asdict(data, dict_factory=enum_dict_factory) + + +class ActionStatusCode(int, Enum): + ING = 1 + SUCCESS = 0 + HTTP_ERROR = -1000 # http error + ARGS_ERROR = -1001 # 参数错误 + API_ERROR = -1002 # 不知道的API错误 + + +class ActionValidCode(int, Enum): + FINISH = 1 + OPEN = 0 + CLOSED = -1 + INVALID = -2 + ABSENT = -3 # NO ACTION + + +@dataclass +class ActionReturn: + args: Dict + url: Optional[str] = None + type: Optional[str] = None + result: Optional[str] = None + errmsg: Optional[str] = None + state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS + thought: Optional[str] = None + valid: Optional[ActionValidCode] = ActionValidCode.OPEN + + +class AgentStatusCode(Enum): + END = 0 # end of streaming + STREAM_ING = 1 # response is in streaming + SERVER_ERR = -1 # triton server's error + SESSION_CLOSED = -2 # session has been closed + SESSION_OUT_OF_LIMIT = -3 # request length out of limit + CMD = 2 # return command + SESSION_INVALID_ARG = -4 # invalid argument + SESSION_READY = 3 # session is ready for inference + + +@dataclass +class AgentReturn: + actions: List[ActionReturn] = field(default_factory=list) + response: str = '' + inner_steps: List = field(default_factory=list) + errmsg: Optional[str] = None + + +if is_module_exist('lmdeploy'): + from lmdeploy.serve.turbomind.chatbot import StatusCode + STATE_MAP = { + StatusCode.TRITON_STREAM_END: AgentStatusCode.END, + StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR, + StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED, + StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING, + StatusCode.TRITON_SESSION_OUT_OF_LIMIT: + AgentStatusCode.SESSION_OUT_OF_LIMIT, + StatusCode.TRITON_SESSION_INVALID_ARG: + AgentStatusCode.SESSION_INVALID_ARG, + StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY + } +else: + STATE_MAP = {} diff --git a/lagent/utils/__init__.py b/lagent/utils/__init__.py new file mode 100644 index 0000000..f681e70 --- /dev/null +++ b/lagent/utils/__init__.py @@ -0,0 +1,3 @@ +from .package import is_module_exist + +__all__ = ['is_module_exist'] diff --git a/lagent/utils/package.py b/lagent/utils/package.py new file mode 100644 index 0000000..3b092b4 --- /dev/null +++ b/lagent/utils/package.py @@ -0,0 +1,9 @@ +import importlib + + +def is_module_exist(module_name): + spec = importlib.util.find_spec(module_name) + if spec is None: + return False + else: + return True diff --git a/lagent/version.py b/lagent/version.py new file mode 100644 index 0000000..0fb3ffc --- /dev/null +++ b/lagent/version.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +__version__ = '0.1.0' + + +def parse_version_info(version_str: str, length: int = 4) -> tuple: + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into + (2, 0, 0, 0, 'rc', 1) (when length is set to 4). + """ + from packaging.version import parse + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + release.extend(list(version.pre)) # type: ignore + elif version.is_postrelease: + release.extend(list(version.post)) # type: ignore + else: + release.extend([0, 0]) + return tuple(release) + + +version_info = tuple(int(x) for x in __version__.split('.')[:3]) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/projects/.gitkeep b/projects/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..751e92e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +-r requirements/optional.txt +-r requirements/runtime.txt diff --git a/requirements/optional.txt b/requirements/optional.txt new file mode 100644 index 0000000..4f492dd --- /dev/null +++ b/requirements/optional.txt @@ -0,0 +1,2 @@ +torch +transformers diff --git a/requirements/runtime.txt b/requirements/runtime.txt new file mode 100644 index 0000000..2e601be --- /dev/null +++ b/requirements/runtime.txt @@ -0,0 +1,5 @@ +distro +func_timeout +jsonschema +requests +tiktoken diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1512bde --- /dev/null +++ b/setup.cfg @@ -0,0 +1,24 @@ +[isort] +line_length = 79 +multi_line_output = 0 +extra_standard_library = setuptools +known_first_party = mmdet +known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,mmengine,numpy,onnx,onnxruntime,pycocotools,parameterized,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY + +[yapf] +BASED_ON_STYLE = pep8 +BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true +SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true + +# ignore-words-list needs to be lowercase format. For example, if we want to +# ignore word "BA", then we need to append "ba" to ignore-words-list rather +# than "BA" +[codespell] +skip = *.ipynb +quiet-level = 3 +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer + +[flake8] +per-file-ignores = ftdp/configs/*: F401,F403,F405 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..c0defe9 --- /dev/null +++ b/setup.py @@ -0,0 +1,107 @@ +from pathlib import Path +from setuptools import find_packages, setup + + +def get_version(): + version_file = 'lagent/version.py' + with open(version_file, encoding='utf-8') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def parse_requirements(fname='requirements.txt', with_version=True): + """Parse the package dependencies listed in a requirements file but strip + specific version information. + + Args: + fname (str): Path to requirements file. + with_version (bool, default=False): If True, include version specs. + Returns: + info (list[str]): List of requirements items. + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import re + import sys + from os.path import exists + + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for info in parse_require_file(target): + yield info + else: + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, + rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath) as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith('#'): + yield from parse_line(line) + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + + +if __name__ == '__main__': + with Path(Path(__file__).parent, + 'README.md').open(encoding='utf-8') as file: + long_description = file.read() + + setup( + name='lagent', + packages=find_packages(), + include_package_data=True, + version=get_version(), + license='Apache 2.0', + description='An open-source framework for language agent research.', + long_description=long_description, + long_description_content_type='text/markdown', + data_files=[('.', ['README.md'])], + keywords=['artificial general intelligence', 'agent', 'agi', 'llm'], + install_requires=parse_requirements('requirements/runtime.txt'), + extras_require={ + 'all': parse_requirements('requirements.txt'), + 'optional': parse_requirements('requirements/optional.txt'), + }, + ) diff --git a/tests/data/search.json b/tests/data/search.json new file mode 100644 index 0000000..0aaa790 --- /dev/null +++ b/tests/data/search.json @@ -0,0 +1,144 @@ +{ + "searchParameters": { + "q": "What is the capital of China?", + "gl": "us", + "hl": "en", + "num": 10, + "type": "search" + }, + "answerBox": { + "title": "China / Capital", + "answer": "Beijing" + }, + "organic": [ + { + "title": "Beijing - Wikipedia", + "link": "https://en.wikipedia.org/wiki/Beijing", + "snippet": "With over 21 million residents, Beijing is the world's most populous national capital city as well as China's second largest city after Shanghai.", + "sitelinks": [ + { + "title": "Etymology", + "link": "https://en.wikipedia.org/wiki/Beijing#Etymology" + }, + { + "title": "History", + "link": "https://en.wikipedia.org/wiki/Beijing#History" + }, + { + "title": "Geography", + "link": "https://en.wikipedia.org/wiki/Beijing#Geography" + } + ], + "position": 1 + }, + { + "title": "What is the Capital of China? - Mappr", + "link": "https://www.mappr.co/capital-cities/china/", + "snippet": "Beijing, also known as Peking, is the capital of China and one of the most populous cities in the world. It is the country's political, educational, ...", + "position": 2 + }, + { + "title": "Google Map of the City of Beijing, capital of P.R. China - Nations Online Project", + "link": "https://www.nationsonline.org/oneworld/map/google_map_Beijing.htm", + "snippet": "Google Earth: Searchable map/satellite view of Beijing, capital city of P.R. China. City Coordinates: 39°54′50″N 116°23′30″E, Bookmark/share this page ...", + "position": 3 + }, + { + "title": "Capital City of China - CountryReports.org", + "link": "https://www.countryreports.org/country/china/capital-city.htm", + "snippet": "Capital City, Beijing ; Capital location, 39 55 N, 116 23 E ; Capital - history, (Peking) Founded about 3,000 years ago on the site of a former Chinese capital, ...", + "position": 4 + }, + { + "title": "Capital of China - Beijing", + "link": "https://www.chinahighlights.com/beijing/capital-of-china.htm", + "snippet": "In Chinese, Beijing means 'Northern Capital'. Bei means 'north' and jing means 'capital'. In the history of China, Beijing's name has changed ...", + "date": "Dec 15, 2022", + "position": 5 + }, + { + "title": "Beijing is the capital of the People's Republic of China. It is the world's most populous capital city, with over 21 million residents within an... | By DL&D Consult | Facebook", + "link": "https://facebook.com/dldconsult/videos/beijing-capital-city-of-china/373001500555301/", + "snippet": "Beijing is an important world capital and global power ...", + "date": "Oct 19, 2020", + "attributes": { + "Duration": "2:58", + "Posted": "Oct 19, 2020" + }, + "imageUrl": "https://encrypted-tbn1.gstatic.com/images?q=tbn:ANd9GcQExx68yUr7xP_1wcRapEKlT5bxe4ptMa6WaLnwXdVAtAdloa7WeTIvCoJp", + "position": 6 + }, + { + "title": "What is the Capital of China? - WorldAtlas", + "link": "https://www.worldatlas.com/articles/what-is-the-capital-of-china.html", + "snippet": "The capital of China is Beijing.", + "date": "Jul 3, 2018", + "position": 7 + }, + { + "title": "A Chinese capital that's not Beijing - BBC Travel", + "link": "https://www.bbc.com/travel/article/20151008-a-chinese-capital-thats-not-beijing", + "snippet": "Beijing may be the capital of China today, but for many centuries the country was ruled from Nanjing, a historic city located on the shores ...", + "date": "Oct 13, 2015", + "position": 8 + }, + { + "title": "Beijing | Province, City, History, Map, & Facts - Britannica", + "link": "https://www.britannica.com/place/Beijing", + "snippet": "Beijing, city, province-level shi (municipality), and capital of the People's Republic of China. The city has been an integral part of China's history over ...", + "position": 9 + } + ], + "peopleAlsoAsk": [ + { + "question": "Does China have 2 capitals?", + "snippet": "There are traditionally four major historical capitals of China referred to as\nthe 'Four Great Ancient Capitals of China' (simplified Chinese: 中国四大古都;\ntraditional Chinese: 中國四大古都; pinyin: Zhōngguó Sì Dà Gǔ Dū). The four are\nBeijing, Nanjing, Luoyang and Xi\"an (Chang\"an).", + "title": "Historical capitals of China - Wikipedia", + "link": "https://en.wikipedia.org/wiki/Historical_capitals_of_China" + }, + { + "question": "What is the capital city of China USA?", + "snippet": "Capital City\nBeijing\nCapital - time difference\nUTC+8 (13 hours ahead of Washington, DC, during Standard Time) note; despite its\nsize, all of China falls within one time zone", + "title": "Capital City of China - CountryReports.org", + "link": "https://www.countryreports.org/country/china/capital-city.htm" + }, + { + "question": "Is Hong Kong is a part of China?", + "snippet": "Hong Kong (US: /ˈhɒŋkɒŋ/ or UK: /hɒŋˈkɒŋ/; Chinese: 香港, Cantonese: [hœ́ːŋ.kɔ̌ːŋ]\n( listen)), officially the Hong Kong Special Administrative Region of the\nPeople's Republic of China (abbr. Hong Kong SAR or HKSAR), is a city and a\nspecial administrative region in China.", + "title": "Hong Kong - Wikipedia", + "link": "https://en.wikipedia.org/wiki/Hong_Kong" + }, + { + "question": "Why China changed its capital?", + "snippet": "Once in charge, it wasn't uncommon for a new emperor to shift the imperial\ncapital in order to: Rebuild after a great loss, as in the Han era when Liu Bang\nmoved the capital from Xianyang to nearby Chang'an (now Xi'an), after the former\nwas destroyed during a rebellion.", + "title": "Why do they keep moving the Capital? - China Simplified", + "link": "https://www.chinasimplified.com/2014/09/29/why-do-they-keep-moving-the-capital/" + } + ], + "relatedSearches": [ + { + "query": "China map" + }, + { + "query": "Where is the capital of china on a map" + }, + { + "query": "Beijing population" + }, + { + "query": "Capital of Korea" + }, + { + "query": "Beijing pronunciation" + }, + { + "query": "What is the capital of India" + }, + { + "query": "What is the capital of Japan" + }, + { + "query": "What is the capital of Pakistan" + } + ] +} \ No newline at end of file diff --git a/tests/test_actions/test_builtin_actions.py b/tests/test_actions/test_builtin_actions.py new file mode 100644 index 0000000..670c793 --- /dev/null +++ b/tests/test_actions/test_builtin_actions.py @@ -0,0 +1,45 @@ +from unittest import TestCase + +from lagent.actions.builtin_actions import (FinishAction, InvalidAction, + NoAction) +from lagent.schema import ActionStatusCode + + +class TestFinishAction(TestCase): + + def test_call(self): + action = FinishAction() + response = 'finish' + action_return = action(response) + self.assertEqual(action_return.state, ActionStatusCode.SUCCESS) + self.assertDictEqual(action_return.result, dict(text='finish')) + + +class TestInvalidAction(TestCase): + + def test_call(self): + action = InvalidAction() + response = 'invalid' + action_return = action(response) + self.assertEqual(action_return.state, ActionStatusCode.API_ERROR) + self.assertEqual(action_return.errmsg, response) + + action = InvalidAction(err_msg='error') + action_return = action() + self.assertEqual(action_return.state, ActionStatusCode.API_ERROR) + self.assertEqual(action_return.errmsg, 'error') + + +class TestNoAction(TestCase): + + def test_call(self): + action = NoAction() + response = 'no' + action_return = action(response) + self.assertEqual(action_return.state, ActionStatusCode.API_ERROR) + self.assertEqual(action_return.errmsg, response) + + action = NoAction(err_msg='error') + action_return = action() + self.assertEqual(action_return.state, ActionStatusCode.API_ERROR) + self.assertEqual(action_return.errmsg, 'error') diff --git a/tests/test_actions/test_python_interpreter.py b/tests/test_actions/test_python_interpreter.py new file mode 100644 index 0000000..f698c77 --- /dev/null +++ b/tests/test_actions/test_python_interpreter.py @@ -0,0 +1,21 @@ +from unittest import TestCase + +from lagent.actions.python_interpreter import PythonInterpreter +from lagent.schema import ActionStatusCode + + +class TestPythonInterpreter(TestCase): + + def test_python_executor(self): + python_executor = PythonInterpreter() + tool_return = python_executor( + '```python\ndef solution():\n return 1\n```') + self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS) + self.assertDictEqual(tool_return.result, dict(text='1')) + + def test_timeout(self): + python_executor = PythonInterpreter(timeout=2) + tool_return = python_executor( + '```python\ndef solution():\n while True:\n pass\n```') + self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR) + self.assertIn('FunctionTimedOut', tool_return.errmsg) diff --git a/tests/test_actions/test_serper_search.py b/tests/test_actions/test_serper_search.py new file mode 100644 index 0000000..19f1d56 --- /dev/null +++ b/tests/test_actions/test_serper_search.py @@ -0,0 +1,35 @@ +import json +from unittest import TestCase, mock + +from lagent.actions import SerperSearch +from lagent.schema import ActionStatusCode + + +class TestSerperSearch(TestCase): + + @mock.patch.object(SerperSearch, '_search') + def test_search_tool(self, mock_search_func): + mock_response = (200, json.load('tests/data/search.json')) + mock_search_func.return_value = mock_response + search_tool = SerperSearch(api_key='abc') + tool_return = search_tool.run("What's the capital of China?") + self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS) + self.assertDictEqual(tool_return.result, dict(text="['Beijing']")) + + @mock.patch.object(SerperSearch, '_search') + def test_api_error(self, mock_search_func): + mock_response = (403, {'message': 'bad requests'}) + mock_search_func.return_value = mock_response + search_tool = SerperSearch(api_key='abc') + tool_return = search_tool.run("What's the capital of China?") + self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR) + self.assertEqual(tool_return.errmsg, str(403)) + + @mock.patch.object(SerperSearch, '_search') + def test_http_error(self, mock_search_func): + mock_response = (-1, 'HTTPSConnectionPool') + mock_search_func.return_value = mock_response + search_tool = SerperSearch(api_key='abc') + tool_return = search_tool.run("What's the capital of China?") + self.assertEqual(tool_return.state, ActionStatusCode.HTTP_ERROR) + self.assertEqual(tool_return.errmsg, 'HTTPSConnectionPool') diff --git a/tests/test_agents/test_rewoo.py b/tests/test_agents/test_rewoo.py new file mode 100644 index 0000000..52fa1ed --- /dev/null +++ b/tests/test_agents/test_rewoo.py @@ -0,0 +1,87 @@ +from unittest import TestCase, mock + +from lagent.actions import ActionExecutor +from lagent.actions.llm_qa import LLMQA +from lagent.actions.serper_search import SerperSearch +from lagent.agents.rewoo import ReWOO, ReWOOProtocol +from lagent.schema import ActionReturn, ActionStatusCode + + +class TestReWOO(TestCase): + + @mock.patch.object(SerperSearch, 'run') + @mock.patch.object(LLMQA, 'run') + @mock.patch.object(ReWOOProtocol, 'parse_worker') + def test_normal_chat(self, mock_parse_worker_func, mock_qa_func, + mock_search_func): + mock_model = mock.Mock() + mock_model.generate_from_template.return_value = 'LLM response' + + mock_parse_worker_func.return_value = (['Thought1', 'Thought2' + ], ['LLMQA', 'SerperSearch'], + ['abc', 'abc']) + + search_return = ActionReturn(args=None) + search_return.state = ActionStatusCode.SUCCESS + search_return.result = dict(text='search_return') + mock_search_func.return_value = search_return + + qa_return = ActionReturn(args=None) + qa_return.state = ActionStatusCode.SUCCESS + qa_return.result = dict(text='qa_return') + mock_qa_func.return_value = qa_return + + chatbot = ReWOO( + llm=mock_model, + action_executor=ActionExecutor(actions=[ + LLMQA(mock_model), + SerperSearch(api_key=''), + ])) + agent_return = chatbot.chat('abc') + self.assertEqual(agent_return.response, 'LLM response') + + def test_parse_worker(self): + prompt = ReWOOProtocol() + message = """ + Plan: a. + #E1 = tool1["a"] + #E2 = tool2["b"] + """ + try: + thoughts, actions, actions_input = prompt.parse_worker(message) + except Exception as e: + self.assertEqual( + 'Each Plan should only correspond to only ONE action', str(e)) + else: + self.assertFalse( + True, 'it should raise exception when the format is incorrect') + + message = """ + Plan: a. + #E1 = tool1("a") + Plan: b. + #E2 = tool2["b"] + """ + try: + thoughts, actions, actions_input = prompt.parse_worker(message) + except Exception as e: + self.assertIsInstance(e, BaseException) + else: + self.assertFalse( + True, 'it should raise exception when the format is incorrect') + + message = """ + Plan: a. + #E1 = tool1["a"] + Plan: b. + #E2 = tool2["b"] + """ + try: + thoughts, actions, actions_input = prompt.parse_worker(message) + except Exception: + self.assertFalse( + True, + 'it should not raise exception when the format is correct') + self.assertEqual(thoughts, ['a.', 'b.']) + self.assertEqual(actions, ['tool1', 'tool2']) + self.assertEqual(actions_input, ['"a"', '"b"'])