Compare commits
86 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
581d9fb898 | ||
|
|
26672731e9 | ||
|
|
466d7aa24d | ||
|
|
5f55e12736 | ||
|
|
e16a6bfc3a | ||
|
|
605a921878 | ||
|
|
4fd014bebf | ||
|
|
432ffaae8a | ||
|
|
a64aa599ce | ||
|
|
662011783e | ||
|
|
3cf20f5011 | ||
|
|
7b71988d09 | ||
|
|
90ef5215b6 | ||
|
|
6a5447663a | ||
|
|
a2c23ef9dd | ||
|
|
3be9ec042c | ||
|
|
aa5a357a34 | ||
|
|
5650a75f3e | ||
|
|
eea6e1cb56 | ||
|
|
42c6d265e1 | ||
|
|
e20a768066 | ||
|
|
60244a253a | ||
|
|
990828ceb2 | ||
|
|
b2e6d0ee7a | ||
|
|
fda869162b | ||
|
|
baeed6ed88 | ||
|
|
94959c16da | ||
|
|
d61dc03079 | ||
|
|
190fb731be | ||
|
|
f887b423fa | ||
|
|
35331a5016 | ||
|
|
5dc6fb5b3d | ||
|
|
7489767651 | ||
|
|
83437b081f | ||
|
|
3d9992182e | ||
|
|
ca1ab4b5ec | ||
|
|
f974a7bd39 | ||
|
|
d943becaec | ||
|
|
f90010c867 | ||
|
|
ae3c7c37a6 | ||
|
|
559275daf4 | ||
|
|
c7c46785bb | ||
|
|
c9973b9558 | ||
|
|
94ba3a1a75 | ||
|
|
76a46c9a8c | ||
|
|
d4a71f40b5 | ||
|
|
5581fad8ce | ||
|
|
95b68c821a | ||
|
|
85b91cc652 | ||
|
|
c89729620f | ||
|
|
6b287605bb | ||
|
|
987618c978 | ||
|
|
8ddde9ba5e | ||
|
|
026eff8704 | ||
|
|
511b038890 | ||
|
|
830b9609c3 | ||
|
|
50641c27c2 | ||
|
|
b7ca22adcd | ||
|
|
2ecfb9838b | ||
|
|
060bc2c67a | ||
|
|
1cfe5ac099 | ||
|
|
14b549fa6e | ||
|
|
e08ad02dce | ||
|
|
4a6f92a77f | ||
|
|
506164521b | ||
|
|
5a7a51e38b | ||
|
|
98c83f620e | ||
|
|
f8383c9938 | ||
|
|
540db86c91 | ||
|
|
a41186b9e0 | ||
|
|
d908d1b47a | ||
|
|
276154c212 | ||
|
|
c082f6e705 | ||
|
|
f20ddf3f54 | ||
|
|
f05069226e | ||
|
|
d9ec1908b5 | ||
|
|
ce0081ccb7 | ||
|
|
e5e74fe4cf | ||
|
|
4d3d894226 | ||
|
|
0f13661bd3 | ||
|
|
aa504769b3 | ||
|
|
3ba9abf859 | ||
|
|
4ba6166544 | ||
|
|
3880369ce1 | ||
|
|
241cc57cd2 | ||
|
|
848f8ffc59 |
26
.github/workflows/publish-to-pypi.yml
vendored
Normal file
26
.github/workflows/publish-to-pypi.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: deploy
|
||||
|
||||
on: push
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Build lagent
|
||||
run: |
|
||||
pip install wheel
|
||||
python setup.py sdist bdist_wheel
|
||||
- name: Publish distribution to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -160,3 +160,4 @@ cython_debug/
|
||||
#.idea/
|
||||
.vscode/
|
||||
docs/*/_build/
|
||||
tmp_dir/
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 5.0.4
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.11.5
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
@@ -13,7 +13,7 @@ repos:
|
||||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
- id: mixed-line-ending
|
||||
args: ["--fix=lf"]
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.9
|
||||
rev: 0.7.17
|
||||
hooks:
|
||||
- id: mdformat
|
||||
args: ["--number"]
|
||||
@@ -35,16 +35,11 @@ repos:
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.1
|
||||
rev: v2.2.6
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.0.0
|
||||
rev: v3.15.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: ["--py36-plus"]
|
||||
|
||||
15
.readthedocs.yaml
Normal file
15
.readthedocs.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
version: 2
|
||||
|
||||
formats: all
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.10"
|
||||
|
||||
python:
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
|
||||
sphinx:
|
||||
configuration: docs/en/conf.py
|
||||
@@ -1,8 +0,0 @@
|
||||
version: 2
|
||||
|
||||
formats: all
|
||||
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
116
README.md
116
README.md
@@ -1,79 +1,95 @@
|
||||
# Lagent: Large Language Model as Agent
|
||||
<div id="top"></div>
|
||||
<div align="center">
|
||||
<img src="docs/imgs/lagent_logo.png" width="450"/>
|
||||
|
||||
English | [简体中文](README_zh-CN.md)
|
||||
[](https://lagent.readthedocs.io/en/latest/)
|
||||
[](https://pypi.org/project/lagent)
|
||||
[](https://github.com/InternLM/lagent/tree/main/LICENSE)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
## Introduction
|
||||
English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md)
|
||||
|
||||
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM. The overview of our framework is shown below:
|
||||
</div>
|
||||
|
||||

|
||||
<p align="center">
|
||||
👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">𝕏 (Twitter)</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a>
|
||||
</p>
|
||||
|
||||
### Major Features
|
||||
<div align="center">
|
||||
|
||||
- **Support multiple agent frameworks out of box.** We implement ReAct, AutoGPT and ReWOO, which enables the agents to drive LLMs for multiple trails of reasoning and tool utilization.
|
||||
https://github.com/InternLM/lagent/assets/24622904/3242f9bf-32d2-4907-8815-e16a75a4ac0e
|
||||
|
||||
- **Extremely simple and easy to extend.** The framework is quite simple with clear project structure. With only 20 lines of code, you are able to construct your own agent. It also supports three typical tools: Python interpreter, API call, and google search.
|
||||
|
||||
- **Support various large language models.** We support different LLMs, including API-based (GPT3.5/4) and HuggingFace-based (LLaMa2, InternLM) models.
|
||||
</div>
|
||||
|
||||
## Getting Started
|
||||
|
||||
Please see [Overview](docs/en/get_started/overview.md) for the general introduction of Lagent. Meanwhile, we provide extremely simple code for quick start. You may refer to [examples](examples/) for more details.
|
||||
Please see the [overview](docs/en/get_started/overview.md) for the general introduction of Lagent. Meanwhile, we provide extremely simple code for quick start. You may refer to [examples](examples/) for more details.
|
||||
|
||||
### Installation
|
||||
|
||||
```
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
Install with pip (Recommended).
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
### Run a ReWOO agent with GPT3.5
|
||||
### Run a Web Demo
|
||||
|
||||
```python
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, LLMQA
|
||||
from lagent.llms import GPTAPI
|
||||
You need to install Streamlit first.
|
||||
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo', key=['OPENAI_API_KEY'])
|
||||
search_tool = GoogleSearch(api_key='SERPER_API_KEY')
|
||||
llmqa_tool = LLMQA(llm)
|
||||
|
||||
chatbot = ReWOO(
|
||||
llm=llm,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, llmqa_tool]),
|
||||
)
|
||||
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
print(response.response)
|
||||
>>> Film director.
|
||||
```bash
|
||||
# pip install streamlit
|
||||
streamlit run examples/internlm2_agent_web_demo.py
|
||||
```
|
||||
|
||||
### Run a ReAct agent with InternLM
|
||||
## What's Lagent?
|
||||
|
||||
NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first.
|
||||
Lagent is a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents. It also provides some typical tools to augment LLM. The overview of our framework is shown below:
|
||||
|
||||
```python
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
|
||||
from lagent.llms import HFTransformer
|
||||

|
||||
|
||||
llm = HFTransformer('internlm/internlm-7b-chat-v1.1')
|
||||
search_tool = GoogleSearch()
|
||||
python_interpreter = PythonInterpreter()
|
||||
## Major Features
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=llm,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
- Stream Output: Provides the `stream_chat` interface for streaming output, allowing cool streaming demos right at your local setup.
|
||||
- Interfacing is unified, with a comprehensive design upgrade for enhanced extensibility, including:
|
||||
- Model: Whether it's the OpenAI API, Transformers, or LMDeploy inference acceleration framework, you can seamlessly switch between models.
|
||||
- Action: Simple inheritance and decoration allow you to create your own personal toolkit, adaptable to both InternLM and GPT.
|
||||
- Agent: Consistent with the Model's input interface, the transformation from model to intelligent agent only takes one step, facilitating the exploration and implementation of various agents.
|
||||
- Documentation has been thoroughly upgraded with full API documentation coverage.
|
||||
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$')
|
||||
print(response.response)
|
||||
>>> 根据已有的信息,可以求得$z=-1+\\sqrt{3}i$,然后代入计算,得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此,答案是(C)。
|
||||
## 💻Tech Stack
|
||||
|
||||
<p>
|
||||
<a href="">
|
||||
<img src="https://img.shields.io/badge/Python-007ACC?style=for-the-badge&logo=python&logoColor=yellow" alt="python" />
|
||||
</a>
|
||||
|
||||
### All Thanks To Our Contributors:
|
||||
|
||||
<a href="https://github.com/InternLM/lagent/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=InternLM/lagent" />
|
||||
</a>
|
||||
|
||||
## Citation
|
||||
|
||||
If you find this project useful in your research, please consider cite:
|
||||
|
||||
```latex
|
||||
@misc{lagent2023,
|
||||
title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents},
|
||||
author={Lagent Developer Team},
|
||||
howpublished = {\url{https://github.com/InternLM/lagent}},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is released under the [Apache 2.0 license](LICENSE).
|
||||
|
||||
<p align="right"><a href="#top">🔼 Back to top</a></p>
|
||||
|
||||
69
README_KR_Kr.md
Normal file
69
README_KR_Kr.md
Normal file
@@ -0,0 +1,69 @@
|
||||
<div id="top"></div>
|
||||
<div align="center">
|
||||
<img src="docs/imgs/lagent_logo.png" width="450"/>
|
||||
|
||||
[](https://lagent.readthedocs.io/en/latest/)
|
||||
[](https://pypi.org/project/lagent)
|
||||
[](https://github.com/InternLM/lagent/tree/main/LICENSE)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
|
||||
English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md)
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
## 시작하기
|
||||
|
||||
일반적인 Lagent 소개에 대한 [overview](docs/en/get_started/overview.md) 를 확인하십시오. 동시에 빠른 시작을 위한 매우 간단한 코드를 제공합니다. 자세한 내용은 [examples](examples/) 를 참조하십시오.
|
||||
|
||||
### 설치
|
||||
|
||||
pip를 사용하여 설치하십시오 (권장).
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
### 웹 데모 실행
|
||||
|
||||
먼저 streamlit을 설치해야 합니다
|
||||
|
||||
```bash
|
||||
# pip install streamlit
|
||||
streamlit run examples/internlm2_agent_web_demo.py
|
||||
```
|
||||
|
||||
## 소개
|
||||
|
||||
Lagent는 사용자가 효율적으로 대규모 언어 모델(LLM) 기반 에이전트를 구축할 수 있게 해주는 경량의 오픈 소스 프레임워크입니다. 또한 LLM을 보강하기 위한 몇 가지 일반적인 도구도 제공합니다. 우리 프레임워크의 개요는 아래와 같이 나와 있습니다:
|
||||
|
||||

|
||||
|
||||
## 인용
|
||||
|
||||
이 프로젝트가 귀하의 연구에 유용하다고 생각하면 다음과 같이 인용해 주십시오:
|
||||
|
||||
```latex
|
||||
@misc{lagent2023,
|
||||
title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents},
|
||||
author={Lagent Developer Team},
|
||||
howpublished = {\url{https://github.com/InternLM/lagent}},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 라이선스
|
||||
|
||||
이 프로젝트는 [Apache 2.0](LICENSE) 하에 공개되었습니다.
|
||||
|
||||
<p align="right"><a href="#top">🔼 Back to top</a></p>
|
||||
51
README_in_HIN.md
Normal file
51
README_in_HIN.md
Normal file
@@ -0,0 +1,51 @@
|
||||
<div id="top"></div>
|
||||
<div align="center">
|
||||
<img src="docs/imgs/lagent_logo.png" width="450"/>
|
||||
|
||||
[](https://lagent.readthedocs.io/en/latest/)
|
||||
[](https://pypi.org/project/lagent)
|
||||
[](https://github.com/InternLM/lagent/tree/main/LICENSE)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
|
||||
English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md)
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
## शुरू करना
|
||||
|
||||
लैजेंट के सामान्य परिचय के लिए कृपया [अवलोकन](docs/in/get_started/overview.md) देखें। इस बीच, हम त्वरित शुरुआत के लिए अत्यंत सरल कोड प्रदान करते हैं। अधिक जानकारी के लिए आप [उदाहरण](examples/) अधिक जानकारी के लिए।
|
||||
|
||||
### इंस्टालेशन
|
||||
|
||||
pip के साथ स्थापित करें (अनुशंसित)।
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
### वेब डेमो चलाएँ
|
||||
|
||||
```bash
|
||||
# You need to install streamlit first
|
||||
# pip install streamlit
|
||||
streamlit run examples/internlm2_agent_web_demo.py
|
||||
```
|
||||
|
||||
## परिचय
|
||||
|
||||
Lagent एक हल्का ओपन-सोर्स फ्रेमवर्क है जो उपयोगकर्ताओं को बड़े भाषा मॉडल (एलएलएम)-आधारित एजेंटों को कुशलतापूर्वक बनाने की अनुमति देता है। यह एलएलएम को बढ़ाने के लिए कुछ विशिष्ट उपकरण भी प्रदान करता है। हमारे ढांचे का अवलोकन नीचे दिखाया गया है:
|
||||
|
||||

|
||||
|
||||
## लाइसेंस
|
||||
|
||||
यह प्रोजेक्ट [Apache 2.0 license](LICENSE) के तहत जारी किया गया है।
|
||||
|
||||
<p align="right"><a href="#top">🔼 Back to top</a></p>
|
||||
55
README_in_beng.md
Normal file
55
README_in_beng.md
Normal file
@@ -0,0 +1,55 @@
|
||||
<div id="top"></div>
|
||||
<div align="center">
|
||||
<img src="docs/imgs/lagent_logo.png" width="450"/>
|
||||
|
||||
[](https://lagent.readthedocs.io/en/latest/)
|
||||
[](https://pypi.org/project/lagent)
|
||||
[](https://github.com/InternLM/lagent/tree/main/LICENSE)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
|
||||
English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md)
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
👋 <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> এবং <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> সাথে আমাদের সাথে যোগদান করুন
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
## শুরু করা
|
||||
|
||||
লেজেন্টের সাধারণ পরিচিতির জন্য [অবলো](docs/en/get_started/overview.md) দেখ
|
||||
|
||||
## ইনস্টলেশন
|
||||
|
||||
পিপ দিয়ে ইনস্টল করুন (সুপারিশ).
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
### ওয়েব ডেমো চালান
|
||||
|
||||
```bash
|
||||
# You need to install streamlit first
|
||||
# pip install streamlit
|
||||
streamlit run examples/internlm2_agent_web_demo.py
|
||||
```
|
||||
|
||||
## পরিচিতি
|
||||
|
||||
লেজেন্ট হল একটি হালকা ওপেন-সোর্স ফ্রেমওয়ার্ক, যা ব্যবহারকারীদের দ্বারা প্রশাসক ভাষা মডেল (LLM) ভিত্তিক এজেন্ট সৃজনশীলভাবে তৈরি করতে দেয়। এটি লেজেন্ট যেসব প্রধান সরঞ্জাম সরবরাহ করে, সেটি নীচে দেখানো হয়:
|
||||
|
||||

|
||||
|
||||
## লাইসেন্স
|
||||
|
||||
এই প্রকল্পটি [Apache 2.0 license](LICENSE) অনুসরণ করে প্রকাশিত হয়।
|
||||
|
||||
<p align="right"><a href="#top">🔼 Back to top</a></p>
|
||||
56
README_ja_JP.md
Normal file
56
README_ja_JP.md
Normal file
@@ -0,0 +1,56 @@
|
||||
<div id="top"></div>
|
||||
<div align="center">
|
||||
<img src="docs/imgs/lagent_logo.png" width="450"/>
|
||||
|
||||
[](https://lagent.readthedocs.io/en/latest/)
|
||||
[](https://pypi.org/project/lagent)
|
||||
[](https://github.com/InternLM/lagent/tree/main/LICENSE)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
|
||||
English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md)
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
👋 <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> そして <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> に参加する
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
## はじめに
|
||||
|
||||
Lagent の概要については[概要](docs/ja/get_started/overview.md)をご覧ください。また、クイックスタートのために非常にシンプルなコードを用意しています。詳細は [examples](examples/) を参照してください。
|
||||
|
||||
### インストール
|
||||
|
||||
pip でインストールする(推奨)。
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
### ウェブデモの実行
|
||||
|
||||
最初に streamlit をインストールする必要があります
|
||||
|
||||
```bash
|
||||
# pip install streamlit
|
||||
streamlit run examples/internlm2_agent_web_demo.py
|
||||
```
|
||||
|
||||
## はじめに
|
||||
|
||||
Lagent は、大規模言語モデル(LLM)ベースのエージェントを効率的に構築できる軽量なオープンソースフレームワークです。また、LLM を拡張するための典型的なツールも提供します。我々のフレームワークの概要を以下に示します:
|
||||
|
||||

|
||||
|
||||
## ライセンス
|
||||
|
||||
このプロジェクトは [Apache 2.0 license](LICENSE) の下でリリースされています。
|
||||
|
||||
<p align="right"><a href="#top">🔼 Back to top</a></p>
|
||||
111
README_zh-CN.md
111
README_zh-CN.md
@@ -1,79 +1,80 @@
|
||||
# Lagent: Large Language Model as Agent
|
||||
<div id="top"></div>
|
||||
<div align="center">
|
||||
<img src="docs/imgs/lagent_logo.png" width="450"/>
|
||||
|
||||
[](https://lagent.readthedocs.io/en/latest/)
|
||||
[](https://pypi.org/project/lagent)
|
||||
[](https://github.com/InternLM/lagent/tree/main/LICENSE)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
[](https://github.com/InternLM/lagent/issues)
|
||||
|
||||
English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md)
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
[English](README.md) | 简体中文
|
||||
|
||||
## 简介
|
||||
|
||||
Lagent是一个开源的LLM代理框架,支持用户快速地将一个大语言模型转变为多种类型的智能体,并提供了一些典型工具为大语言模型赋能。它的整个框架图如下:
|
||||
|
||||

|
||||
|
||||
### 主要特点
|
||||
|
||||
- **实现了多种类型的智能体,** 我们支持了经典的 ReAct,AutoGPT 和 ReWoo 等智能体,这些智能体能够调用大语言模型进行多轮的推理和工具调用。
|
||||
|
||||
- **框架简单易拓展.** 框架的代码结构清晰且简单,只需要不到20行代码你就能够创造出一个你自己的agent。同时我们支持了Python解释器、API 调用和搜索三类常用典型工具。
|
||||
|
||||
- **灵活支持多个大语言模型.** 我们提供了多种大语言模型支持,包括 InternLM、Llama-2 等开源模型和 GPT-4/3.5 等基于 API 的闭源模型。
|
||||
|
||||
## 教程
|
||||
|
||||
请阅读[概述](docs/en/get_started/overview.md)对Lagent进行初步的了解。同时, 我们提供了两个非常简单的code帮助你快速入门。 你也可以阅读[examples](examples/)获得更多的例子参考。
|
||||
请阅读[概述](docs/en/get_started/overview.md)对 Lagent 项目进行初步的了解。同时, 我们提供了两个非常简单的样例帮助你快速入门。 你也可以阅读[示例代码](examples/)获得更多的例子参考。
|
||||
|
||||
### 安装
|
||||
|
||||
```
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
通过 pip 进行安装 (推荐)。
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
### 用 GPT3.5 构建一个 ReWOO 代理
|
||||
### 运行一个智能体的网页样例
|
||||
|
||||
```python
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, LLMQA
|
||||
from lagent.llms import GPTAPI
|
||||
你可能需要先安装 Streamlit 包
|
||||
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo', key='OPENAI_API_KEY')
|
||||
search_tool = GoogleSearch(api_key='SERPER_API_KEY')
|
||||
llmqa_tool = LLMQA(llm)
|
||||
|
||||
chatbot = ReWOO(
|
||||
llm=llm,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, llmqa_tool]),
|
||||
)
|
||||
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
print(response.response)
|
||||
>>> Film director.
|
||||
```bash
|
||||
# pip install streamlit
|
||||
streamlit run examples/internlm2_agent_web_demo.py
|
||||
```
|
||||
|
||||
### 用 InternLM 构建一个 ReAct 代理
|
||||
## 简介
|
||||
|
||||
注意:如果你想要启动一个HuggingFace的模型,请先运行`pip install -e .[all]`。
|
||||
Lagent 是一个轻量级、开源的基于大语言模型的智能体(agent)框架,支持用户快速地将一个大语言模型转变为多种类型的智能体,并提供了一些典型工具为大语言模型赋能。它的整个框架图如下:
|
||||
|
||||
```python
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
|
||||
from lagent.llms import HFTransformer
|
||||

|
||||
|
||||
llm = HFTransformer('internlm/internlm-7b-chat-v1.1')
|
||||
search_tool = GoogleSearch(api_key='SERPER_API_KEY')
|
||||
python_interpreter = PythonInterpreter()
|
||||
## 特性
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=llm,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
- 流式输出:提供 `stream_chat` 接口作流式输出,本地就能演示酷炫的流式 Demo。
|
||||
- 接口统一,设计全面升级,提升拓展性,包括
|
||||
- Model : 不论是 OpenAI API, Transformers 还是推理加速框架 LMDeploy 一网打尽,模型切换可以游刃有余;
|
||||
- Action: 简单的继承和装饰,即可打造自己个人的工具集,不论 InternLM 还是 GPT 均可适配;
|
||||
- Agent:与 Model 的输入接口保持一致,模型到智能体的蜕变只需一步,便捷各种 agent 的探索实现;
|
||||
- 文档全面升级,API 文档全覆盖。
|
||||
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$ (A) $-1+\sqrt{3}i$ (B) $-1-\sqrt{3}i$ (C) $-\frac{1}{3}+\frac{{\sqrt{3}}}{3}i$ (D) $-\frac{1}{3}-\frac{{\sqrt{3}}}{3}i$')
|
||||
print(response.response)
|
||||
>>> 根据已有的信息,可以求得$z=-1+\\sqrt{3}i$,然后代入计算,得到结果为$-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$。因此,答案是(C)。
|
||||
## 引用
|
||||
|
||||
如果你觉得本项目对你的研究工作有所帮助,请参考如下 bibtex 引用 Lagent:
|
||||
|
||||
```latex
|
||||
@misc{lagent2023,
|
||||
title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents},
|
||||
author={Lagent Developer Team},
|
||||
howpublished = {\url{https://github.com/InternLM/lagent}},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 开源许可证
|
||||
|
||||
该项目采用[Apache 2.0 开源许可证](LICENSE)。
|
||||
|
||||
<p align="right"><a href="#top">🔼 Back to top</a></p>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
.header-logo {
|
||||
background-image: url("../images/robot.png");
|
||||
background-image: url("../images/lagent_icon.png");
|
||||
background-size: 40px 40px;
|
||||
height: 40px;
|
||||
width: 40px;
|
||||
|
||||
BIN
docs/en/_static/images/lagent_icon.png
Normal file
BIN
docs/en/_static/images/lagent_icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
14
docs/en/_templates/autoapi/index.rst
Normal file
14
docs/en/_templates/autoapi/index.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
API Reference
|
||||
=============
|
||||
|
||||
This page contains auto-generated API reference documentation.
|
||||
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for page in pages %}
|
||||
{% if page.top_level_object and page.display %}
|
||||
{{ page.include_path }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
112
docs/en/_templates/autoapi/python/module.rst
Normal file
112
docs/en/_templates/autoapi/python/module.rst
Normal file
@@ -0,0 +1,112 @@
|
||||
{% if not obj.display %}
|
||||
:orphan:
|
||||
|
||||
{% endif %}
|
||||
:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}`
|
||||
=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }}
|
||||
|
||||
.. py:module:: {{ obj.name }}
|
||||
|
||||
{% if obj.docstring %}
|
||||
.. autoapi-nested-parse::
|
||||
|
||||
{{ obj.docstring|indent(3) }}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% block subpackages %}
|
||||
{% set visible_subpackages = obj.subpackages|selectattr("display")|list %}
|
||||
{% if visible_subpackages %}
|
||||
Subpackages
|
||||
-----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for subpackage in visible_subpackages %}
|
||||
{{ subpackage.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block submodules %}
|
||||
{% set visible_submodules = obj.submodules|selectattr("display")|list %}
|
||||
{% if visible_submodules %}
|
||||
Submodules
|
||||
----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 1
|
||||
|
||||
{% for submodule in visible_submodules %}
|
||||
{{ submodule.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% if obj.type is equalto("package") %}
|
||||
{% set visible_children = obj.children|selectattr("display")|list %}
|
||||
{% else %}
|
||||
{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %}
|
||||
{% endif %}
|
||||
{% if visible_children %}
|
||||
{{ obj.type|title }} Contents
|
||||
{{ "-" * obj.type|length }}---------
|
||||
|
||||
{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %}
|
||||
{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %}
|
||||
{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %}
|
||||
{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %}
|
||||
{% block classes scoped %}
|
||||
{% if visible_classes %}
|
||||
Classes
|
||||
~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for klass in visible_classes %}
|
||||
{{ klass.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block functions scoped %}
|
||||
{% if visible_functions %}
|
||||
Functions
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for function in visible_functions %}
|
||||
{{ function.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block attributes scoped %}
|
||||
{% if visible_attributes %}
|
||||
Attributes
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for attribute in visible_attributes %}
|
||||
{{ attribute.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% endif %}
|
||||
{% for obj_item in visible_children %}
|
||||
{{ obj_item.render()|indent(0) }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
16
docs/en/changelog.md
Normal file
16
docs/en/changelog.md
Normal file
@@ -0,0 +1,16 @@
|
||||
## Changelog
|
||||
|
||||
### v0.1.2 (24/10/2023)
|
||||
|
||||
#### Highlights
|
||||
|
||||
- Support Efficient Inference Engine [lmdeploy turbomind](https://github.com/InternLM/lmdeploy/tree/main)
|
||||
|
||||
#### New Features
|
||||
|
||||
- Support Efficient Inference Engine [TurboMind](https://github.com/InternLM/lmdeploy/tree/main): Based on lmdeploy turbomind, Lagent supports the inference of LLaMA and its variant models on NVIDIA GPUs. (#47)
|
||||
|
||||
#### Contributors
|
||||
|
||||
A total of 2 developers contributed to this release.
|
||||
Thanks @Harold-lkk @jiangningliu30
|
||||
110
docs/en/conf.py
110
docs/en/conf.py
@@ -11,17 +11,16 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import pytorch_sphinx_theme
|
||||
|
||||
sys.path.insert(0, os.path.abspath('../../'))
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Lagent'
|
||||
copyright = '2020-2030, InternLM'
|
||||
author = 'InternLM'
|
||||
language = 'en'
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../lagent/version.py'
|
||||
@@ -36,97 +35,74 @@ release = __version__
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx_rtd_theme',
|
||||
'myst_nb',
|
||||
'autoapi.extension',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx_copybutton',
|
||||
'myst_parser',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.autodoc.typehints',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx_tabs.tabs',
|
||||
]
|
||||
|
||||
nb_output_stderr = 'remove-warn'
|
||||
autodoc_typehints = 'description'
|
||||
autosummary_generate = True # Turn on sphinx.ext.autosummary
|
||||
|
||||
# Ignore >>> when copying code
|
||||
copybutton_prompt_text = r'>>> |\.\.\. '
|
||||
copybutton_prompt_is_regexp = True
|
||||
|
||||
myst_enable_extensions = ['colon_fence']
|
||||
# sphinx-autoapi configuration
|
||||
autoapi_dirs = ['../../lagent']
|
||||
autoapi_options = [
|
||||
'members',
|
||||
'undoc-members',
|
||||
'show-inheritance',
|
||||
'show-module-summary',
|
||||
]
|
||||
autoapi_ignore = ['*migrations*', '*command.py', '*cli.py']
|
||||
autoapi_template_dir = '_templates/autoapi'
|
||||
autoapi_add_toctree_entry = False
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = {
|
||||
'.rst': 'restructuredtext',
|
||||
'.md': 'markdown',
|
||||
}
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
# html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = 'pytorch_sphinx_theme'
|
||||
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme_options = {
|
||||
'menu': [
|
||||
{
|
||||
'name': 'GitHub',
|
||||
'url': 'https://github.com/InternLM/lagent'
|
||||
},
|
||||
],
|
||||
# Specify the language of shared menu
|
||||
'menu_lang': 'en'
|
||||
'navigation_depth': 3,
|
||||
'titles_only': False,
|
||||
'style_nav_header_background': '#4fabab',
|
||||
}
|
||||
|
||||
language = 'en'
|
||||
html_context = {
|
||||
'display_github': True,
|
||||
'github_host': 'github.com',
|
||||
'github_user': 'InternLM',
|
||||
'github_repo': 'lagent',
|
||||
'github_version': 'main',
|
||||
'conf_py_path': '/docs/en/',
|
||||
}
|
||||
html_title = 'Lagent'
|
||||
html_logo = '../imgs/lagent_logo.png'
|
||||
html_favicon = '../imgs/lagent_icon.png'
|
||||
|
||||
master_doc = 'index'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
# so a file named 'default.css' will overwrite the builtin 'default.css'.
|
||||
html_static_path = ['_static']
|
||||
|
||||
html_css_files = [
|
||||
'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css',
|
||||
'css/readthedocs.css'
|
||||
]
|
||||
html_js_files = [
|
||||
'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js',
|
||||
'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js',
|
||||
'js/collapsed.js',
|
||||
'js/table.js',
|
||||
]
|
||||
|
||||
myst_heading_anchors = 4
|
||||
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'numpy': ('https://numpy.org/doc/stable', None),
|
||||
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||
}
|
||||
def custom_skip(app, what, name, obj, skip, options):
|
||||
if what in ['data', 'function', 'class'] and re.search('logger', name):
|
||||
skip = True
|
||||
return skip
|
||||
|
||||
|
||||
def builder_inited_handler(app):
|
||||
pass
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.connect('builder-inited', builder_inited_handler)
|
||||
def setup(sphinx):
|
||||
sphinx.connect('autoapi-skip-member', custom_skip)
|
||||
|
||||
19
docs/en/get_started/install.md
Normal file
19
docs/en/get_started/install.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Installation
|
||||
|
||||
## With pip
|
||||
|
||||
Install with pip (Recommended).
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
## From source
|
||||
|
||||
Optionally, you could also build Lagent from source in case you want to modify the code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
```
|
||||
@@ -1,10 +1,10 @@
|
||||
# OVERVIEW
|
||||
# Overview
|
||||
|
||||
This chapter introduces you to the framework of Lagent, and provides links to detailed tutorials about Lagent.
|
||||
|
||||
## What is Lagent
|
||||
|
||||
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below:
|
||||
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ability of LLM, and the whole framework is shown below:
|
||||
|
||||

|
||||
|
||||
@@ -18,6 +18,6 @@ Lagent consists of 3 main parts, agents, llms, and actions.
|
||||
|
||||
Here is a detailed step-by-step guide to learn more about Lagent:
|
||||
|
||||
1. For installation instructions, please see [README](../README.md).
|
||||
1. For installation instructions, please see [README](https://github.com/InternLM/lagent/blob/main/README.md).
|
||||
|
||||
2. We provide several examples to build agents with Lagent in [examples](examples/) by simply run `python examples/react_example.py`.
|
||||
2. We provide several examples to build agents with Lagent in [examples](https://github.com/InternLM/lagent/tree/main/examples) by simply run `python examples/react_example.py`.
|
||||
|
||||
89
docs/en/get_started/quickstart.md
Normal file
89
docs/en/get_started/quickstart.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Quickstart
|
||||
|
||||
Using Lagent, you can easily build agents with just a few lines of code.
|
||||
|
||||
## Run a ReWOO agent with GPT-3.5
|
||||
|
||||
Below is an example of running ReWOO with GPT-3.5
|
||||
|
||||
```python
|
||||
# Import necessary modules and classes from the "lagent" library.
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.actions import ActionExecutor, GoogleSearch
|
||||
from lagent.llms import GPTAPI
|
||||
|
||||
# Initialize the Language Model (llm) and provide your API key.
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
|
||||
|
||||
# Initialize the Google Search tool and provide your API key.
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# Create a chatbot by configuring the ReWOO agent.
|
||||
chatbot = ReWOO(
|
||||
llm=llm, # Provide the Language Model instance.
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool] # Specify the actions the chatbot can perform.
|
||||
),
|
||||
)
|
||||
|
||||
# Ask the chatbot a question and store the response.
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
|
||||
# Print the chatbot's response.
|
||||
print(response.response) # Output the response generated by the chatbot.
|
||||
```
|
||||
|
||||
```python
|
||||
>>> Film director.
|
||||
```
|
||||
|
||||
## Run a ReAct agent with InternLM
|
||||
|
||||
NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first.
|
||||
|
||||
```python
|
||||
# Import necessary modules and classes from the "lagent" library.
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
|
||||
from lagent.llms import HFTransformer
|
||||
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
|
||||
# Initialize the HFTransformer-based Language Model (llm) and
|
||||
# provide the model name.
|
||||
llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META)
|
||||
|
||||
# Initialize the Google Search tool and provide your API key.
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# Initialize the Python Interpreter tool.
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
# Create a chatbot by configuring the ReAct agent.
|
||||
# Specify the actions the chatbot can perform.
|
||||
chatbot = ReAct(
|
||||
llm=llm, # Provide the Language Model instance.
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]),
|
||||
)
|
||||
# Ask the chatbot a mathematical question in LaTeX format.
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$')
|
||||
|
||||
# Print the chatbot's response.
|
||||
print(response.response) # Output the response generated by the chatbot.
|
||||
```
|
||||
|
||||
```python
|
||||
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
|
||||
```
|
||||
|
||||
## Run ReAct Web Demo
|
||||
|
||||
```python
|
||||
# You need to install streamlit first
|
||||
# pip install streamlit
|
||||
streamlit run examples/react_web_demo.py
|
||||
```
|
||||
|
||||
Then you can chat through the UI shown as below
|
||||

|
||||
@@ -8,13 +8,31 @@ You can switch between English and Chinese in the lower-left corner of the layou
|
||||
:caption: Get Started
|
||||
|
||||
get_started/overview.md
|
||||
get_started/install.md
|
||||
get_started/quickstart.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Tutorials
|
||||
|
||||
tutorials/action.md
|
||||
|
||||
.. toctree::
|
||||
:caption: Switch Language
|
||||
|
||||
switch_language.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: API Reference
|
||||
|
||||
autoapi/lagent/actions/index
|
||||
autoapi/lagent/agents/index
|
||||
autoapi/lagent/llms/index
|
||||
autoapi/lagent/utils/index
|
||||
autoapi/lagent/schema/index
|
||||
autoapi/lagent/version/index
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
## <a href='https://lagent.readthedocs.io/en/main/'>English</a>
|
||||
## <a href='https://lagent.readthedocs.io/en/latest/'>English</a>
|
||||
|
||||
## <a href='https://lagent.readthedocs.io/zh_CN/main/'>简体中文</a>
|
||||
## <a href='https://lagent.readthedocs.io/zh-cn/latest/'>简体中文</a>
|
||||
|
||||
400
docs/en/tutorials/action.md
Normal file
400
docs/en/tutorials/action.md
Normal file
@@ -0,0 +1,400 @@
|
||||
# Action
|
||||
|
||||
Actions, also called **tools**, provide a suite of functions LLM-driven agents can use to interact with the real world and perform complex tasks.
|
||||
|
||||
## Basic Concepts
|
||||
|
||||
### Tool & Toolkit
|
||||
|
||||
There are two categories of tools:
|
||||
|
||||
- tool: provide only one API to call.
|
||||
- toolkit: implement multiple APIs that undertake different sub-tasks.
|
||||
|
||||
### Tool Description
|
||||
|
||||
In Lagent, the tool description is a dictionary containing the action's core information of usage, observed by LLMs for decision-making.
|
||||
|
||||
For simple tools, the description can be created as follows
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'bold', # name of the tool
|
||||
'description': 'a function used to make text bold', # introduce the tool's function
|
||||
'parameters': [ # a list of parameters the tool take.
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text'], # specify names of parameters required
|
||||
}
|
||||
```
|
||||
|
||||
In some situations there may be optional `return_data`, `parameter_description` keys describing the returns and argument passing format respectively.
|
||||
|
||||
```{attention}
|
||||
`parameter_description` is usually inserted into the tool description automatically by the action's parser. It will be introduced in [Interface Design](#interface-design) .
|
||||
```
|
||||
|
||||
For toolkits, the description is very similar but nest submethods
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'PhraseEmphasis', # name of the toolkit
|
||||
'description': 'a toolkit which provides different styles of text emphasis', # introduce the tool's function
|
||||
'api_list': [
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
},
|
||||
{
|
||||
'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Make Functions Tools
|
||||
|
||||
It's not necessary to prepare an extra description for a defined function. In Lagent we provide a decorator `tool_api` which can conveniently turn a function into a tool by automatically parsing the function's typehints and dosctrings to generate the description dictionary and binding it to an attribute `api_description`.
|
||||
|
||||
```python
|
||||
from lagent import tool_api
|
||||
|
||||
@tool_api
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text']}
|
||||
```
|
||||
|
||||
Once `returns_named_value` is enabled you should declare the name of the return data, which will be processed to form a new field `return_data`:
|
||||
|
||||
```python
|
||||
@tool_api(returns_named_value=True)
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
bold_text (str): bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'return_data': [{'name': 'bold_text',
|
||||
'description': 'bold text',
|
||||
'type': 'STRING'}]}
|
||||
```
|
||||
|
||||
Sometimes the tool may return a `dict` or `tuple`, and you want to elaborate each member in `return_data` rather than take them as a whole. Set `explode_return=True` and list them in the return part of docstrings.
|
||||
|
||||
```python
|
||||
@tool_api(explode_return=True)
|
||||
def list_args(a: str, b: int, c: float = 0.0) -> dict:
|
||||
"""Return arguments in dict format
|
||||
|
||||
Args:
|
||||
a (str): a
|
||||
b (int): b
|
||||
c (float): c
|
||||
|
||||
Returns:
|
||||
dict: input arguments
|
||||
- a (str): a
|
||||
- b (int): b
|
||||
- c: c
|
||||
"""
|
||||
return {'a': a, 'b': b, 'c': c}
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'list_args',
|
||||
'description': 'Return arguments in dict format',
|
||||
'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'},
|
||||
{'name': 'b', 'type': 'NUMBER', 'description': 'b'},
|
||||
{'name': 'c', 'type': 'FLOAT', 'description': 'c'}],
|
||||
'required': ['a', 'b'],
|
||||
'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'},
|
||||
{'name': 'b', 'description': 'b', 'type': 'NUMBER'},
|
||||
{'name': 'c', 'description': 'c'}]}
|
||||
```
|
||||
|
||||
```{warning}
|
||||
Only Google style Python docstrings is currently supported.
|
||||
```
|
||||
|
||||
## Interface Design
|
||||
|
||||
`BaseAction(description=None, parser=JsonParser, enable=True)` is the base class all actions should inherit from. It takes three initialization arguments
|
||||
|
||||
- **description**: a tool description dictionary, used set instance attribute `description`. Mostly you don't need explicitly pass this argument since the meta class of `BaseAction` will search methods decorated by `tool_api` and assemble their `api_description` as a class attribute `__tool_description__`, and if the initial `description` is left null, then `__tool_description__` will be copied as `description`.
|
||||
|
||||
- **parser**: `BaseParser` class. It will instantialize a parser used to validate the arguments of APIs in `description`.
|
||||
|
||||
For example, `JsonParser` requires arguments passed in the format of JSON or `dict`. To make LLMs aware of this, It inserts a field `parameter_description` into the `description`.
|
||||
|
||||
```python
|
||||
from lagent import BaseAction
|
||||
|
||||
action = BaseAction(
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
)
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input content'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}
|
||||
```
|
||||
|
||||
- **enable**: specify whether the tool is available.
|
||||
|
||||
### Custom Action
|
||||
|
||||
A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word.
|
||||
|
||||
```{tip}
|
||||
`run` is allowed not to be decorated by `tool_api` for simple tools unless you want to hint the return data.
|
||||
```
|
||||
|
||||
```python
|
||||
class Bold(BaseAction):
|
||||
|
||||
def run(self, text: str):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
class PhraseEmphasis(BaseAction):
|
||||
"""a toolkit which provides different styles of text emphasis"""
|
||||
|
||||
@tool_api
|
||||
def bold(self, text):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
@tool_api
|
||||
def italic(self, text):
|
||||
"""make text italic
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: italic text
|
||||
"""
|
||||
return '*' + text + '*'
|
||||
|
||||
# Inspect the default description
|
||||
# Bold.__tool_description__, PhraseEmphasis.__tool_description__
|
||||
```
|
||||
|
||||
### Auto-registration
|
||||
|
||||
Any subclass of `BaseAction` will be registered automatically. You can use `list_tools()` and `get_tool()` to view all tools and initialize by name.
|
||||
|
||||
```python
|
||||
from lagent import list_tools, get_tool
|
||||
|
||||
list_tools()
|
||||
```
|
||||
|
||||
```python
|
||||
['BaseAction',
|
||||
'InvalidAction',
|
||||
'NoAction',
|
||||
'FinishAction',
|
||||
'ArxivSearch',
|
||||
'BINGMap',
|
||||
'GoogleScholar',
|
||||
'GoogleSearch',
|
||||
'IPythonInterpreter',
|
||||
'PPT',
|
||||
'PythonInterpreter',
|
||||
'Bold',
|
||||
'PhraseEmphasis']
|
||||
```
|
||||
|
||||
Create a `PhraseEmphasis` object
|
||||
|
||||
```python
|
||||
action = get_tool('PhraseEmphasis')
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'PhraseEmphasis',
|
||||
'description': 'a toolkit which provides different styles of text emphasis',
|
||||
'api_list': [{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]}
|
||||
```
|
||||
|
||||
## Tool Calling
|
||||
|
||||
### Run a Tool
|
||||
|
||||
`__call__` method of `Action` takes two arguments
|
||||
|
||||
- `inputs`: It depends on the action's parser. Often a string in specific formats generated by LLMs.
|
||||
- `JsonParser`: Allow passing arguments in the format of JSON string or Python `dict`.
|
||||
- `TupleParser`: Allow passing arguments in the format of tuple string format or Python `tuple`.
|
||||
- `name`: Which API to call. Default is `run`.
|
||||
|
||||
It returns an `ActionReturn` object which encapsulates calling details
|
||||
|
||||
- `args`: Dictionary of action inputs.
|
||||
- `type`: Action name.
|
||||
- `result`: List of dicts. Each contains two keys: 'type' and 'content'. when errors occur, it is `None`.
|
||||
- `errmsg`: Error message. Default is `None`.
|
||||
|
||||
Below is an example
|
||||
|
||||
```python
|
||||
from lagent import IPythonInterpreter, TupleParser
|
||||
|
||||
action1 = IPythonInterpreter()
|
||||
ret = action1('{"command": "import math;math.sqrt(100)"}')
|
||||
print(ret.result)
|
||||
ret = action1({'command': 'import math;math.sqrt(100)'})
|
||||
print(ret.result)
|
||||
|
||||
action2 = IPythonInterpreter(parser=TupleParser)
|
||||
ret = action2('("import math;math.sqrt(100)", )')
|
||||
print(ret.result)
|
||||
ret = action2(('import math;math.sqrt(100)',))
|
||||
print(ret.result)
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
|
||||
### Dynamic Invocation
|
||||
|
||||
Lagent provides an `ActionExecutor` to manage multiple tools. It will flatten `api_list` of toolkits and rename each `{tool_name}.{api_name}`.
|
||||
|
||||
```python
|
||||
from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
|
||||
executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()])
|
||||
executor.get_actions_info() # This information is fed to LLMs as the tool meta prompt
|
||||
```
|
||||
|
||||
```python
|
||||
[{'name': 'ArxivSearch.get_arxiv_article_information',
|
||||
'description': 'Run Arxiv search and get the article meta information.',
|
||||
'parameters': [{'name': 'query',
|
||||
'type': 'STRING',
|
||||
'description': 'the content of search query'}],
|
||||
'required': ['query'],
|
||||
'return_data': [{'name': 'content',
|
||||
'description': 'a list of 3 arxiv search papers',
|
||||
'type': 'STRING'}],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'IPythonInterpreter',
|
||||
'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.",
|
||||
'parameters': [{'name': 'command',
|
||||
'type': 'STRING',
|
||||
'description': 'Python code'},
|
||||
{'name': 'timeout',
|
||||
'type': 'NUMBER',
|
||||
'description': 'Upper bound of waiting time for Python script execution.'}],
|
||||
'required': ['command'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]
|
||||
```
|
||||
|
||||
Trigger an action through the executor
|
||||
|
||||
```python
|
||||
ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}')
|
||||
ret.result
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
BIN
docs/imgs/lagent_icon.png
Normal file
BIN
docs/imgs/lagent_icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
BIN
docs/imgs/lagent_logo.png
Normal file
BIN
docs/imgs/lagent_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
15
docs/zh_cn/.readthedocs.yaml
Normal file
15
docs/zh_cn/.readthedocs.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
version: 2
|
||||
|
||||
formats: all
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.10"
|
||||
|
||||
python:
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
|
||||
sphinx:
|
||||
configuration: docs/zh_cn/conf.py
|
||||
@@ -1,5 +1,5 @@
|
||||
.header-logo {
|
||||
background-image: url("../images/robot.png");
|
||||
background-image: url("../images/lagent_icon.png");
|
||||
background-size: 40px 40px;
|
||||
height: 40px;
|
||||
width: 40px;
|
||||
|
||||
BIN
docs/zh_cn/_static/images/lagent_icon.png
Normal file
BIN
docs/zh_cn/_static/images/lagent_icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
14
docs/zh_cn/_templates/autoapi/index.rst
Normal file
14
docs/zh_cn/_templates/autoapi/index.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
API Reference
|
||||
=============
|
||||
|
||||
This page contains auto-generated API reference documentation.
|
||||
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for page in pages %}
|
||||
{% if page.top_level_object and page.display %}
|
||||
{{ page.include_path }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
112
docs/zh_cn/_templates/autoapi/python/module.rst
Normal file
112
docs/zh_cn/_templates/autoapi/python/module.rst
Normal file
@@ -0,0 +1,112 @@
|
||||
{% if not obj.display %}
|
||||
:orphan:
|
||||
|
||||
{% endif %}
|
||||
:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}`
|
||||
=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }}
|
||||
|
||||
.. py:module:: {{ obj.name }}
|
||||
|
||||
{% if obj.docstring %}
|
||||
.. autoapi-nested-parse::
|
||||
|
||||
{{ obj.docstring|indent(3) }}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% block subpackages %}
|
||||
{% set visible_subpackages = obj.subpackages|selectattr("display")|list %}
|
||||
{% if visible_subpackages %}
|
||||
Subpackages
|
||||
-----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 3
|
||||
|
||||
{% for subpackage in visible_subpackages %}
|
||||
{{ subpackage.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block submodules %}
|
||||
{% set visible_submodules = obj.submodules|selectattr("display")|list %}
|
||||
{% if visible_submodules %}
|
||||
Submodules
|
||||
----------
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
:maxdepth: 1
|
||||
|
||||
{% for submodule in visible_submodules %}
|
||||
{{ submodule.short_name }}/index.rst
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% if obj.type is equalto("package") %}
|
||||
{% set visible_children = obj.children|selectattr("display")|list %}
|
||||
{% else %}
|
||||
{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %}
|
||||
{% endif %}
|
||||
{% if visible_children %}
|
||||
{{ obj.type|title }} Contents
|
||||
{{ "-" * obj.type|length }}---------
|
||||
|
||||
{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %}
|
||||
{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %}
|
||||
{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %}
|
||||
{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %}
|
||||
{% block classes scoped %}
|
||||
{% if visible_classes %}
|
||||
Classes
|
||||
~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for klass in visible_classes %}
|
||||
{{ klass.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block functions scoped %}
|
||||
{% if visible_functions %}
|
||||
Functions
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for function in visible_functions %}
|
||||
{{ function.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block attributes scoped %}
|
||||
{% if visible_attributes %}
|
||||
Attributes
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoapisummary::
|
||||
|
||||
{% for attribute in visible_attributes %}
|
||||
{{ attribute.id }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
{% endif %}
|
||||
{% for obj_item in visible_children %}
|
||||
{{ obj_item.render()|indent(0) }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
@@ -11,18 +11,16 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import sys
|
||||
|
||||
import pytorch_sphinx_theme
|
||||
|
||||
sys.path.insert(0, os.path.abspath('../../'))
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Lagent'
|
||||
copyright = '2020-2030, InternLM'
|
||||
author = 'InternLM'
|
||||
language = 'zh_CN'
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../lagent/version.py'
|
||||
@@ -37,97 +35,74 @@ release = __version__
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx_rtd_theme',
|
||||
'myst_nb',
|
||||
'autoapi.extension',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx_copybutton',
|
||||
'myst_parser',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.autodoc.typehints',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx_tabs.tabs',
|
||||
]
|
||||
|
||||
nb_output_stderr = 'remove-warn'
|
||||
autodoc_typehints = 'description'
|
||||
|
||||
autosummary_generate = True # Turn on sphinx.ext.autosummary
|
||||
# Ignore >>> when copying code
|
||||
copybutton_prompt_text = r'>>> |\.\.\. '
|
||||
copybutton_prompt_is_regexp = True
|
||||
|
||||
myst_enable_extensions = ['colon_fence']
|
||||
# sphinx-autoapi configuration
|
||||
autoapi_dirs = ['../../lagent']
|
||||
autoapi_options = [
|
||||
'members',
|
||||
'undoc-members',
|
||||
'show-inheritance',
|
||||
'show-module-summary',
|
||||
]
|
||||
autoapi_ignore = ['*migrations*', '*command.py', '*cli.py']
|
||||
autoapi_template_dir = '_templates/autoapi'
|
||||
autoapi_add_toctree_entry = False
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = {
|
||||
'.rst': 'restructuredtext',
|
||||
'.md': 'markdown',
|
||||
}
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
# html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = 'pytorch_sphinx_theme'
|
||||
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme_options = {
|
||||
'menu': [
|
||||
{
|
||||
'name': 'GitHub',
|
||||
'url': 'https://github.com/InternLM/lagent'
|
||||
},
|
||||
],
|
||||
# Specify the language of shared menu
|
||||
'menu_lang': 'cn',
|
||||
'navigation_depth': 3,
|
||||
'titles_only': False,
|
||||
'style_nav_header_background': '#4fabab',
|
||||
}
|
||||
|
||||
language = 'zh_CN'
|
||||
html_context = {
|
||||
'display_github': True,
|
||||
'github_host': 'github.com',
|
||||
'github_user': 'InternLM',
|
||||
'github_repo': 'lagent',
|
||||
'github_version': 'main',
|
||||
'conf_py_path': '/docs/zh_cn/',
|
||||
}
|
||||
html_title = 'Lagent'
|
||||
html_logo = '../imgs/lagent_logo.png'
|
||||
html_favicon = '../imgs/lagent_icon.png'
|
||||
|
||||
master_doc = 'index'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
# so a file named 'default.css' will overwrite the builtin 'default.css'.
|
||||
html_static_path = ['_static']
|
||||
html_css_files = [
|
||||
'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css',
|
||||
'css/readthedocs.css'
|
||||
]
|
||||
html_js_files = [
|
||||
'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js',
|
||||
'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js',
|
||||
'js/collapsed.js',
|
||||
'js/table.js',
|
||||
]
|
||||
|
||||
myst_heading_anchors = 4
|
||||
|
||||
# Configuration for intersphinx
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'numpy': ('https://numpy.org/doc/stable', None),
|
||||
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||
}
|
||||
|
||||
|
||||
def builder_inited_handler(app):
|
||||
subprocess.run(['./cp_origin_docs.sh'])
|
||||
def custom_skip(app, what, name, obj, skip, options):
|
||||
if what in ['data', 'function', 'class'] and re.search('logger', name):
|
||||
skip = True
|
||||
return skip
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.connect('builder-inited', builder_inited_handler)
|
||||
def setup(sphinx):
|
||||
sphinx.connect('autoapi-skip-member', custom_skip)
|
||||
|
||||
19
docs/zh_cn/get_started/install.md
Normal file
19
docs/zh_cn/get_started/install.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# 安装方式
|
||||
|
||||
## pip安装
|
||||
|
||||
推荐使用 pip 安装
|
||||
|
||||
```bash
|
||||
pip install lagent
|
||||
```
|
||||
|
||||
## 源码安装
|
||||
|
||||
如需修改部分功能,可以从源码构建 Lagent
|
||||
|
||||
```bash
|
||||
git clone https://github.com/InternLM/lagent.git
|
||||
cd lagent
|
||||
pip install -e .
|
||||
```
|
||||
23
docs/zh_cn/get_started/overview.md
Normal file
23
docs/zh_cn/get_started/overview.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# 总览
|
||||
|
||||
本章节将介绍 Lagent 的架构,并提供 Lagent 详细教程的链接。
|
||||
|
||||
## Lagent 是什么
|
||||
|
||||
Lagent 是一个开源的 LLM 智能体框架,允许使用者快速将一个大语言模型转换成智能体,并提供一些典型工具来激发大语言模型的潜能。Lagent 框架图如下:
|
||||
|
||||

|
||||
|
||||
Lagent 包含三个主要模块:agents,llms 和 actions。
|
||||
|
||||
- **agents** 实现了多种智能体,如 ReAct,AutoGPT。
|
||||
- **llms** 支持多种大语言模型,包括在 HuggingFace 上托管的开源模型(Llama-2, InterLM)及 GPT3.5/4 等闭源模型。
|
||||
- **actions** 包含一系列工具,并提供工具执行器来统一管理。
|
||||
|
||||
## 如何使用
|
||||
|
||||
以下是帮助您了解关于 Lagent 更多信息的详细教程:
|
||||
|
||||
1. 安装请参考 [README](https://github.com/InternLM/lagent/blob/main/README.md).
|
||||
|
||||
2. 一些构建智能体的实例 [examples](https://github.com/InternLM/lagent/tree/main/examples),直接运行脚本即可,如 `python examples/react_example.py`.
|
||||
87
docs/zh_cn/get_started/quickstart.md
Normal file
87
docs/zh_cn/get_started/quickstart.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# 快速上手
|
||||
|
||||
借助 Lagent 仅需几行代码就能构建大语言模型智能体。
|
||||
|
||||
## GPT-3.5 驱动的 ReWOO 智能体
|
||||
|
||||
下面是使用 GPT-3.5 运行 ReWOO 的示例
|
||||
|
||||
```python
|
||||
# 从 Lagent 导入必要的模块和类
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.actions import ActionExecutor, GoogleSearch
|
||||
from lagent.llms import GPTAPI
|
||||
|
||||
# 初始化 LLM,你可能需要提供 API 密钥
|
||||
llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY'])
|
||||
|
||||
# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# 配置 ReWOO 智能体,创建聊天机器人
|
||||
chatbot = ReWOO(
|
||||
llm=llm, # 大语言模型实例
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool] # 指定智能体可以调用的工具
|
||||
),
|
||||
)
|
||||
|
||||
# 询问问题并获取回复
|
||||
response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common')
|
||||
|
||||
# 打印回复
|
||||
print(response.response)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> Film director.
|
||||
```
|
||||
|
||||
## InterLM 驱动的 ReAct 智能体
|
||||
|
||||
注意,如果你想使用 HuggingFace 模型,请先运行 `pip install -e .[all]`
|
||||
|
||||
```python
|
||||
# 从 Lagent 导入必要的模块和类
|
||||
from lagent.agents import ReAct
|
||||
from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
|
||||
from lagent.llms import HFTransformer
|
||||
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
|
||||
# 初始化 HFTransformer 模型
|
||||
llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META)
|
||||
|
||||
# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥
|
||||
search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
|
||||
|
||||
# 初始化 Python 代码解释其
|
||||
python_interpreter = PythonInterpreter()
|
||||
|
||||
# 配置 ReAct 智能体,创建聊天机器人
|
||||
chatbot = ReAct(
|
||||
llm=llm, # 大语言模型实例
|
||||
action_executor=ActionExecutor(
|
||||
actions=[search_tool, python_interpreter]), # 指定智能体可以调用的工具
|
||||
)
|
||||
# 询问LaTeX格式的数学问题
|
||||
response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$')
|
||||
|
||||
# 打印回复
|
||||
print(response.response)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$
|
||||
```
|
||||
|
||||
## 启动 ReAct 网页 App
|
||||
|
||||
```python
|
||||
# 你需要先安装 streamlit
|
||||
# pip install streamlit
|
||||
streamlit run examples/react_web_demo.py
|
||||
```
|
||||
|
||||
然后你可以通过下图所示UI界面进行对话
|
||||

|
||||
@@ -8,12 +8,32 @@
|
||||
:caption: 新手入门
|
||||
|
||||
get_started/overview.md
|
||||
get_started/install.md
|
||||
get_started/quickstart.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: 教程
|
||||
|
||||
tutorials/action.md
|
||||
|
||||
.. toctree::
|
||||
:caption: 切换语言
|
||||
|
||||
switch_language.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: API 参考
|
||||
|
||||
autoapi/lagent/actions/index
|
||||
autoapi/lagent/agents/index
|
||||
autoapi/lagent/llms/index
|
||||
autoapi/lagent/utils/index
|
||||
autoapi/lagent/schema/index
|
||||
autoapi/lagent/version/index
|
||||
|
||||
|
||||
导引
|
||||
==================
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
## <a href='https://lagent.readthedocs.io/en/main/'>English</a>
|
||||
## <a href='https://lagent.readthedocs.io/en/latest/'>English</a>
|
||||
|
||||
## <a href='https://lagent.readthedocs.io/zh_CN/main/'>简体中文</a>
|
||||
## <a href='https://lagent.readthedocs.io/zh-cn/latest/'>简体中文</a>
|
||||
|
||||
398
docs/zh_cn/tutorials/action.md
Normal file
398
docs/zh_cn/tutorials/action.md
Normal file
@@ -0,0 +1,398 @@
|
||||
# 动作
|
||||
|
||||
动作,也被称为工具,提供了一套LLM驱动的智能体用来与真实世界交互并执行复杂任务的函数。
|
||||
|
||||
## 基本概念
|
||||
|
||||
### 工具 & 工具包
|
||||
|
||||
有两种类型的工具:
|
||||
|
||||
- 简单工具: 只提供一个API接口供调用。
|
||||
- 工具包: 实现多个API接口,承担不同的子任务。
|
||||
|
||||
### 工具描述
|
||||
|
||||
在Lagent中,工具描述是一个刻画工具调用方式的字典,能够被LLM观察并用于决策。
|
||||
|
||||
对于简单工具,描述可按如下格式声明:
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'bold', # 工具名称
|
||||
'description': 'a function used to make text bold', # 介绍工具的功能
|
||||
'parameters': [ # 这个工具所需要的参数列表
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text'], # 指定必需的参数名
|
||||
}
|
||||
```
|
||||
|
||||
在某些情况下,可能还包含 `return_data`,`parameter_description` 字段,分别描述返回内容及参数传递格式。
|
||||
|
||||
```{attention}
|
||||
`parameter_description` 通常被动作的解析器自动插入到工具描述中,这部分将在[接口设计](#id6)中进行介绍。
|
||||
```
|
||||
|
||||
对于工具包,描述非常相似,但嵌套了子方法
|
||||
|
||||
```python
|
||||
TOOL_DESCRIPTION = {
|
||||
'name': 'PhraseEmphasis', # 工具包的名字
|
||||
'description': 'a toolkit which provides different styles of text emphasis', # 介绍工具包的功能
|
||||
'api_list': [
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
},
|
||||
{
|
||||
'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 将函数转换为工具
|
||||
|
||||
对于已定义好的函数,无需人工添加额外的描述。在 Lagent 中,我们提供了一个修饰器 `tool_api`,它可以通过自动解析函数的类型提示和文档字符串来生成描述字典,并将其绑定到属性 `api_description`。
|
||||
|
||||
```python
|
||||
from lagent import tool_api
|
||||
|
||||
@tool_api
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text']}
|
||||
```
|
||||
|
||||
一旦启用 `returns_named_value`,您应当声明返回值的名称,这将被处理成一个新的字段 `return_data`:
|
||||
|
||||
```python
|
||||
@tool_api(returns_named_value=True)
|
||||
def bold(text: str) -> str:
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
bold_text (str): bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
bold.api_description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'return_data': [{'name': 'bold_text',
|
||||
'description': 'bold text',
|
||||
'type': 'STRING'}]}
|
||||
```
|
||||
|
||||
有时工具可能返回一个 `dict` 或 `tuple`,如果你想在 `return_data` 中详细说明每个成员的含义而不是把它们当作一个整体,设置 `explode_return=True` 并在文档字符串的 Returns 部分中罗列它们。
|
||||
|
||||
```python
|
||||
@tool_api(explode_return=True)
|
||||
def list_args(a: str, b: int, c: float = 0.0) -> dict:
|
||||
"""Return arguments in dict format
|
||||
|
||||
Args:
|
||||
a (str): a
|
||||
b (int): b
|
||||
c (float): c
|
||||
|
||||
Returns:
|
||||
dict: input arguments
|
||||
- a (str): a
|
||||
- b (int): b
|
||||
- c: c
|
||||
"""
|
||||
return {'a': a, 'b': b, 'c': c}
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'list_args',
|
||||
'description': 'Return arguments in dict format',
|
||||
'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'},
|
||||
{'name': 'b', 'type': 'NUMBER', 'description': 'b'},
|
||||
{'name': 'c', 'type': 'FLOAT', 'description': 'c'}],
|
||||
'required': ['a', 'b'],
|
||||
'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'},
|
||||
{'name': 'b', 'description': 'b', 'type': 'NUMBER'},
|
||||
{'name': 'c', 'description': 'c'}]}
|
||||
```
|
||||
|
||||
```{warning}
|
||||
目前仅支持 Google 格式的 Python 文档字符串。
|
||||
```
|
||||
|
||||
## 接口设计
|
||||
|
||||
`BaseAction(description=None, parser=JsonParser, enable=True)` 是所有动作应该继承的基类,它接收三个初始化参数:
|
||||
|
||||
- **description**:一个工具描述的字典,用于设置实例属性 `description`。通常不需要显式地传递这个参数,因为 `BaseAction` 的元类将查找被 `tool_api` 装饰的方法,并组装它们的 `api_description` 构造一个类属性 `__tool_description__`,如果实例化时 `description` 为空,那么该实例属性将置为 `__tool_description__`。
|
||||
|
||||
- **parser**:`BaseParser` 类,用于实例化一个动作解析器校验 `description` 所描述的工具的参数。例如,`JsonParser` 会要求模型在调用工具时传入一个 JSON 格式字符串或者 Python 字典,为了让 LLM 感知到该指令,它会在 `description` 中插入一个 `parameter_description` 字段。
|
||||
|
||||
```python
|
||||
from lagent import BaseAction
|
||||
|
||||
action = BaseAction(
|
||||
{
|
||||
'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [
|
||||
{
|
||||
'name': 'text', 'type': 'STRING', 'description': 'input content'
|
||||
}
|
||||
],
|
||||
'required': ['text']
|
||||
}
|
||||
)
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'bold',
|
||||
'description': 'a function used to make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input content'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}
|
||||
```
|
||||
|
||||
- **enable**: 指明该动作是否生效。
|
||||
|
||||
### 自定义动作
|
||||
|
||||
一个简单工具必须实现 `run` 方法,而工具包则应当避免将各子API名称定义为该保留字段。
|
||||
|
||||
```{tip}
|
||||
对于非工具包的 Action,`run` 允许不被 `tool_api` 装饰,除非你想提示返回信息。
|
||||
```
|
||||
|
||||
```python
|
||||
class Bold(BaseAction):
|
||||
|
||||
def run(self, text: str):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
class PhraseEmphasis(BaseAction):
|
||||
"""a toolkit which provides different styles of text emphasis"""
|
||||
|
||||
@tool_api
|
||||
def bold(self, text):
|
||||
"""make text bold
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
"""
|
||||
return '**' + text + '**'
|
||||
|
||||
@tool_api
|
||||
def italic(self, text):
|
||||
"""make text italic
|
||||
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: italic text
|
||||
"""
|
||||
return '*' + text + '*'
|
||||
|
||||
# 查看默认工具描述
|
||||
# Bold.__tool_description__, PhraseEmphasis.__tool_description__
|
||||
```
|
||||
|
||||
### 自动注册
|
||||
|
||||
任何 `BaseAction` 的子类都会自动被注册。你可以使用 `list_tools()` 和 `get_tool()` 来查看所有工具类并通过工具名进行初始化。
|
||||
|
||||
```python
|
||||
from lagent import list_tools, get_tool
|
||||
|
||||
list_tools()
|
||||
```
|
||||
|
||||
```python
|
||||
['BaseAction',
|
||||
'InvalidAction',
|
||||
'NoAction',
|
||||
'FinishAction',
|
||||
'ArxivSearch',
|
||||
'BINGMap',
|
||||
'GoogleScholar',
|
||||
'GoogleSearch',
|
||||
'IPythonInterpreter',
|
||||
'PPT',
|
||||
'PythonInterpreter',
|
||||
'Bold',
|
||||
'PhraseEmphasis']
|
||||
```
|
||||
|
||||
创建一个 `PhraseEmphasis` 对象。
|
||||
|
||||
```python
|
||||
action = get_tool('PhraseEmphasis')
|
||||
action.description
|
||||
```
|
||||
|
||||
```python
|
||||
{'name': 'PhraseEmphasis',
|
||||
'description': 'a toolkit which provides different styles of text emphasis',
|
||||
'api_list': [{'name': 'bold',
|
||||
'description': 'make text bold',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'italic',
|
||||
'description': 'make text italic',
|
||||
'parameters': [{'name': 'text',
|
||||
'type': 'STRING',
|
||||
'description': 'input text'}],
|
||||
'required': ['text'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]}
|
||||
```
|
||||
|
||||
## 工具调用
|
||||
|
||||
### 执行工具
|
||||
|
||||
`Action` 的 `__call__` 方法需要传入两个参数
|
||||
|
||||
- `inputs`: 其类型与动作绑定的 `BaseParser` 相关,通常是由大语言模型生成的字符串。
|
||||
- `JsonParser`: 允许传入 JSON 格式字符串或 Python 字典。
|
||||
- `TupleParser`: 允许传入字面量为元组的字符串或 Python 元组。
|
||||
- `name`: 调用哪个 API,默认为 `run`。
|
||||
|
||||
工具会返回一个封装了调用细节的 `ActionReturn` 对象。
|
||||
|
||||
- `args`: 一个字典,表示该动作的入参。
|
||||
- `type`: 动作名称。
|
||||
- `result`: 以字典为成员的列表,每个字典包含两个键——'type' 和 'content',发生异常时该字段为 `None`。
|
||||
- `errmsg`: 错误信息,默认为 `None`。
|
||||
|
||||
以下是一个例子:
|
||||
|
||||
```python
|
||||
from lagent import IPythonInterpreter, TupleParser
|
||||
|
||||
action1 = IPythonInterpreter()
|
||||
ret = action1('{"command": "import math;math.sqrt(100)"}')
|
||||
print(ret.result)
|
||||
ret = action1({'command': 'import math;math.sqrt(100)'})
|
||||
print(ret.result)
|
||||
|
||||
action2 = IPythonInterpreter(parser=TupleParser)
|
||||
ret = action2('("import math;math.sqrt(100)", )')
|
||||
print(ret.result)
|
||||
ret = action2(('import math;math.sqrt(100)',))
|
||||
print(ret.result)
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
|
||||
### 动态触发
|
||||
|
||||
Lagent 提供 `ActionExecutor` 接口管理多个工具,它会将工具包的 `api_list` 平展并将各 API 更名为 `{tool_name}.{api_name}`。
|
||||
|
||||
```python
|
||||
from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
|
||||
executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()])
|
||||
executor.get_actions_info() # 该结果会作为LLM系统提示词的一部分
|
||||
```
|
||||
|
||||
```python
|
||||
[{'name': 'ArxivSearch.get_arxiv_article_information',
|
||||
'description': 'Run Arxiv search and get the article meta information.',
|
||||
'parameters': [{'name': 'query',
|
||||
'type': 'STRING',
|
||||
'description': 'the content of search query'}],
|
||||
'required': ['query'],
|
||||
'return_data': [{'name': 'content',
|
||||
'description': 'a list of 3 arxiv search papers',
|
||||
'type': 'STRING'}],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'},
|
||||
{'name': 'IPythonInterpreter',
|
||||
'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.",
|
||||
'parameters': [{'name': 'command',
|
||||
'type': 'STRING',
|
||||
'description': 'Python code'},
|
||||
{'name': 'timeout',
|
||||
'type': 'NUMBER',
|
||||
'description': 'Upper bound of waiting time for Python script execution.'}],
|
||||
'required': ['command'],
|
||||
'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]
|
||||
```
|
||||
|
||||
通过动作执行器来触发一个工具
|
||||
|
||||
```python
|
||||
ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}')
|
||||
ret.result
|
||||
```
|
||||
|
||||
```python
|
||||
[{'type': 'text', 'content': '10.0'}]
|
||||
```
|
||||
@@ -1,46 +0,0 @@
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.actions.builtin_actions import FinishAction
|
||||
from lagent.actions.python_interpreter import PythonInterpreter
|
||||
from lagent.agents.autogpt import AutoGPT
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='')
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
|
||||
def main():
|
||||
# set OPEN_API_KEY in your environment or directly pass it with key=''
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
|
||||
chatbot = AutoGPT(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(
|
||||
actions=[
|
||||
PythonInterpreter(),
|
||||
],
|
||||
finish_action=FinishAction(
|
||||
description=(
|
||||
'Goals are accomplished and there is nothing left '
|
||||
'to do. Parameter: {"response: "final response '
|
||||
'for the goal"}')),
|
||||
finish_in_action=True),
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,36 +0,0 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='chatbot')
|
||||
parser.add_argument('--mode', default='chat')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# set OPEN_API_KEY in your environment or directly pass it with key=''
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
history = []
|
||||
while True:
|
||||
try:
|
||||
prompt = input('>>> ')
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
if args.mode == 'chat':
|
||||
history.append(dict(role='user', content=prompt))
|
||||
response = model.generate_from_template(history, max_out_len=512)
|
||||
history.append(dict(role='assistant', content=response))
|
||||
elif args.mode == 'generate':
|
||||
response = model.generate(prompt, max_out_len=512)
|
||||
print('Assistant:', response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,37 +0,0 @@
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.actions.python_interpreter import PythonInterpreter
|
||||
from lagent.agents.react import ReAct
|
||||
from lagent.llms.huggingface import HFTransformer
|
||||
|
||||
model = HFTransformer(
|
||||
path='internlm/internlm-chat-7b-v1.1',
|
||||
meta_template=[
|
||||
dict(role='system', begin='<|System|>:', end='<TOKENS_UNUSED_2>\n'),
|
||||
dict(role='user', begin='<|User|>:', end='<eoh>\n'),
|
||||
dict(role='assistant', begin='<|Bot|>:', end='<eoa>\n', generate=True)
|
||||
],
|
||||
)
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(actions=[PythonInterpreter()]),
|
||||
)
|
||||
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='')
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
99
examples/internlm2_agent_cli_demo.py
Normal file
99
examples/internlm2_agent_cli_demo.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
|
||||
from lagent.llms import HFTransformer
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
from lagent.schema import AgentStatusCode
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='chatbot')
|
||||
parser.add_argument(
|
||||
'--path',
|
||||
type=str,
|
||||
default='internlm/internlm2-chat-20b',
|
||||
help='The path to the model')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# Initialize the HFTransformer-based Language Model (llm)
|
||||
model = HFTransformer(
|
||||
path=args.path,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=None,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
plugin_executor = ActionExecutor(actions=[ArxivSearch()]) # noqa: F841
|
||||
interpreter_executor = ActionExecutor(actions=[IPythonInterpreter()])
|
||||
|
||||
chatbot = Internlm2Agent(
|
||||
llm=model,
|
||||
plugin_executor=None,
|
||||
interpreter_executor=interpreter_executor,
|
||||
protocol=Internlm2Protocol(
|
||||
meta_prompt=META_CN,
|
||||
interpreter_prompt=INTERPRETER_CN,
|
||||
plugin_prompt=PLUGIN_CN,
|
||||
tool=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(
|
||||
plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='', flush=True)
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
history = []
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
history.append(dict(role='user', content=prompt))
|
||||
print('\nInternLm2:', end='')
|
||||
current_length = 0
|
||||
last_status = None
|
||||
for agent_return in chatbot.stream_chat(history):
|
||||
status = agent_return.state
|
||||
if status not in [
|
||||
AgentStatusCode.STREAM_ING, AgentStatusCode.CODING,
|
||||
AgentStatusCode.PLUGIN_START
|
||||
]:
|
||||
continue
|
||||
if status != last_status:
|
||||
current_length = 0
|
||||
print('')
|
||||
if isinstance(agent_return.response, dict):
|
||||
action = f"\n\n {agent_return.response['name']}: \n\n"
|
||||
action_input = agent_return.response['parameters']
|
||||
if agent_return.response['name'] == 'IPythonInterpreter':
|
||||
action_input = action_input['command']
|
||||
response = action + action_input
|
||||
else:
|
||||
response = agent_return.response
|
||||
print(response[current_length:], end='', flush=True)
|
||||
current_length = len(response)
|
||||
last_status = status
|
||||
print('')
|
||||
history.extend(agent_return.inner_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
333
examples/internlm2_agent_web_demo.py
Normal file
333
examples/internlm2_agent_web_demo.py
Normal file
@@ -0,0 +1,333 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
|
||||
from lagent.llms.lmdepoly_wrapper import LMDeployClient
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
from lagent.schema import AgentStatusCode
|
||||
|
||||
# from streamlit.logger import get_logger
|
||||
|
||||
|
||||
class SessionState:
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize session state variables."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
|
||||
action_list = [
|
||||
ArxivSearch(),
|
||||
]
|
||||
st.session_state['plugin_map'] = {
|
||||
action.name: action
|
||||
for action in action_list
|
||||
}
|
||||
st.session_state['model_map'] = {}
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['plugin_actions'] = set()
|
||||
st.session_state['history'] = []
|
||||
|
||||
def clear_state(self):
|
||||
"""Clear the existing session state."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['file'] = set()
|
||||
if 'chatbot' in st.session_state:
|
||||
st.session_state['chatbot']._session_history = []
|
||||
|
||||
|
||||
class StreamlitUI:
|
||||
|
||||
def __init__(self, session_state: SessionState):
|
||||
self.init_streamlit()
|
||||
self.session_state = session_state
|
||||
|
||||
def init_streamlit(self):
|
||||
"""Initialize Streamlit's UI settings."""
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
st.sidebar.title('模型控制')
|
||||
st.session_state['file'] = set()
|
||||
st.session_state['ip'] = None
|
||||
|
||||
def setup_sidebar(self):
|
||||
"""Setup the sidebar for model and plugin selection."""
|
||||
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
||||
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
|
||||
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
|
||||
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
|
||||
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
|
||||
model_ip = st.sidebar.text_input('模型IP:', value='10.140.0.220:23333')
|
||||
if model_name != st.session_state[
|
||||
'model_selected'] or st.session_state['ip'] != model_ip:
|
||||
st.session_state['ip'] = model_ip
|
||||
model = self.init_model(model_name, model_ip)
|
||||
self.session_state.clear_state()
|
||||
st.session_state['model_selected'] = model_name
|
||||
if 'chatbot' in st.session_state:
|
||||
del st.session_state['chatbot']
|
||||
else:
|
||||
model = st.session_state['model_map'][model_name]
|
||||
|
||||
plugin_name = st.sidebar.multiselect(
|
||||
'插件选择',
|
||||
options=list(st.session_state['plugin_map'].keys()),
|
||||
default=[],
|
||||
)
|
||||
da_flag = st.sidebar.checkbox(
|
||||
'数据分析',
|
||||
value=False,
|
||||
)
|
||||
plugin_action = [
|
||||
st.session_state['plugin_map'][name] for name in plugin_name
|
||||
]
|
||||
|
||||
if 'chatbot' in st.session_state:
|
||||
if len(plugin_action) > 0:
|
||||
st.session_state['chatbot']._action_executor = ActionExecutor(
|
||||
actions=plugin_action)
|
||||
else:
|
||||
st.session_state['chatbot']._action_executor = None
|
||||
if da_flag:
|
||||
st.session_state[
|
||||
'chatbot']._interpreter_executor = ActionExecutor(
|
||||
actions=[IPythonInterpreter()])
|
||||
else:
|
||||
st.session_state['chatbot']._interpreter_executor = None
|
||||
st.session_state['chatbot']._protocol._meta_template = meta_prompt
|
||||
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
|
||||
st.session_state[
|
||||
'chatbot']._protocol.interpreter_prompt = da_prompt
|
||||
if st.sidebar.button('清空对话', key='clear'):
|
||||
self.session_state.clear_state()
|
||||
uploaded_file = st.sidebar.file_uploader('上传文件')
|
||||
|
||||
return model_name, model, plugin_action, uploaded_file, model_ip
|
||||
|
||||
def init_model(self, model_name, ip=None):
|
||||
"""Initialize the model based on the input model name."""
|
||||
model_url = f'http://{ip}'
|
||||
st.session_state['model_map'][model_name] = LMDeployClient(
|
||||
model_name=model_name,
|
||||
url=model_url,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=100,
|
||||
temperature=0,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
return st.session_state['model_map'][model_name]
|
||||
|
||||
def initialize_chatbot(self, model, plugin_action):
|
||||
"""Initialize the chatbot with the given model and plugin actions."""
|
||||
return Internlm2Agent(
|
||||
llm=model,
|
||||
protocol=Internlm2Protocol(
|
||||
tool=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(
|
||||
plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
), ),
|
||||
max_turn=7)
|
||||
|
||||
def render_user(self, prompt: str):
|
||||
with st.chat_message('user'):
|
||||
st.markdown(prompt)
|
||||
|
||||
def render_assistant(self, agent_return):
|
||||
with st.chat_message('assistant'):
|
||||
for action in agent_return.actions:
|
||||
if (action) and (action.type != 'FinishAction'):
|
||||
self.render_action(action)
|
||||
st.markdown(agent_return.response)
|
||||
|
||||
def render_plugin_args(self, action):
|
||||
action_name = action.type
|
||||
args = action.args
|
||||
import json
|
||||
parameter_dict = dict(name=action_name, parameters=args)
|
||||
parameter_str = '```json\n' + json.dumps(
|
||||
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
|
||||
st.markdown(parameter_str)
|
||||
|
||||
def render_interpreter_args(self, action):
|
||||
st.info(action.type)
|
||||
st.markdown(action.args['text'])
|
||||
|
||||
def render_action(self, action):
|
||||
st.markdown(action.thought)
|
||||
if action.type == 'IPythonInterpreter':
|
||||
self.render_interpreter_args(action)
|
||||
elif action.type == 'FinishAction':
|
||||
pass
|
||||
else:
|
||||
self.render_plugin_args(action)
|
||||
self.render_action_results(action)
|
||||
|
||||
def render_action_results(self, action):
|
||||
"""Render the results of action, including text, images, videos, and
|
||||
audios."""
|
||||
if (isinstance(action.result, dict)):
|
||||
if 'text' in action.result:
|
||||
st.markdown('```\n' + action.result['text'] + '\n```')
|
||||
if 'image' in action.result:
|
||||
# image_path = action.result['image']
|
||||
for image_path in action.result['image']:
|
||||
image_data = open(image_path, 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
if 'video' in action.result:
|
||||
video_data = action.result['video']
|
||||
video_data = open(video_data, 'rb').read()
|
||||
st.video(video_data)
|
||||
if 'audio' in action.result:
|
||||
audio_data = action.result['audio']
|
||||
audio_data = open(audio_data, 'rb').read()
|
||||
st.audio(audio_data)
|
||||
elif isinstance(action.result, list):
|
||||
for item in action.result:
|
||||
if item['type'] == 'text':
|
||||
st.markdown('```\n' + item['content'] + '\n```')
|
||||
elif item['type'] == 'image':
|
||||
image_data = open(item['content'], 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
elif item['type'] == 'video':
|
||||
video_data = open(item['content'], 'rb').read()
|
||||
st.video(video_data)
|
||||
elif item['type'] == 'audio':
|
||||
audio_data = open(item['content'], 'rb').read()
|
||||
st.audio(audio_data)
|
||||
if action.errmsg:
|
||||
st.error(action.errmsg)
|
||||
|
||||
|
||||
def main():
|
||||
# logger = get_logger(__name__)
|
||||
# Initialize Streamlit UI and setup sidebar
|
||||
if 'ui' not in st.session_state:
|
||||
session_state = SessionState()
|
||||
session_state.init_state()
|
||||
st.session_state['ui'] = StreamlitUI(session_state)
|
||||
|
||||
else:
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
_, model, plugin_action, uploaded_file, _ = st.session_state[
|
||||
'ui'].setup_sidebar()
|
||||
|
||||
# Initialize chatbot if it is not already initialized
|
||||
# or if the model has changed
|
||||
if 'chatbot' not in st.session_state or model != st.session_state[
|
||||
'chatbot']._llm:
|
||||
st.session_state['chatbot'] = st.session_state[
|
||||
'ui'].initialize_chatbot(model, plugin_action)
|
||||
st.session_state['session_history'] = []
|
||||
|
||||
for prompt, agent_return in zip(st.session_state['user'],
|
||||
st.session_state['assistant']):
|
||||
st.session_state['ui'].render_user(prompt)
|
||||
st.session_state['ui'].render_assistant(agent_return)
|
||||
|
||||
if user_input := st.chat_input(''):
|
||||
with st.container():
|
||||
st.session_state['ui'].render_user(user_input)
|
||||
st.session_state['user'].append(user_input)
|
||||
# Add file uploader to sidebar
|
||||
if (uploaded_file
|
||||
and uploaded_file.name not in st.session_state['file']):
|
||||
|
||||
st.session_state['file'].add(uploaded_file.name)
|
||||
file_bytes = uploaded_file.read()
|
||||
file_type = uploaded_file.type
|
||||
if 'image' in file_type:
|
||||
st.image(file_bytes, caption='Uploaded Image')
|
||||
elif 'video' in file_type:
|
||||
st.video(file_bytes, caption='Uploaded Video')
|
||||
elif 'audio' in file_type:
|
||||
st.audio(file_bytes, caption='Uploaded Audio')
|
||||
# Save the file to a temporary location and get the path
|
||||
|
||||
postfix = uploaded_file.name.split('.')[-1]
|
||||
# prefix = str(uuid.uuid4())
|
||||
prefix = hashlib.md5(file_bytes).hexdigest()
|
||||
filename = f'{prefix}.{postfix}'
|
||||
file_path = os.path.join(root_dir, filename)
|
||||
with open(file_path, 'wb') as tmpfile:
|
||||
tmpfile.write(file_bytes)
|
||||
file_size = os.stat(file_path).st_size / 1024 / 1024
|
||||
file_size = f'{round(file_size, 2)} MB'
|
||||
# st.write(f'File saved at: {file_path}')
|
||||
user_input = [
|
||||
dict(role='user', content=user_input),
|
||||
dict(
|
||||
role='user',
|
||||
content=json.dumps(dict(path=file_path, size=file_size)),
|
||||
name='file')
|
||||
]
|
||||
if isinstance(user_input, str):
|
||||
user_input = [dict(role='user', content=user_input)]
|
||||
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
|
||||
for agent_return in st.session_state['chatbot'].stream_chat(
|
||||
st.session_state['session_history'] + user_input):
|
||||
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_plugin_args(
|
||||
agent_return.actions[-1])
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif agent_return.state == AgentStatusCode.CODE_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif (agent_return.state == AgentStatusCode.STREAM_ING
|
||||
or agent_return.state == AgentStatusCode.CODING):
|
||||
# st.markdown(agent_return.response)
|
||||
# 清除占位符的当前内容,并显示新内容
|
||||
with st.container():
|
||||
if agent_return.state != st.session_state['last_status']:
|
||||
st.session_state['temp'] = ''
|
||||
placeholder = st.empty()
|
||||
st.session_state['placeholder'] = placeholder
|
||||
if isinstance(agent_return.response, dict):
|
||||
action = f"\n\n {agent_return.response['name']}: \n\n"
|
||||
action_input = agent_return.response['parameters']
|
||||
if agent_return.response[
|
||||
'name'] == 'IPythonInterpreter':
|
||||
action_input = action_input['command']
|
||||
response = action + action_input
|
||||
else:
|
||||
response = agent_return.response
|
||||
st.session_state['temp'] = response
|
||||
st.session_state['placeholder'].markdown(
|
||||
st.session_state['temp'])
|
||||
elif agent_return.state == AgentStatusCode.END:
|
||||
st.session_state['session_history'] += (
|
||||
user_input + agent_return.inner_steps)
|
||||
agent_return = copy.deepcopy(agent_return)
|
||||
agent_return.response = st.session_state['temp']
|
||||
st.session_state['assistant'].append(
|
||||
copy.deepcopy(agent_return))
|
||||
st.session_state['last_status'] = agent_return.state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
root_dir = os.path.join(root_dir, 'tmp_dir')
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
main()
|
||||
332
examples/internlm2_agent_web_demo_hf.py
Normal file
332
examples/internlm2_agent_web_demo_hf.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
|
||||
from lagent.llms import HFTransformer
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
from lagent.schema import AgentStatusCode
|
||||
|
||||
# from streamlit.logger import get_logger
|
||||
|
||||
|
||||
class SessionState:
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize session state variables."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
|
||||
action_list = [
|
||||
ArxivSearch(),
|
||||
]
|
||||
st.session_state['plugin_map'] = {
|
||||
action.name: action
|
||||
for action in action_list
|
||||
}
|
||||
st.session_state['model_map'] = {}
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['plugin_actions'] = set()
|
||||
st.session_state['history'] = []
|
||||
|
||||
def clear_state(self):
|
||||
"""Clear the existing session state."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['file'] = set()
|
||||
if 'chatbot' in st.session_state:
|
||||
st.session_state['chatbot']._session_history = []
|
||||
|
||||
|
||||
class StreamlitUI:
|
||||
|
||||
def __init__(self, session_state: SessionState):
|
||||
self.init_streamlit()
|
||||
self.session_state = session_state
|
||||
|
||||
def init_streamlit(self):
|
||||
"""Initialize Streamlit's UI settings."""
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
st.sidebar.title('模型控制')
|
||||
st.session_state['file'] = set()
|
||||
st.session_state['model_path'] = None
|
||||
|
||||
def setup_sidebar(self):
|
||||
"""Setup the sidebar for model and plugin selection."""
|
||||
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
||||
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
|
||||
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
|
||||
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
|
||||
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
|
||||
model_path = st.sidebar.text_input(
|
||||
'模型路径:', value='internlm/internlm2-chat-20b')
|
||||
if model_name != st.session_state['model_selected'] or st.session_state[
|
||||
'model_path'] != model_path:
|
||||
st.session_state['model_path'] = model_path
|
||||
model = self.init_model(model_name, model_path)
|
||||
self.session_state.clear_state()
|
||||
st.session_state['model_selected'] = model_name
|
||||
if 'chatbot' in st.session_state:
|
||||
del st.session_state['chatbot']
|
||||
else:
|
||||
model = st.session_state['model_map'][model_name]
|
||||
|
||||
plugin_name = st.sidebar.multiselect(
|
||||
'插件选择',
|
||||
options=list(st.session_state['plugin_map'].keys()),
|
||||
default=[],
|
||||
)
|
||||
da_flag = st.sidebar.checkbox(
|
||||
'数据分析',
|
||||
value=False,
|
||||
)
|
||||
plugin_action = [
|
||||
st.session_state['plugin_map'][name] for name in plugin_name
|
||||
]
|
||||
|
||||
if 'chatbot' in st.session_state:
|
||||
if len(plugin_action) > 0:
|
||||
st.session_state['chatbot']._action_executor = ActionExecutor(
|
||||
actions=plugin_action)
|
||||
else:
|
||||
st.session_state['chatbot']._action_executor = None
|
||||
if da_flag:
|
||||
st.session_state[
|
||||
'chatbot']._interpreter_executor = ActionExecutor(
|
||||
actions=[IPythonInterpreter()])
|
||||
else:
|
||||
st.session_state['chatbot']._interpreter_executor = None
|
||||
st.session_state['chatbot']._protocol._meta_template = meta_prompt
|
||||
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
|
||||
st.session_state[
|
||||
'chatbot']._protocol.interpreter_prompt = da_prompt
|
||||
if st.sidebar.button('清空对话', key='clear'):
|
||||
self.session_state.clear_state()
|
||||
uploaded_file = st.sidebar.file_uploader('上传文件')
|
||||
|
||||
return model_name, model, plugin_action, uploaded_file, model_path
|
||||
|
||||
def init_model(self, model_name, path):
|
||||
"""Initialize the model based on the input model name."""
|
||||
st.session_state['model_map'][model_name] = HFTransformer(
|
||||
path=path,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=None,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
return st.session_state['model_map'][model_name]
|
||||
|
||||
def initialize_chatbot(self, model, plugin_action):
|
||||
"""Initialize the chatbot with the given model and plugin actions."""
|
||||
return Internlm2Agent(
|
||||
llm=model,
|
||||
protocol=Internlm2Protocol(
|
||||
tool=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(
|
||||
plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
), ),
|
||||
max_turn=7)
|
||||
|
||||
def render_user(self, prompt: str):
|
||||
with st.chat_message('user'):
|
||||
st.markdown(prompt)
|
||||
|
||||
def render_assistant(self, agent_return):
|
||||
with st.chat_message('assistant'):
|
||||
for action in agent_return.actions:
|
||||
if (action) and (action.type != 'FinishAction'):
|
||||
self.render_action(action)
|
||||
st.markdown(agent_return.response)
|
||||
|
||||
def render_plugin_args(self, action):
|
||||
action_name = action.type
|
||||
args = action.args
|
||||
import json
|
||||
parameter_dict = dict(name=action_name, parameters=args)
|
||||
parameter_str = '```json\n' + json.dumps(
|
||||
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
|
||||
st.markdown(parameter_str)
|
||||
|
||||
def render_interpreter_args(self, action):
|
||||
st.info(action.type)
|
||||
st.markdown(action.args['text'])
|
||||
|
||||
def render_action(self, action):
|
||||
st.markdown(action.thought)
|
||||
if action.type == 'IPythonInterpreter':
|
||||
self.render_interpreter_args(action)
|
||||
elif action.type == 'FinishAction':
|
||||
pass
|
||||
else:
|
||||
self.render_plugin_args(action)
|
||||
self.render_action_results(action)
|
||||
|
||||
def render_action_results(self, action):
|
||||
"""Render the results of action, including text, images, videos, and
|
||||
audios."""
|
||||
if (isinstance(action.result, dict)):
|
||||
if 'text' in action.result:
|
||||
st.markdown('```\n' + action.result['text'] + '\n```')
|
||||
if 'image' in action.result:
|
||||
# image_path = action.result['image']
|
||||
for image_path in action.result['image']:
|
||||
image_data = open(image_path, 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
if 'video' in action.result:
|
||||
video_data = action.result['video']
|
||||
video_data = open(video_data, 'rb').read()
|
||||
st.video(video_data)
|
||||
if 'audio' in action.result:
|
||||
audio_data = action.result['audio']
|
||||
audio_data = open(audio_data, 'rb').read()
|
||||
st.audio(audio_data)
|
||||
elif isinstance(action.result, list):
|
||||
for item in action.result:
|
||||
if item['type'] == 'text':
|
||||
st.markdown('```\n' + item['content'] + '\n```')
|
||||
elif item['type'] == 'image':
|
||||
image_data = open(item['content'], 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
elif item['type'] == 'video':
|
||||
video_data = open(item['content'], 'rb').read()
|
||||
st.video(video_data)
|
||||
elif item['type'] == 'audio':
|
||||
audio_data = open(item['content'], 'rb').read()
|
||||
st.audio(audio_data)
|
||||
if action.errmsg:
|
||||
st.error(action.errmsg)
|
||||
|
||||
|
||||
def main():
|
||||
# logger = get_logger(__name__)
|
||||
# Initialize Streamlit UI and setup sidebar
|
||||
if 'ui' not in st.session_state:
|
||||
session_state = SessionState()
|
||||
session_state.init_state()
|
||||
st.session_state['ui'] = StreamlitUI(session_state)
|
||||
|
||||
else:
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
_, model, plugin_action, uploaded_file, _ = st.session_state[
|
||||
'ui'].setup_sidebar()
|
||||
|
||||
# Initialize chatbot if it is not already initialized
|
||||
# or if the model has changed
|
||||
if 'chatbot' not in st.session_state or model != st.session_state[
|
||||
'chatbot']._llm:
|
||||
st.session_state['chatbot'] = st.session_state[
|
||||
'ui'].initialize_chatbot(model, plugin_action)
|
||||
st.session_state['session_history'] = []
|
||||
|
||||
for prompt, agent_return in zip(st.session_state['user'],
|
||||
st.session_state['assistant']):
|
||||
st.session_state['ui'].render_user(prompt)
|
||||
st.session_state['ui'].render_assistant(agent_return)
|
||||
|
||||
if user_input := st.chat_input(''):
|
||||
with st.container():
|
||||
st.session_state['ui'].render_user(user_input)
|
||||
st.session_state['user'].append(user_input)
|
||||
# Add file uploader to sidebar
|
||||
if (uploaded_file
|
||||
and uploaded_file.name not in st.session_state['file']):
|
||||
|
||||
st.session_state['file'].add(uploaded_file.name)
|
||||
file_bytes = uploaded_file.read()
|
||||
file_type = uploaded_file.type
|
||||
if 'image' in file_type:
|
||||
st.image(file_bytes, caption='Uploaded Image')
|
||||
elif 'video' in file_type:
|
||||
st.video(file_bytes, caption='Uploaded Video')
|
||||
elif 'audio' in file_type:
|
||||
st.audio(file_bytes, caption='Uploaded Audio')
|
||||
# Save the file to a temporary location and get the path
|
||||
|
||||
postfix = uploaded_file.name.split('.')[-1]
|
||||
# prefix = str(uuid.uuid4())
|
||||
prefix = hashlib.md5(file_bytes).hexdigest()
|
||||
filename = f'{prefix}.{postfix}'
|
||||
file_path = os.path.join(root_dir, filename)
|
||||
with open(file_path, 'wb') as tmpfile:
|
||||
tmpfile.write(file_bytes)
|
||||
file_size = os.stat(file_path).st_size / 1024 / 1024
|
||||
file_size = f'{round(file_size, 2)} MB'
|
||||
# st.write(f'File saved at: {file_path}')
|
||||
user_input = [
|
||||
dict(role='user', content=user_input),
|
||||
dict(
|
||||
role='user',
|
||||
content=json.dumps(dict(path=file_path, size=file_size)),
|
||||
name='file')
|
||||
]
|
||||
if isinstance(user_input, str):
|
||||
user_input = [dict(role='user', content=user_input)]
|
||||
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
|
||||
for agent_return in st.session_state['chatbot'].stream_chat(
|
||||
st.session_state['session_history'] + user_input):
|
||||
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_plugin_args(
|
||||
agent_return.actions[-1])
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif agent_return.state == AgentStatusCode.CODE_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif (agent_return.state == AgentStatusCode.STREAM_ING
|
||||
or agent_return.state == AgentStatusCode.CODING):
|
||||
# st.markdown(agent_return.response)
|
||||
# 清除占位符的当前内容,并显示新内容
|
||||
with st.container():
|
||||
if agent_return.state != st.session_state['last_status']:
|
||||
st.session_state['temp'] = ''
|
||||
placeholder = st.empty()
|
||||
st.session_state['placeholder'] = placeholder
|
||||
if isinstance(agent_return.response, dict):
|
||||
action = f"\n\n {agent_return.response['name']}: \n\n"
|
||||
action_input = agent_return.response['parameters']
|
||||
if agent_return.response[
|
||||
'name'] == 'IPythonInterpreter':
|
||||
action_input = action_input['command']
|
||||
response = action + action_input
|
||||
else:
|
||||
response = agent_return.response
|
||||
st.session_state['temp'] = response
|
||||
st.session_state['placeholder'].markdown(
|
||||
st.session_state['temp'])
|
||||
elif agent_return.state == AgentStatusCode.END:
|
||||
st.session_state['session_history'] += (
|
||||
user_input + agent_return.inner_steps)
|
||||
agent_return = copy.deepcopy(agent_return)
|
||||
agent_return.response = st.session_state['temp']
|
||||
st.session_state['assistant'].append(
|
||||
copy.deepcopy(agent_return))
|
||||
st.session_state['last_status'] = agent_return.state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
root_dir = os.path.join(root_dir, 'tmp_dir')
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
main()
|
||||
63
examples/model_cli_demo.py
Normal file
63
examples/model_cli_demo.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from lagent.llms import HFTransformer
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='chatbot')
|
||||
parser.add_argument(
|
||||
'--path',
|
||||
type=str,
|
||||
default='internlm/internlm2-chat-20b',
|
||||
help='The path to the model')
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
default='chat',
|
||||
help='Completion through chat or generate')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# Initialize the HFTransformer-based Language Model (llm)
|
||||
model = HFTransformer(
|
||||
path=args.path,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=None,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='', flush=True)
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
history = []
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
history.append(dict(role='user', content=prompt))
|
||||
if args.mode == 'generate':
|
||||
history = [dict(role='user', content=prompt)]
|
||||
print('\nInternLm2:', end='')
|
||||
current_length = 0
|
||||
for status, response, _ in model.stream_chat(history):
|
||||
print(response[current_length:], end='', flush=True)
|
||||
current_length = len(response)
|
||||
history.append(dict(role='assistant', content=response))
|
||||
print('')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,36 +0,0 @@
|
||||
from lagent.actions.action_executor import ActionExecutor
|
||||
from lagent.actions.python_interpreter import PythonInterpreter
|
||||
from lagent.agents.react import ReAct
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
|
||||
def input_prompt():
|
||||
print('\ndouble enter to end input >>> ', end='')
|
||||
sentinel = '' # ends when this string is seen
|
||||
return '\n'.join(iter(input, sentinel))
|
||||
|
||||
|
||||
def main():
|
||||
# set OPEN_API_KEY in your environment or directly pass it with key=''
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
|
||||
chatbot = ReAct(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(actions=[PythonInterpreter()]),
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input_prompt()
|
||||
except UnicodeDecodeError:
|
||||
print('UnicodeDecodeError')
|
||||
continue
|
||||
if prompt == 'exit':
|
||||
exit(0)
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,19 +0,0 @@
|
||||
from lagent.actions import LLMQA, ActionExecutor, GoogleSearch
|
||||
from lagent.agents import ReWOO
|
||||
from lagent.llms.openai import GPTAPI
|
||||
|
||||
# set OPEN_API_KEY in your environment or directly pass it with key=''
|
||||
model = GPTAPI(model_type='gpt-3.5-turbo')
|
||||
# please set the serper search API key
|
||||
search_tool = GoogleSearch(api_key='SERPER_API_KEY')
|
||||
llmqa_tool = LLMQA(model)
|
||||
|
||||
chatbot = ReWOO(
|
||||
llm=model,
|
||||
action_executor=ActionExecutor(actions=[llmqa_tool, search_tool]),
|
||||
)
|
||||
|
||||
prompt = 'What profession does Nicholas Ray and Elia Kazan have in common'
|
||||
|
||||
agent_return = chatbot.chat(prompt)
|
||||
print(agent_return.response)
|
||||
8
lagent/__init__.py
Normal file
8
lagent/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .actions import * # noqa: F401, F403
|
||||
from .agents import * # noqa: F401, F403
|
||||
from .llms import * # noqa: F401, F403
|
||||
from .schema import * # noqa: F401, F403
|
||||
from .version import __version__, version_info
|
||||
|
||||
__all__ = ['__version__', 'version_info']
|
||||
@@ -1,11 +1,62 @@
|
||||
from typing import Type
|
||||
|
||||
from .action_executor import ActionExecutor
|
||||
from .base_action import BaseAction
|
||||
from .arxiv_search import ArxivSearch
|
||||
from .base_action import TOOL_REGISTRY, BaseAction, tool_api
|
||||
from .bing_map import BINGMap
|
||||
from .builtin_actions import FinishAction, InvalidAction, NoAction
|
||||
from .google_scholar_search import GoogleScholar
|
||||
from .google_search import GoogleSearch
|
||||
from .llm_qa import LLMQA
|
||||
from .ipython_interactive import IPythonInteractive
|
||||
from .ipython_interpreter import IPythonInterpreter
|
||||
from .parser import BaseParser, JsonParser, TupleParser
|
||||
from .ppt import PPT
|
||||
from .python_interpreter import PythonInterpreter
|
||||
|
||||
__all__ = [
|
||||
'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction',
|
||||
'FinishAction', 'GoogleSearch', 'PythonInterpreter', 'LLMQA'
|
||||
'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction',
|
||||
'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch',
|
||||
'GoogleScholar', 'IPythonInterpreter', 'IPythonInteractive',
|
||||
'PythonInterpreter', 'PPT', 'BaseParser', 'JsonParser', 'TupleParser',
|
||||
'tool_api', 'list_tools', 'get_tool_cls', 'get_tool'
|
||||
]
|
||||
|
||||
|
||||
def list_tools(with_class: bool = False):
|
||||
"""List available tools.
|
||||
|
||||
Args:
|
||||
with_class (bool): whether to return the action class along
|
||||
with its name. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
list: all action names
|
||||
"""
|
||||
return list(TOOL_REGISTRY.items()) if with_class else list(
|
||||
TOOL_REGISTRY.keys())
|
||||
|
||||
|
||||
def get_tool_cls(specifier: str) -> Type[BaseAction]:
|
||||
"""Get the action class.
|
||||
|
||||
Args:
|
||||
specifier (:class:`str`): tool name
|
||||
|
||||
Returns:
|
||||
Type[BaseAction]: action class
|
||||
"""
|
||||
return TOOL_REGISTRY.get_class(specifier)
|
||||
|
||||
|
||||
def get_tool(specifier: str, *args, **kwargs) -> BaseAction:
|
||||
"""Instantiate an action.
|
||||
|
||||
Args:
|
||||
specifier (str): tool name
|
||||
args: positional arguments passed to the action's ``__init__`` method
|
||||
kwargs: keyword arguments passed to the action's ``__init__`` method
|
||||
|
||||
Returns:
|
||||
:class:`BaseAction`: action object
|
||||
"""
|
||||
return TOOL_REGISTRY.get(specifier, *args, **kwargs)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from lagent.schema import ActionReturn, ActionValidCode
|
||||
from .base_action import BaseAction
|
||||
@@ -39,14 +39,20 @@ class ActionExecutor:
|
||||
self.no_action = no_action
|
||||
self.finish_action = finish_action
|
||||
|
||||
def get_actions_info(self, only_enable: bool = True) -> Dict:
|
||||
if only_enable:
|
||||
return {
|
||||
k: v.description
|
||||
for k, v in self.actions.items() if v.enable
|
||||
}
|
||||
else:
|
||||
return {k: v.description for k, v in self.actions.items()}
|
||||
def get_actions_info(self) -> List[Dict]:
|
||||
actions = []
|
||||
for action_name, action in self.actions.items():
|
||||
if not action.enable:
|
||||
continue
|
||||
if action.is_toolkit:
|
||||
for api in action.description['api_list']:
|
||||
api_desc = api.copy()
|
||||
api_desc['name'] = f"{action_name}.{api_desc['name']}"
|
||||
actions.append(api_desc)
|
||||
else:
|
||||
action_desc = action.description.copy()
|
||||
actions.append(action_desc)
|
||||
return actions
|
||||
|
||||
def is_valid(self, name: str):
|
||||
return name in self.actions and self.actions[name].enable
|
||||
@@ -66,19 +72,17 @@ class ActionExecutor:
|
||||
if name in self.actions:
|
||||
del self.actions[name]
|
||||
|
||||
def __call__(self, name: str, command: Any) -> ActionReturn:
|
||||
if isinstance(command, str):
|
||||
args, kwargs = (command, ), {}
|
||||
else:
|
||||
args, kwargs = (), command
|
||||
if not self.is_valid(name):
|
||||
def __call__(self, name: str, command: str) -> ActionReturn:
|
||||
action_name, api_name = (
|
||||
name.split('.') if '.' in name else (name, 'run'))
|
||||
if not self.is_valid(action_name):
|
||||
if name == self.no_action.name:
|
||||
action_return = self.no_action.run(*args, **kwargs)
|
||||
action_return = self.no_action(command)
|
||||
elif name == self.finish_action.name:
|
||||
action_return = self.finish_action.run(*args, **kwargs)
|
||||
action_return = self.finish_action(command)
|
||||
else:
|
||||
action_return = self.invalid_action(*args, **kwargs)
|
||||
action_return = self.invalid_action(command)
|
||||
else:
|
||||
action_return = self.actions[name].run(*args, **kwargs)
|
||||
action_return = self.actions[action_name](command, api_name)
|
||||
action_return.valid = ActionValidCode.OPEN
|
||||
return action_return
|
||||
|
||||
56
lagent/actions/arxiv_search.py
Normal file
56
lagent/actions/arxiv_search.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
|
||||
class ArxivSearch(BaseAction):
|
||||
"""Search information from Arxiv.org. \
|
||||
Useful for when you need to answer questions about Physics, Mathematics, \
|
||||
Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
|
||||
Electrical Engineering, and Economics from scientific articles on arxiv.org.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
top_k_results: int = 3,
|
||||
max_query_len: int = 300,
|
||||
doc_content_chars_max: int = 1500,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
self.top_k_results = top_k_results
|
||||
self.max_query_len = max_query_len
|
||||
self.doc_content_chars_max = doc_content_chars_max
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_arxiv_article_information(self, query: str) -> dict:
|
||||
"""Run Arxiv search and get the article meta information.
|
||||
|
||||
Args:
|
||||
query (:class:`str`): the content of search query
|
||||
|
||||
Returns:
|
||||
:class:`dict`: article information
|
||||
* content (str): a list of 3 arxiv search papers
|
||||
"""
|
||||
import arxiv
|
||||
|
||||
try:
|
||||
results = arxiv.Search( # type: ignore
|
||||
query[:self.max_query_len],
|
||||
max_results=self.top_k_results).results()
|
||||
except Exception as exc:
|
||||
return ActionReturn(
|
||||
errmsg=f'Arxiv exception: {exc}',
|
||||
state=ActionStatusCode.HTTP_ERROR)
|
||||
docs = [
|
||||
f'Published: {result.updated.date()}\nTitle: {result.title}\n'
|
||||
f'Authors: {", ".join(a.name for a in result.authors)}\n'
|
||||
f'Summary: {result.summary[:self.doc_content_chars_max]}'
|
||||
for result in results
|
||||
]
|
||||
if docs:
|
||||
return {'content': '\n\n'.join(docs)}
|
||||
return {'content': 'No good Arxiv Result was found'}
|
||||
@@ -1,57 +1,385 @@
|
||||
from typing import Optional
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from abc import ABCMeta
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, Type, get_args, get_origin
|
||||
|
||||
from lagent.schema import ActionReturn
|
||||
try:
|
||||
from typing import Annotated
|
||||
except ImportError:
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from class_registry import AutoRegister, ClassRegistry
|
||||
from griffe import Docstring
|
||||
from griffe.enumerations import DocstringSectionKind
|
||||
|
||||
from ..schema import ActionReturn, ActionStatusCode
|
||||
from .parser import BaseParser, JsonParser, ParseError
|
||||
|
||||
logging.getLogger('griffe').setLevel(logging.ERROR)
|
||||
|
||||
TOOL_REGISTRY = ClassRegistry('__tool_name__', unique=True)
|
||||
|
||||
|
||||
class BaseAction:
|
||||
def tool_api(func: Optional[Callable] = None,
|
||||
*,
|
||||
explode_return: bool = False,
|
||||
returns_named_value: bool = False,
|
||||
**kwargs):
|
||||
"""Turn functions into tools. It will parse typehints as well as docstrings
|
||||
to build the tool description and attach it to functions via an attribute
|
||||
``api_description``.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# typehints has higher priority than docstrings
|
||||
from typing import Annotated
|
||||
|
||||
@tool_api
|
||||
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
|
||||
'''Add operation
|
||||
|
||||
Args:
|
||||
x (int): a
|
||||
y (int): b
|
||||
'''
|
||||
return a + b
|
||||
|
||||
print(add.api_description)
|
||||
|
||||
Args:
|
||||
func (Optional[Callable]): function to decorate. Defaults to ``None``.
|
||||
explode_return (bool): whether to flatten the dictionary or tuple return
|
||||
as the ``return_data`` field. When enabled, it is recommended to
|
||||
annotate the member in docstrings. Defaults to ``False``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def foo(a, b):
|
||||
'''A simple function
|
||||
|
||||
Args:
|
||||
a (int): a
|
||||
b (int): b
|
||||
|
||||
Returns:
|
||||
dict: information of inputs
|
||||
* x: value of a
|
||||
* y: value of b
|
||||
'''
|
||||
return {'x': a, 'y': b}
|
||||
|
||||
print(foo.api_description)
|
||||
|
||||
returns_named_value (bool): whether to parse ``thing: Description`` in
|
||||
returns sections as a name and description, rather than a type and
|
||||
description. When true, type must be wrapped in parentheses:
|
||||
``(int): Description``. When false, parentheses are optional but
|
||||
the items cannot be named: ``int: Description``. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
Callable: wrapped function or partial decorator
|
||||
|
||||
Important:
|
||||
``return_data`` field will be added to ``api_description`` only
|
||||
when ``explode_return`` or ``returns_named_value`` is enabled.
|
||||
"""
|
||||
|
||||
def _detect_type(string):
|
||||
field_type = 'STRING'
|
||||
if 'list' in string:
|
||||
field_type = 'Array'
|
||||
elif 'str' not in string:
|
||||
if 'float' in string:
|
||||
field_type = 'FLOAT'
|
||||
elif 'int' in string:
|
||||
field_type = 'NUMBER'
|
||||
elif 'bool' in string:
|
||||
field_type = 'BOOLEAN'
|
||||
return field_type
|
||||
|
||||
def _explode(desc):
|
||||
kvs = []
|
||||
desc = '\nArgs:\n' + '\n'.join([
|
||||
' ' + item.lstrip(' -+*#.')
|
||||
for item in desc.split('\n')[1:] if item.strip()
|
||||
])
|
||||
docs = Docstring(desc).parse('google')
|
||||
if not docs:
|
||||
return kvs
|
||||
if docs[0].kind is DocstringSectionKind.parameters:
|
||||
for d in docs[0].value:
|
||||
d = d.as_dict()
|
||||
if not d['annotation']:
|
||||
d.pop('annotation')
|
||||
else:
|
||||
d['type'] = _detect_type(d.pop('annotation').lower())
|
||||
kvs.append(d)
|
||||
return kvs
|
||||
|
||||
def _parse_tool(function):
|
||||
# remove rst syntax
|
||||
docs = Docstring(
|
||||
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
|
||||
'google', returns_named_value=returns_named_value, **kwargs)
|
||||
desc = dict(
|
||||
name=function.__name__,
|
||||
description=docs[0].value
|
||||
if docs[0].kind is DocstringSectionKind.text else '',
|
||||
parameters=[],
|
||||
required=[],
|
||||
)
|
||||
args_doc, returns_doc = {}, []
|
||||
for doc in docs:
|
||||
if doc.kind is DocstringSectionKind.parameters:
|
||||
for d in doc.value:
|
||||
d = d.as_dict()
|
||||
d['type'] = _detect_type(d.pop('annotation').lower())
|
||||
args_doc[d['name']] = d
|
||||
if doc.kind is DocstringSectionKind.returns:
|
||||
for d in doc.value:
|
||||
d = d.as_dict()
|
||||
if not d['name']:
|
||||
d.pop('name')
|
||||
if not d['annotation']:
|
||||
d.pop('annotation')
|
||||
else:
|
||||
d['type'] = _detect_type(d.pop('annotation').lower())
|
||||
returns_doc.append(d)
|
||||
|
||||
sig = inspect.signature(function)
|
||||
for name, param in sig.parameters.items():
|
||||
if name == 'self':
|
||||
continue
|
||||
parameter = dict(
|
||||
name=param.name,
|
||||
type='STRING',
|
||||
description=args_doc.get(param.name,
|
||||
{}).get('description', ''))
|
||||
annotation = param.annotation
|
||||
if annotation is inspect.Signature.empty:
|
||||
parameter['type'] = args_doc.get(param.name,
|
||||
{}).get('type', 'STRING')
|
||||
else:
|
||||
if get_origin(annotation) is Annotated:
|
||||
annotation, info = get_args(annotation)
|
||||
if info:
|
||||
parameter['description'] = info
|
||||
while get_origin(annotation):
|
||||
annotation = get_args(annotation)
|
||||
parameter['type'] = _detect_type(str(annotation))
|
||||
desc['parameters'].append(parameter)
|
||||
if param.default is inspect.Signature.empty:
|
||||
desc['required'].append(param.name)
|
||||
|
||||
return_data = []
|
||||
if explode_return:
|
||||
return_data = _explode(returns_doc[0]['description'])
|
||||
elif returns_named_value:
|
||||
return_data = returns_doc
|
||||
if return_data:
|
||||
desc['return_data'] = return_data
|
||||
return desc
|
||||
|
||||
if callable(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
wrapper.api_description = _parse_tool(func)
|
||||
return wrapper
|
||||
|
||||
def decorate(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
wrapper.api_description = _parse_tool(func)
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
class ToolMeta(ABCMeta):
|
||||
"""Metaclass of tools."""
|
||||
|
||||
def __new__(mcs, name, base, attrs):
|
||||
is_toolkit, tool_desc = True, dict(
|
||||
name=attrs.setdefault('__tool_name__', name),
|
||||
description=Docstring(attrs.get('__doc__',
|
||||
'')).parse('google')[0].value)
|
||||
for key, value in attrs.items():
|
||||
if callable(value) and hasattr(value, 'api_description'):
|
||||
api_desc = getattr(value, 'api_description')
|
||||
if key == 'run':
|
||||
tool_desc['parameters'] = api_desc['parameters']
|
||||
tool_desc['required'] = api_desc['required']
|
||||
if api_desc['description']:
|
||||
tool_desc['description'] = api_desc['description']
|
||||
if api_desc.get('return_data'):
|
||||
tool_desc['return_data'] = api_desc['return_data']
|
||||
is_toolkit = False
|
||||
else:
|
||||
tool_desc.setdefault('api_list', []).append(api_desc)
|
||||
if not is_toolkit and 'api_list' in tool_desc:
|
||||
raise KeyError('`run` and other tool APIs can not be implemented '
|
||||
'at the same time')
|
||||
if is_toolkit and 'api_list' not in tool_desc:
|
||||
is_toolkit = False
|
||||
if callable(attrs.get('run')):
|
||||
run_api = tool_api(attrs['run'])
|
||||
api_desc = run_api.api_description
|
||||
tool_desc['parameters'] = api_desc['parameters']
|
||||
tool_desc['required'] = api_desc['required']
|
||||
if api_desc['description']:
|
||||
tool_desc['description'] = api_desc['description']
|
||||
if api_desc.get('return_data'):
|
||||
tool_desc['return_data'] = api_desc['return_data']
|
||||
attrs['run'] = run_api
|
||||
else:
|
||||
tool_desc['parameters'], tool_desc['required'] = [], []
|
||||
attrs['_is_toolkit'] = is_toolkit
|
||||
attrs['__tool_description__'] = tool_desc
|
||||
return super().__new__(mcs, name, base, attrs)
|
||||
|
||||
|
||||
class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)):
|
||||
"""Base class for all actions.
|
||||
|
||||
Args:
|
||||
description (str, optional): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
description (:class:`Optional[dict]`): The description of the action.
|
||||
Defaults to ``None``.
|
||||
parser (:class:`Type[BaseParser]`): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (:class:`bool`): Whether the action is enabled. Defaults to
|
||||
``True``.
|
||||
|
||||
Examples:
|
||||
|
||||
* simple tool
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Bold(BaseAction):
|
||||
'''Make text bold'''
|
||||
|
||||
def run(self, text: str):
|
||||
'''
|
||||
Args:
|
||||
text (str): input text
|
||||
|
||||
Returns:
|
||||
str: bold text
|
||||
'''
|
||||
return '**' + text + '**'
|
||||
|
||||
action = Bold()
|
||||
|
||||
* toolkit with multiple APIs
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Calculator(BaseAction):
|
||||
'''Calculator'''
|
||||
|
||||
@tool_api
|
||||
def add(self, a, b):
|
||||
'''Add operation
|
||||
|
||||
Args:
|
||||
a (int): augend
|
||||
b (int): addend
|
||||
|
||||
Returns:
|
||||
int: sum
|
||||
'''
|
||||
return a + b
|
||||
|
||||
@tool_api
|
||||
def sub(self, a, b):
|
||||
'''Subtraction operation
|
||||
|
||||
Args:
|
||||
a (int): minuend
|
||||
b (int): subtrahend
|
||||
|
||||
Returns:
|
||||
int: difference
|
||||
'''
|
||||
return a - b
|
||||
|
||||
action = Calculator()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
description: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
if name is None:
|
||||
name = self.__class__.__name__
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._disable_description = disable_description
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
self._description = deepcopy(description or self.__tool_description__)
|
||||
self._name = self._description['name']
|
||||
self._parser = parser(self)
|
||||
self._enable = enable
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ActionReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}:{self.description}'
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def run(self, *args, **kwargs) -> ActionReturn:
|
||||
return self.__call__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
def __call__(self, inputs: str, name='run') -> ActionReturn:
|
||||
fallback_args = {'inputs': inputs, 'name': name}
|
||||
if not hasattr(self, name):
|
||||
return ActionReturn(
|
||||
fallback_args,
|
||||
type=self.name,
|
||||
errmsg=f'invalid API: {name}',
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
try:
|
||||
inputs = self._parser.parse_inputs(inputs, name)
|
||||
except ParseError as exc:
|
||||
return ActionReturn(
|
||||
fallback_args,
|
||||
type=self.name,
|
||||
errmsg=exc.err_msg,
|
||||
state=ActionStatusCode.ARGS_ERROR)
|
||||
try:
|
||||
outputs = getattr(self, name)(**inputs)
|
||||
except Exception as exc:
|
||||
return ActionReturn(
|
||||
inputs,
|
||||
type=self.name,
|
||||
errmsg=str(exc),
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
if isinstance(outputs, ActionReturn):
|
||||
action_return = outputs
|
||||
if not action_return.args:
|
||||
action_return.args = inputs
|
||||
if not action_return.type:
|
||||
action_return.type = self.name
|
||||
else:
|
||||
result = self._parser.parse_outputs(outputs)
|
||||
action_return = ActionReturn(inputs, type=self.name, result=result)
|
||||
return action_return
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if self.enable:
|
||||
return self._description
|
||||
else:
|
||||
return self._disable_description
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
@property
|
||||
def is_toolkit(self):
|
||||
return self._is_toolkit
|
||||
|
||||
@property
|
||||
def description(self) -> dict:
|
||||
"""Description of the tool."""
|
||||
return self._description
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.description}'
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
144
lagent/actions/bing_map.py
Normal file
144
lagent/actions/bing_map.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# flake8: noqa: E501
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Type
|
||||
|
||||
import requests
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
|
||||
|
||||
class BINGMap(BaseAction):
|
||||
"""BING Map plugin for looking up map information."""
|
||||
|
||||
def __init__(self,
|
||||
key: Optional[str] = None,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True) -> None:
|
||||
super().__init__(description, parser, enable)
|
||||
key = os.environ.get('BING_MAP_KEY', key)
|
||||
if key is None:
|
||||
raise ValueError(
|
||||
'Please set BING Map API key either in the environment '
|
||||
'as BING_MAP_KEY or pass it as `key` parameter.')
|
||||
self.key = key
|
||||
self.base_url = 'http://dev.virtualearth.net/REST/V1/'
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_distance(self, start: str, end: str) -> dict:
|
||||
"""Get the distance between two locations in km.
|
||||
|
||||
Args:
|
||||
start (:class:`str`): The start location
|
||||
end (:class:`str`): The end location
|
||||
|
||||
Returns:
|
||||
:class:`dict`: distance information
|
||||
* distance (str): the distance in km.
|
||||
"""
|
||||
# Request URL
|
||||
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
||||
# GET request
|
||||
r = requests.get(url)
|
||||
# TODO check request status?
|
||||
data = json.loads(r.text)
|
||||
# Extract route information
|
||||
route = data['resourceSets'][0]['resources'][0]
|
||||
# Extract distance in miles
|
||||
distance = route['travelDistance']
|
||||
return dict(distance=distance)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_route(self, start: str, end: str) -> dict:
|
||||
"""Get the route between two locations in km.
|
||||
|
||||
Args:
|
||||
start (:class:`str`): The start location
|
||||
end (:class:`str`): The end location
|
||||
|
||||
Returns:
|
||||
:class:`dict`: route information
|
||||
* route (list): the route, a list of actions.
|
||||
"""
|
||||
# Request URL
|
||||
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
||||
# GET request
|
||||
r = requests.get(url)
|
||||
data = json.loads(r.text)
|
||||
# Extract route information
|
||||
route = data['resourceSets'][0]['resources'][0]
|
||||
itinerary = route['routeLegs'][0]['itineraryItems']
|
||||
# Extract route text information
|
||||
route_text = []
|
||||
for item in itinerary:
|
||||
if 'instruction' in item:
|
||||
route_text.append(item['instruction']['text'])
|
||||
return dict(route=route_text)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_coordinates(self, location: str) -> dict:
|
||||
"""Get the coordinates of a location.
|
||||
|
||||
Args:
|
||||
location (:class:`str`): the location need to get coordinates.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: coordinates information
|
||||
* latitude (float): the latitude of the location.
|
||||
* longitude (float): the longitude of the location.
|
||||
"""
|
||||
url = self.base_url + 'Locations'
|
||||
params = {'query': location, 'key': self.key}
|
||||
response = requests.get(url, params=params)
|
||||
json_data = response.json()
|
||||
coordinates = json_data['resourceSets'][0]['resources'][0]['point'][
|
||||
'coordinates']
|
||||
return dict(latitude=coordinates[0], longitude=coordinates[1])
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def search_nearby(self,
|
||||
search_term: str,
|
||||
places: str = 'unknown',
|
||||
latitude: float = 0.0,
|
||||
longitude: float = 0.0,
|
||||
radius: int = 5000) -> dict:
|
||||
"""Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
|
||||
|
||||
Args:
|
||||
search_term (:class:`str`): the place name.
|
||||
places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
|
||||
latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
|
||||
longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
|
||||
radius (:class:`int`): radius in meters. Defaults to ``5000``.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: places information
|
||||
* places (list): the list of places, each place is a dict with name and address, at most 5 places.
|
||||
"""
|
||||
url = self.base_url + 'LocalSearch'
|
||||
if places != 'unknown':
|
||||
pos = self.get_coordinates(**{'location': places})
|
||||
latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
|
||||
# Build the request query string
|
||||
params = {
|
||||
'query': search_term,
|
||||
'userLocation': f'{latitude},{longitude}',
|
||||
'radius': radius,
|
||||
'key': self.key
|
||||
}
|
||||
# Make the request
|
||||
response = requests.get(url, params=params)
|
||||
# Parse the response
|
||||
response_data = json.loads(response.content)
|
||||
# Get the results
|
||||
results = response_data['resourceSets'][0]['resources']
|
||||
addresses = []
|
||||
for result in results:
|
||||
name = result['name']
|
||||
address = result['Address']['formattedAddress']
|
||||
addresses.append(dict(name=name, address=address))
|
||||
if len(addresses) == 5:
|
||||
break
|
||||
return dict(place=addresses)
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
|
||||
|
||||
|
||||
@@ -19,12 +20,13 @@ class InvalidAction(BaseAction):
|
||||
def __init__(self,
|
||||
err_msg:
|
||||
str = 'The action is invalid, please check the action name.',
|
||||
**kwargs) -> None:
|
||||
|
||||
super().__init__(enable=False, **kwargs)
|
||||
description: Optional[dict] = None,
|
||||
parser=BaseParser) -> None:
|
||||
super().__init__(description, parser, enable=False)
|
||||
self._err_msg = err_msg
|
||||
|
||||
def __call__(self, err_msg: Optional[str] = None):
|
||||
@tool_api
|
||||
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
|
||||
"""Return the error message.
|
||||
|
||||
Args:
|
||||
@@ -35,7 +37,7 @@ class InvalidAction(BaseAction):
|
||||
action_return = ActionReturn(
|
||||
url=None,
|
||||
args=dict(text=err_msg),
|
||||
errmsg=err_msg if err_msg else self._err_msg,
|
||||
errmsg=err_msg or self._err_msg,
|
||||
type=self.name,
|
||||
valid=ActionValidCode.INVALID,
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
@@ -51,12 +53,15 @@ class NoAction(BaseAction):
|
||||
'Please follow the format'.
|
||||
"""
|
||||
|
||||
def __init__(self, err_msg: str = 'Please follow the format', **kwargs):
|
||||
|
||||
super().__init__(enable=False, **kwargs)
|
||||
def __init__(self,
|
||||
err_msg: str = 'Please follow the format',
|
||||
description: Optional[dict] = None,
|
||||
parser=BaseParser):
|
||||
super().__init__(description, parser, enable=False)
|
||||
self._err_msg = err_msg
|
||||
|
||||
def __call__(self, err_msg: Optional[str] = None):
|
||||
@tool_api
|
||||
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
|
||||
"""Return the error message.
|
||||
|
||||
Args:
|
||||
@@ -71,7 +76,7 @@ class NoAction(BaseAction):
|
||||
url=None,
|
||||
args=dict(text=err_msg),
|
||||
type=self.name,
|
||||
errmsg=err_msg if err_msg else self._err_msg,
|
||||
errmsg=err_msg or self._err_msg,
|
||||
valid=ActionValidCode.INVALID,
|
||||
state=ActionStatusCode.API_ERROR)
|
||||
return action_return
|
||||
@@ -81,7 +86,11 @@ class FinishAction(BaseAction):
|
||||
"""This is a finish action class, which is used to return the final
|
||||
result."""
|
||||
|
||||
def __call__(self, response: str) -> ActionReturn:
|
||||
def __init__(self, description: Optional[dict] = None, parser=BaseParser):
|
||||
super().__init__(description, parser, enable=True)
|
||||
|
||||
@tool_api
|
||||
def run(self, response: str) -> ActionReturn:
|
||||
"""Return the final result.
|
||||
|
||||
Args:
|
||||
@@ -93,7 +102,7 @@ class FinishAction(BaseAction):
|
||||
action_return = ActionReturn(
|
||||
url=None,
|
||||
args=dict(text=response),
|
||||
result=dict(text=response),
|
||||
result=[dict(type='text', content=response)],
|
||||
type=self.name,
|
||||
valid=ActionValidCode.FINISH,
|
||||
state=ActionStatusCode.SUCCESS)
|
||||
|
||||
270
lagent/actions/google_scholar_search.py
Normal file
270
lagent/actions/google_scholar_search.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# flake8: noqa: E501
|
||||
import os
|
||||
from typing import Optional, Type
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .parser import BaseParser, JsonParser
|
||||
|
||||
|
||||
class GoogleScholar(BaseAction):
|
||||
"""Plugin for google scholar search.
|
||||
|
||||
Args:
|
||||
api_key (str): API KEY to use serper google search API,
|
||||
You can create a free API key at https://serper.dev.
|
||||
description (dict): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
api_key: Optional[str] = None,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
api_key = os.environ.get('SERPER_API_KEY', api_key)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
'Please set Serper API key either in the environment '
|
||||
'as SERPER_API_KEY or pass it as `api_key` parameter.')
|
||||
self.api_key = api_key
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def search_google_scholar(
|
||||
self,
|
||||
query: str,
|
||||
cites: Optional[str] = None,
|
||||
as_ylo: Optional[int] = None,
|
||||
as_yhi: Optional[int] = None,
|
||||
scisbd: Optional[int] = None,
|
||||
cluster: Optional[str] = None,
|
||||
hl: Optional[str] = None,
|
||||
lr: Optional[str] = None,
|
||||
start: Optional[int] = None,
|
||||
num: Optional[int] = None,
|
||||
as_sdt: Optional[str] = None,
|
||||
safe: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
as_vis: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Search for scholarly articles based on a query according to the google scholar.
|
||||
|
||||
Args:
|
||||
query (str): The query to search for.
|
||||
cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
|
||||
as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
|
||||
as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
|
||||
scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
|
||||
cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
|
||||
hl (Optional[str]): The language to use for the Google Scholar search.
|
||||
lr (Optional[str]): One or multiple languages to limit the search to.
|
||||
start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
|
||||
num (Optional[int]): The maximum number of results to return, limited to 20.
|
||||
as_sdt (Optional[str]): Can be used either as a search type or a filter.
|
||||
safe (Optional[str]): The level of filtering for adult content.
|
||||
filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
|
||||
as_vis (Optional[str]): Defines whether to include citations or not.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: article information
|
||||
- title: a list of the titles of the three selected papers
|
||||
- cited_by: a list of the citation numbers of the three selected papers
|
||||
- organic_id: a list of the organic results' ids of the three selected papers
|
||||
- pub_info: publication information of selected papers
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'q': query,
|
||||
'engine': 'google_scholar',
|
||||
'api_key': self.api_key,
|
||||
'cites': cites,
|
||||
'as_ylo': as_ylo,
|
||||
'as_yhi': as_yhi,
|
||||
'scisbd': scisbd,
|
||||
'cluster': cluster,
|
||||
'hl': hl,
|
||||
'lr': lr,
|
||||
'start': start,
|
||||
'num': num,
|
||||
'as_sdt': as_sdt,
|
||||
'safe': safe,
|
||||
'filter': filter,
|
||||
'as_vis': as_vis
|
||||
}
|
||||
search = GoogleSearch(params)
|
||||
try:
|
||||
r = search.get_dict()
|
||||
results = r['organic_results']
|
||||
title = []
|
||||
snippets = []
|
||||
cited_by = []
|
||||
organic_id = []
|
||||
pub_info = []
|
||||
for item in results[:3]:
|
||||
title.append(item['title'])
|
||||
pub_info.append(item['publication_info']['summary'])
|
||||
citation = item['inline_links'].get('cited_by', {'total': ''})
|
||||
cited_by.append(citation['total'])
|
||||
snippets.append(item['snippet'])
|
||||
organic_id.append(item['result_id'])
|
||||
return dict(
|
||||
title=title,
|
||||
cited_by=cited_by,
|
||||
organic_id=organic_id,
|
||||
snippets=snippets)
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_author_information(self,
|
||||
author_id: str,
|
||||
hl: Optional[str] = None,
|
||||
view_op: Optional[str] = None,
|
||||
sort: Optional[str] = None,
|
||||
citation_id: Optional[str] = None,
|
||||
start: Optional[int] = None,
|
||||
num: Optional[int] = None,
|
||||
no_cache: Optional[bool] = None,
|
||||
async_req: Optional[bool] = None,
|
||||
output: Optional[str] = None) -> dict:
|
||||
"""Search for an author's information by author's id provided by get_author_id.
|
||||
|
||||
Args:
|
||||
author_id (str): Required. The ID of an author.
|
||||
hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
|
||||
view_op (Optional[str]): Used for viewing specific parts of a page.
|
||||
sort (Optional[str]): Used for sorting and refining articles.
|
||||
citation_id (Optional[str]): Used for retrieving individual article citation.
|
||||
start (Optional[int]): Defines the result offset. Default is 0.
|
||||
num (Optional[int]): Defines the number of results to return. Default is 20.
|
||||
no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
|
||||
async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
|
||||
output (Optional[str]): Defines the final output you want. Default is 'json'.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: author information
|
||||
* name: author's name
|
||||
* affliation: the affliation of the author
|
||||
* articles: at most 3 articles by the author
|
||||
* website: the author's homepage url
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'engine': 'google_scholar_author',
|
||||
'author_id': author_id,
|
||||
'api_key': self.api_key,
|
||||
'hl': hl,
|
||||
'view_op': view_op,
|
||||
'sort': sort,
|
||||
'citation_id': citation_id,
|
||||
'start': start,
|
||||
'num': num,
|
||||
'no_cache': no_cache,
|
||||
'async': async_req,
|
||||
'output': output
|
||||
}
|
||||
try:
|
||||
search = GoogleSearch(params)
|
||||
results = search.get_dict()
|
||||
author = results['author']
|
||||
articles = results.get('articles', [])
|
||||
return dict(
|
||||
name=author['name'],
|
||||
affiliations=author.get('affiliations', ''),
|
||||
website=author.get('website', ''),
|
||||
articles=[
|
||||
dict(title=article['title'], authors=article['authors'])
|
||||
for article in articles[:3]
|
||||
])
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_citation_format(self,
|
||||
q: str,
|
||||
no_cache: Optional[bool] = None,
|
||||
async_: Optional[bool] = None,
|
||||
output: Optional[str] = 'json') -> dict:
|
||||
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
|
||||
|
||||
Args:
|
||||
q (str): ID of an individual Google Scholar organic search result.
|
||||
no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
|
||||
async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
|
||||
output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: citation format
|
||||
* authors: the authors of the article
|
||||
* citation: the citation format of the article
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'q': q,
|
||||
'engine': 'google_scholar_cite',
|
||||
'api_key': self.api_key,
|
||||
'no_cache': no_cache,
|
||||
'async': async_,
|
||||
'output': output
|
||||
}
|
||||
try:
|
||||
search = GoogleSearch(params)
|
||||
results = search.get_dict()
|
||||
citation = results['citations']
|
||||
citation_info = citation[0]['snippet']
|
||||
return citation_info
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def get_author_id(self,
|
||||
mauthors: str,
|
||||
hl: Optional[str] = 'en',
|
||||
after_author: Optional[str] = None,
|
||||
before_author: Optional[str] = None,
|
||||
no_cache: Optional[bool] = False,
|
||||
_async: Optional[bool] = False,
|
||||
output: Optional[str] = 'json') -> dict:
|
||||
"""The getAuthorId function is used to get the author's id by his or her name.
|
||||
|
||||
Args:
|
||||
mauthors (str): Defines the author you want to search for.
|
||||
hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
|
||||
after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
|
||||
before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
|
||||
no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
|
||||
_async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
|
||||
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: author id
|
||||
* author_id: the author_id of the author
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'mauthors': mauthors,
|
||||
'engine': 'google_scholar_profiles',
|
||||
'api_key': self.api_key,
|
||||
'hl': hl,
|
||||
'after_author': after_author,
|
||||
'before_author': before_author,
|
||||
'no_cache': no_cache,
|
||||
'async': _async,
|
||||
'output': output
|
||||
}
|
||||
try:
|
||||
search = GoogleSearch(params)
|
||||
results = search.get_dict()
|
||||
profile = results['profiles']
|
||||
author_info = dict(author_id=profile[0]['author_id'])
|
||||
return author_info
|
||||
except Exception as e:
|
||||
return ActionReturn(
|
||||
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
||||
@@ -1,15 +1,11 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import requests
|
||||
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction
|
||||
|
||||
DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。
|
||||
当你需要对于一个特定问题找到简短明了的回答时,可以使用它。
|
||||
输入应该是一个搜索查询。
|
||||
"""
|
||||
from .base_action import BaseAction, tool_api
|
||||
from .parser import BaseParser, JsonParser
|
||||
|
||||
|
||||
class GoogleSearch(BaseAction):
|
||||
@@ -25,18 +21,13 @@ class GoogleSearch(BaseAction):
|
||||
Args:
|
||||
api_key (str): API KEY to use serper google search API,
|
||||
You can create a free API key at https://serper.dev.
|
||||
timeout (int): Upper bound of waiting time for an serper request.
|
||||
timeout (int): Upper bound of waiting time for a serper request.
|
||||
search_type (str): Serper API support ['search', 'images', 'news',
|
||||
'places'] types of search, currently we only support 'search'.
|
||||
k (int): select first k results in the search results as response.
|
||||
description (str): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
description (dict): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool): Whether the action is enabled. Defaults to ``True``.
|
||||
"""
|
||||
result_key_for_type = {
|
||||
'news': 'news',
|
||||
@@ -49,43 +40,36 @@ class GoogleSearch(BaseAction):
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 5,
|
||||
search_type: str = 'search',
|
||||
k: int = 10,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
api_key = os.environ.get('SERPER_API_KEY', api_key)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
'Please Set Serper API key either in the environment '
|
||||
' as SERPER_API_KEY or pass it as `api_key` parameter.')
|
||||
'Please set Serper API key either in the environment '
|
||||
'as SERPER_API_KEY or pass it as `api_key` parameter.')
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.search_type = search_type
|
||||
self.k = k
|
||||
|
||||
def __call__(self, query: str) -> ActionReturn:
|
||||
"""Return the search response.
|
||||
@tool_api
|
||||
def run(self, query: str, k: int = 10) -> ActionReturn:
|
||||
"""一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
|
||||
|
||||
Args:
|
||||
query (str): The search content.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
query (str): the search content
|
||||
k (int): select first k results in the search results as response
|
||||
"""
|
||||
|
||||
tool_return = ActionReturn(url=None, args=None)
|
||||
status_code, response = self._search(
|
||||
query, search_type=self.search_type, k=self.k)
|
||||
tool_return = ActionReturn(type=self.name)
|
||||
status_code, response = self._search(query, k=k)
|
||||
# convert search results to ToolReturn format
|
||||
if status_code == -1:
|
||||
tool_return.errmsg = response
|
||||
tool_return.state = ActionStatusCode.HTTP_ERROR
|
||||
elif status_code == 200:
|
||||
parsed_res = self._parse_results(response)
|
||||
tool_return.result = dict(text=str(parsed_res))
|
||||
tool_return.result = [dict(type='text', content=str(parsed_res))]
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
else:
|
||||
tool_return.errmsg = str(status_code)
|
||||
@@ -139,7 +123,7 @@ class GoogleSearch(BaseAction):
|
||||
|
||||
def _search(self,
|
||||
search_term: str,
|
||||
search_type: str = 'search',
|
||||
search_type: Optional[str] = None,
|
||||
**kwargs) -> Tuple[int, Union[dict, str]]:
|
||||
"""HTTP requests to Serper API.
|
||||
|
||||
@@ -166,7 +150,7 @@ class GoogleSearch(BaseAction):
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
f'https://google.serper.dev/{search_type}',
|
||||
f'https://google.serper.dev/{search_type or self.search_type}',
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=self.timeout)
|
||||
|
||||
188
lagent/actions/ipython_interactive.py
Normal file
188
lagent/actions/ipython_interactive.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import re
|
||||
from contextlib import redirect_stdout
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import Optional, Type
|
||||
|
||||
from ..schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction, tool_api
|
||||
from .parser import BaseParser, JsonParser
|
||||
|
||||
|
||||
class Status(str, Enum):
|
||||
"""Execution status."""
|
||||
SUCCESS = 'success'
|
||||
FAILURE = 'failure'
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Execution result."""
|
||||
status: Status
|
||||
value: Optional[str] = None
|
||||
msg: Optional[str] = None
|
||||
|
||||
|
||||
class IPythonInteractive(BaseAction):
|
||||
"""An interactive IPython shell for code execution.
|
||||
|
||||
Args:
|
||||
timeout (int): Upper bound of waiting time for Python script execution.
|
||||
Defaults to ``20``.
|
||||
max_out_len (int): maximum output length. No truncation occurs if negative.
|
||||
Defaults to ``2048``.
|
||||
use_signals (bool): whether signals should be used for timing function out
|
||||
or the multiprocessing. Set to ``False`` when not running in the main
|
||||
thread, e.g. web applications. Defaults to ``True``
|
||||
description (dict): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool): Whether the action is enabled. Defaults to ``True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 30,
|
||||
max_out_len: int = 2048,
|
||||
use_signals: bool = True,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True,
|
||||
):
|
||||
super().__init__(description, parser, enable)
|
||||
from IPython import InteractiveShell
|
||||
self.timeout = timeout
|
||||
self._executor = InteractiveShell()
|
||||
self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m')
|
||||
self._max_out_len = max_out_len if max_out_len >= 0 else None
|
||||
self._use_signals = use_signals
|
||||
|
||||
def reset(self):
|
||||
"""Clear the context."""
|
||||
self._executor.reset()
|
||||
|
||||
@tool_api
|
||||
def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
|
||||
"""Launch an IPython Interactive Shell to execute code.
|
||||
|
||||
Args:
|
||||
command (:class:`str`): Python code snippet
|
||||
timeout (:class:`Optional[int]`): timeout for execution.
|
||||
This argument only works in the main thread. Defaults to ``None``.
|
||||
"""
|
||||
from timeout_decorator import timeout as timer
|
||||
tool_return = ActionReturn(args={'text': command}, type=self.name)
|
||||
ret = (
|
||||
timer(timeout or self.timeout)(self.exec)(command)
|
||||
if self._use_signals else self.exec(command))
|
||||
if ret.status is Status.SUCCESS:
|
||||
tool_return.result = [{'type': 'text', 'content': ret.value}]
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
else:
|
||||
tool_return.errmsg = ret.msg
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
|
||||
def exec(self, code: str) -> ExecutionResult:
|
||||
"""Run Python scripts in IPython shell.
|
||||
|
||||
Args:
|
||||
code (:class:`str`): code block
|
||||
|
||||
Returns:
|
||||
:py:class:`ExecutionResult`: execution result
|
||||
"""
|
||||
with StringIO() as io:
|
||||
with redirect_stdout(io):
|
||||
ret = self._executor.run_cell(self.extract_code(code))
|
||||
result = ret.result
|
||||
if result is not None:
|
||||
return ExecutionResult(Status.SUCCESS,
|
||||
str(result)[:self._max_out_len])
|
||||
outs = io.getvalue().strip().split('\n')
|
||||
if not outs:
|
||||
return ExecutionResult(Status.SUCCESS, '')
|
||||
for i, out in enumerate(outs):
|
||||
if re.search('Error|Traceback', out, re.S):
|
||||
if 'TimeoutError' in out:
|
||||
return ExecutionResult(
|
||||
Status.FAILURE,
|
||||
msg=('The code interpreter encountered '
|
||||
'an unexpected error.'))
|
||||
err_idx = i
|
||||
break
|
||||
else:
|
||||
return ExecutionResult(Status.SUCCESS,
|
||||
outs[-1].strip()[:self._max_out_len])
|
||||
return ExecutionResult(
|
||||
Status.FAILURE,
|
||||
msg=self._highlighting.sub(
|
||||
'', '\n'.join(outs[err_idx:])[:self._max_out_len]),
|
||||
)
|
||||
|
||||
async def async_exec(self, code: str) -> ExecutionResult:
|
||||
"""Asynchronously run Python scripts in IPython shell.
|
||||
|
||||
Args:
|
||||
code (:class:`str`): code block
|
||||
|
||||
Returns:
|
||||
:py:class:`ExecutionResult`: execution result
|
||||
"""
|
||||
with StringIO() as io:
|
||||
with redirect_stdout(io):
|
||||
ret = await self._executor.run_cell_async(
|
||||
self.extract_code(code))
|
||||
result = ret.result
|
||||
if result is not None:
|
||||
return ExecutionResult(Status.SUCCESS,
|
||||
str(result)[:self._max_out_len])
|
||||
outs = io.getvalue().strip().split('\n')
|
||||
if not outs:
|
||||
return ExecutionResult(Status.SUCCESS, '')
|
||||
for i, out in enumerate(outs):
|
||||
if re.search('Error|Traceback', out, re.S):
|
||||
if 'TimeoutError' in out:
|
||||
return ExecutionResult(
|
||||
Status.FAILURE,
|
||||
msg=('The code interpreter encountered an '
|
||||
'unexpected error.'))
|
||||
err_idx = i
|
||||
break
|
||||
else:
|
||||
return ExecutionResult(Status.SUCCESS,
|
||||
outs[-1].strip()[:self._max_out_len])
|
||||
return ExecutionResult(
|
||||
Status.FAILURE,
|
||||
msg=self._highlighting.sub(
|
||||
'', '\n'.join(outs[err_idx:])[:self._max_out_len]),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_code(text: str) -> str:
|
||||
"""Extract Python code from markup languages.
|
||||
|
||||
Args:
|
||||
text (:class:`str`): Markdown-formatted text
|
||||
|
||||
Returns:
|
||||
:class:`str`: Python code
|
||||
"""
|
||||
import json5
|
||||
|
||||
# Match triple backtick blocks first
|
||||
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
||||
# Match single backtick blocks second
|
||||
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
||||
if triple_match:
|
||||
text = triple_match.group(1)
|
||||
elif single_match:
|
||||
text = single_match.group(1)
|
||||
else:
|
||||
try:
|
||||
text = json5.loads(text)['code']
|
||||
except Exception:
|
||||
pass
|
||||
# If no code blocks found, return original text
|
||||
return text
|
||||
294
lagent/actions/ipython_interpreter.py
Normal file
294
lagent/actions/ipython_interpreter.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# flake8: noqa: E501
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
START_CODE = """
|
||||
def input(*args, **kwargs):
|
||||
raise NotImplementedError('Python input() function is disabled.')
|
||||
|
||||
get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!')
|
||||
{}
|
||||
""" # noqa
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class IPythonInterpreter(BaseAction):
|
||||
"""A IPython executor that can execute Python scripts in a jupyter manner.
|
||||
|
||||
Args:
|
||||
timeout (int): Upper bound of waiting time for Python script execution.
|
||||
Defaults to 20.
|
||||
user_data_dir (str, optional): Specified the user data directory for files
|
||||
loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
|
||||
Defaults to `ENV`.
|
||||
work_dir (str, optional): Specify which directory to save output images to.
|
||||
Defaults to ``'./work_dir/tmp_dir'``.
|
||||
description (dict): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to ``True``.
|
||||
"""
|
||||
|
||||
_KERNEL_CLIENTS = {}
|
||||
|
||||
def __init__(self,
|
||||
timeout: int = 20,
|
||||
user_data_dir: str = 'ENV',
|
||||
work_dir='./work_dir/tmp_dir',
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
|
||||
self.timeout = timeout
|
||||
if user_data_dir == 'ENV':
|
||||
user_data_dir = os.environ.get('USER_DATA_DIR', '')
|
||||
|
||||
if user_data_dir:
|
||||
user_data_dir = os.path.dirname(user_data_dir)
|
||||
user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
|
||||
self.user_data_dir = user_data_dir
|
||||
self._initialized = False
|
||||
self.work_dir = work_dir
|
||||
if not os.path.exists(self.work_dir):
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def start_kernel():
|
||||
from jupyter_client import KernelManager
|
||||
|
||||
# start the kernel and manager
|
||||
km = KernelManager()
|
||||
km.start_kernel()
|
||||
kc = km.client()
|
||||
return km, kc
|
||||
|
||||
def initialize(self):
|
||||
if self._initialized:
|
||||
return
|
||||
pid = os.getpid()
|
||||
if pid not in self._KERNEL_CLIENTS:
|
||||
self._KERNEL_CLIENTS[pid] = self.start_kernel()
|
||||
self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid]
|
||||
self._initialized = True
|
||||
self._call(START_CODE.format(self.user_data_dir), None)
|
||||
|
||||
def reset(self):
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
else:
|
||||
code = "get_ipython().run_line_magic('reset', '-f')\n" + \
|
||||
START_CODE.format(self.user_data_dir)
|
||||
self._call(code, None)
|
||||
|
||||
def _call(self,
|
||||
command: str,
|
||||
timeout: Optional[int] = None) -> Tuple[str, bool]:
|
||||
self.initialize()
|
||||
command = extract_code(command)
|
||||
|
||||
# check previous remaining result
|
||||
while True:
|
||||
try:
|
||||
msg = self.kernel_client.get_iopub_msg(timeout=5)
|
||||
msg_type = msg['msg_type']
|
||||
if msg_type == 'status':
|
||||
if msg['content'].get('execution_state') == 'idle':
|
||||
break
|
||||
except queue.Empty:
|
||||
# assume no result
|
||||
break
|
||||
|
||||
self.kernel_client.execute(command)
|
||||
|
||||
def _inner_call():
|
||||
result = ''
|
||||
images = []
|
||||
succeed = True
|
||||
image_idx = 0
|
||||
|
||||
while True:
|
||||
text = ''
|
||||
image = ''
|
||||
finished = False
|
||||
msg_type = 'error'
|
||||
try:
|
||||
msg = self.kernel_client.get_iopub_msg(timeout=20)
|
||||
msg_type = msg['msg_type']
|
||||
if msg_type == 'status':
|
||||
if msg['content'].get('execution_state') == 'idle':
|
||||
finished = True
|
||||
elif msg_type == 'execute_result':
|
||||
text = msg['content']['data'].get('text/plain', '')
|
||||
if 'image/png' in msg['content']['data']:
|
||||
image_b64 = msg['content']['data']['image/png']
|
||||
image_url = publish_image_to_local(
|
||||
image_b64, self.work_dir)
|
||||
image_idx += 1
|
||||
image = '' % (image_idx, image_url)
|
||||
|
||||
elif msg_type == 'display_data':
|
||||
if 'image/png' in msg['content']['data']:
|
||||
image_b64 = msg['content']['data']['image/png']
|
||||
image_url = publish_image_to_local(
|
||||
image_b64, self.work_dir)
|
||||
image_idx += 1
|
||||
image = '' % (image_idx, image_url)
|
||||
|
||||
else:
|
||||
text = msg['content']['data'].get('text/plain', '')
|
||||
elif msg_type == 'stream':
|
||||
msg_type = msg['content']['name'] # stdout, stderr
|
||||
text = msg['content']['text']
|
||||
elif msg_type == 'error':
|
||||
succeed = False
|
||||
text = escape_ansi('\n'.join(
|
||||
msg['content']['traceback']))
|
||||
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
|
||||
text = f'Timeout. No response after {timeout} seconds.' # noqa
|
||||
except queue.Empty:
|
||||
# stop current task in case break next input.
|
||||
self.kernel_manager.interrupt_kernel()
|
||||
succeed = False
|
||||
text = f'Timeout. No response after {timeout} seconds.'
|
||||
finished = True
|
||||
except Exception:
|
||||
succeed = False
|
||||
msg = ''.join(traceback.format_exception(*sys.exc_info()))
|
||||
# text = 'The code interpreter encountered an unexpected error.' # noqa
|
||||
text = msg
|
||||
logging.warning(msg)
|
||||
finished = True
|
||||
if text:
|
||||
# result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
|
||||
result += f'{text}'
|
||||
|
||||
if image:
|
||||
images.append(image_url)
|
||||
if finished:
|
||||
return succeed, dict(text=result, image=images)
|
||||
|
||||
try:
|
||||
if timeout:
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
succeed, result = _inner_call()
|
||||
except TimeoutError:
|
||||
succeed = False
|
||||
text = 'The code interpreter encountered an unexpected error.'
|
||||
result = f'\n\nerror:\n\n```\n{text}\n```'
|
||||
finally:
|
||||
if timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
# result = result.strip('\n')
|
||||
return succeed, result
|
||||
|
||||
@tool_api
|
||||
def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
|
||||
r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
||||
|
||||
Args:
|
||||
command (:class:`str`): Python code
|
||||
timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
|
||||
"""
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return.args = dict(text=command)
|
||||
succeed, result = self._call(command, timeout)
|
||||
if succeed:
|
||||
text = result['text']
|
||||
image = result.get('image', [])
|
||||
resp = [dict(type='text', content=text)]
|
||||
if image:
|
||||
resp.extend([dict(type='image', content=im) for im in image])
|
||||
tool_return.result = resp
|
||||
# tool_return.result = dict(
|
||||
# text=result['text'], image=result.get('image', [])[0])
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
else:
|
||||
tool_return.errmsg = result.get('text', '') if isinstance(
|
||||
result, dict) else result
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
|
||||
|
||||
def extract_code(text):
|
||||
import json5
|
||||
|
||||
# Match triple backtick blocks first
|
||||
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
||||
# Match single backtick blocks second
|
||||
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
||||
if triple_match:
|
||||
text = triple_match.group(1)
|
||||
elif single_match:
|
||||
text = single_match.group(1)
|
||||
else:
|
||||
try:
|
||||
text = json5.loads(text)['code']
|
||||
except Exception:
|
||||
pass
|
||||
# If no code blocks found, return original text
|
||||
return text
|
||||
|
||||
|
||||
def escape_ansi(line):
|
||||
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', line)
|
||||
|
||||
|
||||
def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
|
||||
import PIL.Image
|
||||
image_file = str(uuid.uuid4()) + '.png'
|
||||
local_image_file = os.path.join(work_dir, image_file)
|
||||
|
||||
png_bytes = base64.b64decode(image_base64)
|
||||
assert isinstance(png_bytes, bytes)
|
||||
bytes_io = io.BytesIO(png_bytes)
|
||||
PIL.Image.open(bytes_io).save(local_image_file, 'png')
|
||||
|
||||
return local_image_file
|
||||
|
||||
|
||||
# local test for code interpreter
|
||||
def get_multiline_input(hint):
|
||||
print(hint)
|
||||
print('// Press ENTER to make a new line. Press CTRL-D to end input.')
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
except EOFError: # CTRL-D
|
||||
break
|
||||
lines.append(line)
|
||||
print('// Input received.')
|
||||
if lines:
|
||||
return '\n'.join(lines)
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
code_interpreter = IPythonInterpreter()
|
||||
while True:
|
||||
print(code_interpreter(get_multiline_input('Enter python code:')))
|
||||
@@ -1,56 +0,0 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from lagent.llms.base_api import BaseAPIModel
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction
|
||||
|
||||
DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。
|
||||
当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。
|
||||
"""
|
||||
|
||||
|
||||
class LLMQA(BaseAction):
|
||||
"""An LLM Wrapper as BaseAction type.
|
||||
|
||||
Args:
|
||||
llm (BaseModel or BaseAPIModel): a LLM service which can chat.
|
||||
description (str): The description of the action. Defaults to
|
||||
None.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
self._llm = llm
|
||||
|
||||
def __call__(self, query: str) -> ActionReturn:
|
||||
"""Return the QA response.
|
||||
|
||||
Args:
|
||||
query (str): The query content.
|
||||
|
||||
Returns:
|
||||
ActionReturn: The action return.
|
||||
"""
|
||||
|
||||
tool_return = ActionReturn(url=None, args=None)
|
||||
try:
|
||||
response = self._llm.generate_from_template(query, 512)
|
||||
tool_return.result = dict(text=str(response))
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
except Exception as e:
|
||||
tool_return.result = dict(text=str(e))
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
143
lagent/actions/parser.py
Normal file
143
lagent/actions/parser.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import json
|
||||
import re
|
||||
from ast import literal_eval
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Parsing exception class."""
|
||||
|
||||
def __init__(self, err_msg: str):
|
||||
self.err_msg = err_msg
|
||||
|
||||
|
||||
class BaseParser:
|
||||
"""Base parser to process inputs and outputs of actions.
|
||||
|
||||
Args:
|
||||
action (:class:`BaseAction`): action to validate
|
||||
|
||||
Attributes:
|
||||
PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
|
||||
LLMs should follow when generating arguments for decided tools.
|
||||
"""
|
||||
|
||||
PARAMETER_DESCRIPTION: str = ''
|
||||
|
||||
def __init__(self, action):
|
||||
self.action = action
|
||||
self._api2param = {}
|
||||
self._api2required = {}
|
||||
# perform basic argument validation
|
||||
if action.description:
|
||||
for api in action.description.get('api_list',
|
||||
[action.description]):
|
||||
name = (f'{action.name}.{api["name"]}'
|
||||
if self.action.is_toolkit else api['name'])
|
||||
required_parameters = set(api['required'])
|
||||
all_parameters = {j['name'] for j in api['parameters']}
|
||||
if not required_parameters.issubset(all_parameters):
|
||||
raise ValueError(
|
||||
f'unknown parameters for function "{name}": '
|
||||
f'{required_parameters - all_parameters}')
|
||||
if self.PARAMETER_DESCRIPTION:
|
||||
api['parameter_description'] = self.PARAMETER_DESCRIPTION
|
||||
api_name = api['name'] if self.action.is_toolkit else 'run'
|
||||
self._api2param[api_name] = api['parameters']
|
||||
self._api2required[api_name] = api['required']
|
||||
|
||||
def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
|
||||
"""Parse inputs LLMs generate for the action.
|
||||
|
||||
Args:
|
||||
inputs (:class:`str`): input string extracted from responses
|
||||
|
||||
Returns:
|
||||
:class:`dict`: processed input
|
||||
"""
|
||||
inputs = {self._api2param[name][0]['name']: inputs}
|
||||
return inputs
|
||||
|
||||
def parse_outputs(self, outputs: Any) -> List[dict]:
|
||||
"""Parser outputs returned by the action.
|
||||
|
||||
Args:
|
||||
outputs (:class:`Any`): raw output of the action
|
||||
|
||||
Returns:
|
||||
:class:`List[dict]`: processed output of which each member is a
|
||||
dictionary with two keys - 'type' and 'content'.
|
||||
"""
|
||||
if isinstance(outputs, dict):
|
||||
outputs = json.dumps(outputs, ensure_ascii=False)
|
||||
elif not isinstance(outputs, str):
|
||||
outputs = str(outputs)
|
||||
return [{'type': 'text', 'content': outputs}]
|
||||
|
||||
|
||||
class JsonParser(BaseParser):
|
||||
"""Json parser to convert input string into a dictionary.
|
||||
|
||||
Args:
|
||||
action (:class:`BaseAction`): action to validate
|
||||
"""
|
||||
|
||||
PARAMETER_DESCRIPTION = (
|
||||
'If you call this tool, you must pass arguments in '
|
||||
'the JSON format {key: value}, where the key is the parameter name.')
|
||||
|
||||
def parse_inputs(self,
|
||||
inputs: Union[str, dict],
|
||||
name: str = 'run') -> dict:
|
||||
if not isinstance(inputs, dict):
|
||||
try:
|
||||
match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
|
||||
re.S)
|
||||
if match:
|
||||
inputs = match.group(2).strip()
|
||||
inputs = json.loads(inputs)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ParseError(f'invalid json format: {inputs}') from exc
|
||||
input_keys = set(inputs)
|
||||
all_keys = {param['name'] for param in self._api2param[name]}
|
||||
if not input_keys.issubset(all_keys):
|
||||
raise ParseError(f'unknown arguments: {input_keys - all_keys}')
|
||||
required_keys = set(self._api2required[name])
|
||||
if not input_keys.issuperset(required_keys):
|
||||
raise ParseError(
|
||||
f'missing required arguments: {required_keys - input_keys}')
|
||||
return inputs
|
||||
|
||||
|
||||
class TupleParser(BaseParser):
|
||||
"""Tuple parser to convert input string into a tuple.
|
||||
|
||||
Args:
|
||||
action (:class:`BaseAction`): action to validate
|
||||
"""
|
||||
|
||||
PARAMETER_DESCRIPTION = (
|
||||
'If you call this tool, you must pass arguments in the tuple format '
|
||||
'like (arg1, arg2, arg3), and the arguments are ordered.')
|
||||
|
||||
def parse_inputs(self,
|
||||
inputs: Union[str, tuple],
|
||||
name: str = 'run') -> dict:
|
||||
if not isinstance(inputs, tuple):
|
||||
try:
|
||||
inputs = literal_eval(inputs)
|
||||
except Exception as exc:
|
||||
raise ParseError(f'invalid tuple format: {inputs}') from exc
|
||||
if len(inputs) < len(self._api2required[name]):
|
||||
raise ParseError(
|
||||
f'API takes {len(self._api2required[name])} required positional '
|
||||
f'arguments but {len(inputs)} were given')
|
||||
if len(inputs) > len(self._api2param[name]):
|
||||
raise ParseError(
|
||||
f'API takes {len(self._api2param[name])} positional arguments '
|
||||
f'but {len(inputs)} were given')
|
||||
inputs = {
|
||||
self._api2param[name][i]['name']: item
|
||||
for i, item in enumerate(inputs)
|
||||
}
|
||||
return inputs
|
||||
158
lagent/actions/ppt.py
Normal file
158
lagent/actions/ppt.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
|
||||
THEME_MAPPING = {
|
||||
'Default': {
|
||||
'template': None,
|
||||
'title': 'Title Slide',
|
||||
'single': 'Title and Content',
|
||||
'two': 'Two Content',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PPT(BaseAction):
|
||||
"""Plugin to create ppt slides with text, paragraph, images in good looking styles."""
|
||||
|
||||
def __init__(self,
|
||||
theme_mapping: Optional[Dict[str, dict]] = None,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True):
|
||||
super().__init__(description, parser, enable)
|
||||
self.theme_mapping = theme_mapping or THEME_MAPPING
|
||||
self.pointer = None
|
||||
self.location = None
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def create_file(self, theme: str, abs_location: str) -> dict:
|
||||
"""Create a pptx file with specific themes.
|
||||
|
||||
Args:
|
||||
theme (:class:`str`): the theme used. The value should be one of ['Default'].
|
||||
abs_location (:class:`str`): the ppt file's absolute location
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
from pptx import Presentation
|
||||
self.location = abs_location
|
||||
try:
|
||||
self.pointer = Presentation(self.theme_mapping[theme]['template'])
|
||||
self.pointer.slide_master.name = theme
|
||||
# print('created')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return dict(status='created a ppt file.')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def add_first_page(self, title: str, subtitle: str) -> dict:
|
||||
"""Add the first page of ppt.
|
||||
|
||||
Args:
|
||||
title (:class:`str`): the title of ppt
|
||||
subtitle (:class:`str`): the subtitle of ppt
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
layout_name = self.theme_mapping[
|
||||
self.pointer.slide_master.name]['title']
|
||||
layout = next(i for i in self.pointer.slide_master.slide_layouts
|
||||
if i.name == layout_name)
|
||||
slide = self.pointer.slides.add_slide(layout)
|
||||
ph_title, ph_subtitle = slide.placeholders
|
||||
ph_title.text = title
|
||||
if subtitle:
|
||||
ph_subtitle.text = subtitle
|
||||
return dict(status='added page')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def add_text_page(self, title: str, bullet_items: str) -> dict:
|
||||
"""Add text page of ppt.
|
||||
|
||||
Args:
|
||||
title (:class:`str`): the title of the page
|
||||
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
layout_name = self.theme_mapping[
|
||||
self.pointer.slide_master.name]['single']
|
||||
layout = next(i for i in self.pointer.slide_master.slide_layouts
|
||||
if i.name == layout_name)
|
||||
slide = self.pointer.slides.add_slide(layout)
|
||||
ph_title, ph_body = slide.placeholders
|
||||
ph_title.text = title
|
||||
ph = ph_body
|
||||
tf = ph.text_frame
|
||||
for i, item in enumerate(bullet_items.split('[SPAN]')):
|
||||
if i == 0:
|
||||
p = tf.paragraphs[0]
|
||||
else:
|
||||
p = tf.add_paragraph()
|
||||
p.text = item.strip()
|
||||
p.level = 0
|
||||
return dict(status='added page')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def add_text_image_page(self, title: str, bullet_items: str,
|
||||
image: str) -> dict:
|
||||
"""Add a text page with one image. Image should be a path.
|
||||
|
||||
Args:
|
||||
title (:class:`str`): the title of the page
|
||||
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
||||
image (:class:`str`): the path of the image
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
from PIL import Image
|
||||
layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
|
||||
layout = next(i for i in self.pointer.slide_master.slide_layouts
|
||||
if i.name == layout_name)
|
||||
slide = self.pointer.slides.add_slide(layout)
|
||||
ph_title, ph_body1, ph_body2 = slide.placeholders
|
||||
ph_title.text = title
|
||||
ph = ph_body2
|
||||
image = Image.open(image)
|
||||
image_pil = image.to_pil()
|
||||
left = ph.left
|
||||
width = ph.width
|
||||
height = int(width / image_pil.width * image_pil.height)
|
||||
top = (ph.top + (ph.top + ph.height)) // 2 - height // 2
|
||||
slide.shapes.add_picture(image.to_path(), left, top, width, height)
|
||||
|
||||
ph = ph_body1
|
||||
tf = ph.text_frame
|
||||
for i, item in enumerate(bullet_items.split('[SPAN]')):
|
||||
if i == 0:
|
||||
p = tf.paragraphs[0]
|
||||
else:
|
||||
p = tf.add_paragraph()
|
||||
p.text = item.strip()
|
||||
p.level = 0
|
||||
|
||||
return dict(status='added page')
|
||||
|
||||
@tool_api(explode_return=True)
|
||||
def submit_file(self) -> dict:
|
||||
"""When all steps done, YOU MUST use submit_file() to submit your work.
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
# file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx')
|
||||
# self.pointer.save(file_path)
|
||||
# retreival_url = upload_file(file_path)
|
||||
self.pointer.save(self.location)
|
||||
return dict(status=f'submitted. view ppt at {self.location}')
|
||||
@@ -1,11 +1,11 @@
|
||||
# flake8: noqa: E501
|
||||
import copy
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from func_timeout import FunctionTimedOut, func_set_timeout
|
||||
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
|
||||
@@ -29,72 +29,72 @@ class GenericRuntime:
|
||||
return eval(expr, self._global_vars)
|
||||
|
||||
|
||||
DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数,
|
||||
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
||||
```python
|
||||
# import 依赖包
|
||||
import xxx
|
||||
def solution():
|
||||
# 初始化一些变量
|
||||
variable_names_with_real_meaning = xxx
|
||||
# 步骤一
|
||||
mid_variable = func(variable_names_with_real_meaning)
|
||||
# 步骤 x
|
||||
mid_variable = func(mid_variable)
|
||||
# 最后结果
|
||||
final_answer = func(mid_variable)
|
||||
return final_answer
|
||||
```"""
|
||||
|
||||
|
||||
class PythonInterpreter(BaseAction):
|
||||
"""A Python executor that can execute Python scripts.
|
||||
|
||||
Args:
|
||||
description (str): The description of the action. Defaults to
|
||||
DEFAULT_DESCRIPTION.
|
||||
answer_symbol (str, Optional): the answer symbol from LLM
|
||||
answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
|
||||
answer_expr (str, Optional): the answer function name of the Python
|
||||
script. Default to 'solution()'.
|
||||
answer_from_stdout (boolean): whether the execution results is from
|
||||
stdout.
|
||||
name (str, optional): The name of the action. If None, the name will
|
||||
be class nameDefaults to None.
|
||||
script. Defaults to ``'solution()'``.
|
||||
answer_from_stdout (boolean, Optional): whether the execution results is from
|
||||
stdout. Defaults to ``False``.
|
||||
timeout (int, Optional): Upper bound of waiting time for Python script execution.
|
||||
Defaults to ``20``.
|
||||
description (dict, Optional): The description of the action. Defaults to ``None``.
|
||||
parser (Type[BaseParser]): The parser class to process the
|
||||
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
||||
enable (bool, optional): Whether the action is enabled. Defaults to
|
||||
True.
|
||||
disable_description (str, optional): The description of the action when
|
||||
it is disabled. Defaults to None.
|
||||
timeout (int): Upper bound of waiting time for Python script execution.
|
||||
``True``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
answer_symbol: Optional[str] = None,
|
||||
answer_expr: Optional[str] = 'solution()',
|
||||
answer_from_stdout: bool = False,
|
||||
name: Optional[str] = None,
|
||||
enable: bool = True,
|
||||
disable_description: Optional[str] = None,
|
||||
timeout: int = 20) -> None:
|
||||
super().__init__(description, name, enable, disable_description)
|
||||
|
||||
timeout: int = 20,
|
||||
description: Optional[dict] = None,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True) -> None:
|
||||
super().__init__(description, parser, enable)
|
||||
self.answer_symbol = answer_symbol
|
||||
self.answer_expr = answer_expr
|
||||
self.answer_from_stdout = answer_from_stdout
|
||||
self.timeout = timeout
|
||||
|
||||
def __call__(self, command: str) -> ActionReturn:
|
||||
@tool_api
|
||||
def run(self, command: str) -> ActionReturn:
|
||||
"""用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
||||
|
||||
```python
|
||||
# import 依赖包
|
||||
import xxx
|
||||
def solution():
|
||||
# 初始化一些变量
|
||||
variable_names_with_real_meaning = xxx
|
||||
# 步骤一
|
||||
mid_variable = func(variable_names_with_real_meaning)
|
||||
# 步骤 x
|
||||
mid_variable = func(mid_variable)
|
||||
# 最后结果
|
||||
final_answer = func(mid_variable)
|
||||
return final_answer
|
||||
```
|
||||
|
||||
Args:
|
||||
command (:class:`str`): Python code snippet
|
||||
"""
|
||||
from func_timeout import FunctionTimedOut, func_set_timeout
|
||||
self.runtime = GenericRuntime()
|
||||
try:
|
||||
tool_return = func_set_timeout(self.timeout)(self._call)(command)
|
||||
except FunctionTimedOut as e:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return = ActionReturn(type=self.name)
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
|
||||
def _call(self, command: str) -> ActionReturn:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return = ActionReturn(type=self.name)
|
||||
try:
|
||||
if '```python' in command:
|
||||
command = command.split('```python')[1].split('```')[0]
|
||||
@@ -124,7 +124,7 @@ class PythonInterpreter(BaseAction):
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
try:
|
||||
tool_return.result = dict(text=str(res))
|
||||
tool_return.result = [dict(type='text', content=str(res))]
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
except Exception as e:
|
||||
tool_return.errmsg = repr(e)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .autogpt import AutoGPT
|
||||
from .base_agent import BaseAgent
|
||||
from .react import ReAct
|
||||
from .rewoo import ReWOO
|
||||
|
||||
__all__ = ['BaseAgent', 'ReAct', 'AutoGPT', 'ReWOO']
|
||||
from .autogpt import * # noqa: F401, F403
|
||||
from .base_agent import * # noqa: F401, F403
|
||||
from .internlm2_agent import * # noqa: F401, F403
|
||||
from .react import * # noqa: F401, F403
|
||||
from .rewoo import * # noqa: F401, F403
|
||||
|
||||
@@ -219,21 +219,20 @@ class AutoGPTProtocol:
|
||||
dict(role='user', content=self.triggering_prompt))
|
||||
return formatted_data
|
||||
|
||||
def format_response(self, action_return):
|
||||
"""format the final response at current step.
|
||||
def format_response(self, action_return) -> dict:
|
||||
"""Format the final response at current step.
|
||||
|
||||
Args:
|
||||
action_return (ActionReturn): return value of the current action.
|
||||
|
||||
Returns:
|
||||
str: the final response at current step.
|
||||
dict: the final response at current step.
|
||||
"""
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.result['text']
|
||||
response = f'Command {action_return.type} returned: {response}'
|
||||
response = f'Command {action_return.type} returned: {response.format_result()}'
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
return response
|
||||
return dict(role='system', content=response)
|
||||
|
||||
|
||||
class AutoGPT(BaseAgent):
|
||||
@@ -260,29 +259,26 @@ class AutoGPT(BaseAgent):
|
||||
super().__init__(
|
||||
llm=llm, action_executor=action_executor, protocol=protocol)
|
||||
|
||||
def chat(self, goal: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
def chat(self, goal: str, **kwargs) -> AgentReturn:
|
||||
inner_history = []
|
||||
agent_return = AgentReturn()
|
||||
default_response = '对不起,我无法回答你的问题'
|
||||
default_response = 'Sorry that I cannot answer your question.'
|
||||
for _ in range(self.max_turn):
|
||||
prompt = self._protocol.format(
|
||||
goal=goal,
|
||||
inner_history=self._inner_history,
|
||||
inner_history=inner_history,
|
||||
action_executor=self._action_executor)
|
||||
response = self._llm.generate_from_template(prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
response = self._llm.chat(prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=response))
|
||||
action, action_input = self._protocol.parse(
|
||||
response, self._action_executor)
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
action, action_input)
|
||||
agent_return.actions.append(action_return)
|
||||
if action_return.type == self._action_executor.finish_action.name:
|
||||
agent_return.response = action_return.result['text']
|
||||
agent_return.response = action_return.format_result()
|
||||
return agent_return
|
||||
self._inner_history.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=self._protocol.format_response(action_return)))
|
||||
inner_history.append(self._protocol.format_response(action_return))
|
||||
agent_return.inner_steps = inner_history
|
||||
agent_return.response = default_response
|
||||
return agent_return
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
@@ -19,8 +17,6 @@ class BaseAgent:
|
||||
|
||||
def __init__(self, llm: BaseModel, action_executor: ActionExecutor,
|
||||
protocol: object) -> None:
|
||||
|
||||
self._session_history = []
|
||||
self._llm = llm
|
||||
self._action_executor = action_executor
|
||||
self._protocol = protocol
|
||||
@@ -41,9 +37,5 @@ class BaseAgent:
|
||||
"""
|
||||
self._action_executor.del_action(name)
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
def chat(self, message: str, **kwargs) -> AgentReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def session_history(self) -> List:
|
||||
return self._session_history
|
||||
|
||||
368
lagent/agents/internlm2_agent.py
Normal file
368
lagent/agents/internlm2_agent.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.agents.base_agent import BaseAgent
|
||||
from lagent.llms import BaseAPIModel, BaseModel
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn, AgentStatusCode, ModelStatusCode # noqa: E501
|
||||
|
||||
API_PREFIX = (
|
||||
"This is the subfunction for tool '{tool_name}', you can use this tool. "
|
||||
'The description of this function is: \n{description}')
|
||||
|
||||
META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')
|
||||
|
||||
INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
|
||||
'当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
|
||||
'这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
|
||||
'复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
|
||||
'文本处理和分析(比如文本解析和自然语言处理),'
|
||||
'机器学习和数据科学(用于展示模型训练和数据可视化),'
|
||||
'以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')
|
||||
|
||||
PLUGIN_CN = ('你可以使用如下工具:'
|
||||
'\n{prompt}\n'
|
||||
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
|
||||
'同时注意你可以使用的工具,不要随意捏造!')
|
||||
|
||||
|
||||
class Internlm2Protocol:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta_prompt: str = META_CN,
|
||||
interpreter_prompt: str = INTERPRETER_CN,
|
||||
plugin_prompt: str = PLUGIN_CN,
|
||||
few_shot: Optional[List] = None,
|
||||
language: Dict = dict(
|
||||
begin='',
|
||||
end='',
|
||||
belong='assistant',
|
||||
),
|
||||
tool: Dict = dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
),
|
||||
execute: Dict = dict(
|
||||
role='execute', begin='', end='', fallback_role='environment'),
|
||||
) -> None:
|
||||
self.meta_prompt = meta_prompt
|
||||
self.interpreter_prompt = interpreter_prompt
|
||||
self.plugin_prompt = plugin_prompt
|
||||
self.roles_cfg = dict(tool=tool, language=language)
|
||||
self.language = language
|
||||
self.execute = execute
|
||||
self.tool = tool
|
||||
self.few_shot = few_shot
|
||||
|
||||
def format_sub_role(self, messages: List[Dict]) -> List[Dict]:
|
||||
|
||||
def format_interpreter(message):
|
||||
if isinstance(message['content'], dict):
|
||||
# assert message['content']['name'] == 'IPythonInterpreter'
|
||||
return dict(
|
||||
role=message['role'],
|
||||
name=message['name'],
|
||||
content=message['content']['parameters']['command'])
|
||||
else:
|
||||
return message
|
||||
|
||||
def format_plugin(message):
|
||||
if isinstance(message['content'], dict):
|
||||
return dict(
|
||||
role=message['role'],
|
||||
name=message['name'],
|
||||
content=json.dumps(message['content']))
|
||||
else:
|
||||
return message
|
||||
|
||||
new_message = list()
|
||||
for message in messages:
|
||||
if message['role'] in [
|
||||
'assistant', 'user', 'system', 'environment'
|
||||
]:
|
||||
new_message.append(message)
|
||||
continue
|
||||
role_cfg = self.roles_cfg[message['role']]
|
||||
begin = role_cfg['begin']
|
||||
if message['role'] == 'tool':
|
||||
if message['name'] == 'interpreter':
|
||||
message = format_interpreter(message)
|
||||
elif message['name'] == 'plugin':
|
||||
message = format_plugin(message)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
begin = role_cfg['begin'].format(
|
||||
start_token=role_cfg.get('start_token', ''),
|
||||
name=role_cfg.get('name_map', {}).get(message['name'], ''))
|
||||
new_content = begin + message['content'] + role_cfg['end']
|
||||
if role_cfg.get('fallback_role'):
|
||||
new_message.append(
|
||||
dict(role=role_cfg['fallback_role'], content=new_content))
|
||||
elif role_cfg.get('belong'):
|
||||
if new_message[-1]['role'] != role_cfg.get('belong'):
|
||||
new_message.append(
|
||||
dict(role=role_cfg.get('belong'), content=new_content))
|
||||
else:
|
||||
new_message[-1]['content'] += new_content
|
||||
else:
|
||||
new_message.append(
|
||||
dict(role=message['role'], content=new_content))
|
||||
|
||||
return new_message
|
||||
|
||||
def format(self,
|
||||
inner_step: List[Dict],
|
||||
plugin_executor: ActionExecutor = None,
|
||||
interpreter_executor: ActionExecutor = None,
|
||||
**kwargs) -> list:
|
||||
formatted = []
|
||||
if self.meta_prompt:
|
||||
formatted.append(dict(role='system', content=self.meta_prompt))
|
||||
if interpreter_executor and self.interpreter_prompt:
|
||||
interpreter_info = interpreter_executor.get_actions_info()[0]
|
||||
interpreter_prompt = self.interpreter_prompt.format(
|
||||
code_prompt=interpreter_info['description'])
|
||||
formatted.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=interpreter_prompt,
|
||||
name='interpreter'))
|
||||
if plugin_executor and plugin_executor.actions and self.plugin_prompt:
|
||||
plugin_descriptions = []
|
||||
for api_info in plugin_executor.get_actions_info():
|
||||
plugin = deepcopy(api_info)
|
||||
if isinstance(api_info, dict):
|
||||
tool_name = api_info['name'].split('.')[0]
|
||||
plugin['description'] = API_PREFIX.format(
|
||||
tool_name=tool_name, description=plugin['description'])
|
||||
# only keep required parameters
|
||||
required_parameters = [
|
||||
param for param in plugin['parameters']
|
||||
if param['name'] in plugin['required']
|
||||
]
|
||||
plugin['parameters'] = required_parameters
|
||||
plugin_descriptions.append(plugin)
|
||||
plugin_prompt = self.plugin_prompt.format(
|
||||
prompt=json.dumps(
|
||||
plugin_descriptions, ensure_ascii=False, indent=4))
|
||||
formatted.append(
|
||||
dict(role='system', content=plugin_prompt, name='plugin'))
|
||||
if self.few_shot:
|
||||
for few_shot in self.few_shot:
|
||||
formatted += self.format_sub_role(few_shot)
|
||||
formatted += self.format_sub_role(inner_step)
|
||||
return formatted
|
||||
|
||||
def parse(self, message, plugin_executor: ActionExecutor,
|
||||
interpreter_executor: ActionExecutor):
|
||||
if self.language['begin']:
|
||||
message = message.split(self.language['begin'])[-1]
|
||||
if self.tool['name_map']['plugin'] in message:
|
||||
message, action = message.split(
|
||||
f"{self.tool['start_token']}{self.tool['name_map']['plugin']}")
|
||||
action = action.split(self.tool['end'].strip())[0]
|
||||
return 'plugin', message, action
|
||||
if self.tool['name_map']['interpreter'] in message:
|
||||
message, code = message.split(
|
||||
f"{self.tool['start_token']}"
|
||||
f"{self.tool['name_map']['interpreter']}")
|
||||
code = code.split(self.tool['end'].strip())[0].strip()
|
||||
return 'interpreter', message, dict(
|
||||
name=interpreter_executor.action_names()[0],
|
||||
parameters=dict(command=code))
|
||||
return None, message.split(self.tool['start_token'])[0], None
|
||||
|
||||
def format_response(self, action_return, name) -> dict:
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.format_result()
|
||||
else:
|
||||
response = str(action_return.errmsg)
|
||||
content = self.execute['begin'] + response + self.execute['end']
|
||||
if self.execute.get('fallback_role'):
|
||||
return dict(
|
||||
role=self.execute['fallback_role'], content=content, name=name)
|
||||
elif self.execute.get('belong'):
|
||||
return dict(
|
||||
role=self.execute['belong'], content=content, name=name)
|
||||
return dict(role=self.execute['role'], content=response, name=name)
|
||||
|
||||
|
||||
class Internlm2Agent(BaseAgent):
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
plugin_executor: ActionExecutor = None,
|
||||
interpreter_executor: ActionExecutor = None,
|
||||
protocol=Internlm2Protocol(),
|
||||
max_turn: int = 3) -> None:
|
||||
self.max_turn = max_turn
|
||||
self._interpreter_executor = interpreter_executor
|
||||
super().__init__(
|
||||
llm=llm, action_executor=plugin_executor, protocol=protocol)
|
||||
|
||||
def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
message = dict(role='user', content=message)
|
||||
if isinstance(message, dict):
|
||||
message = [message]
|
||||
inner_history = message[:]
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
for _ in range(self.max_turn):
|
||||
# list of dict
|
||||
prompt = self._protocol.format(
|
||||
inner_step=inner_history,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
response = self._llm.chat(prompt, **kwargs)
|
||||
name, language, action = self._protocol.parse(
|
||||
message=response,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
if name:
|
||||
if name == 'plugin':
|
||||
if self._action_executor:
|
||||
executor = self._action_executor
|
||||
else:
|
||||
logging.info(msg='No plugin is instantiated!')
|
||||
continue
|
||||
try:
|
||||
action = json.loads(action)
|
||||
except Exception as e:
|
||||
logging.info(msg=f'Invaild action {e}')
|
||||
continue
|
||||
elif name == 'interpreter':
|
||||
if self._interpreter_executor:
|
||||
executor = self._interpreter_executor
|
||||
else:
|
||||
logging.info(msg='No interpreter is instantiated!')
|
||||
continue
|
||||
else:
|
||||
logging.info(
|
||||
msg=(f"Invalid name '{name}'. Currently only 'plugin' "
|
||||
"and 'interpreter' are supported."))
|
||||
continue
|
||||
action_return: ActionReturn = executor(action['name'],
|
||||
action['parameters'])
|
||||
action_return.thought = language
|
||||
agent_return.actions.append(action_return)
|
||||
inner_history.append(dict(role='language', content=language))
|
||||
if not name or action_return.type == executor.finish_action.name:
|
||||
agent_return.response = language
|
||||
agent_return.state = AgentStatusCode.END
|
||||
break
|
||||
else:
|
||||
inner_history.append(
|
||||
dict(role='tool', content=action, name=name))
|
||||
inner_history.append(
|
||||
self._protocol.format_response(action_return, name=name))
|
||||
agent_return.inner_steps = inner_history[offset:]
|
||||
return agent_return
|
||||
|
||||
def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
message = dict(role='user', content=message)
|
||||
if isinstance(message, dict):
|
||||
message = [message]
|
||||
inner_history = message[:]
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
last_agent_state = AgentStatusCode.SESSION_READY
|
||||
for _ in range(self.max_turn):
|
||||
# list of dict
|
||||
prompt = self._protocol.format(
|
||||
inner_step=inner_history,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
response = ''
|
||||
for model_state, res, _ in self._llm.stream_chat(prompt, **kwargs):
|
||||
model_state: ModelStatusCode
|
||||
response = res
|
||||
if model_state.value < 0:
|
||||
agent_return.state = getattr(AgentStatusCode,
|
||||
model_state.name)
|
||||
yield deepcopy(agent_return)
|
||||
return
|
||||
else:
|
||||
name, language, action = self._protocol.parse(
|
||||
message=response,
|
||||
plugin_executor=self._action_executor,
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
if name:
|
||||
if model_state == ModelStatusCode.END:
|
||||
agent_state = last_agent_state + 1
|
||||
if name == 'plugin':
|
||||
if self._action_executor:
|
||||
executor = self._action_executor
|
||||
else:
|
||||
logging.info(
|
||||
msg='No plugin is instantiated!')
|
||||
continue
|
||||
try:
|
||||
action = json.loads(action)
|
||||
except Exception as e:
|
||||
logging.info(msg=f'Invaild action {e}')
|
||||
continue
|
||||
elif name == 'interpreter':
|
||||
if self._interpreter_executor:
|
||||
executor = self._interpreter_executor
|
||||
else:
|
||||
logging.info(
|
||||
msg='No interpreter is instantiated!')
|
||||
continue
|
||||
agent_return.state = agent_state
|
||||
agent_return.response = action
|
||||
else:
|
||||
agent_state = (
|
||||
AgentStatusCode.PLUGIN_START if name
|
||||
== 'plugin' else AgentStatusCode.CODING)
|
||||
if agent_state != last_agent_state:
|
||||
# agent_return.state = agent_state
|
||||
agent_return.response = language
|
||||
yield deepcopy(agent_return)
|
||||
agent_return.state = agent_state
|
||||
agent_return.response = action
|
||||
else:
|
||||
agent_state = AgentStatusCode.STREAM_ING
|
||||
agent_return.state = agent_state
|
||||
agent_return.response = language
|
||||
last_agent_state = agent_state
|
||||
yield deepcopy(agent_return)
|
||||
if name:
|
||||
action_return: ActionReturn = executor(action['name'],
|
||||
action['parameters'])
|
||||
action_return.thought = language
|
||||
agent_return.actions.append(action_return)
|
||||
inner_history.append(dict(role='language', content=language))
|
||||
if not name:
|
||||
agent_return.response = language
|
||||
break
|
||||
elif action_return.type == executor.finish_action.name:
|
||||
try:
|
||||
response = action_return.args['text']['response']
|
||||
except Exception:
|
||||
logging.info(msg='Unable to parse FinishAction.')
|
||||
response = ''
|
||||
agent_return.response = response
|
||||
break
|
||||
else:
|
||||
inner_history.append(
|
||||
dict(role='tool', content=action, name=name))
|
||||
inner_history.append(
|
||||
self._protocol.format_response(action_return, name=name))
|
||||
agent_state += 1
|
||||
agent_return.state = agent_state
|
||||
yield agent_return
|
||||
agent_return.inner_steps = deepcopy(inner_history[offset:])
|
||||
agent_return.state = AgentStatusCode.END
|
||||
yield agent_return
|
||||
@@ -32,16 +32,17 @@ FORCE_STOP_PROMPT_CN = '你需要基于历史消息返回一个最终结果'
|
||||
# The English prompts for ReAct
|
||||
|
||||
CALL_PROTOCOL_EN = """You are a assistant who can utilize external tools.
|
||||
{tool_description}
|
||||
To use a tool, please use the following format:
|
||||
```
|
||||
{thought}: Think what you need to solve, do you need to use tools?
|
||||
{action}: the tool name, should be one of [{action_names}]
|
||||
{action_input}: the input to the action
|
||||
{thought}Think what you need to solve, do you need to use tools?
|
||||
{action}the tool name, should be one of [{action_names}]
|
||||
{action_input}the input to the action
|
||||
```
|
||||
The response after utilizing tools should using the following format:
|
||||
```
|
||||
{response}: the results after call the tool.
|
||||
``
|
||||
{response}the results after call the tool.
|
||||
```
|
||||
If you already know the answer, or you do not need to use tools,
|
||||
please using the following format to reply:
|
||||
```
|
||||
@@ -76,13 +77,13 @@ class ReActProtocol:
|
||||
belong='assistant'),
|
||||
action: dict = dict(role='ACTION', begin='Action:', end='\n'),
|
||||
action_input: dict = dict(
|
||||
role='ARGS', begin='ActionInput:', end='\n'),
|
||||
role='ARGS', begin='Action Input:', end='\n'),
|
||||
response: dict = dict(
|
||||
role='RESPONSE', begin='Response:', end='\n'),
|
||||
finish: dict = dict(
|
||||
role='FINISH', begin='FinalAnswer:', end='\n'),
|
||||
call_protocol: str = CALL_PROTOCOL_CN,
|
||||
force_stop: str = FORCE_STOP_PROMPT_CN) -> None:
|
||||
role='FINISH', begin='Final Answer:', end='\n'),
|
||||
call_protocol: str = CALL_PROTOCOL_EN,
|
||||
force_stop: str = FORCE_STOP_PROMPT_EN) -> None:
|
||||
self.call_protocol = call_protocol
|
||||
self.force_stop = force_stop
|
||||
self.thought = thought
|
||||
@@ -168,20 +169,22 @@ class ReActProtocol:
|
||||
action_input = arg_match[-1]
|
||||
return thought, action.strip(), action_input.strip().strip('"')
|
||||
|
||||
def format_response(self, action_return: ActionReturn) -> str:
|
||||
"""format the final response at current step.
|
||||
def format_response(self, action_return: ActionReturn) -> dict:
|
||||
"""Format the final response at current step.
|
||||
|
||||
Args:
|
||||
action_return (ActionReturn): return value of the current action.
|
||||
|
||||
Returns:
|
||||
str: the final response at current step.
|
||||
dict: the final response at current step.
|
||||
"""
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.result['text']
|
||||
response = action_return.format_result()
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
return self.response['begin'] + response + self.response['end']
|
||||
return dict(
|
||||
role='system',
|
||||
content=self.response['begin'] + response + self.response['end'])
|
||||
|
||||
|
||||
class ReAct(BaseAgent):
|
||||
@@ -195,33 +198,40 @@ class ReAct(BaseAgent):
|
||||
protocol (ReActProtocol): a wrapper to generate prompt and
|
||||
parse the response from LLM / actions.
|
||||
max_turn (int): the maximum number of trails for LLM to generate
|
||||
plans that can be successfully parsed by ReWOO protocol.
|
||||
plans that can be successfully parsed by ReAct protocol.
|
||||
Defaults to 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
action_executor: ActionExecutor,
|
||||
protocol: ReActProtocol = ReActProtocol(),
|
||||
max_turn: int = 2) -> None:
|
||||
max_turn: int = 4) -> None:
|
||||
self.max_turn = max_turn
|
||||
super().__init__(
|
||||
llm=llm, action_executor=action_executor, protocol=protocol)
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
self._inner_history.append(dict(role='user', content=message))
|
||||
def chat(self, message: Union[str, dict, List[dict]],
|
||||
**kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
inner_history = [dict(role='user', content=message)]
|
||||
elif isinstance(message, dict):
|
||||
inner_history = [message]
|
||||
elif isinstance(message, list):
|
||||
inner_history = message[:]
|
||||
else:
|
||||
raise TypeError(f'unsupported type: {type(message)}')
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
force_stop = False
|
||||
default_response = '对不起,我无法回答你的问题'
|
||||
default_response = 'Sorry that I cannot answer your question.'
|
||||
for turn in range(self.max_turn):
|
||||
prompt = self._protocol.format(
|
||||
chat_history=self.session_history,
|
||||
inner_step=self._inner_history,
|
||||
chat_history=[],
|
||||
inner_step=inner_history,
|
||||
action_executor=self._action_executor,
|
||||
force_stop=force_stop)
|
||||
response = self._llm.generate_from_template(prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
force_stop=(turn == self.max_turn - 1))
|
||||
response = self._llm.chat(prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=response))
|
||||
thought, action, action_input = self._protocol.parse(
|
||||
response, self._action_executor)
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
@@ -229,17 +239,10 @@ class ReAct(BaseAgent):
|
||||
action_return.thought = thought
|
||||
agent_return.actions.append(action_return)
|
||||
if action_return.type == self._action_executor.finish_action.name:
|
||||
agent_return.response = action_return.result['text']
|
||||
return agent_return
|
||||
self._inner_history.append(
|
||||
dict(
|
||||
role='system',
|
||||
content=self._protocol.format_response(action_return)))
|
||||
if turn == self.max_turn - 1:
|
||||
force_stop = True
|
||||
agent_return.response = default_response
|
||||
# only append the user and final response
|
||||
self._session_history.append(dict(role='user', content=message))
|
||||
self._session_history.append(
|
||||
dict(role='assistant', content=agent_return.response))
|
||||
agent_return.response = action_return.format_result()
|
||||
break
|
||||
inner_history.append(self._protocol.format_response(action_return))
|
||||
else:
|
||||
agent_return.response = default_response
|
||||
agent_return.inner_steps = inner_history[offset:]
|
||||
return agent_return
|
||||
|
||||
@@ -94,10 +94,10 @@ class ReWOOProtocol:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
planner_prompt: str = PLANNER_PROMPT_CN,
|
||||
worker_prompt: str = WORKER_PROMPT_CN,
|
||||
solver_prompt: str = SOLVER_PROMPT_CN,
|
||||
reformat_prompt: str = REFORMAT_PROMPT_CN,
|
||||
planner_prompt: str = PLANNER_PROMPT_EN,
|
||||
worker_prompt: str = WORKER_PROMPT_EN,
|
||||
solver_prompt: str = SOLVER_PROMPT_EN,
|
||||
reformat_prompt: str = REFORMAT_PROMPT_EN,
|
||||
) -> None:
|
||||
self.planner_prompt = planner_prompt
|
||||
self.worker_prompt = worker_prompt
|
||||
@@ -191,7 +191,7 @@ class ReWOOProtocol:
|
||||
worker_log = ''
|
||||
for thought, action_return in zip(thought_list, action_return_list):
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
action_resp = action_return.result['text']
|
||||
action_resp = action_return.format_result()
|
||||
else:
|
||||
action_resp = action_return.errmsg
|
||||
worker_response = self.worker_prompt.format(
|
||||
@@ -226,9 +226,17 @@ class ReWOO(BaseAgent):
|
||||
|
||||
self.max_turn = max_turn
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
self._inner_history.append(dict(role='user', content=message))
|
||||
def chat(self, message: Union[str, dict, List[dict]],
|
||||
**kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
inner_history = [dict(role='user', content=message)]
|
||||
elif isinstance(message, dict):
|
||||
inner_history = [message]
|
||||
elif isinstance(message, list):
|
||||
inner_history = message[:]
|
||||
else:
|
||||
raise TypeError(f'unsupported type: {type(message)}')
|
||||
offset = len(inner_history)
|
||||
agent_return = AgentReturn()
|
||||
|
||||
# planner
|
||||
@@ -236,13 +244,12 @@ class ReWOO(BaseAgent):
|
||||
reformat_request = ''
|
||||
while turn_id < self.max_turn:
|
||||
planner_prompt = self._protocol.format_planner(
|
||||
chat_history=self.session_history,
|
||||
inner_step=self._inner_history,
|
||||
chat_history=[],
|
||||
inner_step=inner_history,
|
||||
action_executor=self._action_executor,
|
||||
reformat_request=reformat_request)
|
||||
response = self._llm.generate_from_template(planner_prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=response))
|
||||
response = self._llm.chat(planner_prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=response))
|
||||
try:
|
||||
thoughts, actions, actions_input = self._protocol.parse_worker(
|
||||
response)
|
||||
@@ -266,17 +273,17 @@ class ReWOO(BaseAgent):
|
||||
for prev_ptr in prev_ptrs:
|
||||
ptr_num = int(prev_ptr.strip('#E')) - 1 # start from 0
|
||||
actions_input[action_id] = actions_input[action_id].replace(
|
||||
prev_ptr, action_responses[ptr_num].result['text'])
|
||||
prev_ptr, action_responses[ptr_num].format_result())
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
actions[action_id], actions_input[action_id])
|
||||
action_responses.append(action_return)
|
||||
|
||||
solver_prompt, worker_log = self._protocol.format_solver(
|
||||
message, thoughts, action_responses)
|
||||
self._inner_history.append(dict(role='system', content=worker_log))
|
||||
inner_history.append(dict(role='system', content=worker_log))
|
||||
|
||||
final_response = self._llm.generate_from_template(solver_prompt, 512)
|
||||
self._inner_history.append(
|
||||
dict(role='assistant', content=final_response))
|
||||
final_response = self._llm.chat(solver_prompt, **kwargs)
|
||||
inner_history.append(dict(role='assistant', content=final_response))
|
||||
agent_return.inner_steps = inner_history[offset:]
|
||||
agent_return.response = final_response
|
||||
return agent_return
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from lagent.utils import is_module_exist
|
||||
from .base_api import BaseAPIModel
|
||||
from .base_llm import BaseModel
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
|
||||
from .lmdepoly_wrapper import LMDeployClient, LMDeployPipeline, LMDeployServer
|
||||
from .meta_template import INTERNLM2_META
|
||||
from .openai import GPTAPI
|
||||
from .vllm_wrapper import VllmModel
|
||||
|
||||
__all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI']
|
||||
|
||||
if is_module_exist('transformers'):
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401
|
||||
__all__.extend(['HFTransformer', 'HFTransformerCasualLM'])
|
||||
__all__ = [
|
||||
'BaseModel', 'BaseAPIModel', 'GPTAPI', 'LMDeployClient',
|
||||
'LMDeployPipeline', 'LMDeployServer', 'HFTransformer',
|
||||
'HFTransformerCasualLM', 'INTERNLM2_META', 'HFTransformerChat', 'VllmModel'
|
||||
]
|
||||
|
||||
@@ -1,92 +1,11 @@
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
from abc import abstractclassmethod
|
||||
from time import sleep
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .base_llm import BaseModel
|
||||
|
||||
|
||||
class BaseAPIModel(BaseModel):
|
||||
"""Base class for API model wrapper.
|
||||
|
||||
Args:
|
||||
model_type (str): The type of model.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
max_seq_len (int): The maximum sequence length of the model. Defaults
|
||||
to 2048.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
model_type: str,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None):
|
||||
self.model_type = model_type
|
||||
self.max_seq_len = max_seq_len
|
||||
self.meta_template = meta_template
|
||||
self.retry = retry
|
||||
self.query_per_second = query_per_second
|
||||
self.token_bucket = TokenBucket(query_per_second)
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs, max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str or list]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
|
||||
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
|
||||
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
|
||||
|
||||
# Count English words
|
||||
english_count = sum(len(part.split()) for part in english_parts)
|
||||
|
||||
# Count Chinese words
|
||||
chinese_count = sum(len(part) for part in chinese_parts)
|
||||
|
||||
return english_count + chinese_count
|
||||
|
||||
def wait(self):
|
||||
"""Wait till the next query can be sent.
|
||||
|
||||
Applicable in both single-thread and multi-thread environments.
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
||||
def to(self, device):
|
||||
pass
|
||||
|
||||
|
||||
class APITemplateParser:
|
||||
"""Intermidate prompt template parser, specifically for API models.
|
||||
|
||||
@@ -106,7 +25,7 @@ class APITemplateParser:
|
||||
'role in meta prompt must be unique!'
|
||||
self.roles[item['role']] = item.copy()
|
||||
|
||||
def parse_template(self, dialog: List[Union[str, List]]):
|
||||
def __call__(self, dialog: List[Union[str, List]]):
|
||||
"""Parse the intermidate prompt template, and wrap it with meta
|
||||
template if applicable. When the meta template is set and the input is
|
||||
a list, the return value will be a list containing the full
|
||||
@@ -119,7 +38,6 @@ class APITemplateParser:
|
||||
Args:
|
||||
dialog (List[str or list]): An intermidate prompt
|
||||
template (potentially before being wrapped by meta template).
|
||||
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
||||
|
||||
Returns:
|
||||
List[str or list]: The finalized prompt or a conversation.
|
||||
@@ -200,11 +118,10 @@ class APITemplateParser:
|
||||
return res
|
||||
|
||||
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
|
||||
|
||||
merged_prompt = self.roles.get(
|
||||
role_prompt['role'],
|
||||
self.roles.get(
|
||||
self.roles[role_prompt['role']].get('fallback_role')))
|
||||
merged_prompt = self.roles[role_prompt['role']]
|
||||
if merged_prompt.get('fallback_role'):
|
||||
merged_prompt = self.roles[self.roles[
|
||||
merged_prompt['fallback_role']]]
|
||||
res = {}
|
||||
res['role'] = merged_prompt['api_role']
|
||||
res['content'] = merged_prompt.get('begin', '')
|
||||
@@ -213,6 +130,60 @@ class APITemplateParser:
|
||||
return res
|
||||
|
||||
|
||||
class BaseAPIModel(BaseModel):
|
||||
"""Base class for API model wrapper.
|
||||
|
||||
Args:
|
||||
model_type (str): The type of model.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
model_type: str,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
template_parser: 'APITemplateParser' = APITemplateParser,
|
||||
meta_template: Optional[Dict] = None,
|
||||
*,
|
||||
max_new_tokens: int = 512,
|
||||
top_p: float = 0.8,
|
||||
top_k: float = None,
|
||||
temperature: float = 0.8,
|
||||
repetition_penalty: float = 0.0,
|
||||
stop_words: Union[List[str], str] = None):
|
||||
self.model_type = model_type
|
||||
self.meta_template = meta_template
|
||||
self.retry = retry
|
||||
self.query_per_second = query_per_second
|
||||
self.token_bucket = TokenBucket(query_per_second)
|
||||
if template_parser:
|
||||
self.template_parser = template_parser(meta_template)
|
||||
|
||||
if isinstance(stop_words, str):
|
||||
stop_words = [stop_words]
|
||||
self.gen_params = dict(
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
stop_words=stop_words)
|
||||
|
||||
def _wait(self):
|
||||
"""Wait till the next query can be sent.
|
||||
|
||||
Applicable in both single-thread and multi-thread environments.
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""A token bucket for rate limiting.
|
||||
|
||||
|
||||
@@ -1,85 +1,17 @@
|
||||
from abc import abstractclassmethod
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class BaseModel:
|
||||
"""Base class for model wrapper.
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
max_seq_len (int): The maximum sequence length of the model. Defaults
|
||||
to 2048.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
"""
|
||||
|
||||
is_api: bool = False
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_only: bool = False,
|
||||
meta_template: Optional[Dict] = None):
|
||||
self.path = path
|
||||
self.max_seq_len = max_seq_len
|
||||
self.tokenizer_only = tokenizer_only
|
||||
# meta template
|
||||
self.template_parser = LMTemplateParser(meta_template)
|
||||
self.eos_token_id = None
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of strings.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
|
||||
def parse_template(self, dialog) -> str:
|
||||
"""Parse a prompt template, and wrap it with meta template if
|
||||
applicable.
|
||||
|
||||
Args:
|
||||
dialog (List[str or PromptList]): A prompt
|
||||
template (potentially before being wrapped by meta template).
|
||||
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
||||
|
||||
Returns:
|
||||
str: The final string.
|
||||
"""
|
||||
return self.template_parser.parse_template(dialog)
|
||||
|
||||
def generate_from_template(self, templates, max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
"""
|
||||
inputs = self.parse_template(templates)
|
||||
return self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
from warnings import warn
|
||||
|
||||
|
||||
class LMTemplateParser:
|
||||
"""Intermidate prompt template parser, specifically for language models.
|
||||
|
||||
Args:
|
||||
meta_template (Dict): The meta template for the model.
|
||||
meta_template (list of dict, optional): The meta template for the
|
||||
model.
|
||||
"""
|
||||
|
||||
def __init__(self, meta_template: Optional[Dict] = None):
|
||||
def __init__(self, meta_template: Optional[List[Dict]] = None):
|
||||
self.meta_template = meta_template
|
||||
if meta_template:
|
||||
assert isinstance(meta_template, list)
|
||||
@@ -90,14 +22,13 @@ class LMTemplateParser:
|
||||
'role in meta prompt must be unique!'
|
||||
self.roles[item['role']] = item.copy()
|
||||
|
||||
def parse_template(self, dialog) -> str:
|
||||
def __call__(self, dialog) -> str:
|
||||
"""Parse a prompt template, and wrap it with meta template if
|
||||
applicable.
|
||||
|
||||
Args:
|
||||
dialog (List[str or PromptList]): A prompt
|
||||
template (potentially before being wrapped by meta template).
|
||||
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
||||
|
||||
Returns:
|
||||
str: The final string.
|
||||
@@ -127,20 +58,171 @@ class LMTemplateParser:
|
||||
last_sep = '\n'
|
||||
return prompt
|
||||
|
||||
def _format_begin(self, role_cfg, message):
|
||||
name = message.get('name', None)
|
||||
if name is not None:
|
||||
begin = role_cfg['begin'].get('with_name', '')
|
||||
if name in role_cfg['begin'].get('name', {}):
|
||||
begin = begin.format(name=role_cfg['begin']['name'][name])
|
||||
else:
|
||||
begin = begin.format(name=name)
|
||||
else:
|
||||
if isinstance(role_cfg.get('begin', ''), str):
|
||||
begin = role_cfg.get('begin', '')
|
||||
elif isinstance(role_cfg['begin'], dict):
|
||||
begin = role_cfg['begin'].get('without_name', '')
|
||||
return begin
|
||||
|
||||
def _prompt2str(self,
|
||||
prompt: Union[List, str, Dict],
|
||||
prompt: Union[str, Dict],
|
||||
last: bool = False) -> Tuple[str, bool]:
|
||||
if isinstance(prompt, str):
|
||||
return prompt
|
||||
merged_prompt = self.roles.get(
|
||||
prompt['role'],
|
||||
self.roles.get(self.roles[prompt['role']].get('fallback_role')))
|
||||
res = merged_prompt.get('begin', '')
|
||||
merged_prompt = self.roles.get(prompt['role'])
|
||||
|
||||
if merged_prompt.get('fallback_role'):
|
||||
merged_prompt = self.roles.get(merged_prompt['fallback_role'])
|
||||
begin = self._format_begin(merged_prompt, prompt)
|
||||
res = begin
|
||||
if last and merged_prompt.get('generate', False):
|
||||
res += prompt.get('content', '')
|
||||
return res
|
||||
res += prompt.get('content', '') + merged_prompt.get('end', '')
|
||||
if last and merged_prompt['role'] != 'assistant':
|
||||
res += self.roles['assistant']['begin']
|
||||
res += self._format_begin(self.roles['assistant'], {})
|
||||
return res
|
||||
return res
|
||||
|
||||
|
||||
class BaseModel:
|
||||
"""Base class for model wrapper.
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
|
||||
to 512.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
meta_template (list of dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
"""
|
||||
|
||||
is_api: bool = False
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
tokenizer_only: bool = False,
|
||||
template_parser: 'LMTemplateParser' = LMTemplateParser,
|
||||
meta_template: Optional[List[Dict]] = None,
|
||||
*,
|
||||
max_new_tokens: int = 512,
|
||||
top_p: float = 0.8,
|
||||
top_k: float = None,
|
||||
temperature: float = 0.8,
|
||||
repetition_penalty: float = 1.0,
|
||||
stop_words: Union[List[str], str] = None):
|
||||
self.path = path
|
||||
self.tokenizer_only = tokenizer_only
|
||||
# meta template
|
||||
self.template_parser = template_parser(meta_template)
|
||||
self.eos_token_id = None
|
||||
if meta_template and 'eos_token_id' in meta_template:
|
||||
self.eos_token_id = meta_template['eos_token_id']
|
||||
|
||||
if isinstance(stop_words, str):
|
||||
stop_words = [stop_words]
|
||||
self.gen_params = dict(
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
stop_words=stop_words)
|
||||
|
||||
def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
|
||||
"""Generate results given a str (or list of) inputs.
|
||||
|
||||
Args:
|
||||
inputs (Union[str, List[str]]):
|
||||
gen_params (dict): The input params for generation.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: A (list of) generated strings.
|
||||
|
||||
eg.
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
response = ['']
|
||||
if batched:
|
||||
return response
|
||||
return response[0]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def stream_generate(self, inputs: str, **gen_params) -> List[str]:
|
||||
"""Generate results as streaming given a str inputs.
|
||||
|
||||
Args:
|
||||
inputs (str):
|
||||
gen_params (dict): The input params for generation.
|
||||
|
||||
Returns:
|
||||
str: A generated string.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params):
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
inputs (Union[List[dict], List[List[dict]]]):
|
||||
gen_params (dict): The input params for generation.
|
||||
Returns:
|
||||
"""
|
||||
if isinstance(inputs[0], list):
|
||||
_inputs = list()
|
||||
for msg in inputs:
|
||||
_inputs.append(self.template_parser(msg))
|
||||
else:
|
||||
_inputs = self.template_parser(inputs)
|
||||
return self.generate(_inputs, **gen_params)
|
||||
|
||||
def generate_from_template(self, inputs: Union[List[dict],
|
||||
List[List[dict]]],
|
||||
**gen_params):
|
||||
warn(
|
||||
'This function will be deprecated after three months'
|
||||
'and will be replaced. Please use `.chat()`', DeprecationWarning,
|
||||
2)
|
||||
return self.chat(inputs, **gen_params)
|
||||
|
||||
def stream_chat(self, inputs: List[dict], **gen_params):
|
||||
"""Generate results as streaming given a list of templates.
|
||||
|
||||
Args:
|
||||
inputs (Union[List[dict]):
|
||||
gen_params (dict): The input params for generation.
|
||||
Returns:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize(self, prompts: Union[str, List[str], List[dict],
|
||||
List[List[dict]]]):
|
||||
"""Tokenize the input prompts.
|
||||
|
||||
Args:
|
||||
prompts(str | List[str]): user's prompt, or a batch prompts
|
||||
|
||||
Returns:
|
||||
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
|
||||
ids, ids' length and requested output length
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_gen_params(self, **kwargs):
|
||||
gen_params = copy(self.gen_params)
|
||||
gen_params.update(kwargs)
|
||||
return gen_params
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from lagent.schema import ModelStatusCode
|
||||
from .base_api import APITemplateParser
|
||||
from .base_llm import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFTransformer(BaseModel):
|
||||
"""Model wrapper around HuggingFace general models.
|
||||
|
||||
Adapted from OpenCompass (https://github.com/InternLM/opencompass
|
||||
/blob/main/opencompass/models/huggingface.py)
|
||||
Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/
|
||||
chat/web_demo.py)
|
||||
|
||||
Args:
|
||||
path (str): The name or path to HuggingFace's model.
|
||||
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||
to 2048.
|
||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||
Defaults to {}.
|
||||
@@ -25,55 +29,50 @@ class HFTransformer(BaseModel):
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
extract_pred_after_decode (bool): Whether to extract the prediction
|
||||
string from the decoded output string, instead of extract the
|
||||
prediction tokens before decoding. Defaults to False.
|
||||
batch_padding (bool): If False, inference with be performed in for-loop
|
||||
without batch padding.
|
||||
|
||||
Note:
|
||||
About ``extract_pred_after_decode``: Commonly, we should extract the
|
||||
the prediction tokens before decoding. But for some tokenizers using
|
||||
``sentencepiece``, like LLaMA, this behavior may change the number of
|
||||
whitespaces, which is harmful for Python programming tasks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
max_seq_len: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = [
|
||||
dict(
|
||||
role='system',
|
||||
begin='<|System|>:',
|
||||
end='<TOKENS_UNUSED_2>\n'),
|
||||
dict(role='user', begin='<|User|>:', end='<eoh>\n'),
|
||||
dict(
|
||||
role='assistant',
|
||||
begin='<|Bot|>:',
|
||||
end='<eoa>\n',
|
||||
generate=True)
|
||||
], # default meta template for InternLM-7b
|
||||
extract_pred_after_decode: bool = False,
|
||||
batch_padding: bool = False):
|
||||
def __init__(self,
|
||||
path: str,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = None,
|
||||
stop_words_id: Union[List[int], int] = None,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
tokenizer_only=tokenizer_only,
|
||||
meta_template=meta_template)
|
||||
meta_template=meta_template,
|
||||
**kwargs)
|
||||
if isinstance(stop_words_id, int):
|
||||
stop_words_id = [stop_words_id]
|
||||
self.gen_params.update(stop_words_id=stop_words_id)
|
||||
if self.gen_params['stop_words'] is not None and \
|
||||
self.gen_params['stop_words_id'] is not None:
|
||||
logger.warning('Both stop_words and stop_words_id are specified,'
|
||||
'only stop_words_id will be used.')
|
||||
|
||||
self._load_tokenizer(
|
||||
path=path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
tokenizer_kwargs=tokenizer_kwargs)
|
||||
self.batch_padding = batch_padding
|
||||
self.extract_pred_after_decode = extract_pred_after_decode
|
||||
if not tokenizer_only:
|
||||
self._load_model(path=path, model_kwargs=model_kwargs)
|
||||
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501
|
||||
self.logits_processor = LogitsProcessorList()
|
||||
self.stopping_criteria = StoppingCriteriaList()
|
||||
self.prefix_allowed_tokens_fn = None
|
||||
|
||||
stop_words_id = []
|
||||
if self.gen_params.get('stop_words_id'):
|
||||
stop_words_id = self.gen_params.get('stop_words_id')
|
||||
elif self.gen_params.get('stop_words'):
|
||||
for sw in self.gen_params.get('stop_words'):
|
||||
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
|
||||
self.additional_eos_token_id = stop_words_id
|
||||
|
||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
||||
tokenizer_kwargs: dict):
|
||||
from transformers import AutoTokenizer
|
||||
@@ -81,60 +80,260 @@ class HFTransformer(BaseModel):
|
||||
tokenizer_path if tokenizer_path else path,
|
||||
trust_remote_code=True,
|
||||
**tokenizer_kwargs)
|
||||
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
if self.tokenizer.eos_token is not None:
|
||||
logger.warning(
|
||||
f'Using eos_token_id {self.tokenizer.eos_token} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
else:
|
||||
from transformers.generation import GenerationConfig
|
||||
self.gcfg = GenerationConfig.from_pretrained(path)
|
||||
|
||||
if self.gcfg.pad_token_id is not None:
|
||||
logger.warning(
|
||||
f'Using pad_token_id {self.gcfg.pad_token_id} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
'pad_token_id is not set for this tokenizer. Try to '
|
||||
'set pad_token_id via passing '
|
||||
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
path, trust_remote_code=True, **model_kwargs)
|
||||
self.model.eval()
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int,
|
||||
**kwargs) -> List[str]:
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
if self.extract_pred_after_decode:
|
||||
prompt_lens = [len(input_) for input_ in inputs]
|
||||
def tokenize(self, inputs: str):
|
||||
assert isinstance(inputs, str)
|
||||
inputs = self.tokenizer(
|
||||
inputs, return_tensors='pt', return_length=True)
|
||||
return inputs['input_ids'].tolist()
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
inputs, truncation=True,
|
||||
max_length=self.max_seq_len - max_out_len)['input_ids']
|
||||
input_ids = torch.tensor(input_ids, device=self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids, max_new_tokens=max_out_len, **kwargs)
|
||||
|
||||
if not self.extract_pred_after_decode:
|
||||
outputs = outputs[:, input_ids.shape[1]:]
|
||||
|
||||
decodeds = self.tokenizer.batch_decode(
|
||||
outputs, skip_special_tokens=True)
|
||||
if self.extract_pred_after_decode:
|
||||
decodeds = [
|
||||
token[len_:] for token, len_ in zip(decodeds, prompt_lens)
|
||||
]
|
||||
|
||||
return decodeds[0]
|
||||
|
||||
def generate_from_template(self, templates, max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
def generate(
|
||||
self,
|
||||
inputs: Union[str, List[str]],
|
||||
do_sample: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Return the chat completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
inputs (Union[str, List[str]]): input texts to be completed.
|
||||
do_sample (bool): do sampling if enabled
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
inputs = self.parse_template(templates)
|
||||
response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
return response.replace(
|
||||
self.template_parser.roles['assistant']['end'].strip(),
|
||||
'').strip()
|
||||
for status, chunk, _ in self.stream_generate(inputs, do_sample,
|
||||
**kwargs):
|
||||
response = chunk
|
||||
return response
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
inputs: List[str],
|
||||
do_sample: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Return the chat completions in stream mode.
|
||||
|
||||
Args:
|
||||
inputs (Union[str, List[str]]): input texts to be completed.
|
||||
do_sample (bool): do sampling if enabled
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
with torch.no_grad():
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
inputs = self.tokenizer(
|
||||
inputs, padding=True, return_tensors='pt', return_length=True)
|
||||
input_length = inputs['length']
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
input_ids = inputs['input_ids']
|
||||
attention_mask = inputs['attention_mask']
|
||||
batch_size = input_ids.shape[0]
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
generation_config = self.model.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
new_gen_params = self.update_gen_params(**kwargs)
|
||||
generation_config.update(**new_gen_params)
|
||||
generation_config.update(**kwargs)
|
||||
model_kwargs = generation_config.to_dict()
|
||||
model_kwargs['attention_mask'] = attention_mask
|
||||
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
if eos_token_id is None:
|
||||
if self.gcfg.eos_token_id is not None:
|
||||
eos_token_id = self.gcfg.eos_token_id
|
||||
else:
|
||||
eos_token_id = []
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if self.additional_eos_token_id is not None:
|
||||
eos_token_id.extend(self.additional_eos_token_id)
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(
|
||||
input_ids.device) if eos_token_id is not None else None
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length)
|
||||
# Set generation parameters if not already defined
|
||||
logits_processor = self.logits_processor
|
||||
stopping_criteria = self.stopping_criteria
|
||||
|
||||
logits_processor = self.model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = self.model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria)
|
||||
logits_warper = self.model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(batch_size).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = self.model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids,
|
||||
next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if do_sample:
|
||||
next_tokens = torch.multinomial(
|
||||
probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
|
||||
# update generated ids, model inputs,
|
||||
# and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]],
|
||||
dim=-1)
|
||||
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
|
||||
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
|
||||
output_token_ids = input_ids.cpu().tolist()
|
||||
for i in range(len(output_token_ids)):
|
||||
output_token_ids[i] = output_token_ids[i][:][
|
||||
input_length[i]:]
|
||||
# Find the first occurrence of
|
||||
# an EOS token in the sequence
|
||||
first_eos_idx = next(
|
||||
(idx
|
||||
for idx, token_id in enumerate(output_token_ids[i])
|
||||
if token_id in eos_token_id), None)
|
||||
# If an EOS token is found, only the previous
|
||||
# part of it is retained
|
||||
if first_eos_idx is not None:
|
||||
output_token_ids[i] = output_token_ids[
|
||||
i][:first_eos_idx]
|
||||
|
||||
response = self.tokenizer.batch_decode(output_token_ids)
|
||||
# print(response)
|
||||
if not batched:
|
||||
response = response[0]
|
||||
yield ModelStatusCode.STREAM_ING, response, None
|
||||
# stop when each sentence is finished,
|
||||
# or if we exceed the maximum length
|
||||
if (unfinished_sequences.max() == 0
|
||||
or stopping_criteria(input_ids, scores)):
|
||||
break
|
||||
yield ModelStatusCode.END, response, None
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
inputs: List[dict],
|
||||
do_sample: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Return the chat completions in stream mode.
|
||||
|
||||
Args:
|
||||
inputs (List[dict]): input messages to be completed.
|
||||
do_sample (bool): do sampling if enabled
|
||||
Returns:
|
||||
the text/chat completion
|
||||
"""
|
||||
prompt = self.template_parser(inputs)
|
||||
yield from self.stream_generate(prompt, do_sample, **kwargs)
|
||||
|
||||
|
||||
class HFTransformerCasualLM(HFTransformer):
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
model_kwargs.setdefault('torch_dtype', torch.float16)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path, trust_remote_code=True, **model_kwargs)
|
||||
self.model.eval()
|
||||
|
||||
|
||||
class HFTransformerChat(HFTransformerCasualLM):
|
||||
|
||||
def __init__(self, template_parser=APITemplateParser, **kwargs):
|
||||
super().__init__(template_parser=template_parser, **kwargs)
|
||||
|
||||
def chat(self,
|
||||
inputs: Union[List[dict], List[List[dict]]],
|
||||
do_sample: bool = True,
|
||||
**kwargs):
|
||||
"""Return the chat completions in stream mode.
|
||||
|
||||
Args:
|
||||
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
|
||||
do_sample (bool): do sampling if enabled
|
||||
Returns:
|
||||
the text/chat completion
|
||||
"""
|
||||
# handle batch inference with vanilla for loop
|
||||
if isinstance(inputs[0], list):
|
||||
resps = []
|
||||
for input in inputs:
|
||||
resps.append(self.chat(input, do_sample, **kwargs))
|
||||
return resps
|
||||
prompt = self.template_parser(inputs)
|
||||
query = prompt[-1]['content']
|
||||
history = prompt[:-1]
|
||||
try:
|
||||
response, history = self.model.chat(
|
||||
self.tokenizer, query, history=history)
|
||||
except Exception as e:
|
||||
# handle over-length input error
|
||||
logger.warning(str(e))
|
||||
response = ''
|
||||
return response
|
||||
|
||||
452
lagent/llms/lmdepoly_wrapper.py
Normal file
452
lagent/llms/lmdepoly_wrapper.py
Normal file
@@ -0,0 +1,452 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.schema import ModelStatusCode
|
||||
from lagent.utils.util import filter_suffix
|
||||
|
||||
|
||||
class TritonClient(BaseModel):
|
||||
"""TritonClient is a wrapper of TritonClient for LLM.
|
||||
|
||||
Args:
|
||||
tritonserver_addr (str): the address in format "ip:port" of
|
||||
triton inference server
|
||||
model_name (str): the name of the model
|
||||
session_len (int): the context size
|
||||
max_tokens (int): the expected generated token numbers
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tritonserver_addr: str,
|
||||
model_name: str,
|
||||
session_len: int = 32768,
|
||||
log_level: str = 'WARNING',
|
||||
**kwargs):
|
||||
super().__init__(path=None, **kwargs)
|
||||
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
|
||||
self.state_map = {
|
||||
StatusCode.TRITON_STREAM_END: ModelStatusCode.END,
|
||||
StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR,
|
||||
StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED,
|
||||
StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING,
|
||||
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
|
||||
ModelStatusCode.SESSION_OUT_OF_LIMIT,
|
||||
StatusCode.TRITON_SESSION_INVALID_ARG:
|
||||
ModelStatusCode.SESSION_INVALID_ARG,
|
||||
StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY
|
||||
}
|
||||
self.chatbot = Chatbot(
|
||||
tritonserver_addr=tritonserver_addr,
|
||||
model_name=model_name,
|
||||
session_len=session_len,
|
||||
log_level=log_level,
|
||||
**kwargs)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
session_id: int = 2967,
|
||||
request_id: str = '',
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
inputs (str, List[str]): user's prompt(s) in this round
|
||||
session_id (int): the identical id of a session
|
||||
request_id (str): the identical id of this round conversation
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
prompt = inputs
|
||||
|
||||
assert isinstance(session_id, int), \
|
||||
f'INT session id is required, but got {type(session_id)}'
|
||||
|
||||
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
||||
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
||||
|
||||
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_new_tokens}')
|
||||
|
||||
if self.chatbot._session is None:
|
||||
sequence_start = True
|
||||
self.chatbot._session = Session(session_id=session_id)
|
||||
elif self.chatbot._session.status == 0:
|
||||
logger.error(f'session {session_id} has been ended. Please set '
|
||||
f'`sequence_start` be True if you want to restart it')
|
||||
return ''
|
||||
|
||||
self.chatbot._session.status = 1
|
||||
self.chatbot._session.request_id = request_id
|
||||
self.chatbot._session.response = ''
|
||||
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
sequence_start,
|
||||
sequence_end,
|
||||
skip_special_tokens=skip_special_tokens):
|
||||
status = self.state_map.get(status)
|
||||
if status < ModelStatusCode.END:
|
||||
return ''
|
||||
elif status == ModelStatusCode.END:
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt +
|
||||
self.chatbot._session.response)
|
||||
# remove stop_words
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
return res
|
||||
|
||||
def stream_chat(self,
|
||||
inputs: List[dict],
|
||||
session_id: int = 2967,
|
||||
request_id: str = '',
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in stream mode.
|
||||
|
||||
Args:
|
||||
session_id (int): the identical id of a session
|
||||
inputs (List[dict]): user's inputs in this round conversation
|
||||
request_id (str): the identical id of this round conversation
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
"""
|
||||
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
|
||||
assert isinstance(session_id, int), \
|
||||
f'INT session id is required, but got {type(session_id)}'
|
||||
|
||||
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
||||
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
||||
|
||||
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_new_tokens}')
|
||||
|
||||
if self.chatbot._session is None:
|
||||
sequence_start = True
|
||||
self.chatbot._session = Session(session_id=session_id)
|
||||
elif self.chatbot._session.status == 0:
|
||||
logger.error(f'session {session_id} has been ended. Please set '
|
||||
f'`sequence_start` be True if you want to restart it')
|
||||
return ModelStatusCode.SESSION_CLOSED, '', 0
|
||||
|
||||
self.chatbot._session.status = 1
|
||||
self.chatbot._session.request_id = request_id
|
||||
self.chatbot._session.response = ''
|
||||
|
||||
prompt = self.template_parser(inputs)
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
sequence_start,
|
||||
sequence_end,
|
||||
skip_special_tokens=skip_special_tokens):
|
||||
status = self.state_map.get(status)
|
||||
# The stop symbol also appears in the output of the last STREAM_ING state.
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
if status < ModelStatusCode.END:
|
||||
return status, res, _
|
||||
elif status == ModelStatusCode.END: # remove stop_words
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt +
|
||||
self.chatbot._session.response)
|
||||
yield status, res, _
|
||||
break
|
||||
else:
|
||||
yield status, res, _
|
||||
|
||||
def _update_gen_params(self, **kwargs):
|
||||
import mmengine
|
||||
new_gen_params = self.update_gen_params(**kwargs)
|
||||
self.gen_params['stop_words'] = new_gen_params.pop('stop_words')
|
||||
stop_words = self.chatbot._stop_words(
|
||||
self.gen_params.get('stop_words'))
|
||||
cfg = mmengine.Config(
|
||||
dict(
|
||||
session_len=self.chatbot.model.session_len,
|
||||
stop_words=stop_words,
|
||||
bad_words=self.chatbot.cfg.bad_words,
|
||||
**new_gen_params))
|
||||
return cfg
|
||||
|
||||
|
||||
class LMDeployPipeline(BaseModel):
|
||||
"""
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
model_name (str): needed when model_path is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
||||
tp (int): tensor parallel
|
||||
pipeline_cfg (dict): config of pipeline
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
model_name: Optional[str] = None,
|
||||
tp: int = 1,
|
||||
pipeline_cfg=dict(),
|
||||
**kwargs):
|
||||
|
||||
super().__init__(path=path, **kwargs)
|
||||
from lmdeploy import pipeline
|
||||
self.model = pipeline(
|
||||
model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
do_preprocess: bool = None,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Return the chat completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
inputs (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
from lmdeploy.messages import GenerationConfig
|
||||
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
prompt = inputs
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens, **gen_params)
|
||||
response = self.model.batch_infer(
|
||||
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
|
||||
response = [resp.text for resp in response]
|
||||
# remove stop_words
|
||||
response = filter_suffix(response, self.gen_params.get('stop_words'))
|
||||
if batched:
|
||||
return response
|
||||
return response[0]
|
||||
|
||||
|
||||
class LMDeployServer(BaseModel):
|
||||
"""
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download from
|
||||
ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
model_name (str): needed when model_path is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
||||
server_name (str): host ip for serving
|
||||
server_port (int): server port
|
||||
tp (int): tensor parallel
|
||||
log_level (str): set log level whose value among
|
||||
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
model_name: Optional[str] = None,
|
||||
server_name: str = '0.0.0.0',
|
||||
server_port: int = 23333,
|
||||
tp: int = 1,
|
||||
log_level: str = 'WARNING',
|
||||
serve_cfg=dict(),
|
||||
**kwargs):
|
||||
super().__init__(path=path, **kwargs)
|
||||
self.model_name = model_name
|
||||
# TODO get_logger issue in multi processing
|
||||
import lmdeploy
|
||||
self.client = lmdeploy.serve(
|
||||
model_path=self.path,
|
||||
model_name=model_name,
|
||||
server_name=server_name,
|
||||
server_port=server_port,
|
||||
tp=tp,
|
||||
log_level=log_level,
|
||||
**serve_cfg)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
session_id: int = 2967,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: Optional[bool] = False,
|
||||
timeout: int = 30,
|
||||
**kwargs) -> List[str]:
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
inputs (str, List[str]): user's prompt(s) in this round
|
||||
session_id (int): the identical id of a session
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
ignore_eos (bool): indicator for ignoring eos
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
timeout (int): max time to wait for response
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
max_new_tokens = gen_params.pop('max_new_tokens')
|
||||
gen_params.update(max_tokens=max_new_tokens)
|
||||
|
||||
resp = [''] * len(inputs)
|
||||
for text in self.client.completions_v1(
|
||||
self.model_name,
|
||||
inputs,
|
||||
session_id=session_id,
|
||||
sequence_start=sequence_start,
|
||||
sequence_end=sequence_end,
|
||||
stream=False,
|
||||
ignore_eos=ignore_eos,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp = [
|
||||
resp[i] + item['text']
|
||||
for i, item in enumerate(text['choices'])
|
||||
]
|
||||
# remove stop_words
|
||||
resp = filter_suffix(resp, self.gen_params.get('stop_words'))
|
||||
if not batched:
|
||||
return resp[0]
|
||||
return resp
|
||||
|
||||
def stream_chat(self,
|
||||
inputs: List[dict],
|
||||
session_id=0,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
stream: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: Optional[bool] = False,
|
||||
timeout: int = 30,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in stream mode.
|
||||
|
||||
Args:
|
||||
session_id (int): the identical id of a session
|
||||
inputs (List[dict]): user's inputs in this round conversation
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
stream (bool): return in a streaming format if enabled
|
||||
ignore_eos (bool): indicator for ignoring eos
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
timeout (int): max time to wait for response
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
"""
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
max_new_tokens = gen_params.pop('max_new_tokens')
|
||||
gen_params.update(max_tokens=max_new_tokens)
|
||||
prompt = self.template_parser(inputs)
|
||||
|
||||
resp = ''
|
||||
finished = False
|
||||
stop_words = self.gen_params.get('stop_words')
|
||||
for text in self.client.completions_v1(
|
||||
self.model_name,
|
||||
prompt,
|
||||
session_id=session_id,
|
||||
sequence_start=sequence_start,
|
||||
sequence_end=sequence_end,
|
||||
stream=stream,
|
||||
ignore_eos=ignore_eos,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp += text['choices'][0]['text']
|
||||
if not resp:
|
||||
continue
|
||||
# remove stop_words
|
||||
for sw in stop_words:
|
||||
if sw in resp:
|
||||
resp = filter_suffix(resp, stop_words)
|
||||
finished = True
|
||||
break
|
||||
yield ModelStatusCode.STREAM_ING, resp, None
|
||||
if finished:
|
||||
break
|
||||
yield ModelStatusCode.END, resp, None
|
||||
|
||||
|
||||
class LMDeployClient(LMDeployServer):
|
||||
"""
|
||||
|
||||
Args:
|
||||
url (str): communicating address 'http://<ip>:<port>' of
|
||||
api_server
|
||||
model_name (str): needed when model_path is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, model_name: str, **kwargs):
|
||||
BaseModel.__init__(self, path=url, **kwargs)
|
||||
from lmdeploy.serve.openai.api_client import APIClient
|
||||
self.client = APIClient(url)
|
||||
self.model_name = model_name
|
||||
40
lagent/llms/meta_template.py
Normal file
40
lagent/llms/meta_template.py
Normal file
@@ -0,0 +1,40 @@
|
||||
INTERNLM2_META = [
|
||||
dict(
|
||||
role='system',
|
||||
begin=dict(
|
||||
with_name='<|im_start|>system name={name}\n',
|
||||
without_name='<|im_start|>system\n',
|
||||
name={
|
||||
'interpreter': '<|interpreter|>',
|
||||
'plugin': '<|plugin|>',
|
||||
}),
|
||||
end='<|im_end|>\n',
|
||||
),
|
||||
dict(
|
||||
role='user',
|
||||
begin=dict(
|
||||
with_name='<|im_start|>user name={name}\n',
|
||||
without_name='<|im_start|>user\n',
|
||||
),
|
||||
end='<|im_end|>\n'),
|
||||
dict(
|
||||
role='assistant',
|
||||
begin=dict(
|
||||
with_name='<|im_start|>assistant name={name}\n',
|
||||
without_name='<|im_start|>assistant\n',
|
||||
name={
|
||||
'interpreter': '<|interpreter|>',
|
||||
'plugin': '<|plugin|>',
|
||||
}),
|
||||
end='<|im_end|>\n'),
|
||||
dict(
|
||||
role='environment',
|
||||
begin=dict(
|
||||
with_name='<|im_start|>environment name={name}\n',
|
||||
without_name='<|im_start|>environment\n',
|
||||
name={
|
||||
'interpreter': '<|interpreter|>',
|
||||
'plugin': '<|plugin|>',
|
||||
}),
|
||||
end='<|im_end|>\n'),
|
||||
]
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from logging import getLogger
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -18,9 +18,6 @@ class GPTAPI(BaseAPIModel):
|
||||
|
||||
Args:
|
||||
model_type (str): The name of OpenAI's model.
|
||||
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||
Note that the length of prompt + generated tokens shall not exceed
|
||||
this value. Defaults to 2048.
|
||||
query_per_second (int): The maximum queries allowed per second
|
||||
between two consecutive calls of the API. Defaults to 1.
|
||||
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||
@@ -38,16 +35,14 @@ class GPTAPI(BaseAPIModel):
|
||||
wrapping of any meta instructions.
|
||||
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
||||
'https://api.openai.com/v1/chat/completions'.
|
||||
temperature (float, optional): What sampling temperature to use.
|
||||
If not None, will override the temperature in the `generate()`
|
||||
call. Defaults to None.
|
||||
gen_params: Default generation configuration which could be overridden
|
||||
on the fly of generation.
|
||||
"""
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
model_type: str = 'gpt-3.5-turbo',
|
||||
max_seq_len: int = 4096,
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
key: Union[str, List[str]] = 'ENV',
|
||||
@@ -58,16 +53,14 @@ class GPTAPI(BaseAPIModel):
|
||||
dict(role='assistant', api_role='assistant')
|
||||
],
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
temperature: Optional[float] = None):
|
||||
|
||||
**gen_params):
|
||||
super().__init__(
|
||||
model_type=model_type,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
query_per_second=query_per_second,
|
||||
retry=retry)
|
||||
retry=retry,
|
||||
**gen_params)
|
||||
self.logger = getLogger(__name__)
|
||||
self.temperature = temperature
|
||||
|
||||
if isinstance(key, str):
|
||||
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
||||
@@ -97,67 +90,57 @@ class GPTAPI(BaseAPIModel):
|
||||
context_window = 8192
|
||||
self.context_window = context_window
|
||||
|
||||
def generate(
|
||||
def chat(
|
||||
self,
|
||||
inputs: Union[List, str],
|
||||
max_out_len: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
inputs: Union[List[dict], List[List[dict]]],
|
||||
**gen_params,
|
||||
) -> Union[str, List[str]]:
|
||||
"""Generate responses given the contexts.
|
||||
|
||||
Args:
|
||||
inputs (List[str or List]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic. Defaults to 0.7.
|
||||
inputs (Union[List[dict], List[List[dict]]]): a list of messages
|
||||
or list of lists of messages
|
||||
gen_params: additional generation configuration
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
Union[str, List[str]]: generated string(s)
|
||||
"""
|
||||
if self.temperature is not None:
|
||||
temperature = self.temperature
|
||||
return self._generate(inputs, max_out_len, temperature)
|
||||
assert isinstance(inputs, list)
|
||||
if 'max_tokens' in gen_params:
|
||||
raise NotImplementedError('unsupported parameter: max_tokens')
|
||||
gen_params = {**self.gen_params, **gen_params}
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
tasks = [
|
||||
executor.submit(self._chat, messages, **gen_params)
|
||||
for messages in (
|
||||
[inputs] if isinstance(inputs[0], dict) else inputs)
|
||||
]
|
||||
ret = [task.result() for task in tasks]
|
||||
return ret[0] if isinstance(inputs[0], dict) else ret
|
||||
|
||||
def _generate(self, input: str or List, max_out_len: int,
|
||||
temperature: float) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
def _chat(self, messages: List[dict], **gen_params) -> str:
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
inputs (str or List): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic.
|
||||
messages (List[dict]): a list of prompt dictionaries
|
||||
gen_params: additional generation configuration
|
||||
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert isinstance(input, (str, list, dict))
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
elif isinstance(input, dict):
|
||||
messages = [input]
|
||||
else:
|
||||
messages = input
|
||||
assert isinstance(messages, list)
|
||||
gen_params = gen_params.copy()
|
||||
|
||||
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
||||
max_out_len = min(
|
||||
max_out_len,
|
||||
self.context_window - self.get_token_len(str(input)) - 100)
|
||||
if max_out_len <= 0:
|
||||
max_tokens = min(
|
||||
gen_params.pop('max_new_tokens'),
|
||||
self.context_window - len(self.tokenize(str(input))) - 100)
|
||||
if max_tokens <= 0:
|
||||
return ''
|
||||
|
||||
max_num_retries = 0
|
||||
while max_num_retries < self.retry:
|
||||
self.wait()
|
||||
self._wait()
|
||||
|
||||
with Lock():
|
||||
if len(self.invalid_keys) == len(self.keys):
|
||||
@@ -190,10 +173,11 @@ class GPTAPI(BaseAPIModel):
|
||||
data = dict(
|
||||
model=self.model_type,
|
||||
messages=messages,
|
||||
max_tokens=max_out_len,
|
||||
max_tokens=max_tokens,
|
||||
n=1,
|
||||
stop=None,
|
||||
temperature=temperature,
|
||||
stop=gen_params.pop('stop_words'),
|
||||
frequency_penalty=gen_params.pop('repetition_penalty'),
|
||||
**gen_params,
|
||||
)
|
||||
raw_response = requests.post(
|
||||
self.url, headers=header, data=json.dumps(data))
|
||||
@@ -225,18 +209,16 @@ class GPTAPI(BaseAPIModel):
|
||||
f'{max_num_retries} times. Check the logs for '
|
||||
'details.')
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized string. Only English and Chinese
|
||||
characters are counted for now. Users are encouraged to override this
|
||||
method if more accurate length is needed.
|
||||
def tokenize(self, prompt: str) -> list:
|
||||
"""Tokenize the input prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
list: token ids
|
||||
"""
|
||||
import tiktoken
|
||||
self.tiktoken = tiktoken
|
||||
enc = self.tiktoken.encoding_for_model(self.model_type)
|
||||
return len(enc.encode(prompt))
|
||||
return enc.encode(prompt)
|
||||
|
||||
71
lagent/llms/vllm_wrapper.py
Normal file
71
lagent/llms/vllm_wrapper.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import List, Union
|
||||
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.utils.util import filter_suffix
|
||||
|
||||
|
||||
class VllmModel(BaseModel):
|
||||
"""
|
||||
A wrapper of vLLM model.
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a huggingface model.
|
||||
- ii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
tp (int): tensor parallel
|
||||
vllm_cfg (dict): Other kwargs for vllm model initialization.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
|
||||
|
||||
super().__init__(path=path, **kwargs)
|
||||
from vllm import LLM
|
||||
self.model = LLM(
|
||||
model=self.path,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp,
|
||||
**vllm_cfg)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
do_preprocess: bool = None,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Return the chat completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
inputs (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
from vllm import SamplingParams
|
||||
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
prompt = inputs
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
max_new_tokens = gen_params.pop('max_new_tokens')
|
||||
stop_words = gen_params.pop('stop_words')
|
||||
|
||||
sampling_config = SamplingParams(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_tokens=max_new_tokens,
|
||||
stop=stop_words,
|
||||
**gen_params)
|
||||
response = self.model.generate(prompt, sampling_params=sampling_config)
|
||||
response = [resp.outputs[0].text for resp in response]
|
||||
# remove stop_words
|
||||
response = filter_suffix(response, self.gen_params.get('stop_words'))
|
||||
if batched:
|
||||
return response
|
||||
return response[0]
|
||||
@@ -1,12 +1,10 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from lagent.utils import is_module_exist
|
||||
from enum import IntEnum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def enum_dict_factory(inputs):
|
||||
inputs = [(i[0], i[-1].value) if isinstance(i[-1], Enum) else i
|
||||
inputs = [(i[0], i[-1].value) if isinstance(i[-1], IntEnum) else i
|
||||
for i in inputs]
|
||||
return dict(inputs)
|
||||
|
||||
@@ -15,7 +13,7 @@ def dataclass2dict(data):
|
||||
return asdict(data, dict_factory=enum_dict_factory)
|
||||
|
||||
|
||||
class ActionStatusCode(int, Enum):
|
||||
class ActionStatusCode(IntEnum):
|
||||
ING = 1
|
||||
SUCCESS = 0
|
||||
HTTP_ERROR = -1000 # http error
|
||||
@@ -23,7 +21,7 @@ class ActionStatusCode(int, Enum):
|
||||
API_ERROR = -1002 # 不知道的API错误
|
||||
|
||||
|
||||
class ActionValidCode(int, Enum):
|
||||
class ActionValidCode(IntEnum):
|
||||
FINISH = 1
|
||||
OPEN = 0
|
||||
CLOSED = -1
|
||||
@@ -33,47 +31,58 @@ class ActionValidCode(int, Enum):
|
||||
|
||||
@dataclass
|
||||
class ActionReturn:
|
||||
args: Dict
|
||||
args: Optional[dict] = None
|
||||
url: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
result: Optional[str] = None
|
||||
result: Optional[List[dict]] = None
|
||||
errmsg: Optional[str] = None
|
||||
state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS
|
||||
thought: Optional[str] = None
|
||||
valid: Optional[ActionValidCode] = ActionValidCode.OPEN
|
||||
|
||||
def format_result(self) -> str:
|
||||
"""Concatenate items in result."""
|
||||
result = []
|
||||
for item in self.result or []:
|
||||
if item['type'] == 'text':
|
||||
result.append(item['content'])
|
||||
else:
|
||||
result.append(f"[{item['type']}]({item['content']})")
|
||||
result = '\n'.join(result)
|
||||
return result
|
||||
|
||||
class AgentStatusCode(Enum):
|
||||
END = 0 # end of streaming
|
||||
|
||||
# 需要集成int,如此asdict可以把AgentStatusCode 转换成 int
|
||||
class ModelStatusCode(IntEnum):
|
||||
END = 0 # end of streaming 返回本次history
|
||||
STREAM_ING = 1 # response is in streaming
|
||||
SERVER_ERR = -1 # triton server's error
|
||||
SESSION_CLOSED = -2 # session has been closed
|
||||
SESSION_OUT_OF_LIMIT = -3 # request length out of limit
|
||||
CMD = 2 # return command
|
||||
SESSION_INVALID_ARG = -4 # invalid argument
|
||||
SESSION_READY = 3 # session is ready for inference
|
||||
SESSION_READY = 2 # session is ready for inference
|
||||
|
||||
|
||||
class AgentStatusCode(IntEnum):
|
||||
END = 0 # end of streaming 返回本次history
|
||||
STREAM_ING = 1 # response is in streaming
|
||||
SERVER_ERR = -1 # triton server's error
|
||||
SESSION_CLOSED = -2 # session has been closed
|
||||
SESSION_OUT_OF_LIMIT = -3 # request length out of limit
|
||||
SESSION_INVALID_ARG = -4 # invalid argument
|
||||
SESSION_READY = 2 # session is ready for inference
|
||||
PLUGIN_START = 3 # start tool
|
||||
PLUGIN_END = 4 # finish tool
|
||||
PLUGIN_RETURN = 5 # finish tool
|
||||
CODING = 6 # start python
|
||||
CODE_END = 7 # end python
|
||||
CODE_RETURN = 8 # python return
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentReturn:
|
||||
state: Union[AgentStatusCode, int] = AgentStatusCode.END
|
||||
actions: List[ActionReturn] = field(default_factory=list)
|
||||
response: str = ''
|
||||
inner_steps: List = field(default_factory=list)
|
||||
errmsg: Optional[str] = None
|
||||
|
||||
|
||||
if is_module_exist('lmdeploy'):
|
||||
from lmdeploy.serve.turbomind.chatbot import StatusCode
|
||||
STATE_MAP = {
|
||||
StatusCode.TRITON_STREAM_END: AgentStatusCode.END,
|
||||
StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR,
|
||||
StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED,
|
||||
StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING,
|
||||
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
|
||||
AgentStatusCode.SESSION_OUT_OF_LIMIT,
|
||||
StatusCode.TRITON_SESSION_INVALID_ARG:
|
||||
AgentStatusCode.SESSION_INVALID_ARG,
|
||||
StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY
|
||||
}
|
||||
else:
|
||||
STATE_MAP = {}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import importlib
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
||||
def is_module_exist(module_name):
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
spec = find_spec(module_name)
|
||||
if spec is None:
|
||||
return False
|
||||
else:
|
||||
|
||||
31
lagent/utils/util.py
Normal file
31
lagent/utils/util.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def filter_suffix(response: Union[str, List[str]],
|
||||
suffixes: Optional[List[str]] = None) -> str:
|
||||
"""Filter response with suffixes.
|
||||
|
||||
Args:
|
||||
response (Union[str, List[str]]): generated responses by LLMs.
|
||||
suffixes (str): a list of suffixes to be deleted.
|
||||
|
||||
Return:
|
||||
str: a clean response.
|
||||
"""
|
||||
if suffixes is None:
|
||||
return response
|
||||
batched = True
|
||||
if isinstance(response, str):
|
||||
response = [response]
|
||||
batched = False
|
||||
processed = []
|
||||
for resp in response:
|
||||
for item in suffixes:
|
||||
# if response.endswith(item):
|
||||
# response = response[:len(response) - len(item)]
|
||||
if item in resp:
|
||||
resp = resp.split(item)[0]
|
||||
processed.append(resp)
|
||||
if not batched:
|
||||
return processed[0]
|
||||
return processed
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
__version__ = '0.1.0'
|
||||
__version__ = '0.2.2'
|
||||
|
||||
|
||||
def parse_version_info(version_str: str, length: int = 4) -> tuple:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
docutils==0.16.0
|
||||
astroid<3.0.0
|
||||
docutils==0.18.1
|
||||
markdown>=3.4.0
|
||||
myst-parser
|
||||
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
sphinx==4.0.2
|
||||
myst-nb
|
||||
# -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
# sphinx==4.0.2
|
||||
sphinx==6.1.0
|
||||
sphinx-autoapi
|
||||
sphinx-rtd-theme==1.3.0
|
||||
sphinx-tabs
|
||||
sphinx_copybutton
|
||||
sphinx_markdown_tables>=0.0.16
|
||||
|
||||
@@ -1,2 +1,8 @@
|
||||
google-search-results
|
||||
lmdeploy>=0.2.3
|
||||
pillow
|
||||
python-pptx
|
||||
timeout_decorator
|
||||
torch
|
||||
transformers
|
||||
transformers>=4.34
|
||||
vllm>=0.3.3
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
arxiv
|
||||
distro
|
||||
func_timeout
|
||||
griffe
|
||||
json5
|
||||
jsonschema
|
||||
jupyter
|
||||
jupyter_client
|
||||
phx-class-registry
|
||||
requests
|
||||
streamlit
|
||||
tiktoken
|
||||
typing-extensions
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[isort]
|
||||
line_length = 79
|
||||
line_length = 119
|
||||
multi_line_output = 0
|
||||
extra_standard_library = setuptools
|
||||
known_first_party = mmdet
|
||||
@@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
|
||||
[codespell]
|
||||
skip = *.ipynb
|
||||
quiet-level = 3
|
||||
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer
|
||||
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer,astroid
|
||||
|
||||
[flake8]
|
||||
per-file-ignores = ftdp/configs/*: F401,F403,F405
|
||||
max-line-length = 200
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from lagent.actions.builtin_actions import (FinishAction, InvalidAction,
|
||||
NoAction)
|
||||
from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction
|
||||
from lagent.schema import ActionStatusCode
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user