Compare commits
	
		
			64 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 581d9fb898 | ||
|   | 26672731e9 | ||
|   | 466d7aa24d | ||
|   | 5f55e12736 | ||
|   | e16a6bfc3a | ||
|   | 605a921878 | ||
|   | 4fd014bebf | ||
|   | 432ffaae8a | ||
|   | a64aa599ce | ||
|   | 662011783e | ||
|   | 3cf20f5011 | ||
|   | 7b71988d09 | ||
|   | 90ef5215b6 | ||
|   | 6a5447663a | ||
|   | a2c23ef9dd | ||
|   | 3be9ec042c | ||
|   | aa5a357a34 | ||
|   | 5650a75f3e | ||
|   | eea6e1cb56 | ||
|   | 42c6d265e1 | ||
|   | e20a768066 | ||
|   | 60244a253a | ||
|   | 990828ceb2 | ||
|   | b2e6d0ee7a | ||
|   | fda869162b | ||
|   | baeed6ed88 | ||
|   | 94959c16da | ||
|   | d61dc03079 | ||
|   | 190fb731be | ||
|   | f887b423fa | ||
|   | 35331a5016 | ||
|   | 5dc6fb5b3d | ||
|   | 7489767651 | ||
|   | 83437b081f | ||
|   | 3d9992182e | ||
|   | ca1ab4b5ec | ||
|   | f974a7bd39 | ||
|   | d943becaec | ||
|   | f90010c867 | ||
|   | ae3c7c37a6 | ||
|   | 559275daf4 | ||
|   | c7c46785bb | ||
|   | c9973b9558 | ||
|   | 94ba3a1a75 | ||
|   | 76a46c9a8c | ||
|   | d4a71f40b5 | ||
|   | 5581fad8ce | ||
|   | 95b68c821a | ||
|   | 85b91cc652 | ||
|   | c89729620f | ||
|   | 6b287605bb | ||
|   | 987618c978 | ||
|   | 8ddde9ba5e | ||
|   | 026eff8704 | ||
|   | 511b038890 | ||
|   | 830b9609c3 | ||
|   | 50641c27c2 | ||
|   | b7ca22adcd | ||
|   | 2ecfb9838b | ||
|   | 060bc2c67a | ||
|   | 1cfe5ac099 | ||
|   | 14b549fa6e | ||
|   | e08ad02dce | ||
|   | 4a6f92a77f | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -160,3 +160,4 @@ cython_debug/ | ||||
| #.idea/ | ||||
| .vscode/ | ||||
| docs/*/_build/ | ||||
| tmp_dir/ | ||||
|   | ||||
| @@ -1,11 +1,11 @@ | ||||
| exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/ | ||||
| repos: | ||||
|   - repo: https://github.com/PyCQA/flake8 | ||||
|     rev: 5.0.4 | ||||
|     rev: 7.0.0 | ||||
|     hooks: | ||||
|       - id: flake8 | ||||
|   - repo: https://github.com/PyCQA/isort | ||||
|     rev: 5.11.5 | ||||
|     rev: 5.13.2 | ||||
|     hooks: | ||||
|       - id: isort | ||||
|   - repo: https://github.com/pre-commit/mirrors-yapf | ||||
| @@ -13,7 +13,7 @@ repos: | ||||
|     hooks: | ||||
|       - id: yapf | ||||
|   - repo: https://github.com/pre-commit/pre-commit-hooks | ||||
|     rev: v4.3.0 | ||||
|     rev: v4.5.0 | ||||
|     hooks: | ||||
|       - id: trailing-whitespace | ||||
|       - id: check-yaml | ||||
| @@ -26,7 +26,7 @@ repos: | ||||
|       - id: mixed-line-ending | ||||
|         args: ["--fix=lf"] | ||||
|   - repo: https://github.com/executablebooks/mdformat | ||||
|     rev: 0.7.9 | ||||
|     rev: 0.7.17 | ||||
|     hooks: | ||||
|       - id: mdformat | ||||
|         args: ["--number"] | ||||
| @@ -35,16 +35,11 @@ repos: | ||||
|           - mdformat_frontmatter | ||||
|           - linkify-it-py | ||||
|   - repo: https://github.com/codespell-project/codespell | ||||
|     rev: v2.2.1 | ||||
|     rev: v2.2.6 | ||||
|     hooks: | ||||
|       - id: codespell | ||||
|   - repo: https://github.com/myint/docformatter | ||||
|     rev: v1.3.1 | ||||
|     hooks: | ||||
|       - id: docformatter | ||||
|         args: ["--in-place", "--wrap-descriptions", "79"] | ||||
|   - repo: https://github.com/asottile/pyupgrade | ||||
|     rev: v3.0.0 | ||||
|     rev: v3.15.0 | ||||
|     hooks: | ||||
|       - id: pyupgrade | ||||
|         args: ["--py36-plus"] | ||||
|   | ||||
							
								
								
									
										15
									
								
								.readthedocs.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								.readthedocs.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| version: 2 | ||||
|  | ||||
| formats: all | ||||
|  | ||||
| build: | ||||
|   os: ubuntu-22.04 | ||||
|   tools: | ||||
|     python: "3.10" | ||||
|  | ||||
| python: | ||||
|   install: | ||||
|     - requirements: requirements/docs.txt | ||||
|  | ||||
| sphinx: | ||||
|   configuration: docs/en/conf.py | ||||
| @@ -1,8 +0,0 @@ | ||||
| version: 2 | ||||
|  | ||||
| formats: all | ||||
|  | ||||
| python: | ||||
|     version: 3.7 | ||||
|     install: | ||||
|       - requirements: requirements/docs.txt | ||||
							
								
								
									
										110
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										110
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,3 +1,4 @@ | ||||
| <div id="top"></div> | ||||
| <div align="center"> | ||||
|   <img src="docs/imgs/lagent_logo.png" width="450"/> | ||||
|  | ||||
| @@ -6,32 +7,24 @@ | ||||
| [](https://github.com/InternLM/lagent/tree/main/LICENSE) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | ||||
| English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md) | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <p align="center"> | ||||
|     👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> | ||||
|     👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">𝕏 (Twitter)</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> | ||||
| </p> | ||||
|  | ||||
| ## Introduction | ||||
| <div align="center"> | ||||
|  | ||||
| Lagent is a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents. It also provides some typical tools to augment LLM. The overview of our framework is shown below: | ||||
| https://github.com/InternLM/lagent/assets/24622904/3242f9bf-32d2-4907-8815-e16a75a4ac0e | ||||
|  | ||||
|  | ||||
|  | ||||
| ### Major Features | ||||
|  | ||||
| **0.1.2** was released in 24/10/2023: | ||||
|  | ||||
| - **Support efficient inference engine.** Lagent now supports efficient inference engine [lmdeploy turbomind](https://github.com/InternLM/lmdeploy/tree/main). | ||||
|  | ||||
| - **Support multiple kinds of agents out of box.** Lagent now supports [ReAct](https://arxiv.org/abs/2210.03629), [AutoGPT](https://github.com/Significant-Gravitas/Auto-GPT) and [ReWOO](https://arxiv.org/abs/2305.18323), which can drive the large language models(LLMs) for multiple trials of reasoning and function calling. | ||||
|  | ||||
| - **Extremely simple and easy to extend.** The framework is quite simple with a clear structure. With only 20 lines of code, you are able to construct your own agent. It also supports three typical tools: Python interpreter, API call, and google search. | ||||
|  | ||||
| - **Support various large language models.** We support different LLMs, including API-based (GPT-3.5/4) and open-source (LLaMA 2, InternLM) models. | ||||
| </div> | ||||
|  | ||||
| ## Getting Started | ||||
|  | ||||
| @@ -45,73 +38,58 @@ Install with pip (Recommended). | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| Optionally, you could also build Lagent from source in case you want to modify the code: | ||||
| ### Run a Web Demo | ||||
|  | ||||
| You need to install Streamlit first. | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
|  | ||||
| ### Run ReAct Web Demo | ||||
|  | ||||
| ```bash | ||||
| # You need to install streamlit first | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| streamlit run examples/internlm2_agent_web_demo.py | ||||
| ``` | ||||
|  | ||||
| Then you can chat through the UI shown as below | ||||
|  | ||||
| ## What's Lagent? | ||||
|  | ||||
| ### Run a ReWOO agent with GPT-3.5 | ||||
| Lagent is a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents. It also provides some typical tools to augment LLM. The overview of our framework is shown below: | ||||
|  | ||||
| Below is an example for running ReWOO with GPT-3.5 | ||||
|  | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, LLMQA | ||||
| from lagent.llms import GPTAPI | ||||
| ## Major Features | ||||
|  | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| llmqa_tool = LLMQA(llm) | ||||
| - Stream Output: Provides the `stream_chat` interface for streaming output, allowing cool streaming demos right at your local setup. | ||||
| - Interfacing is unified, with a comprehensive design upgrade for enhanced extensibility, including: | ||||
|   - Model: Whether it's the OpenAI API, Transformers, or LMDeploy inference acceleration framework, you can seamlessly switch between models. | ||||
|   - Action: Simple inheritance and decoration allow you to create your own personal toolkit, adaptable to both InternLM and GPT. | ||||
|   - Agent: Consistent with the Model's input interface, the transformation from model to intelligent agent only takes one step, facilitating the exploration and implementation of various agents. | ||||
| - Documentation has been thoroughly upgraded with full API documentation coverage. | ||||
|  | ||||
| chatbot = ReWOO( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, llmqa_tool]), | ||||
| ) | ||||
| ## 💻Tech Stack | ||||
|  | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
| print(response.response) | ||||
| >>> Film director. | ||||
| ``` | ||||
| <p> | ||||
|   <a href=""> | ||||
|     <img src="https://img.shields.io/badge/Python-007ACC?style=for-the-badge&logo=python&logoColor=yellow" alt="python" /> | ||||
|   </a> | ||||
|  | ||||
| ### Run a ReAct agent with InternLM | ||||
| ### All Thanks To Our Contributors: | ||||
|  | ||||
| NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first. | ||||
| <a href="https://github.com/InternLM/lagent/graphs/contributors"> | ||||
|   <img src="https://contrib.rocks/image?repo=InternLM/lagent" /> | ||||
| </a> | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
| ## Citation | ||||
|  | ||||
| llm = HFTransformer('internlm/internlm-chat-7b-v1_1') | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| python_interpreter = PythonInterpreter() | ||||
| If you find this project useful in your research, please consider cite: | ||||
|  | ||||
| chatbot = ReAct( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
| print(response.response) | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ```latex | ||||
| @misc{lagent2023, | ||||
|     title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents}, | ||||
|     author={Lagent Developer Team}, | ||||
|     howpublished = {\url{https://github.com/InternLM/lagent}}, | ||||
|     year={2023} | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## License | ||||
|  | ||||
| This project is released under the [Apache 2.0 license](LICENSE). | ||||
|  | ||||
| <p align="right"><a href="#top">🔼 Back to top</a></p> | ||||
|   | ||||
							
								
								
									
										69
									
								
								README_KR_Kr.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								README_KR_Kr.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| <div id="top"></div> | ||||
| <div align="center"> | ||||
|   <img src="docs/imgs/lagent_logo.png" width="450"/> | ||||
|  | ||||
| [](https://lagent.readthedocs.io/en/latest/) | ||||
| [](https://pypi.org/project/lagent) | ||||
| [](https://github.com/InternLM/lagent/tree/main/LICENSE) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
|  | ||||
| English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <p align="center"> | ||||
|     👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> | ||||
| </p> | ||||
|  | ||||
| <div align="center"> | ||||
|  | ||||
| https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f | ||||
|  | ||||
| </div> | ||||
|  | ||||
| ## 시작하기 | ||||
|  | ||||
| 일반적인 Lagent 소개에 대한 [overview](docs/en/get_started/overview.md) 를 확인하십시오. 동시에 빠른 시작을 위한 매우 간단한 코드를 제공합니다. 자세한 내용은 [examples](examples/) 를 참조하십시오. | ||||
|  | ||||
| ### 설치 | ||||
|  | ||||
| pip를 사용하여 설치하십시오 (권장). | ||||
|  | ||||
| ```bash | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| ### 웹 데모 실행 | ||||
|  | ||||
| 먼저 streamlit을 설치해야 합니다 | ||||
|  | ||||
| ```bash | ||||
| # pip install streamlit | ||||
| streamlit run examples/internlm2_agent_web_demo.py | ||||
| ``` | ||||
|  | ||||
| ## 소개 | ||||
|  | ||||
| Lagent는 사용자가 효율적으로 대규모 언어 모델(LLM) 기반 에이전트를 구축할 수 있게 해주는 경량의 오픈 소스 프레임워크입니다. 또한 LLM을 보강하기 위한 몇 가지 일반적인 도구도 제공합니다. 우리 프레임워크의 개요는 아래와 같이 나와 있습니다: | ||||
|  | ||||
|  | ||||
|  | ||||
| ## 인용 | ||||
|  | ||||
| 이 프로젝트가 귀하의 연구에 유용하다고 생각하면 다음과 같이 인용해 주십시오: | ||||
|  | ||||
| ```latex | ||||
| @misc{lagent2023, | ||||
|     title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents}, | ||||
|     author={Lagent Developer Team}, | ||||
|     howpublished = {\url{https://github.com/InternLM/lagent}}, | ||||
|     year={2023} | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 라이선스 | ||||
|  | ||||
| 이 프로젝트는 [Apache 2.0](LICENSE) 하에 공개되었습니다. | ||||
|  | ||||
| <p align="right"><a href="#top">🔼 Back to top</a></p> | ||||
| @@ -1,3 +1,4 @@ | ||||
| <div id="top"></div> | ||||
| <div align="center"> | ||||
|   <img src="docs/imgs/lagent_logo.png" width="450"/> | ||||
|  | ||||
| @@ -7,27 +8,15 @@ | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
|  | ||||
| [English](README.md) | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | हिंदी | [বাংলা](README_in_beng.md) | ||||
| English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md) | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <p align="center"> | ||||
|     👋 <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> और <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> पर हमसे जुड़ें | ||||
| </p> | ||||
| <div align="center"> | ||||
|  | ||||
| ## परिचय | ||||
| https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f | ||||
|  | ||||
| Lagent एक हल्का ओपन-सोर्स फ्रेमवर्क है जो उपयोगकर्ताओं को बड़े भाषा मॉडल (एलएलएम)-आधारित एजेंटों को कुशलतापूर्वक बनाने की अनुमति देता है। यह एलएलएम को बढ़ाने के लिए कुछ विशिष्ट उपकरण भी प्रदान करता है। हमारे ढांचे का अवलोकन नीचे दिखाया गया है: | ||||
|  | ||||
|  | ||||
|  | ||||
| ### प्रमुख विशेषताएं | ||||
|  | ||||
| - **बॉक्स से बाहर कई प्रकार के एजेंटों का समर्थन करें।** लैजेंट अब समर्थन करता है [ReAct](https://arxiv.org/abs/2210.03629), [AutoGPT](https://github.com/Significant-Gravitas/Auto-GPT) और [ReWOO](https://arxiv.org/abs/2305.18323), जो तर्क और फ़ंक्शन कॉलिंग के कई परीक्षणों के लिए बड़े भाषा मॉडल (एलएलएम) को संचालित कर सकता है। | ||||
|  | ||||
| - **बेहद सरल और विस्तार करने में आसान।** स्पष्ट संरचना के साथ ढांचा काफी सरल है। कोड की केवल 20 पंक्तियों के साथ, आप अपना स्वयं का एजेंट बनाने में सक्षम हैं। यह तीन विशिष्ट टूल का भी समर्थन करता है: पायथन इंटरप्रेटर, एपीआई कॉल और गूगल सर्च। | ||||
|  | ||||
| - **विभिन्न बड़े भाषा मॉडल का समर्थन करें।** हम एपीआई-आधारित (जीपीटी-3.5/4) और ओपन-सोर्स (एलएलएएमए 2, इंटर्नएलएम) मॉडल सहित विभिन्न एलएलएम का समर्थन करते हैं। | ||||
| </div> | ||||
|  | ||||
| ## शुरू करना | ||||
|  | ||||
| @@ -41,73 +30,22 @@ pip के साथ स्थापित करें (अनुशंसि | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| वैकल्पिक रूप से, यदि आप कोड को संशोधित करना चाहते हैं तो आप स्रोत से लैजेंट भी बना सकते हैं: | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
|  | ||||
| ### रिएक्ट वेब डेमो चलाएँ | ||||
| ### वेब डेमो चलाएँ | ||||
|  | ||||
| ```bash | ||||
| # You need to install streamlit first | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| streamlit run examples/internlm2_agent_web_demo.py | ||||
| ``` | ||||
|  | ||||
| फिर आप नीचे दिखाए गए यूआई के माध्यम से चैट कर सकते हैं | ||||
|  | ||||
| ## परिचय | ||||
|  | ||||
| ### GPT-3.5 के साथ ReWOO एजेंट चलाएँ | ||||
| Lagent एक हल्का ओपन-सोर्स फ्रेमवर्क है जो उपयोगकर्ताओं को बड़े भाषा मॉडल (एलएलएम)-आधारित एजेंटों को कुशलतापूर्वक बनाने की अनुमति देता है। यह एलएलएम को बढ़ाने के लिए कुछ विशिष्ट उपकरण भी प्रदान करता है। हमारे ढांचे का अवलोकन नीचे दिखाया गया है: | ||||
|  | ||||
| GPT-3.5 के साथ ReWOO चलाने का एक उदाहरण नीचे दिया गया है | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, LLMQA | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| llmqa_tool = LLMQA(llm) | ||||
|  | ||||
| chatbot = ReWOO( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, llmqa_tool]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
| print(response.response) | ||||
| >>> Film director. | ||||
| ``` | ||||
|  | ||||
| ### InternLM के साथ एक ReAct एजेंट चलाएँ | ||||
|  | ||||
| नोट: यदि आप हगिंगफेस मॉडल चलाना चाहते हैं, तो कृपया पहले `pip install -e .[all]` चलाएं। | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| llm = HFTransformer('internlm/internlm-chat-7b-v1_1') | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| chatbot = ReAct( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
| print(response.response) | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## लाइसेंस | ||||
|  | ||||
| यह प्रोजेक्ट [Apache 2.0 license](LICENSE) के तहत जारी किया गया है। | ||||
|  | ||||
| <p align="right"><a href="#top">🔼 Back to top</a></p> | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| <div id="top"></div> | ||||
| <div align="center"> | ||||
|   <img src="docs/imgs/lagent_logo.png" width="450"/> | ||||
|  | ||||
| @@ -7,7 +8,7 @@ | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
|  | ||||
| [English](README.md) | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | বাংলা | ||||
| English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md) | ||||
|  | ||||
| </div> | ||||
|  | ||||
| @@ -15,19 +16,11 @@ | ||||
|     👋 <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> এবং <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> সাথে আমাদের সাথে যোগদান করুন | ||||
| </p> | ||||
|  | ||||
| ## পরিচিতি | ||||
| <div align="center"> | ||||
|  | ||||
| লেজেন্ট হল একটি হালকা ওপেন-সোর্স ফ্রেমওয়ার্ক, যা ব্যবহারকারীদের দ্বারা প্রশাসক ভাষা মডেল (LLM) ভিত্তিক এজেন্ট সৃজনশীলভাবে তৈরি করতে দেয়। এটি লেজেন্ট যেসব প্রধান সরঞ্জাম সরবরাহ করে, সেটি নীচে দেখানো হয়: | ||||
| https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f | ||||
|  | ||||
|  | ||||
|  | ||||
| ### মৌলিক বৈশিষ্ট্য | ||||
|  | ||||
| - **বাক্যের কিছু প্রকারের এজেন্টের সাথে সমর্থন।** লেজেন্ট এখন [ReAct](https://arxiv.org/abs/2210.03629), [AutoGPT](https://github.com/Significant-Gravitas/Auto-GPT) এবং [ReWOO](https://arxiv.org/abs/2305.18323) সহ বড় ভাষা মডেল (LLM)-কে যুক্ত করতে পারে, যা বড় ভাষা মডেল (LLM) ব্যবহার করে বিভিন্ন কারণের এবং কার্য করার জন্য অনেক পরীক্ষা করতে পারে। | ||||
|  | ||||
| - **অত্যন্ত সহজ এবং বড় করা সম্ভব।** এই ফ্রেমওয়ার্কটি খুব সাধারণ, একটি স্পষ্ট স্ট্রাকচার সহ। শুধুমাত্র 20 লাইনের কোডের সাথে, আপনি নিজের এজেন্ট তৈরি করতে পারেন। এটি সাথে একই সময়ে তিনটি সাধারণ টুলও সমর্থন করে: পাইথন ইন্টারপ্রিটার, API কল, এবং গুগল সার্চ। | ||||
|  | ||||
| - **বিভিন্ন বড় ভাষা মডেলের সমর্থন।** আমরা API-ভিত্তিক (GPT-3.5/4) এবং ওপেন-সোর্স (LLaMA 2, InternLM) মডেলগুলির মধ্যে বিভিন্ন LLM-এসকে সমর্থন করি। | ||||
| </div> | ||||
|  | ||||
| ## শুরু করা | ||||
|  | ||||
| @@ -41,73 +34,22 @@ | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| আপনি চাইলে সোর্স থেকে লেজেন্ট তৈরি করতে পারেন, কোড পরিবর্তন করতে চাইলে: | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
|  | ||||
| ### ReAct ওয়েব ডেমো চালান | ||||
| ### ওয়েব ডেমো চালান | ||||
|  | ||||
| ```bash | ||||
| # You need to install streamlit first | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| streamlit run examples/internlm2_agent_web_demo.py | ||||
| ``` | ||||
|  | ||||
| তারপর আপনি নীচে দেওয়া ছবির মাধ্যমে ইউআই দিয়ে চ্যাট করতে পারেন | ||||
|  | ||||
| ## পরিচিতি | ||||
|  | ||||
| ### GPT-3.5 সহ একটি ReWOO এজেন্ট চালান | ||||
| লেজেন্ট হল একটি হালকা ওপেন-সোর্স ফ্রেমওয়ার্ক, যা ব্যবহারকারীদের দ্বারা প্রশাসক ভাষা মডেল (LLM) ভিত্তিক এজেন্ট সৃজনশীলভাবে তৈরি করতে দেয়। এটি লেজেন্ট যেসব প্রধান সরঞ্জাম সরবরাহ করে, সেটি নীচে দেখানো হয়: | ||||
|  | ||||
| নীচে একটি উদাহরণ দেওয়া হল ReWOO সহ GPT-3.5 চালানোর জন্য | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, LLMQA | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| llmqa_tool = LLMQA(llm) | ||||
|  | ||||
| chatbot = ReWOO( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, llmqa_tool]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
| print(response.response) | ||||
| >>> Film director. | ||||
| ``` | ||||
|  | ||||
| ### InternLM দিয়ে একটি ReAct এজেন্ট চালান | ||||
|  | ||||
| নোট: আপনি যদি একটি HuggingFace মডেল চালাতে চান, তবে প্রথমে pip install -e .[all] চালানো দরকার। | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| llm = HFTransformer('internlm/internlm-chat-7b-v1_1') | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| chatbot = ReAct( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
| print(response.response) | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## লাইসেন্স | ||||
|  | ||||
| এই প্রকল্পটি [Apache 2.0 license](LICENSE) অনুসরণ করে প্রকাশিত হয়। | ||||
|  | ||||
| <p align="right"><a href="#top">🔼 Back to top</a></p> | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| <div id="top"></div> | ||||
| <div align="center"> | ||||
|   <img src="docs/imgs/lagent_logo.png" width="450"/> | ||||
|  | ||||
| @@ -7,7 +8,7 @@ | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
|  | ||||
| [English](README.md) | [简体中文](README_zh-CN.md) | 日本語 | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | ||||
| English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md) | ||||
|  | ||||
| </div> | ||||
|  | ||||
| @@ -15,19 +16,11 @@ | ||||
|     👋 <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> そして <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> に参加する | ||||
| </p> | ||||
|  | ||||
| ## はじめに | ||||
| <div align="center"> | ||||
|  | ||||
| Lagent は、大規模言語モデル(LLM)ベースのエージェントを効率的に構築できる軽量なオープンソースフレームワークです。また、LLM を拡張するための典型的なツールも提供します。我々のフレームワークの概要を以下に示します: | ||||
| https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 主な特徴 | ||||
|  | ||||
| - **複数のエージェントをすぐにサポート** Lagent は現在、[ReAct](https://arxiv.org/abs/2210.03629)、[AutoGPT](https://github.com/Significant-Gravitas/Auto-GPT)、[ReWOO](https://arxiv.org/abs/2305.18323) をサポートしており、推論や関数呼び出しの複数の試行に対して大規模言語モデル(LLM)を駆動することができる。 | ||||
|  | ||||
| - **非常にシンプルで、拡張も簡単。** フレームワークは非常にシンプルで、明確な構造を持っています。わずか 20 行のコードで、独自のエージェントを構築することができます。また、3 つの代表的なツールをサポートしています: Python インタプリタ、API コール、google 検索です。 | ||||
|  | ||||
| - **様々な大規模言語モデルをサポート。** API ベース(GPT-3.5/4)やオープンソース(LLaMA 2, InternLM)を含む様々な LLM をサポートしています。 | ||||
| </div> | ||||
|  | ||||
| ## はじめに | ||||
|  | ||||
| @@ -41,73 +34,23 @@ pip でインストールする(推奨)。 | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| オプションとして、コードを修正したい場合に備えて、Lagent をソースからビルドすることもできる: | ||||
| ### ウェブデモの実行 | ||||
|  | ||||
| 最初に streamlit をインストールする必要があります | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
|  | ||||
| ### ReAct ウェブデモの実行 | ||||
|  | ||||
| ```bash | ||||
| # 最初に streamlit をインストールする必要があります | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| streamlit run examples/internlm2_agent_web_demo.py | ||||
| ``` | ||||
|  | ||||
| その後、以下のような UI からチャットができます | ||||
|  | ||||
| ## はじめに | ||||
|  | ||||
| ### GPT-3.5 で ReWOO エージェントを動かす | ||||
| Lagent は、大規模言語モデル(LLM)ベースのエージェントを効率的に構築できる軽量なオープンソースフレームワークです。また、LLM を拡張するための典型的なツールも提供します。我々のフレームワークの概要を以下に示します: | ||||
|  | ||||
| 以下は、GPT-3.5 で ReWOO を実行する例です | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, LLMQA | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| llmqa_tool = LLMQA(llm) | ||||
|  | ||||
| chatbot = ReWOO( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, llmqa_tool]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
| print(response.response) | ||||
| >>> Film director. | ||||
| ``` | ||||
|  | ||||
| ### InternLM で ReAct エージェントを動かす | ||||
|  | ||||
| 注: Hugging Face モデルを実行したい場合は、まず `pip install -e .[all]` を実行してください。 | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| llm = HFTransformer('internlm/internlm-chat-7b-v1_1') | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| chatbot = ReAct( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
| print(response.response) | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## ライセンス | ||||
|  | ||||
| このプロジェクトは [Apache 2.0 license](LICENSE) の下でリリースされています。 | ||||
|  | ||||
| <p align="right"><a href="#top">🔼 Back to top</a></p> | ||||
|   | ||||
							
								
								
									
										104
									
								
								README_zh-CN.md
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								README_zh-CN.md
									
									
									
									
									
								
							| @@ -1,3 +1,4 @@ | ||||
| <div id="top"></div> | ||||
| <div align="center"> | ||||
|   <img src="docs/imgs/lagent_logo.png" width="450"/> | ||||
|  | ||||
| @@ -7,7 +8,7 @@ | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
| [](https://github.com/InternLM/lagent/issues) | ||||
|  | ||||
| [English](README.md) | 简体中文 | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | ||||
| English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [हिंदी](README_in_HIN.md) | [বাংলা](README_in_beng.md) | [한국어](README_KR_Kr.md) | ||||
|  | ||||
| </div> | ||||
|  | ||||
| @@ -15,26 +16,14 @@ | ||||
|     👋 join us on <a href="https://twitter.com/intern_lm" target="_blank">Twitter</a>, <a href="https://discord.gg/xa29JuW87d" target="_blank">Discord</a> and <a href="https://r.vansin.top/?r=internwx" target="_blank">WeChat</a> | ||||
| </p> | ||||
|  | ||||
| <div align="center"> | ||||
|  | ||||
| https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f | ||||
|  | ||||
| </div> | ||||
|  | ||||
| [English](README.md) | 简体中文 | ||||
|  | ||||
| ## 简介 | ||||
|  | ||||
| Lagent 是一个轻量级、开源的基于大语言模型的智能体(agent)框架,支持用户快速地将一个大语言模型转变为多种类型的智能体,并提供了一些典型工具为大语言模型赋能。它的整个框架图如下: | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 主要特点 | ||||
|  | ||||
| **0.1.2** 版本已经在 2023.10.24 发布 | ||||
|  | ||||
| - **支持高性能推理.** 我们现在支持了高性能推理 [lmdeploy turbomind](https://github.com/InternLM/lmdeploy/tree/main). | ||||
|  | ||||
| - **实现了多种类型的智能体,** 我们支持了经典的 [ReAct](https://arxiv.org/abs/2210.03629),[AutoGPT](https://github.com/Significant-Gravitas/Auto-GPT) 和 [ReWoo](https://arxiv.org/abs/2305.18323) 等智能体,这些智能体能够调用大语言模型进行多轮的推理和工具调用。 | ||||
|  | ||||
| - **框架简单易拓展.** 框架的代码结构清晰且简单,只需要不到20行代码你就能够创造出一个你自己的智能体(agent)。同时我们支持了 Python 解释器、API 调用和搜索三类常用典型工具。 | ||||
|  | ||||
| - **灵活支持多个大语言模型.** 我们提供了多种大语言模型支持,包括 InternLM、Llama-2 等开源模型和 GPT-4/3.5 等基于 API 的闭源模型。 | ||||
|  | ||||
| ## 教程 | ||||
|  | ||||
| 请阅读[概述](docs/en/get_started/overview.md)对 Lagent 项目进行初步的了解。同时, 我们提供了两个非常简单的样例帮助你快速入门。 你也可以阅读[示例代码](examples/)获得更多的例子参考。 | ||||
| @@ -47,72 +36,45 @@ Lagent 是一个轻量级、开源的基于大语言模型的智能体(agent | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| 同时,如果你想修改这部分的代码,也可以通过以下命令从源码编译 Lagent: | ||||
| ### 运行一个智能体的网页样例 | ||||
|  | ||||
| 你可能需要先安装 Streamlit 包 | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
|  | ||||
| ### 运行一个 ReAct 智能体的网页样例 | ||||
|  | ||||
| ```bash | ||||
| # 可能先需要安装 streamlit 包 | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| streamlit run examples/internlm2_agent_web_demo.py | ||||
| ``` | ||||
|  | ||||
| 然后你就可以在网页端和智能体进行对话了,效果如下图所示 | ||||
| ## 简介 | ||||
|  | ||||
|  | ||||
| Lagent 是一个轻量级、开源的基于大语言模型的智能体(agent)框架,支持用户快速地将一个大语言模型转变为多种类型的智能体,并提供了一些典型工具为大语言模型赋能。它的整个框架图如下: | ||||
|  | ||||
| ### 用 GPT-3.5 构建一个 ReWOO 智能体 | ||||
|  | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, LLMQA | ||||
| from lagent.llms import GPTAPI | ||||
| ## 特性 | ||||
|  | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key='OPENAI_API_KEY') | ||||
| search_tool = GoogleSearch(api_key='SERPER_API_KEY') | ||||
| llmqa_tool = LLMQA(llm) | ||||
| - 流式输出:提供 `stream_chat` 接口作流式输出,本地就能演示酷炫的流式 Demo。 | ||||
| - 接口统一,设计全面升级,提升拓展性,包括 | ||||
|   - Model : 不论是 OpenAI API, Transformers 还是推理加速框架 LMDeploy 一网打尽,模型切换可以游刃有余; | ||||
|   - Action: 简单的继承和装饰,即可打造自己个人的工具集,不论 InternLM 还是 GPT 均可适配; | ||||
|   - Agent:与 Model 的输入接口保持一致,模型到智能体的蜕变只需一步,便捷各种 agent 的探索实现; | ||||
| - 文档全面升级,API 文档全覆盖。 | ||||
|  | ||||
| chatbot = ReWOO( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, llmqa_tool]), | ||||
| ) | ||||
| ## 引用 | ||||
|  | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
| print(response.response) | ||||
| >>> Film director. | ||||
| ``` | ||||
| 如果你觉得本项目对你的研究工作有所帮助,请参考如下 bibtex 引用 Lagent: | ||||
|  | ||||
| ### 用 InternLM 构建一个 ReAct 智能体 | ||||
|  | ||||
| 注意:如果你想要启动一个HuggingFace的模型,请先运行 `pip install -e .[all]`。 | ||||
|  | ||||
| ```python | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| llm = HFTransformer('internlm/internlm-chat-7b-v1_1') | ||||
| search_tool = GoogleSearch(api_key='SERPER_API_KEY') | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| chatbot = ReAct( | ||||
|     llm=llm, | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]), | ||||
| ) | ||||
|  | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
| print(response.response) | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i | ||||
| ```latex | ||||
| @misc{lagent2023, | ||||
|     title={{Lagent: InternLM} a lightweight open-source framework that allows users to efficiently build large language model(LLM)-based agents}, | ||||
|     author={Lagent Developer Team}, | ||||
|     howpublished = {\url{https://github.com/InternLM/lagent}}, | ||||
|     year={2023} | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 开源许可证 | ||||
|  | ||||
| 该项目采用[Apache 2.0 开源许可证](LICENSE)。 | ||||
|  | ||||
| <p align="right"><a href="#top">🔼 Back to top</a></p> | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| .header-logo { | ||||
|     background-image: url("../images/robot.png"); | ||||
|     background-image: url("../images/lagent_icon.png"); | ||||
|     background-size: 40px 40px; | ||||
|     height: 40px; | ||||
|     width: 40px; | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								docs/en/_static/images/lagent_icon.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/en/_static/images/lagent_icon.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 11 KiB | 
							
								
								
									
										14
									
								
								docs/en/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								docs/en/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| API Reference | ||||
| ============= | ||||
|  | ||||
| This page contains auto-generated API reference documentation. | ||||
|  | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|  | ||||
|    {% for page in pages %} | ||||
|    {% if page.top_level_object and page.display %} | ||||
|    {{ page.include_path }} | ||||
|    {% endif %} | ||||
|    {% endfor %} | ||||
							
								
								
									
										112
									
								
								docs/en/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								docs/en/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,112 @@ | ||||
| {% if not obj.display %} | ||||
| :orphan: | ||||
|  | ||||
| {% endif %} | ||||
| :py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}` | ||||
| =========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }} | ||||
|  | ||||
| .. py:module:: {{ obj.name }} | ||||
|  | ||||
| {% if obj.docstring %} | ||||
| .. autoapi-nested-parse:: | ||||
|  | ||||
|    {{ obj.docstring|indent(3) }} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| {% block subpackages %} | ||||
| {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} | ||||
| {% if visible_subpackages %} | ||||
| Subpackages | ||||
| ----------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|  | ||||
| {% for subpackage in visible_subpackages %} | ||||
|    {{ subpackage.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block submodules %} | ||||
| {% set visible_submodules = obj.submodules|selectattr("display")|list %} | ||||
| {% if visible_submodules %} | ||||
| Submodules | ||||
| ---------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 1 | ||||
|  | ||||
| {% for submodule in visible_submodules %} | ||||
|    {{ submodule.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block content %} | ||||
| {% if obj.type is equalto("package") %} | ||||
| {% set visible_children = obj.children|selectattr("display")|list %} | ||||
| {% else %} | ||||
| {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} | ||||
| {% endif %} | ||||
| {% if visible_children %} | ||||
| {{ obj.type|title }} Contents | ||||
| {{ "-" * obj.type|length }}--------- | ||||
|  | ||||
| {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} | ||||
| {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} | ||||
| {% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %} | ||||
| {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} | ||||
| {% block classes scoped %} | ||||
| {% if visible_classes %} | ||||
| Classes | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for klass in visible_classes %} | ||||
|    {{ klass.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block functions scoped %} | ||||
| {% if visible_functions %} | ||||
| Functions | ||||
| ~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for function in visible_functions %} | ||||
|    {{ function.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block attributes scoped %} | ||||
| {% if visible_attributes %} | ||||
| Attributes | ||||
| ~~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for attribute in visible_attributes %} | ||||
|    {{ attribute.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% endif %} | ||||
| {% for obj_item in visible_children %} | ||||
| {{ obj_item.render()|indent(0) }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
							
								
								
									
										110
									
								
								docs/en/conf.py
									
									
									
									
									
								
							
							
						
						
									
										110
									
								
								docs/en/conf.py
									
									
									
									
									
								
							| @@ -11,17 +11,16 @@ | ||||
| # documentation root, use os.path.abspath to make it absolute, like shown here. | ||||
|  | ||||
| import os | ||||
| import re | ||||
| import sys | ||||
|  | ||||
| import pytorch_sphinx_theme | ||||
|  | ||||
| sys.path.insert(0, os.path.abspath('../../')) | ||||
| sys.path.insert(0, os.path.abspath('../..')) | ||||
|  | ||||
| # -- Project information ----------------------------------------------------- | ||||
|  | ||||
| project = 'Lagent' | ||||
| copyright = '2020-2030, InternLM' | ||||
| author = 'InternLM' | ||||
| language = 'en' | ||||
|  | ||||
| # The full version, including alpha/beta/rc tags | ||||
| version_file = '../../lagent/version.py' | ||||
| @@ -36,97 +35,74 @@ release = __version__ | ||||
| # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||||
| # ones. | ||||
| extensions = [ | ||||
|     'sphinx_rtd_theme', | ||||
|     'myst_nb', | ||||
|     'autoapi.extension', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx.ext.autodoc', | ||||
|     'sphinx.ext.napoleon', | ||||
|     'sphinx.ext.viewcode', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx_copybutton', | ||||
|     'myst_parser', | ||||
|     'sphinx.ext.intersphinx', | ||||
|     'sphinx.ext.autodoc.typehints', | ||||
|     'sphinx.ext.autosummary', | ||||
|     'sphinx.ext.autosectionlabel', | ||||
|     'sphinx_tabs.tabs', | ||||
| ] | ||||
|  | ||||
| nb_output_stderr = 'remove-warn' | ||||
| autodoc_typehints = 'description' | ||||
| autosummary_generate = True  # Turn on sphinx.ext.autosummary | ||||
|  | ||||
| # Ignore >>> when copying code | ||||
| copybutton_prompt_text = r'>>> |\.\.\. ' | ||||
| copybutton_prompt_is_regexp = True | ||||
|  | ||||
| myst_enable_extensions = ['colon_fence'] | ||||
| # sphinx-autoapi configuration | ||||
| autoapi_dirs = ['../../lagent'] | ||||
| autoapi_options = [ | ||||
|     'members', | ||||
|     'undoc-members', | ||||
|     'show-inheritance', | ||||
|     'show-module-summary', | ||||
| ] | ||||
| autoapi_ignore = ['*migrations*', '*command.py', '*cli.py'] | ||||
| autoapi_template_dir = '_templates/autoapi' | ||||
| autoapi_add_toctree_entry = False | ||||
|  | ||||
| # Add any paths that contain templates here, relative to this directory. | ||||
| templates_path = ['_templates'] | ||||
|  | ||||
| # The suffix(es) of source filenames. | ||||
| # You can specify multiple suffix as a list of string: | ||||
| # | ||||
| source_suffix = { | ||||
|     '.rst': 'restructuredtext', | ||||
|     '.md': 'markdown', | ||||
| } | ||||
|  | ||||
| # The master toctree document. | ||||
| master_doc = 'index' | ||||
|  | ||||
| # List of patterns, relative to source directory, that match files and | ||||
| # directories to ignore when looking for source files. | ||||
| # This pattern also affects html_static_path and html_extra_path. | ||||
| exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] | ||||
| exclude_patterns = [] | ||||
|  | ||||
| # -- Options for HTML output ------------------------------------------------- | ||||
|  | ||||
| # The theme to use for HTML and HTML Help pages.  See the documentation for | ||||
| # a list of builtin themes. | ||||
| # | ||||
| # html_theme = 'sphinx_rtd_theme' | ||||
| html_theme = 'pytorch_sphinx_theme' | ||||
| html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] | ||||
| html_theme = 'sphinx_rtd_theme' | ||||
| html_theme_options = { | ||||
|     'menu': [ | ||||
|         { | ||||
|             'name': 'GitHub', | ||||
|             'url': 'https://github.com/InternLM/lagent' | ||||
|         }, | ||||
|     ], | ||||
|     # Specify the language of shared menu | ||||
|     'menu_lang': 'en' | ||||
|     'navigation_depth': 3, | ||||
|     'titles_only': False, | ||||
|     'style_nav_header_background': '#4fabab', | ||||
| } | ||||
|  | ||||
| language = 'en' | ||||
| html_context = { | ||||
|     'display_github': True, | ||||
|     'github_host': 'github.com', | ||||
|     'github_user': 'InternLM', | ||||
|     'github_repo': 'lagent', | ||||
|     'github_version': 'main', | ||||
|     'conf_py_path': '/docs/en/', | ||||
| } | ||||
| html_title = 'Lagent' | ||||
| html_logo = '../imgs/lagent_logo.png' | ||||
| html_favicon = '../imgs/lagent_icon.png' | ||||
|  | ||||
| master_doc = 'index' | ||||
|  | ||||
| # Add any paths that contain custom static files (such as style sheets) here, | ||||
| # relative to this directory. They are copied after the builtin static files, | ||||
| # so a file named "default.css" will overwrite the builtin "default.css". | ||||
| # so a file named 'default.css' will overwrite the builtin 'default.css'. | ||||
| html_static_path = ['_static'] | ||||
|  | ||||
| html_css_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css', | ||||
|     'css/readthedocs.css' | ||||
| ] | ||||
| html_js_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js', | ||||
|     'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js', | ||||
|     'js/collapsed.js', | ||||
|     'js/table.js', | ||||
| ] | ||||
|  | ||||
| myst_heading_anchors = 4 | ||||
|  | ||||
| intersphinx_mapping = { | ||||
|     'python': ('https://docs.python.org/3', None), | ||||
|     'numpy': ('https://numpy.org/doc/stable', None), | ||||
|     'torch': ('https://pytorch.org/docs/stable/', None), | ||||
| } | ||||
| def custom_skip(app, what, name, obj, skip, options): | ||||
|     if what in ['data', 'function', 'class'] and re.search('logger', name): | ||||
|         skip = True | ||||
|     return skip | ||||
|  | ||||
|  | ||||
| def builder_inited_handler(app): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def setup(app): | ||||
|     app.connect('builder-inited', builder_inited_handler) | ||||
| def setup(sphinx): | ||||
|     sphinx.connect('autoapi-skip-member', custom_skip) | ||||
|   | ||||
							
								
								
									
										19
									
								
								docs/en/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								docs/en/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| # Installation | ||||
|  | ||||
| ## With pip | ||||
|  | ||||
| Install with pip (Recommended). | ||||
|  | ||||
| ```bash | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| ## From source | ||||
|  | ||||
| Optionally, you could also build Lagent from source in case you want to modify the code: | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
| @@ -1,10 +1,10 @@ | ||||
| # OVERVIEW | ||||
| # Overview | ||||
|  | ||||
| This chapter introduces you to the framework of Lagent, and provides links to detailed tutorials about Lagent. | ||||
|  | ||||
| ## What is Lagent | ||||
|  | ||||
| Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below: | ||||
| Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ability of LLM, and the whole framework is shown below: | ||||
|  | ||||
|  | ||||
|  | ||||
| @@ -18,6 +18,6 @@ Lagent consists of 3 main parts, agents, llms, and actions. | ||||
|  | ||||
| Here is a detailed step-by-step guide to learn more about Lagent: | ||||
|  | ||||
| 1. For installation instructions, please see [README](../README.md). | ||||
| 1. For installation instructions, please see [README](https://github.com/InternLM/lagent/blob/main/README.md). | ||||
|  | ||||
| 2. We provide several examples to build agents with Lagent in [examples](examples/) by simply run `python examples/react_example.py`. | ||||
| 2. We provide several examples to build agents with Lagent in [examples](https://github.com/InternLM/lagent/tree/main/examples) by simply run `python examples/react_example.py`. | ||||
|   | ||||
							
								
								
									
										89
									
								
								docs/en/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								docs/en/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| # Quickstart | ||||
|  | ||||
| Using Lagent, you can easily build agents with just a few lines of code. | ||||
|  | ||||
| ## Run a ReWOO agent with GPT-3.5 | ||||
|  | ||||
| Below is an example of running ReWOO with GPT-3.5 | ||||
|  | ||||
| ```python | ||||
| # Import necessary modules and classes from the "lagent" library. | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| # Initialize the Language Model (llm) and provide your API key. | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
|  | ||||
| # Initialize the Google Search tool and provide your API key. | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # Create a chatbot by configuring the ReWOO agent. | ||||
| chatbot = ReWOO( | ||||
|     llm=llm,  # Provide the Language Model instance. | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool]  # Specify the actions the chatbot can perform. | ||||
|     ), | ||||
| ) | ||||
|  | ||||
| # Ask the chatbot a question and store the response. | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
|  | ||||
| # Print the chatbot's response. | ||||
| print(response.response)  # Output the response generated by the chatbot. | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> Film director. | ||||
| ``` | ||||
|  | ||||
| ## Run a ReAct agent with InternLM | ||||
|  | ||||
| NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first. | ||||
|  | ||||
| ```python | ||||
| # Import necessary modules and classes from the "lagent" library. | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
|  | ||||
| # Initialize the HFTransformer-based Language Model (llm) and | ||||
| # provide the model name. | ||||
| llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META) | ||||
|  | ||||
| # Initialize the Google Search tool and provide your API key. | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # Initialize the Python Interpreter tool. | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| # Create a chatbot by configuring the ReAct agent. | ||||
| # Specify the actions the chatbot can perform. | ||||
| chatbot = ReAct( | ||||
|     llm=llm,  # Provide the Language Model instance. | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]), | ||||
| ) | ||||
| # Ask the chatbot a mathematical question in LaTeX format. | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
|  | ||||
| # Print the chatbot's response. | ||||
| print(response.response)  # Output the response generated by the chatbot. | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
| ## Run ReAct Web Demo | ||||
|  | ||||
| ```python | ||||
| # You need to install streamlit first | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| ``` | ||||
|  | ||||
| Then you can chat through the UI shown as below | ||||
|  | ||||
| @@ -8,13 +8,31 @@ You can switch between English and Chinese in the lower-left corner of the layou | ||||
|    :caption: Get Started | ||||
|  | ||||
|    get_started/overview.md | ||||
|    get_started/install.md | ||||
|    get_started/quickstart.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|    :caption: Tutorials | ||||
|  | ||||
|    tutorials/action.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :caption: Switch Language | ||||
|  | ||||
|    switch_language.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|    :caption: API Reference | ||||
|  | ||||
|    autoapi/lagent/actions/index | ||||
|    autoapi/lagent/agents/index | ||||
|    autoapi/lagent/llms/index | ||||
|    autoapi/lagent/utils/index | ||||
|    autoapi/lagent/schema/index | ||||
|    autoapi/lagent/version/index | ||||
|  | ||||
|  | ||||
| Indices and tables | ||||
| ================== | ||||
|   | ||||
| @@ -1,3 +1,3 @@ | ||||
| ## <a href='https://lagent.readthedocs.io/en/main/'>English</a> | ||||
| ## <a href='https://lagent.readthedocs.io/en/latest/'>English</a> | ||||
|  | ||||
| ## <a href='https://lagent.readthedocs.io/zh_CN/main/'>简体中文</a> | ||||
| ## <a href='https://lagent.readthedocs.io/zh-cn/latest/'>简体中文</a> | ||||
|   | ||||
							
								
								
									
										400
									
								
								docs/en/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										400
									
								
								docs/en/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,400 @@ | ||||
| # Action | ||||
|  | ||||
| Actions, also called **tools**, provide a suite of functions LLM-driven agents can use to interact with the real world and perform complex tasks. | ||||
|  | ||||
| ## Basic Concepts | ||||
|  | ||||
| ### Tool & Toolkit | ||||
|  | ||||
| There are two categories of tools: | ||||
|  | ||||
| - tool: provide only one API to call. | ||||
| - toolkit: implement multiple APIs that undertake different sub-tasks. | ||||
|  | ||||
| ### Tool Description | ||||
|  | ||||
| In Lagent, the tool description is a dictionary containing the action's core information of usage, observed by LLMs for decision-making. | ||||
|  | ||||
| For simple tools, the description can be created as follows | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'bold',  # name of the tool | ||||
|     'description': 'a function used to make text bold',  # introduce the tool's function | ||||
|     'parameters': [  # a list of parameters the tool take. | ||||
|         { | ||||
|             'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|         } | ||||
|     ], | ||||
|     'required': ['text'],  # specify names of parameters required | ||||
| } | ||||
| ``` | ||||
|  | ||||
| In some situations there may be optional `return_data`, `parameter_description` keys describing the returns and argument passing format respectively. | ||||
|  | ||||
| ```{attention} | ||||
| `parameter_description` is usually inserted into the tool description automatically by the action's parser. It will be introduced in [Interface Design](#interface-design) . | ||||
| ``` | ||||
|  | ||||
| For toolkits, the description is very similar but nest submethods | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'PhraseEmphasis',  # name of the toolkit | ||||
|     'description': 'a toolkit which provides different styles of text emphasis',  # introduce the tool's function | ||||
|     'api_list': [ | ||||
|         { | ||||
|             'name': 'bold', | ||||
|             'description': 'make text bold', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         }, | ||||
|         { | ||||
|             'name': 'italic', | ||||
|             'description': 'make text italic', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         } | ||||
|     ] | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## Make Functions Tools | ||||
|  | ||||
| It's not necessary to prepare an extra description for a defined function. In Lagent we provide a decorator `tool_api` which can conveniently turn a function into a tool by automatically parsing the function's typehints and dosctrings to generate the description dictionary and binding it to an attribute `api_description`. | ||||
|  | ||||
| ```python | ||||
| from lagent import tool_api | ||||
|  | ||||
| @tool_api | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         str: bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text']} | ||||
| ``` | ||||
|  | ||||
| Once `returns_named_value` is enabled you should declare the name of the return data, which will be processed to form a new field `return_data`: | ||||
|  | ||||
| ```python | ||||
| @tool_api(returns_named_value=True) | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         bold_text (str): bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text'], | ||||
|  'return_data': [{'name': 'bold_text', | ||||
|    'description': 'bold text', | ||||
|    'type': 'STRING'}]} | ||||
| ``` | ||||
|  | ||||
| Sometimes the tool may return a `dict` or `tuple`, and you want to elaborate each member in `return_data` rather than take them as a whole. Set `explode_return=True` and list them in the return part of docstrings. | ||||
|  | ||||
| ```python | ||||
| @tool_api(explode_return=True) | ||||
| def list_args(a: str, b: int, c: float = 0.0) -> dict: | ||||
|     """Return arguments in dict format | ||||
|  | ||||
|     Args: | ||||
|         a (str): a | ||||
|         b (int): b | ||||
|         c (float): c | ||||
|  | ||||
|     Returns: | ||||
|         dict: input arguments | ||||
|             - a (str): a | ||||
|             - b (int): b | ||||
|             - c: c | ||||
|     """ | ||||
|     return {'a': a, 'b': b, 'c': c} | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'list_args', | ||||
|  'description': 'Return arguments in dict format', | ||||
|  'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'}, | ||||
|   {'name': 'b', 'type': 'NUMBER', 'description': 'b'}, | ||||
|   {'name': 'c', 'type': 'FLOAT', 'description': 'c'}], | ||||
|  'required': ['a', 'b'], | ||||
|  'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'}, | ||||
|   {'name': 'b', 'description': 'b', 'type': 'NUMBER'}, | ||||
|   {'name': 'c', 'description': 'c'}]} | ||||
| ``` | ||||
|  | ||||
| ```{warning} | ||||
| Only Google style Python docstrings is currently supported. | ||||
| ``` | ||||
|  | ||||
| ## Interface Design | ||||
|  | ||||
| `BaseAction(description=None, parser=JsonParser, enable=True)` is the base class all actions should inherit from. It takes three initialization arguments | ||||
|  | ||||
| - **description**: a tool description dictionary, used set instance attribute `description`. Mostly you don't need explicitly pass this argument since the meta class of `BaseAction` will search methods decorated by `tool_api` and assemble their `api_description` as a class attribute `__tool_description__`, and if the initial `description` is left null, then `__tool_description__` will be copied as `description`. | ||||
|  | ||||
| - **parser**: `BaseParser` class. It will instantialize a parser used to validate the arguments of APIs in `description`. | ||||
|  | ||||
|   For example, `JsonParser` requires arguments passed in the format of JSON or `dict`. To make LLMs aware of this, It inserts a field `parameter_description` into the `description`. | ||||
|  | ||||
|   ```python | ||||
|   from lagent import BaseAction | ||||
|  | ||||
|   action = BaseAction( | ||||
|       { | ||||
|           'name': 'bold', | ||||
|           'description': 'a function used to make text bold', | ||||
|           'parameters': [ | ||||
|               { | ||||
|                   'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|               } | ||||
|           ], | ||||
|           'required': ['text'] | ||||
|       } | ||||
|   ) | ||||
|   action.description | ||||
|   ``` | ||||
|  | ||||
|   ```python | ||||
|   {'name': 'bold', | ||||
|    'description': 'a function used to make text bold', | ||||
|    'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input content'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'} | ||||
|   ``` | ||||
|  | ||||
| - **enable**: specify whether the tool is available. | ||||
|  | ||||
| ### Custom Action | ||||
|  | ||||
| A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word. | ||||
|  | ||||
| ```{tip} | ||||
| `run` is allowed not to be decorated by `tool_api` for simple tools unless you want to hint the return data. | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| class Bold(BaseAction): | ||||
|  | ||||
|     def run(self, text: str): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
| class PhraseEmphasis(BaseAction): | ||||
|     """a toolkit which provides different styles of text emphasis""" | ||||
|  | ||||
|     @tool_api | ||||
|     def bold(self, text): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
|     @tool_api | ||||
|     def italic(self, text): | ||||
|         """make text italic | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: italic text | ||||
|         """ | ||||
|         return '*' + text + '*' | ||||
|  | ||||
| # Inspect the default description | ||||
| # Bold.__tool_description__, PhraseEmphasis.__tool_description__ | ||||
| ``` | ||||
|  | ||||
| ### Auto-registration | ||||
|  | ||||
| Any subclass of `BaseAction` will be registered automatically. You can use `list_tools()` and `get_tool()` to view all tools and initialize by name. | ||||
|  | ||||
| ```python | ||||
| from lagent import list_tools, get_tool | ||||
|  | ||||
| list_tools() | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| ['BaseAction', | ||||
|  'InvalidAction', | ||||
|  'NoAction', | ||||
|  'FinishAction', | ||||
|  'ArxivSearch', | ||||
|  'BINGMap', | ||||
|  'GoogleScholar', | ||||
|  'GoogleSearch', | ||||
|  'IPythonInterpreter', | ||||
|  'PPT', | ||||
|  'PythonInterpreter', | ||||
|  'Bold', | ||||
|  'PhraseEmphasis'] | ||||
| ``` | ||||
|  | ||||
| Create a `PhraseEmphasis` object | ||||
|  | ||||
| ```python | ||||
| action = get_tool('PhraseEmphasis') | ||||
| action.description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'PhraseEmphasis', | ||||
|  'description': 'a toolkit which provides different styles of text emphasis', | ||||
|  'api_list': [{'name': 'bold', | ||||
|    'description': 'make text bold', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|   {'name': 'italic', | ||||
|    'description': 'make text italic', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]} | ||||
| ``` | ||||
|  | ||||
| ## Tool Calling | ||||
|  | ||||
| ### Run a Tool | ||||
|  | ||||
| `__call__` method of `Action` takes two arguments | ||||
|  | ||||
| - `inputs`: It depends on the action's parser. Often a string in specific formats generated by LLMs. | ||||
|   - `JsonParser`: Allow passing arguments in the format of JSON string or Python `dict`. | ||||
|   - `TupleParser`: Allow passing arguments in the format of tuple string format or Python `tuple`. | ||||
| - `name`: Which API to call. Default is `run`. | ||||
|  | ||||
| It returns an `ActionReturn` object which encapsulates calling details | ||||
|  | ||||
| - `args`: Dictionary of action inputs. | ||||
| - `type`: Action name. | ||||
| - `result`: List of dicts. Each contains two keys: 'type' and 'content'. when errors occur, it is `None`. | ||||
| - `errmsg`: Error message. Default is `None`. | ||||
|  | ||||
| Below is an example | ||||
|  | ||||
| ```python | ||||
| from lagent import IPythonInterpreter, TupleParser | ||||
|  | ||||
| action1 = IPythonInterpreter() | ||||
| ret = action1('{"command": "import math;math.sqrt(100)"}') | ||||
| print(ret.result) | ||||
| ret = action1({'command': 'import math;math.sqrt(100)'}) | ||||
| print(ret.result) | ||||
|  | ||||
| action2 = IPythonInterpreter(parser=TupleParser) | ||||
| ret = action2('("import math;math.sqrt(100)", )') | ||||
| print(ret.result) | ||||
| ret = action2(('import math;math.sqrt(100)',)) | ||||
| print(ret.result) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
|  | ||||
| ### Dynamic Invocation | ||||
|  | ||||
| Lagent provides an `ActionExecutor` to manage multiple tools. It will flatten `api_list` of toolkits and rename each `{tool_name}.{api_name}`. | ||||
|  | ||||
| ```python | ||||
| from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter | ||||
|  | ||||
| executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()]) | ||||
| executor.get_actions_info()  # This information is fed to LLMs as the tool meta prompt | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'name': 'ArxivSearch.get_arxiv_article_information', | ||||
|   'description': 'Run Arxiv search and get the article meta information.', | ||||
|   'parameters': [{'name': 'query', | ||||
|     'type': 'STRING', | ||||
|     'description': 'the content of search query'}], | ||||
|   'required': ['query'], | ||||
|   'return_data': [{'name': 'content', | ||||
|     'description': 'a list of 3 arxiv search papers', | ||||
|     'type': 'STRING'}], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|  {'name': 'IPythonInterpreter', | ||||
|   'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.", | ||||
|   'parameters': [{'name': 'command', | ||||
|     'type': 'STRING', | ||||
|     'description': 'Python code'}, | ||||
|    {'name': 'timeout', | ||||
|     'type': 'NUMBER', | ||||
|     'description': 'Upper bound of waiting time for Python script execution.'}], | ||||
|   'required': ['command'], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}] | ||||
| ``` | ||||
|  | ||||
| Trigger an action through the executor | ||||
|  | ||||
| ```python | ||||
| ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}') | ||||
| ret.result | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
| @@ -1,23 +0,0 @@ | ||||
| # পরিচিতি | ||||
|  | ||||
| এই অধ্যায় আপনাকে Lagent এর পরিচয় দেয়, এবং Lagent সম্পর্কিত বিস্তারিত টিউটোরিয়ালের লিঙ্ক সরবরাহ করে। | ||||
|  | ||||
| ## Lagent কি | ||||
|  | ||||
| Lagent হল একটি খোলা সোর্স এলএলএম এজেন্ট ফ্রেমওয়ার্ক, যা ব্যক্তিকে বড় ভাষা মডেল ভিত্তিক এজেন্ট তৈরি করতে সক্ষম করে। এটি এলএলএমের সাথে কিছু প্রশাসন টুলসও সরবরাহ করে, এবং আমাদের ফ্রেমওয়ার্কের সারমর্ক নীচে দেখানো হয়: | ||||
|  | ||||
|  | ||||
|  | ||||
| Lagent তিনটি প্রধান বিভাগ বা উপাদান শাখা রেখে: | ||||
|  | ||||
| - **এজেন্টস** এজেন্ট কার্যান্বয়ন সরবরাহ করে, যেমন ReAct, AutoGPT। | ||||
| - **এলএলএমস** বিভিন্ন বড় ভাষা মডেলের সমর্থন করে, যেমন হাগিংফেস মডেল (Llama-2, InterLM) বা GPT3.5/4। | ||||
| - **অ্যাকশন্স** এটিতে অ্যাকশনগুলির একটি শ্রেণি রয়েছে, এবং সব অ্যাকশনগুলি পরিচালনার জন্য একটি অ্যাকশন এক্যুটর থাকে। | ||||
|  | ||||
| ## কীভাবে ব্যবহার করবেন | ||||
|  | ||||
| Lagent সম্পর্কে আরও জানতে এখানে বিস্তারিত ধাপে ধাপ গাইড প্রদান করা হয়: | ||||
|  | ||||
| 1. ইনস্টলেশন নির্দেশিকা দেখতে, দয়া করে [README](../../../README_in_beng.md) দেখুন। | ||||
|  | ||||
| 2. আমরা কিছু উদাহরণ সরবরাহ করে থাকি, যা Lagent এজেন্ট তৈরি করতে সাহায্য করে, `python example/react example.py` চালানো দ্বারা [উদ | ||||
| @@ -1,23 +0,0 @@ | ||||
| # अवलोकन | ||||
|  | ||||
| यह अध्याय आपको Lagent की रूपरेखा से परिचित कराता है, और लैजेंट के बारे में विस्तृत ट्यूटोरियल के लिंक प्रदान करता है। | ||||
|  | ||||
| ## Lagent क्या है | ||||
|  | ||||
| Lagent एक खुला स्रोत एलएलएम एजेंट ढांचा है, जो लोगों को एक बड़े भाषा मॉडल को कुशलतापूर्वक एजेंट में बदलने में सक्षम बनाता है। यह एलएलएम की योग्यता को स्पष्ट करने के लिए कुछ विशिष्ट उपकरण भी प्रदान करता है, और संपूर्ण रूपरेखा नीचे दिखाई गई है: | ||||
|  | ||||
|  | ||||
|  | ||||
| Lagent में 3 मुख्य भाग होते हैं, एजेंट, एलएलएम और क्रियाएं। | ||||
|  | ||||
| - **agents** एजेंट कार्यान्वयन प्रदान करता है, जैसे ReAct, AutoGPT। | ||||
| - **llms** विभिन्न बड़े भाषा मॉडलों का समर्थन करता है, जिसमें हगिंगफेस मॉडल के माध्यम से ओपन-सोर्स मॉडल (Llama-2, InterLM) या GPT3.5/4 जैसे क्लोज-सोर्स मॉडल शामिल हैं। | ||||
| - **actions** इसमें क्रियाओं की एक श्रृंखला होती है, साथ ही सभी क्रियाओं को प्रबंधित करने के लिए एक क्रिया निष्पादक भी होता है। | ||||
|  | ||||
| ## का उपयोग कैसे करें | ||||
|  | ||||
| Lagent के बारे में अधिक जानने के लिए यहां एक विस्तृत चरण-दर-चरण मार्गदर्शिका दी गई है: | ||||
|  | ||||
| 1. स्थापना अनुदेशों के लिए, please see [README](../../../README_in_HIN.md). | ||||
|  | ||||
| 2. हम बस `python example/react example.py` चलाकर [उदाहरण](examples/) में Lagent के साथ एजेंट बनाने के लिए कई उदाहरण प्रदान करते हैं। | ||||
| @@ -1,23 +0,0 @@ | ||||
| # 概要 | ||||
|  | ||||
| この章では Lagent のフレームワークを紹介し、Lagent に関する詳細なチュートリアルへのリンクを提供します。 | ||||
|  | ||||
| ## Lagent とは | ||||
|  | ||||
| Lagent はオープンソースの LLM エージェントフレームワークで、大規模な言語モデルを効率的にエージェントに変換することができます。また、LLM の能力を啓発するためのいくつかの典型的なツールも提供します: | ||||
|  | ||||
|  | ||||
|  | ||||
| Lagent はエージェント、LLMS、アクションの 3 つの主要部分から構成されています。 | ||||
|  | ||||
| - **agents** ReAct、AutoGPT などのエージェント実装を提供する。 | ||||
| - **llms** は、オープンソース・モデル(Llama-2、InterLM)から Hugging Face モデル、あるいは GPT3.5/4 のようなクローズドソースモデルを含む、さまざまな大規模言語モデルをサポートしています。 | ||||
| - **actions** には一連のアクションと、すべてのアクションを管理するアクションエグゼキュータが含まれている。 | ||||
|  | ||||
| ## 使用方法 | ||||
|  | ||||
| Lagent についての詳しいステップバイステップガイドはこちら: | ||||
|  | ||||
| 1. インストール方法については、[README](../README.md)を参照してください。 | ||||
|  | ||||
| 2. python の `examples/react_example.py` を実行するだけで、Lagent でエージェントをビルドする例を [examples](examples/) にいくつか用意しています。 | ||||
							
								
								
									
										15
									
								
								docs/zh_cn/.readthedocs.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								docs/zh_cn/.readthedocs.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| version: 2 | ||||
|  | ||||
| formats: all | ||||
|  | ||||
| build: | ||||
|   os: ubuntu-22.04 | ||||
|   tools: | ||||
|     python: "3.10" | ||||
|  | ||||
| python: | ||||
|   install: | ||||
|     - requirements: requirements/docs.txt | ||||
|  | ||||
| sphinx: | ||||
|   configuration: docs/zh_cn/conf.py | ||||
| @@ -1,5 +1,5 @@ | ||||
| .header-logo { | ||||
|     background-image: url("../images/robot.png"); | ||||
|     background-image: url("../images/lagent_icon.png"); | ||||
|     background-size: 40px 40px; | ||||
|     height: 40px; | ||||
|     width: 40px; | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								docs/zh_cn/_static/images/lagent_icon.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/zh_cn/_static/images/lagent_icon.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 11 KiB | 
							
								
								
									
										14
									
								
								docs/zh_cn/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								docs/zh_cn/_templates/autoapi/index.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| API Reference | ||||
| ============= | ||||
|  | ||||
| This page contains auto-generated API reference documentation. | ||||
|  | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|  | ||||
|    {% for page in pages %} | ||||
|    {% if page.top_level_object and page.display %} | ||||
|    {{ page.include_path }} | ||||
|    {% endif %} | ||||
|    {% endfor %} | ||||
							
								
								
									
										112
									
								
								docs/zh_cn/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								docs/zh_cn/_templates/autoapi/python/module.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,112 @@ | ||||
| {% if not obj.display %} | ||||
| :orphan: | ||||
|  | ||||
| {% endif %} | ||||
| :py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}` | ||||
| =========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }} | ||||
|  | ||||
| .. py:module:: {{ obj.name }} | ||||
|  | ||||
| {% if obj.docstring %} | ||||
| .. autoapi-nested-parse:: | ||||
|  | ||||
|    {{ obj.docstring|indent(3) }} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| {% block subpackages %} | ||||
| {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} | ||||
| {% if visible_subpackages %} | ||||
| Subpackages | ||||
| ----------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 3 | ||||
|  | ||||
| {% for subpackage in visible_subpackages %} | ||||
|    {{ subpackage.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block submodules %} | ||||
| {% set visible_submodules = obj.submodules|selectattr("display")|list %} | ||||
| {% if visible_submodules %} | ||||
| Submodules | ||||
| ---------- | ||||
| .. toctree:: | ||||
|    :titlesonly: | ||||
|    :maxdepth: 1 | ||||
|  | ||||
| {% for submodule in visible_submodules %} | ||||
|    {{ submodule.short_name }}/index.rst | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% block content %} | ||||
| {% if obj.type is equalto("package") %} | ||||
| {% set visible_children = obj.children|selectattr("display")|list %} | ||||
| {% else %} | ||||
| {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} | ||||
| {% endif %} | ||||
| {% if visible_children %} | ||||
| {{ obj.type|title }} Contents | ||||
| {{ "-" * obj.type|length }}--------- | ||||
|  | ||||
| {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} | ||||
| {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} | ||||
| {% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %} | ||||
| {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} | ||||
| {% block classes scoped %} | ||||
| {% if visible_classes %} | ||||
| Classes | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for klass in visible_classes %} | ||||
|    {{ klass.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block functions scoped %} | ||||
| {% if visible_functions %} | ||||
| Functions | ||||
| ~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for function in visible_functions %} | ||||
|    {{ function.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
|  | ||||
| {% block attributes scoped %} | ||||
| {% if visible_attributes %} | ||||
| Attributes | ||||
| ~~~~~~~~~~ | ||||
|  | ||||
| .. autoapisummary:: | ||||
|  | ||||
| {% for attribute in visible_attributes %} | ||||
|    {{ attribute.id }} | ||||
| {% endfor %} | ||||
|  | ||||
|  | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| {% endif %} | ||||
| {% for obj_item in visible_children %} | ||||
| {{ obj_item.render()|indent(0) }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| {% endblock %} | ||||
| @@ -11,18 +11,16 @@ | ||||
| # documentation root, use os.path.abspath to make it absolute, like shown here. | ||||
|  | ||||
| import os | ||||
| import subprocess | ||||
| import re | ||||
| import sys | ||||
|  | ||||
| import pytorch_sphinx_theme | ||||
|  | ||||
| sys.path.insert(0, os.path.abspath('../../')) | ||||
| sys.path.insert(0, os.path.abspath('../..')) | ||||
|  | ||||
| # -- Project information ----------------------------------------------------- | ||||
|  | ||||
| project = 'Lagent' | ||||
| copyright = '2020-2030, InternLM' | ||||
| author = 'InternLM' | ||||
| language = 'zh_CN' | ||||
|  | ||||
| # The full version, including alpha/beta/rc tags | ||||
| version_file = '../../lagent/version.py' | ||||
| @@ -37,97 +35,74 @@ release = __version__ | ||||
| # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||||
| # ones. | ||||
| extensions = [ | ||||
|     'sphinx_rtd_theme', | ||||
|     'myst_nb', | ||||
|     'autoapi.extension', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx.ext.autodoc', | ||||
|     'sphinx.ext.napoleon', | ||||
|     'sphinx.ext.viewcode', | ||||
|     'sphinx_markdown_tables', | ||||
|     'sphinx_copybutton', | ||||
|     'myst_parser', | ||||
|     'sphinx.ext.intersphinx', | ||||
|     'sphinx.ext.autodoc.typehints', | ||||
|     'sphinx.ext.autosummary', | ||||
|     'sphinx.ext.autosectionlabel', | ||||
|     'sphinx_tabs.tabs', | ||||
| ] | ||||
|  | ||||
| nb_output_stderr = 'remove-warn' | ||||
| autodoc_typehints = 'description' | ||||
|  | ||||
| autosummary_generate = True  # Turn on sphinx.ext.autosummary | ||||
| # Ignore >>> when copying code | ||||
| copybutton_prompt_text = r'>>> |\.\.\. ' | ||||
| copybutton_prompt_is_regexp = True | ||||
|  | ||||
| myst_enable_extensions = ['colon_fence'] | ||||
| # sphinx-autoapi configuration | ||||
| autoapi_dirs = ['../../lagent'] | ||||
| autoapi_options = [ | ||||
|     'members', | ||||
|     'undoc-members', | ||||
|     'show-inheritance', | ||||
|     'show-module-summary', | ||||
| ] | ||||
| autoapi_ignore = ['*migrations*', '*command.py', '*cli.py'] | ||||
| autoapi_template_dir = '_templates/autoapi' | ||||
| autoapi_add_toctree_entry = False | ||||
|  | ||||
| # Add any paths that contain templates here, relative to this directory. | ||||
| templates_path = ['_templates'] | ||||
|  | ||||
| # The suffix(es) of source filenames. | ||||
| # You can specify multiple suffix as a list of string: | ||||
| # | ||||
| source_suffix = { | ||||
|     '.rst': 'restructuredtext', | ||||
|     '.md': 'markdown', | ||||
| } | ||||
|  | ||||
| # The master toctree document. | ||||
| master_doc = 'index' | ||||
|  | ||||
| # List of patterns, relative to source directory, that match files and | ||||
| # directories to ignore when looking for source files. | ||||
| # This pattern also affects html_static_path and html_extra_path. | ||||
| exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] | ||||
| exclude_patterns = [] | ||||
|  | ||||
| # -- Options for HTML output ------------------------------------------------- | ||||
|  | ||||
| # The theme to use for HTML and HTML Help pages.  See the documentation for | ||||
| # a list of builtin themes. | ||||
| # | ||||
| # html_theme = 'sphinx_rtd_theme' | ||||
| html_theme = 'pytorch_sphinx_theme' | ||||
| html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] | ||||
| html_theme = 'sphinx_rtd_theme' | ||||
| html_theme_options = { | ||||
|     'menu': [ | ||||
|         { | ||||
|             'name': 'GitHub', | ||||
|             'url': 'https://github.com/InternLM/lagent' | ||||
|         }, | ||||
|     ], | ||||
|     # Specify the language of shared menu | ||||
|     'menu_lang': 'cn', | ||||
|     'navigation_depth': 3, | ||||
|     'titles_only': False, | ||||
|     'style_nav_header_background': '#4fabab', | ||||
| } | ||||
|  | ||||
| language = 'zh_CN' | ||||
| html_context = { | ||||
|     'display_github': True, | ||||
|     'github_host': 'github.com', | ||||
|     'github_user': 'InternLM', | ||||
|     'github_repo': 'lagent', | ||||
|     'github_version': 'main', | ||||
|     'conf_py_path': '/docs/zh_cn/', | ||||
| } | ||||
| html_title = 'Lagent' | ||||
| html_logo = '../imgs/lagent_logo.png' | ||||
| html_favicon = '../imgs/lagent_icon.png' | ||||
|  | ||||
| master_doc = 'index' | ||||
|  | ||||
| # Add any paths that contain custom static files (such as style sheets) here, | ||||
| # relative to this directory. They are copied after the builtin static files, | ||||
| # so a file named "default.css" will overwrite the builtin "default.css". | ||||
| # so a file named 'default.css' will overwrite the builtin 'default.css'. | ||||
| html_static_path = ['_static'] | ||||
| html_css_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css', | ||||
|     'css/readthedocs.css' | ||||
| ] | ||||
| html_js_files = [ | ||||
|     'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js', | ||||
|     'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js', | ||||
|     'js/collapsed.js', | ||||
|     'js/table.js', | ||||
| ] | ||||
|  | ||||
| myst_heading_anchors = 4 | ||||
|  | ||||
| # Configuration for intersphinx | ||||
| intersphinx_mapping = { | ||||
|     'python': ('https://docs.python.org/3', None), | ||||
|     'numpy': ('https://numpy.org/doc/stable', None), | ||||
|     'torch': ('https://pytorch.org/docs/stable/', None), | ||||
| } | ||||
|  | ||||
|  | ||||
| def builder_inited_handler(app): | ||||
|     subprocess.run(['./cp_origin_docs.sh']) | ||||
| def custom_skip(app, what, name, obj, skip, options): | ||||
|     if what in ['data', 'function', 'class'] and re.search('logger', name): | ||||
|         skip = True | ||||
|     return skip | ||||
|  | ||||
|  | ||||
| def setup(app): | ||||
|     app.connect('builder-inited', builder_inited_handler) | ||||
| def setup(sphinx): | ||||
|     sphinx.connect('autoapi-skip-member', custom_skip) | ||||
|   | ||||
							
								
								
									
										19
									
								
								docs/zh_cn/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								docs/zh_cn/get_started/install.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| # 安装方式 | ||||
|  | ||||
| ## pip安装 | ||||
|  | ||||
| 推荐使用 pip 安装 | ||||
|  | ||||
| ```bash | ||||
| pip install lagent | ||||
| ``` | ||||
|  | ||||
| ## 源码安装 | ||||
|  | ||||
| 如需修改部分功能,可以从源码构建 Lagent | ||||
|  | ||||
| ```bash | ||||
| git clone https://github.com/InternLM/lagent.git | ||||
| cd lagent | ||||
| pip install -e . | ||||
| ``` | ||||
							
								
								
									
										23
									
								
								docs/zh_cn/get_started/overview.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								docs/zh_cn/get_started/overview.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| # 总览 | ||||
|  | ||||
| 本章节将介绍 Lagent 的架构,并提供 Lagent 详细教程的链接。 | ||||
|  | ||||
| ## Lagent 是什么 | ||||
|  | ||||
| Lagent 是一个开源的 LLM 智能体框架,允许使用者快速将一个大语言模型转换成智能体,并提供一些典型工具来激发大语言模型的潜能。Lagent 框架图如下: | ||||
|  | ||||
|  | ||||
|  | ||||
| Lagent 包含三个主要模块:agents,llms 和 actions。 | ||||
|  | ||||
| - **agents** 实现了多种智能体,如 ReAct,AutoGPT。 | ||||
| - **llms** 支持多种大语言模型,包括在 HuggingFace 上托管的开源模型(Llama-2, InterLM)及 GPT3.5/4 等闭源模型。 | ||||
| - **actions** 包含一系列工具,并提供工具执行器来统一管理。 | ||||
|  | ||||
| ## 如何使用 | ||||
|  | ||||
| 以下是帮助您了解关于 Lagent 更多信息的详细教程: | ||||
|  | ||||
| 1. 安装请参考 [README](https://github.com/InternLM/lagent/blob/main/README.md). | ||||
|  | ||||
| 2. 一些构建智能体的实例 [examples](https://github.com/InternLM/lagent/tree/main/examples),直接运行脚本即可,如 `python examples/react_example.py`. | ||||
							
								
								
									
										87
									
								
								docs/zh_cn/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								docs/zh_cn/get_started/quickstart.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| # 快速上手 | ||||
|  | ||||
| 借助 Lagent 仅需几行代码就能构建大语言模型智能体。 | ||||
|  | ||||
| ## GPT-3.5 驱动的 ReWOO 智能体 | ||||
|  | ||||
| 下面是使用 GPT-3.5 运行 ReWOO 的示例 | ||||
|  | ||||
| ```python | ||||
| # 从 Lagent 导入必要的模块和类 | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.actions import ActionExecutor, GoogleSearch | ||||
| from lagent.llms import GPTAPI | ||||
|  | ||||
| # 初始化 LLM,你可能需要提供 API 密钥 | ||||
| llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) | ||||
|  | ||||
| # 初始化 Goolge 搜索工具,你可能需要提供 API 密钥 | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # 配置 ReWOO 智能体,创建聊天机器人 | ||||
| chatbot = ReWOO( | ||||
|     llm=llm,  # 大语言模型实例 | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool]  # 指定智能体可以调用的工具 | ||||
|     ), | ||||
| ) | ||||
|  | ||||
| # 询问问题并获取回复 | ||||
| response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') | ||||
|  | ||||
| # 打印回复 | ||||
| print(response.response) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> Film director. | ||||
| ``` | ||||
|  | ||||
| ## InterLM 驱动的 ReAct 智能体 | ||||
|  | ||||
| 注意,如果你想使用 HuggingFace 模型,请先运行 `pip install -e .[all]` | ||||
|  | ||||
| ```python | ||||
| # 从 Lagent 导入必要的模块和类 | ||||
| from lagent.agents import ReAct | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.llms import HFTransformer | ||||
|  | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
|  | ||||
| # 初始化 HFTransformer 模型 | ||||
| llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META) | ||||
|  | ||||
| # 初始化 Goolge 搜索工具,你可能需要提供 API 密钥 | ||||
| search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') | ||||
|  | ||||
| # 初始化 Python 代码解释其 | ||||
| python_interpreter = PythonInterpreter() | ||||
|  | ||||
| # 配置 ReAct 智能体,创建聊天机器人 | ||||
| chatbot = ReAct( | ||||
|     llm=llm,  # 大语言模型实例 | ||||
|     action_executor=ActionExecutor( | ||||
|         actions=[search_tool, python_interpreter]),  # 指定智能体可以调用的工具 | ||||
| ) | ||||
| # 询问LaTeX格式的数学问题 | ||||
| response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') | ||||
|  | ||||
| # 打印回复 | ||||
| print(response.response) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ | ||||
| ``` | ||||
|  | ||||
| ## 启动 ReAct 网页 App | ||||
|  | ||||
| ```python | ||||
| # 你需要先安装 streamlit | ||||
| # pip install streamlit | ||||
| streamlit run examples/react_web_demo.py | ||||
| ``` | ||||
|  | ||||
| 然后你可以通过下图所示UI界面进行对话 | ||||
|  | ||||
| @@ -8,12 +8,32 @@ | ||||
|    :caption: 新手入门 | ||||
|  | ||||
|    get_started/overview.md | ||||
|    get_started/install.md | ||||
|    get_started/quickstart.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|    :caption: 教程 | ||||
|  | ||||
|    tutorials/action.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :caption: 切换语言 | ||||
|  | ||||
|    switch_language.md | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 1 | ||||
|    :caption: API 参考 | ||||
|  | ||||
|    autoapi/lagent/actions/index | ||||
|    autoapi/lagent/agents/index | ||||
|    autoapi/lagent/llms/index | ||||
|    autoapi/lagent/utils/index | ||||
|    autoapi/lagent/schema/index | ||||
|    autoapi/lagent/version/index | ||||
|  | ||||
|  | ||||
| 导引 | ||||
| ================== | ||||
|  | ||||
|   | ||||
| @@ -1,3 +1,3 @@ | ||||
| ## <a href='https://lagent.readthedocs.io/en/main/'>English</a> | ||||
| ## <a href='https://lagent.readthedocs.io/en/latest/'>English</a> | ||||
|  | ||||
| ## <a href='https://lagent.readthedocs.io/zh_CN/main/'>简体中文</a> | ||||
| ## <a href='https://lagent.readthedocs.io/zh-cn/latest/'>简体中文</a> | ||||
|   | ||||
							
								
								
									
										398
									
								
								docs/zh_cn/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										398
									
								
								docs/zh_cn/tutorials/action.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,398 @@ | ||||
| # 动作 | ||||
|  | ||||
| 动作,也被称为工具,提供了一套LLM驱动的智能体用来与真实世界交互并执行复杂任务的函数。 | ||||
|  | ||||
| ## 基本概念 | ||||
|  | ||||
| ### 工具 & 工具包 | ||||
|  | ||||
| 有两种类型的工具: | ||||
|  | ||||
| - 简单工具: 只提供一个API接口供调用。 | ||||
| - 工具包: 实现多个API接口,承担不同的子任务。 | ||||
|  | ||||
| ### 工具描述 | ||||
|  | ||||
| 在Lagent中,工具描述是一个刻画工具调用方式的字典,能够被LLM观察并用于决策。 | ||||
|  | ||||
| 对于简单工具,描述可按如下格式声明: | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'bold',  # 工具名称 | ||||
|     'description': 'a function used to make text bold',  # 介绍工具的功能 | ||||
|     'parameters': [  # 这个工具所需要的参数列表 | ||||
|         { | ||||
|             'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|         } | ||||
|     ], | ||||
|     'required': ['text'],  # 指定必需的参数名 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| 在某些情况下,可能还包含 `return_data`,`parameter_description` 字段,分别描述返回内容及参数传递格式。 | ||||
|  | ||||
| ```{attention} | ||||
| `parameter_description` 通常被动作的解析器自动插入到工具描述中,这部分将在[接口设计](#id6)中进行介绍。 | ||||
| ``` | ||||
|  | ||||
| 对于工具包,描述非常相似,但嵌套了子方法 | ||||
|  | ||||
| ```python | ||||
| TOOL_DESCRIPTION = { | ||||
|     'name': 'PhraseEmphasis',  # 工具包的名字 | ||||
|     'description': 'a toolkit which provides different styles of text emphasis',  # 介绍工具包的功能 | ||||
|     'api_list': [ | ||||
|         { | ||||
|             'name': 'bold', | ||||
|             'description': 'make text bold', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         }, | ||||
|         { | ||||
|             'name': 'italic', | ||||
|             'description': 'make text italic', | ||||
|             'parameters': [ | ||||
|                 { | ||||
|                     'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|                 } | ||||
|             ], | ||||
|             'required': ['text'] | ||||
|         } | ||||
|     ] | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 将函数转换为工具 | ||||
|  | ||||
| 对于已定义好的函数,无需人工添加额外的描述。在 Lagent 中,我们提供了一个修饰器 `tool_api`,它可以通过自动解析函数的类型提示和文档字符串来生成描述字典,并将其绑定到属性 `api_description`。 | ||||
|  | ||||
| ```python | ||||
| from lagent import tool_api | ||||
|  | ||||
| @tool_api | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         str: bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text']} | ||||
| ``` | ||||
|  | ||||
| 一旦启用 `returns_named_value`,您应当声明返回值的名称,这将被处理成一个新的字段 `return_data`: | ||||
|  | ||||
| ```python | ||||
| @tool_api(returns_named_value=True) | ||||
| def bold(text: str) -> str: | ||||
|     """make text bold | ||||
|  | ||||
|     Args: | ||||
|         text (str): input text | ||||
|  | ||||
|     Returns: | ||||
|         bold_text (str): bold text | ||||
|     """ | ||||
|     return '**' + text + '**' | ||||
|  | ||||
| bold.api_description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'bold', | ||||
|  'description': 'make text bold', | ||||
|  'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input text'}], | ||||
|  'required': ['text'], | ||||
|  'return_data': [{'name': 'bold_text', | ||||
|    'description': 'bold text', | ||||
|    'type': 'STRING'}]} | ||||
| ``` | ||||
|  | ||||
| 有时工具可能返回一个 `dict` 或 `tuple`,如果你想在 `return_data` 中详细说明每个成员的含义而不是把它们当作一个整体,设置 `explode_return=True` 并在文档字符串的 Returns 部分中罗列它们。 | ||||
|  | ||||
| ```python | ||||
| @tool_api(explode_return=True) | ||||
| def list_args(a: str, b: int, c: float = 0.0) -> dict: | ||||
|     """Return arguments in dict format | ||||
|  | ||||
|     Args: | ||||
|         a (str): a | ||||
|         b (int): b | ||||
|         c (float): c | ||||
|  | ||||
|     Returns: | ||||
|         dict: input arguments | ||||
|             - a (str): a | ||||
|             - b (int): b | ||||
|             - c: c | ||||
|     """ | ||||
|     return {'a': a, 'b': b, 'c': c} | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'list_args', | ||||
|  'description': 'Return arguments in dict format', | ||||
|  'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'}, | ||||
|   {'name': 'b', 'type': 'NUMBER', 'description': 'b'}, | ||||
|   {'name': 'c', 'type': 'FLOAT', 'description': 'c'}], | ||||
|  'required': ['a', 'b'], | ||||
|  'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'}, | ||||
|   {'name': 'b', 'description': 'b', 'type': 'NUMBER'}, | ||||
|   {'name': 'c', 'description': 'c'}]} | ||||
| ``` | ||||
|  | ||||
| ```{warning} | ||||
| 目前仅支持 Google 格式的 Python 文档字符串。 | ||||
| ``` | ||||
|  | ||||
| ## 接口设计 | ||||
|  | ||||
| `BaseAction(description=None, parser=JsonParser, enable=True)` 是所有动作应该继承的基类,它接收三个初始化参数: | ||||
|  | ||||
| - **description**:一个工具描述的字典,用于设置实例属性 `description`。通常不需要显式地传递这个参数,因为 `BaseAction` 的元类将查找被 `tool_api` 装饰的方法,并组装它们的 `api_description` 构造一个类属性 `__tool_description__`,如果实例化时 `description` 为空,那么该实例属性将置为 `__tool_description__`。 | ||||
|  | ||||
| - **parser**:`BaseParser` 类,用于实例化一个动作解析器校验 `description` 所描述的工具的参数。例如,`JsonParser` 会要求模型在调用工具时传入一个 JSON 格式字符串或者 Python 字典,为了让 LLM 感知到该指令,它会在 `description` 中插入一个 `parameter_description` 字段。 | ||||
|  | ||||
|   ```python | ||||
|   from lagent import BaseAction | ||||
|  | ||||
|   action = BaseAction( | ||||
|       { | ||||
|           'name': 'bold', | ||||
|           'description': 'a function used to make text bold', | ||||
|           'parameters': [ | ||||
|               { | ||||
|                   'name': 'text', 'type': 'STRING', 'description': 'input content' | ||||
|               } | ||||
|           ], | ||||
|           'required': ['text'] | ||||
|       } | ||||
|   ) | ||||
|   action.description | ||||
|   ``` | ||||
|  | ||||
|   ```python | ||||
|   {'name': 'bold', | ||||
|    'description': 'a function used to make text bold', | ||||
|    'parameters': [{'name': 'text', | ||||
|    'type': 'STRING', | ||||
|    'description': 'input content'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'} | ||||
|   ``` | ||||
|  | ||||
| - **enable**: 指明该动作是否生效。 | ||||
|  | ||||
| ### 自定义动作 | ||||
|  | ||||
| 一个简单工具必须实现 `run` 方法,而工具包则应当避免将各子API名称定义为该保留字段。 | ||||
|  | ||||
| ```{tip} | ||||
| 对于非工具包的 Action,`run` 允许不被 `tool_api` 装饰,除非你想提示返回信息。 | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| class Bold(BaseAction): | ||||
|  | ||||
|     def run(self, text: str): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
| class PhraseEmphasis(BaseAction): | ||||
|     """a toolkit which provides different styles of text emphasis""" | ||||
|  | ||||
|     @tool_api | ||||
|     def bold(self, text): | ||||
|         """make text bold | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: bold text | ||||
|         """ | ||||
|         return '**' + text + '**' | ||||
|  | ||||
|     @tool_api | ||||
|     def italic(self, text): | ||||
|         """make text italic | ||||
|  | ||||
|         Args: | ||||
|             text (str): input text | ||||
|  | ||||
|         Returns: | ||||
|             str: italic text | ||||
|         """ | ||||
|         return '*' + text + '*' | ||||
|  | ||||
| # 查看默认工具描述 | ||||
| # Bold.__tool_description__, PhraseEmphasis.__tool_description__ | ||||
| ``` | ||||
|  | ||||
| ### 自动注册 | ||||
|  | ||||
| 任何 `BaseAction` 的子类都会自动被注册。你可以使用 `list_tools()` 和 `get_tool()` 来查看所有工具类并通过工具名进行初始化。 | ||||
|  | ||||
| ```python | ||||
| from lagent import list_tools, get_tool | ||||
|  | ||||
| list_tools() | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| ['BaseAction', | ||||
|  'InvalidAction', | ||||
|  'NoAction', | ||||
|  'FinishAction', | ||||
|  'ArxivSearch', | ||||
|  'BINGMap', | ||||
|  'GoogleScholar', | ||||
|  'GoogleSearch', | ||||
|  'IPythonInterpreter', | ||||
|  'PPT', | ||||
|  'PythonInterpreter', | ||||
|  'Bold', | ||||
|  'PhraseEmphasis'] | ||||
| ``` | ||||
|  | ||||
| 创建一个 `PhraseEmphasis` 对象。 | ||||
|  | ||||
| ```python | ||||
| action = get_tool('PhraseEmphasis') | ||||
| action.description | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| {'name': 'PhraseEmphasis', | ||||
|  'description': 'a toolkit which provides different styles of text emphasis', | ||||
|  'api_list': [{'name': 'bold', | ||||
|    'description': 'make text bold', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|   {'name': 'italic', | ||||
|    'description': 'make text italic', | ||||
|    'parameters': [{'name': 'text', | ||||
|      'type': 'STRING', | ||||
|      'description': 'input text'}], | ||||
|    'required': ['text'], | ||||
|    'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]} | ||||
| ``` | ||||
|  | ||||
| ## 工具调用 | ||||
|  | ||||
| ### 执行工具 | ||||
|  | ||||
| `Action` 的 `__call__` 方法需要传入两个参数 | ||||
|  | ||||
| - `inputs`: 其类型与动作绑定的 `BaseParser` 相关,通常是由大语言模型生成的字符串。 | ||||
|   - `JsonParser`: 允许传入 JSON 格式字符串或 Python 字典。 | ||||
|   - `TupleParser`: 允许传入字面量为元组的字符串或 Python 元组。 | ||||
| - `name`: 调用哪个 API,默认为 `run`。 | ||||
|  | ||||
| 工具会返回一个封装了调用细节的 `ActionReturn` 对象。 | ||||
|  | ||||
| - `args`: 一个字典,表示该动作的入参。 | ||||
| - `type`: 动作名称。 | ||||
| - `result`: 以字典为成员的列表,每个字典包含两个键——'type' 和 'content',发生异常时该字段为 `None`。 | ||||
| - `errmsg`: 错误信息,默认为 `None`。 | ||||
|  | ||||
| 以下是一个例子: | ||||
|  | ||||
| ```python | ||||
| from lagent import IPythonInterpreter, TupleParser | ||||
|  | ||||
| action1 = IPythonInterpreter() | ||||
| ret = action1('{"command": "import math;math.sqrt(100)"}') | ||||
| print(ret.result) | ||||
| ret = action1({'command': 'import math;math.sqrt(100)'}) | ||||
| print(ret.result) | ||||
|  | ||||
| action2 = IPythonInterpreter(parser=TupleParser) | ||||
| ret = action2('("import math;math.sqrt(100)", )') | ||||
| print(ret.result) | ||||
| ret = action2(('import math;math.sqrt(100)',)) | ||||
| print(ret.result) | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
|  | ||||
| ### 动态触发 | ||||
|  | ||||
| Lagent 提供 `ActionExecutor` 接口管理多个工具,它会将工具包的 `api_list` 平展并将各 API 更名为 `{tool_name}.{api_name}`。 | ||||
|  | ||||
| ```python | ||||
| from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter | ||||
|  | ||||
| executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()]) | ||||
| executor.get_actions_info()  # 该结果会作为LLM系统提示词的一部分 | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'name': 'ArxivSearch.get_arxiv_article_information', | ||||
|   'description': 'Run Arxiv search and get the article meta information.', | ||||
|   'parameters': [{'name': 'query', | ||||
|     'type': 'STRING', | ||||
|     'description': 'the content of search query'}], | ||||
|   'required': ['query'], | ||||
|   'return_data': [{'name': 'content', | ||||
|     'description': 'a list of 3 arxiv search papers', | ||||
|     'type': 'STRING'}], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, | ||||
|  {'name': 'IPythonInterpreter', | ||||
|   'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.", | ||||
|   'parameters': [{'name': 'command', | ||||
|     'type': 'STRING', | ||||
|     'description': 'Python code'}, | ||||
|    {'name': 'timeout', | ||||
|     'type': 'NUMBER', | ||||
|     'description': 'Upper bound of waiting time for Python script execution.'}], | ||||
|   'required': ['command'], | ||||
|   'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}] | ||||
| ``` | ||||
|  | ||||
| 通过动作执行器来触发一个工具 | ||||
|  | ||||
| ```python | ||||
| ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}') | ||||
| ret.result | ||||
| ``` | ||||
|  | ||||
| ```python | ||||
| [{'type': 'text', 'content': '10.0'}] | ||||
| ``` | ||||
| @@ -1,46 +0,0 @@ | ||||
| from lagent.actions.action_executor import ActionExecutor | ||||
| from lagent.actions.builtin_actions import FinishAction | ||||
| from lagent.actions.python_interpreter import PythonInterpreter | ||||
| from lagent.agents.autogpt import AutoGPT | ||||
| from lagent.llms.openai import GPTAPI | ||||
|  | ||||
|  | ||||
| def input_prompt(): | ||||
|     print('\ndouble enter to end input >>> ', end='') | ||||
|     sentinel = ''  # ends when this string is seen | ||||
|     return '\n'.join(iter(input, sentinel)) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # set OPEN_API_KEY in your environment or directly pass it with key='' | ||||
|     model = GPTAPI(model_type='gpt-3.5-turbo') | ||||
|  | ||||
|     chatbot = AutoGPT( | ||||
|         llm=model, | ||||
|         action_executor=ActionExecutor( | ||||
|             actions=[ | ||||
|                 PythonInterpreter(), | ||||
|             ], | ||||
|             finish_action=FinishAction( | ||||
|                 description=( | ||||
|                     'Goals are accomplished and there is nothing left ' | ||||
|                     'to do. Parameter: {"response: "final response ' | ||||
|                     'for the goal"}')), | ||||
|             finish_in_action=True), | ||||
|     ) | ||||
|  | ||||
|     while True: | ||||
|         try: | ||||
|             prompt = input_prompt() | ||||
|         except UnicodeDecodeError: | ||||
|             print('UnicodeDecodeError') | ||||
|             continue | ||||
|         if prompt == 'exit': | ||||
|             exit(0) | ||||
|  | ||||
|         agent_return = chatbot.chat(prompt) | ||||
|         print(agent_return.response) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
| @@ -1,36 +0,0 @@ | ||||
| from argparse import ArgumentParser | ||||
|  | ||||
| from lagent.llms.openai import GPTAPI | ||||
|  | ||||
|  | ||||
| def parse_args(): | ||||
|     parser = ArgumentParser(description='chatbot') | ||||
|     parser.add_argument('--mode', default='chat') | ||||
|     args = parser.parse_args() | ||||
|     return args | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     args = parse_args() | ||||
|     # set OPEN_API_KEY in your environment or directly pass it with key='' | ||||
|     model = GPTAPI(model_type='gpt-3.5-turbo') | ||||
|     history = [] | ||||
|     while True: | ||||
|         try: | ||||
|             prompt = input('>>> ') | ||||
|         except UnicodeDecodeError: | ||||
|             print('UnicodeDecodeError') | ||||
|             continue | ||||
|         if prompt == 'exit': | ||||
|             exit(0) | ||||
|         if args.mode == 'chat': | ||||
|             history.append(dict(role='user', content=prompt)) | ||||
|             response = model.generate_from_template(history, max_out_len=512) | ||||
|             history.append(dict(role='assistant', content=response)) | ||||
|         elif args.mode == 'generate': | ||||
|             response = model.generate(prompt, max_out_len=512) | ||||
|         print('Assistant:', response) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
| @@ -1,37 +0,0 @@ | ||||
| from lagent.actions.action_executor import ActionExecutor | ||||
| from lagent.actions.python_interpreter import PythonInterpreter | ||||
| from lagent.agents.react import ReAct | ||||
| from lagent.llms.huggingface import HFTransformer | ||||
|  | ||||
| model = HFTransformer( | ||||
|     path='internlm/internlm-chat-7b-v1_1', | ||||
|     meta_template=[ | ||||
|         dict(role='system', begin='<|System|>:', end='<TOKENS_UNUSED_2>\n'), | ||||
|         dict(role='user', begin='<|User|>:', end='<eoh>\n'), | ||||
|         dict(role='assistant', begin='<|Bot|>:', end='<eoa>\n', generate=True) | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| chatbot = ReAct( | ||||
|     llm=model, | ||||
|     action_executor=ActionExecutor(actions=[PythonInterpreter()]), | ||||
| ) | ||||
|  | ||||
|  | ||||
| def input_prompt(): | ||||
|     print('\ndouble enter to end input >>> ', end='') | ||||
|     sentinel = ''  # ends when this string is seen | ||||
|     return '\n'.join(iter(input, sentinel)) | ||||
|  | ||||
|  | ||||
| while True: | ||||
|     try: | ||||
|         prompt = input_prompt() | ||||
|     except UnicodeDecodeError: | ||||
|         print('UnicodeDecodeError') | ||||
|         continue | ||||
|     if prompt == 'exit': | ||||
|         exit(0) | ||||
|  | ||||
|     agent_return = chatbot.chat(prompt) | ||||
|     print(agent_return.response) | ||||
							
								
								
									
										99
									
								
								examples/internlm2_agent_cli_demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								examples/internlm2_agent_cli_demo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | ||||
| from argparse import ArgumentParser | ||||
|  | ||||
| from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter | ||||
| from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol | ||||
| from lagent.llms import HFTransformer | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
| from lagent.schema import AgentStatusCode | ||||
|  | ||||
|  | ||||
| def parse_args(): | ||||
|     parser = ArgumentParser(description='chatbot') | ||||
|     parser.add_argument( | ||||
|         '--path', | ||||
|         type=str, | ||||
|         default='internlm/internlm2-chat-20b', | ||||
|         help='The path to the model') | ||||
|     args = parser.parse_args() | ||||
|     return args | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     args = parse_args() | ||||
|     # Initialize the HFTransformer-based Language Model (llm) | ||||
|     model = HFTransformer( | ||||
|         path=args.path, | ||||
|         meta_template=META, | ||||
|         max_new_tokens=1024, | ||||
|         top_p=0.8, | ||||
|         top_k=None, | ||||
|         temperature=0.1, | ||||
|         repetition_penalty=1.0, | ||||
|         stop_words=['<|im_end|>']) | ||||
|     plugin_executor = ActionExecutor(actions=[ArxivSearch()])  # noqa: F841 | ||||
|     interpreter_executor = ActionExecutor(actions=[IPythonInterpreter()]) | ||||
|  | ||||
|     chatbot = Internlm2Agent( | ||||
|         llm=model, | ||||
|         plugin_executor=None, | ||||
|         interpreter_executor=interpreter_executor, | ||||
|         protocol=Internlm2Protocol( | ||||
|             meta_prompt=META_CN, | ||||
|             interpreter_prompt=INTERPRETER_CN, | ||||
|             plugin_prompt=PLUGIN_CN, | ||||
|             tool=dict( | ||||
|                 begin='{start_token}{name}\n', | ||||
|                 start_token='<|action_start|>', | ||||
|                 name_map=dict( | ||||
|                     plugin='<|plugin|>', interpreter='<|interpreter|>'), | ||||
|                 belong='assistant', | ||||
|                 end='<|action_end|>\n', | ||||
|             ), | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     def input_prompt(): | ||||
|         print('\ndouble enter to end input >>> ', end='', flush=True) | ||||
|         sentinel = ''  # ends when this string is seen | ||||
|         return '\n'.join(iter(input, sentinel)) | ||||
|  | ||||
|     history = [] | ||||
|     while True: | ||||
|         try: | ||||
|             prompt = input_prompt() | ||||
|         except UnicodeDecodeError: | ||||
|             print('UnicodeDecodeError') | ||||
|             continue | ||||
|         if prompt == 'exit': | ||||
|             exit(0) | ||||
|         history.append(dict(role='user', content=prompt)) | ||||
|         print('\nInternLm2:', end='') | ||||
|         current_length = 0 | ||||
|         last_status = None | ||||
|         for agent_return in chatbot.stream_chat(history): | ||||
|             status = agent_return.state | ||||
|             if status not in [ | ||||
|                     AgentStatusCode.STREAM_ING, AgentStatusCode.CODING, | ||||
|                     AgentStatusCode.PLUGIN_START | ||||
|             ]: | ||||
|                 continue | ||||
|             if status != last_status: | ||||
|                 current_length = 0 | ||||
|                 print('') | ||||
|             if isinstance(agent_return.response, dict): | ||||
|                 action = f"\n\n {agent_return.response['name']}: \n\n" | ||||
|                 action_input = agent_return.response['parameters'] | ||||
|                 if agent_return.response['name'] == 'IPythonInterpreter': | ||||
|                     action_input = action_input['command'] | ||||
|                 response = action + action_input | ||||
|             else: | ||||
|                 response = agent_return.response | ||||
|             print(response[current_length:], end='', flush=True) | ||||
|             current_length = len(response) | ||||
|             last_status = status | ||||
|         print('') | ||||
|         history.extend(agent_return.inner_steps) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
							
								
								
									
										333
									
								
								examples/internlm2_agent_web_demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										333
									
								
								examples/internlm2_agent_web_demo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,333 @@ | ||||
| import copy | ||||
| import hashlib | ||||
| import json | ||||
| import os | ||||
|  | ||||
| import streamlit as st | ||||
|  | ||||
| from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter | ||||
| from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol | ||||
| from lagent.llms.lmdepoly_wrapper import LMDeployClient | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
| from lagent.schema import AgentStatusCode | ||||
|  | ||||
| # from streamlit.logger import get_logger | ||||
|  | ||||
|  | ||||
| class SessionState: | ||||
|  | ||||
|     def init_state(self): | ||||
|         """Initialize session state variables.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|  | ||||
|         action_list = [ | ||||
|             ArxivSearch(), | ||||
|         ] | ||||
|         st.session_state['plugin_map'] = { | ||||
|             action.name: action | ||||
|             for action in action_list | ||||
|         } | ||||
|         st.session_state['model_map'] = {} | ||||
|         st.session_state['model_selected'] = None | ||||
|         st.session_state['plugin_actions'] = set() | ||||
|         st.session_state['history'] = [] | ||||
|  | ||||
|     def clear_state(self): | ||||
|         """Clear the existing session state.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|         st.session_state['model_selected'] = None | ||||
|         st.session_state['file'] = set() | ||||
|         if 'chatbot' in st.session_state: | ||||
|             st.session_state['chatbot']._session_history = [] | ||||
|  | ||||
|  | ||||
| class StreamlitUI: | ||||
|  | ||||
|     def __init__(self, session_state: SessionState): | ||||
|         self.init_streamlit() | ||||
|         self.session_state = session_state | ||||
|  | ||||
|     def init_streamlit(self): | ||||
|         """Initialize Streamlit's UI settings.""" | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|         st.sidebar.title('模型控制') | ||||
|         st.session_state['file'] = set() | ||||
|         st.session_state['ip'] = None | ||||
|  | ||||
|     def setup_sidebar(self): | ||||
|         """Setup the sidebar for model and plugin selection.""" | ||||
|         # model_name = st.sidebar.selectbox('模型选择:', options=['internlm']) | ||||
|         model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b') | ||||
|         meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN) | ||||
|         da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN) | ||||
|         plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN) | ||||
|         model_ip = st.sidebar.text_input('模型IP:', value='10.140.0.220:23333') | ||||
|         if model_name != st.session_state[ | ||||
|                 'model_selected'] or st.session_state['ip'] != model_ip: | ||||
|             st.session_state['ip'] = model_ip | ||||
|             model = self.init_model(model_name, model_ip) | ||||
|             self.session_state.clear_state() | ||||
|             st.session_state['model_selected'] = model_name | ||||
|             if 'chatbot' in st.session_state: | ||||
|                 del st.session_state['chatbot'] | ||||
|         else: | ||||
|             model = st.session_state['model_map'][model_name] | ||||
|  | ||||
|         plugin_name = st.sidebar.multiselect( | ||||
|             '插件选择', | ||||
|             options=list(st.session_state['plugin_map'].keys()), | ||||
|             default=[], | ||||
|         ) | ||||
|         da_flag = st.sidebar.checkbox( | ||||
|             '数据分析', | ||||
|             value=False, | ||||
|         ) | ||||
|         plugin_action = [ | ||||
|             st.session_state['plugin_map'][name] for name in plugin_name | ||||
|         ] | ||||
|  | ||||
|         if 'chatbot' in st.session_state: | ||||
|             if len(plugin_action) > 0: | ||||
|                 st.session_state['chatbot']._action_executor = ActionExecutor( | ||||
|                     actions=plugin_action) | ||||
|             else: | ||||
|                 st.session_state['chatbot']._action_executor = None | ||||
|             if da_flag: | ||||
|                 st.session_state[ | ||||
|                     'chatbot']._interpreter_executor = ActionExecutor( | ||||
|                         actions=[IPythonInterpreter()]) | ||||
|             else: | ||||
|                 st.session_state['chatbot']._interpreter_executor = None | ||||
|             st.session_state['chatbot']._protocol._meta_template = meta_prompt | ||||
|             st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt | ||||
|             st.session_state[ | ||||
|                 'chatbot']._protocol.interpreter_prompt = da_prompt | ||||
|         if st.sidebar.button('清空对话', key='clear'): | ||||
|             self.session_state.clear_state() | ||||
|         uploaded_file = st.sidebar.file_uploader('上传文件') | ||||
|  | ||||
|         return model_name, model, plugin_action, uploaded_file, model_ip | ||||
|  | ||||
|     def init_model(self, model_name, ip=None): | ||||
|         """Initialize the model based on the input model name.""" | ||||
|         model_url = f'http://{ip}' | ||||
|         st.session_state['model_map'][model_name] = LMDeployClient( | ||||
|             model_name=model_name, | ||||
|             url=model_url, | ||||
|             meta_template=META, | ||||
|             max_new_tokens=1024, | ||||
|             top_p=0.8, | ||||
|             top_k=100, | ||||
|             temperature=0, | ||||
|             repetition_penalty=1.0, | ||||
|             stop_words=['<|im_end|>']) | ||||
|         return st.session_state['model_map'][model_name] | ||||
|  | ||||
|     def initialize_chatbot(self, model, plugin_action): | ||||
|         """Initialize the chatbot with the given model and plugin actions.""" | ||||
|         return Internlm2Agent( | ||||
|             llm=model, | ||||
|             protocol=Internlm2Protocol( | ||||
|                 tool=dict( | ||||
|                     begin='{start_token}{name}\n', | ||||
|                     start_token='<|action_start|>', | ||||
|                     name_map=dict( | ||||
|                         plugin='<|plugin|>', interpreter='<|interpreter|>'), | ||||
|                     belong='assistant', | ||||
|                     end='<|action_end|>\n', | ||||
|                 ), ), | ||||
|             max_turn=7) | ||||
|  | ||||
|     def render_user(self, prompt: str): | ||||
|         with st.chat_message('user'): | ||||
|             st.markdown(prompt) | ||||
|  | ||||
|     def render_assistant(self, agent_return): | ||||
|         with st.chat_message('assistant'): | ||||
|             for action in agent_return.actions: | ||||
|                 if (action) and (action.type != 'FinishAction'): | ||||
|                     self.render_action(action) | ||||
|             st.markdown(agent_return.response) | ||||
|  | ||||
|     def render_plugin_args(self, action): | ||||
|         action_name = action.type | ||||
|         args = action.args | ||||
|         import json | ||||
|         parameter_dict = dict(name=action_name, parameters=args) | ||||
|         parameter_str = '```json\n' + json.dumps( | ||||
|             parameter_dict, indent=4, ensure_ascii=False) + '\n```' | ||||
|         st.markdown(parameter_str) | ||||
|  | ||||
|     def render_interpreter_args(self, action): | ||||
|         st.info(action.type) | ||||
|         st.markdown(action.args['text']) | ||||
|  | ||||
|     def render_action(self, action): | ||||
|         st.markdown(action.thought) | ||||
|         if action.type == 'IPythonInterpreter': | ||||
|             self.render_interpreter_args(action) | ||||
|         elif action.type == 'FinishAction': | ||||
|             pass | ||||
|         else: | ||||
|             self.render_plugin_args(action) | ||||
|         self.render_action_results(action) | ||||
|  | ||||
|     def render_action_results(self, action): | ||||
|         """Render the results of action, including text, images, videos, and | ||||
|         audios.""" | ||||
|         if (isinstance(action.result, dict)): | ||||
|             if 'text' in action.result: | ||||
|                 st.markdown('```\n' + action.result['text'] + '\n```') | ||||
|             if 'image' in action.result: | ||||
|                 # image_path = action.result['image'] | ||||
|                 for image_path in action.result['image']: | ||||
|                     image_data = open(image_path, 'rb').read() | ||||
|                     st.image(image_data, caption='Generated Image') | ||||
|             if 'video' in action.result: | ||||
|                 video_data = action.result['video'] | ||||
|                 video_data = open(video_data, 'rb').read() | ||||
|                 st.video(video_data) | ||||
|             if 'audio' in action.result: | ||||
|                 audio_data = action.result['audio'] | ||||
|                 audio_data = open(audio_data, 'rb').read() | ||||
|                 st.audio(audio_data) | ||||
|         elif isinstance(action.result, list): | ||||
|             for item in action.result: | ||||
|                 if item['type'] == 'text': | ||||
|                     st.markdown('```\n' + item['content'] + '\n```') | ||||
|                 elif item['type'] == 'image': | ||||
|                     image_data = open(item['content'], 'rb').read() | ||||
|                     st.image(image_data, caption='Generated Image') | ||||
|                 elif item['type'] == 'video': | ||||
|                     video_data = open(item['content'], 'rb').read() | ||||
|                     st.video(video_data) | ||||
|                 elif item['type'] == 'audio': | ||||
|                     audio_data = open(item['content'], 'rb').read() | ||||
|                     st.audio(audio_data) | ||||
|         if action.errmsg: | ||||
|             st.error(action.errmsg) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # logger = get_logger(__name__) | ||||
|     # Initialize Streamlit UI and setup sidebar | ||||
|     if 'ui' not in st.session_state: | ||||
|         session_state = SessionState() | ||||
|         session_state.init_state() | ||||
|         st.session_state['ui'] = StreamlitUI(session_state) | ||||
|  | ||||
|     else: | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|     _, model, plugin_action, uploaded_file, _ = st.session_state[ | ||||
|         'ui'].setup_sidebar() | ||||
|  | ||||
|     # Initialize chatbot if it is not already initialized | ||||
|     # or if the model has changed | ||||
|     if 'chatbot' not in st.session_state or model != st.session_state[ | ||||
|             'chatbot']._llm: | ||||
|         st.session_state['chatbot'] = st.session_state[ | ||||
|             'ui'].initialize_chatbot(model, plugin_action) | ||||
|         st.session_state['session_history'] = [] | ||||
|  | ||||
|     for prompt, agent_return in zip(st.session_state['user'], | ||||
|                                     st.session_state['assistant']): | ||||
|         st.session_state['ui'].render_user(prompt) | ||||
|         st.session_state['ui'].render_assistant(agent_return) | ||||
|  | ||||
|     if user_input := st.chat_input(''): | ||||
|         with st.container(): | ||||
|             st.session_state['ui'].render_user(user_input) | ||||
|         st.session_state['user'].append(user_input) | ||||
|         # Add file uploader to sidebar | ||||
|         if (uploaded_file | ||||
|                 and uploaded_file.name not in st.session_state['file']): | ||||
|  | ||||
|             st.session_state['file'].add(uploaded_file.name) | ||||
|             file_bytes = uploaded_file.read() | ||||
|             file_type = uploaded_file.type | ||||
|             if 'image' in file_type: | ||||
|                 st.image(file_bytes, caption='Uploaded Image') | ||||
|             elif 'video' in file_type: | ||||
|                 st.video(file_bytes, caption='Uploaded Video') | ||||
|             elif 'audio' in file_type: | ||||
|                 st.audio(file_bytes, caption='Uploaded Audio') | ||||
|             # Save the file to a temporary location and get the path | ||||
|  | ||||
|             postfix = uploaded_file.name.split('.')[-1] | ||||
|             # prefix = str(uuid.uuid4()) | ||||
|             prefix = hashlib.md5(file_bytes).hexdigest() | ||||
|             filename = f'{prefix}.{postfix}' | ||||
|             file_path = os.path.join(root_dir, filename) | ||||
|             with open(file_path, 'wb') as tmpfile: | ||||
|                 tmpfile.write(file_bytes) | ||||
|             file_size = os.stat(file_path).st_size / 1024 / 1024 | ||||
|             file_size = f'{round(file_size, 2)} MB' | ||||
|             # st.write(f'File saved at: {file_path}') | ||||
|             user_input = [ | ||||
|                 dict(role='user', content=user_input), | ||||
|                 dict( | ||||
|                     role='user', | ||||
|                     content=json.dumps(dict(path=file_path, size=file_size)), | ||||
|                     name='file') | ||||
|             ] | ||||
|         if isinstance(user_input, str): | ||||
|             user_input = [dict(role='user', content=user_input)] | ||||
|         st.session_state['last_status'] = AgentStatusCode.SESSION_READY | ||||
|         for agent_return in st.session_state['chatbot'].stream_chat( | ||||
|                 st.session_state['session_history'] + user_input): | ||||
|             if agent_return.state == AgentStatusCode.PLUGIN_RETURN: | ||||
|                 with st.container(): | ||||
|                     st.session_state['ui'].render_plugin_args( | ||||
|                         agent_return.actions[-1]) | ||||
|                     st.session_state['ui'].render_action_results( | ||||
|                         agent_return.actions[-1]) | ||||
|             elif agent_return.state == AgentStatusCode.CODE_RETURN: | ||||
|                 with st.container(): | ||||
|                     st.session_state['ui'].render_action_results( | ||||
|                         agent_return.actions[-1]) | ||||
|             elif (agent_return.state == AgentStatusCode.STREAM_ING | ||||
|                   or agent_return.state == AgentStatusCode.CODING): | ||||
|                 # st.markdown(agent_return.response) | ||||
|                 # 清除占位符的当前内容,并显示新内容 | ||||
|                 with st.container(): | ||||
|                     if agent_return.state != st.session_state['last_status']: | ||||
|                         st.session_state['temp'] = '' | ||||
|                         placeholder = st.empty() | ||||
|                         st.session_state['placeholder'] = placeholder | ||||
|                     if isinstance(agent_return.response, dict): | ||||
|                         action = f"\n\n {agent_return.response['name']}: \n\n" | ||||
|                         action_input = agent_return.response['parameters'] | ||||
|                         if agent_return.response[ | ||||
|                                 'name'] == 'IPythonInterpreter': | ||||
|                             action_input = action_input['command'] | ||||
|                         response = action + action_input | ||||
|                     else: | ||||
|                         response = agent_return.response | ||||
|                     st.session_state['temp'] = response | ||||
|                     st.session_state['placeholder'].markdown( | ||||
|                         st.session_state['temp']) | ||||
|             elif agent_return.state == AgentStatusCode.END: | ||||
|                 st.session_state['session_history'] += ( | ||||
|                     user_input + agent_return.inner_steps) | ||||
|                 agent_return = copy.deepcopy(agent_return) | ||||
|                 agent_return.response = st.session_state['temp'] | ||||
|                 st.session_state['assistant'].append( | ||||
|                     copy.deepcopy(agent_return)) | ||||
|             st.session_state['last_status'] = agent_return.state | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
|     root_dir = os.path.join(root_dir, 'tmp_dir') | ||||
|     os.makedirs(root_dir, exist_ok=True) | ||||
|     main() | ||||
							
								
								
									
										332
									
								
								examples/internlm2_agent_web_demo_hf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										332
									
								
								examples/internlm2_agent_web_demo_hf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,332 @@ | ||||
| import copy | ||||
| import hashlib | ||||
| import json | ||||
| import os | ||||
|  | ||||
| import streamlit as st | ||||
|  | ||||
| from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter | ||||
| from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol | ||||
| from lagent.llms import HFTransformer | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
| from lagent.schema import AgentStatusCode | ||||
|  | ||||
| # from streamlit.logger import get_logger | ||||
|  | ||||
|  | ||||
| class SessionState: | ||||
|  | ||||
|     def init_state(self): | ||||
|         """Initialize session state variables.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|  | ||||
|         action_list = [ | ||||
|             ArxivSearch(), | ||||
|         ] | ||||
|         st.session_state['plugin_map'] = { | ||||
|             action.name: action | ||||
|             for action in action_list | ||||
|         } | ||||
|         st.session_state['model_map'] = {} | ||||
|         st.session_state['model_selected'] = None | ||||
|         st.session_state['plugin_actions'] = set() | ||||
|         st.session_state['history'] = [] | ||||
|  | ||||
|     def clear_state(self): | ||||
|         """Clear the existing session state.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|         st.session_state['model_selected'] = None | ||||
|         st.session_state['file'] = set() | ||||
|         if 'chatbot' in st.session_state: | ||||
|             st.session_state['chatbot']._session_history = [] | ||||
|  | ||||
|  | ||||
| class StreamlitUI: | ||||
|  | ||||
|     def __init__(self, session_state: SessionState): | ||||
|         self.init_streamlit() | ||||
|         self.session_state = session_state | ||||
|  | ||||
|     def init_streamlit(self): | ||||
|         """Initialize Streamlit's UI settings.""" | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|         st.sidebar.title('模型控制') | ||||
|         st.session_state['file'] = set() | ||||
|         st.session_state['model_path'] = None | ||||
|  | ||||
|     def setup_sidebar(self): | ||||
|         """Setup the sidebar for model and plugin selection.""" | ||||
|         # model_name = st.sidebar.selectbox('模型选择:', options=['internlm']) | ||||
|         model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b') | ||||
|         meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN) | ||||
|         da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN) | ||||
|         plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN) | ||||
|         model_path = st.sidebar.text_input( | ||||
|             '模型路径:', value='internlm/internlm2-chat-20b') | ||||
|         if model_name != st.session_state['model_selected'] or st.session_state[ | ||||
|                 'model_path'] != model_path: | ||||
|             st.session_state['model_path'] = model_path | ||||
|             model = self.init_model(model_name, model_path) | ||||
|             self.session_state.clear_state() | ||||
|             st.session_state['model_selected'] = model_name | ||||
|             if 'chatbot' in st.session_state: | ||||
|                 del st.session_state['chatbot'] | ||||
|         else: | ||||
|             model = st.session_state['model_map'][model_name] | ||||
|  | ||||
|         plugin_name = st.sidebar.multiselect( | ||||
|             '插件选择', | ||||
|             options=list(st.session_state['plugin_map'].keys()), | ||||
|             default=[], | ||||
|         ) | ||||
|         da_flag = st.sidebar.checkbox( | ||||
|             '数据分析', | ||||
|             value=False, | ||||
|         ) | ||||
|         plugin_action = [ | ||||
|             st.session_state['plugin_map'][name] for name in plugin_name | ||||
|         ] | ||||
|  | ||||
|         if 'chatbot' in st.session_state: | ||||
|             if len(plugin_action) > 0: | ||||
|                 st.session_state['chatbot']._action_executor = ActionExecutor( | ||||
|                     actions=plugin_action) | ||||
|             else: | ||||
|                 st.session_state['chatbot']._action_executor = None | ||||
|             if da_flag: | ||||
|                 st.session_state[ | ||||
|                     'chatbot']._interpreter_executor = ActionExecutor( | ||||
|                         actions=[IPythonInterpreter()]) | ||||
|             else: | ||||
|                 st.session_state['chatbot']._interpreter_executor = None | ||||
|             st.session_state['chatbot']._protocol._meta_template = meta_prompt | ||||
|             st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt | ||||
|             st.session_state[ | ||||
|                 'chatbot']._protocol.interpreter_prompt = da_prompt | ||||
|         if st.sidebar.button('清空对话', key='clear'): | ||||
|             self.session_state.clear_state() | ||||
|         uploaded_file = st.sidebar.file_uploader('上传文件') | ||||
|  | ||||
|         return model_name, model, plugin_action, uploaded_file, model_path | ||||
|  | ||||
|     def init_model(self, model_name, path): | ||||
|         """Initialize the model based on the input model name.""" | ||||
|         st.session_state['model_map'][model_name] = HFTransformer( | ||||
|             path=path, | ||||
|             meta_template=META, | ||||
|             max_new_tokens=1024, | ||||
|             top_p=0.8, | ||||
|             top_k=None, | ||||
|             temperature=0.1, | ||||
|             repetition_penalty=1.0, | ||||
|             stop_words=['<|im_end|>']) | ||||
|         return st.session_state['model_map'][model_name] | ||||
|  | ||||
|     def initialize_chatbot(self, model, plugin_action): | ||||
|         """Initialize the chatbot with the given model and plugin actions.""" | ||||
|         return Internlm2Agent( | ||||
|             llm=model, | ||||
|             protocol=Internlm2Protocol( | ||||
|                 tool=dict( | ||||
|                     begin='{start_token}{name}\n', | ||||
|                     start_token='<|action_start|>', | ||||
|                     name_map=dict( | ||||
|                         plugin='<|plugin|>', interpreter='<|interpreter|>'), | ||||
|                     belong='assistant', | ||||
|                     end='<|action_end|>\n', | ||||
|                 ), ), | ||||
|             max_turn=7) | ||||
|  | ||||
|     def render_user(self, prompt: str): | ||||
|         with st.chat_message('user'): | ||||
|             st.markdown(prompt) | ||||
|  | ||||
|     def render_assistant(self, agent_return): | ||||
|         with st.chat_message('assistant'): | ||||
|             for action in agent_return.actions: | ||||
|                 if (action) and (action.type != 'FinishAction'): | ||||
|                     self.render_action(action) | ||||
|             st.markdown(agent_return.response) | ||||
|  | ||||
|     def render_plugin_args(self, action): | ||||
|         action_name = action.type | ||||
|         args = action.args | ||||
|         import json | ||||
|         parameter_dict = dict(name=action_name, parameters=args) | ||||
|         parameter_str = '```json\n' + json.dumps( | ||||
|             parameter_dict, indent=4, ensure_ascii=False) + '\n```' | ||||
|         st.markdown(parameter_str) | ||||
|  | ||||
|     def render_interpreter_args(self, action): | ||||
|         st.info(action.type) | ||||
|         st.markdown(action.args['text']) | ||||
|  | ||||
|     def render_action(self, action): | ||||
|         st.markdown(action.thought) | ||||
|         if action.type == 'IPythonInterpreter': | ||||
|             self.render_interpreter_args(action) | ||||
|         elif action.type == 'FinishAction': | ||||
|             pass | ||||
|         else: | ||||
|             self.render_plugin_args(action) | ||||
|         self.render_action_results(action) | ||||
|  | ||||
|     def render_action_results(self, action): | ||||
|         """Render the results of action, including text, images, videos, and | ||||
|         audios.""" | ||||
|         if (isinstance(action.result, dict)): | ||||
|             if 'text' in action.result: | ||||
|                 st.markdown('```\n' + action.result['text'] + '\n```') | ||||
|             if 'image' in action.result: | ||||
|                 # image_path = action.result['image'] | ||||
|                 for image_path in action.result['image']: | ||||
|                     image_data = open(image_path, 'rb').read() | ||||
|                     st.image(image_data, caption='Generated Image') | ||||
|             if 'video' in action.result: | ||||
|                 video_data = action.result['video'] | ||||
|                 video_data = open(video_data, 'rb').read() | ||||
|                 st.video(video_data) | ||||
|             if 'audio' in action.result: | ||||
|                 audio_data = action.result['audio'] | ||||
|                 audio_data = open(audio_data, 'rb').read() | ||||
|                 st.audio(audio_data) | ||||
|         elif isinstance(action.result, list): | ||||
|             for item in action.result: | ||||
|                 if item['type'] == 'text': | ||||
|                     st.markdown('```\n' + item['content'] + '\n```') | ||||
|                 elif item['type'] == 'image': | ||||
|                     image_data = open(item['content'], 'rb').read() | ||||
|                     st.image(image_data, caption='Generated Image') | ||||
|                 elif item['type'] == 'video': | ||||
|                     video_data = open(item['content'], 'rb').read() | ||||
|                     st.video(video_data) | ||||
|                 elif item['type'] == 'audio': | ||||
|                     audio_data = open(item['content'], 'rb').read() | ||||
|                     st.audio(audio_data) | ||||
|         if action.errmsg: | ||||
|             st.error(action.errmsg) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # logger = get_logger(__name__) | ||||
|     # Initialize Streamlit UI and setup sidebar | ||||
|     if 'ui' not in st.session_state: | ||||
|         session_state = SessionState() | ||||
|         session_state.init_state() | ||||
|         st.session_state['ui'] = StreamlitUI(session_state) | ||||
|  | ||||
|     else: | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|     _, model, plugin_action, uploaded_file, _ = st.session_state[ | ||||
|         'ui'].setup_sidebar() | ||||
|  | ||||
|     # Initialize chatbot if it is not already initialized | ||||
|     # or if the model has changed | ||||
|     if 'chatbot' not in st.session_state or model != st.session_state[ | ||||
|             'chatbot']._llm: | ||||
|         st.session_state['chatbot'] = st.session_state[ | ||||
|             'ui'].initialize_chatbot(model, plugin_action) | ||||
|         st.session_state['session_history'] = [] | ||||
|  | ||||
|     for prompt, agent_return in zip(st.session_state['user'], | ||||
|                                     st.session_state['assistant']): | ||||
|         st.session_state['ui'].render_user(prompt) | ||||
|         st.session_state['ui'].render_assistant(agent_return) | ||||
|  | ||||
|     if user_input := st.chat_input(''): | ||||
|         with st.container(): | ||||
|             st.session_state['ui'].render_user(user_input) | ||||
|         st.session_state['user'].append(user_input) | ||||
|         # Add file uploader to sidebar | ||||
|         if (uploaded_file | ||||
|                 and uploaded_file.name not in st.session_state['file']): | ||||
|  | ||||
|             st.session_state['file'].add(uploaded_file.name) | ||||
|             file_bytes = uploaded_file.read() | ||||
|             file_type = uploaded_file.type | ||||
|             if 'image' in file_type: | ||||
|                 st.image(file_bytes, caption='Uploaded Image') | ||||
|             elif 'video' in file_type: | ||||
|                 st.video(file_bytes, caption='Uploaded Video') | ||||
|             elif 'audio' in file_type: | ||||
|                 st.audio(file_bytes, caption='Uploaded Audio') | ||||
|             # Save the file to a temporary location and get the path | ||||
|  | ||||
|             postfix = uploaded_file.name.split('.')[-1] | ||||
|             # prefix = str(uuid.uuid4()) | ||||
|             prefix = hashlib.md5(file_bytes).hexdigest() | ||||
|             filename = f'{prefix}.{postfix}' | ||||
|             file_path = os.path.join(root_dir, filename) | ||||
|             with open(file_path, 'wb') as tmpfile: | ||||
|                 tmpfile.write(file_bytes) | ||||
|             file_size = os.stat(file_path).st_size / 1024 / 1024 | ||||
|             file_size = f'{round(file_size, 2)} MB' | ||||
|             # st.write(f'File saved at: {file_path}') | ||||
|             user_input = [ | ||||
|                 dict(role='user', content=user_input), | ||||
|                 dict( | ||||
|                     role='user', | ||||
|                     content=json.dumps(dict(path=file_path, size=file_size)), | ||||
|                     name='file') | ||||
|             ] | ||||
|         if isinstance(user_input, str): | ||||
|             user_input = [dict(role='user', content=user_input)] | ||||
|         st.session_state['last_status'] = AgentStatusCode.SESSION_READY | ||||
|         for agent_return in st.session_state['chatbot'].stream_chat( | ||||
|                 st.session_state['session_history'] + user_input): | ||||
|             if agent_return.state == AgentStatusCode.PLUGIN_RETURN: | ||||
|                 with st.container(): | ||||
|                     st.session_state['ui'].render_plugin_args( | ||||
|                         agent_return.actions[-1]) | ||||
|                     st.session_state['ui'].render_action_results( | ||||
|                         agent_return.actions[-1]) | ||||
|             elif agent_return.state == AgentStatusCode.CODE_RETURN: | ||||
|                 with st.container(): | ||||
|                     st.session_state['ui'].render_action_results( | ||||
|                         agent_return.actions[-1]) | ||||
|             elif (agent_return.state == AgentStatusCode.STREAM_ING | ||||
|                   or agent_return.state == AgentStatusCode.CODING): | ||||
|                 # st.markdown(agent_return.response) | ||||
|                 # 清除占位符的当前内容,并显示新内容 | ||||
|                 with st.container(): | ||||
|                     if agent_return.state != st.session_state['last_status']: | ||||
|                         st.session_state['temp'] = '' | ||||
|                         placeholder = st.empty() | ||||
|                         st.session_state['placeholder'] = placeholder | ||||
|                     if isinstance(agent_return.response, dict): | ||||
|                         action = f"\n\n {agent_return.response['name']}: \n\n" | ||||
|                         action_input = agent_return.response['parameters'] | ||||
|                         if agent_return.response[ | ||||
|                                 'name'] == 'IPythonInterpreter': | ||||
|                             action_input = action_input['command'] | ||||
|                         response = action + action_input | ||||
|                     else: | ||||
|                         response = agent_return.response | ||||
|                     st.session_state['temp'] = response | ||||
|                     st.session_state['placeholder'].markdown( | ||||
|                         st.session_state['temp']) | ||||
|             elif agent_return.state == AgentStatusCode.END: | ||||
|                 st.session_state['session_history'] += ( | ||||
|                     user_input + agent_return.inner_steps) | ||||
|                 agent_return = copy.deepcopy(agent_return) | ||||
|                 agent_return.response = st.session_state['temp'] | ||||
|                 st.session_state['assistant'].append( | ||||
|                     copy.deepcopy(agent_return)) | ||||
|             st.session_state['last_status'] = agent_return.state | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
|     root_dir = os.path.join(root_dir, 'tmp_dir') | ||||
|     os.makedirs(root_dir, exist_ok=True) | ||||
|     main() | ||||
							
								
								
									
										63
									
								
								examples/model_cli_demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								examples/model_cli_demo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| from argparse import ArgumentParser | ||||
|  | ||||
| from lagent.llms import HFTransformer | ||||
| from lagent.llms.meta_template import INTERNLM2_META as META | ||||
|  | ||||
|  | ||||
| def parse_args(): | ||||
|     parser = ArgumentParser(description='chatbot') | ||||
|     parser.add_argument( | ||||
|         '--path', | ||||
|         type=str, | ||||
|         default='internlm/internlm2-chat-20b', | ||||
|         help='The path to the model') | ||||
|     parser.add_argument( | ||||
|         '--mode', | ||||
|         type=str, | ||||
|         default='chat', | ||||
|         help='Completion through chat or generate') | ||||
|     args = parser.parse_args() | ||||
|     return args | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     args = parse_args() | ||||
|     # Initialize the HFTransformer-based Language Model (llm) | ||||
|     model = HFTransformer( | ||||
|         path=args.path, | ||||
|         meta_template=META, | ||||
|         max_new_tokens=1024, | ||||
|         top_p=0.8, | ||||
|         top_k=None, | ||||
|         temperature=0.1, | ||||
|         repetition_penalty=1.0, | ||||
|         stop_words=['<|im_end|>']) | ||||
|  | ||||
|     def input_prompt(): | ||||
|         print('\ndouble enter to end input >>> ', end='', flush=True) | ||||
|         sentinel = ''  # ends when this string is seen | ||||
|         return '\n'.join(iter(input, sentinel)) | ||||
|  | ||||
|     history = [] | ||||
|     while True: | ||||
|         try: | ||||
|             prompt = input_prompt() | ||||
|         except UnicodeDecodeError: | ||||
|             print('UnicodeDecodeError') | ||||
|             continue | ||||
|         if prompt == 'exit': | ||||
|             exit(0) | ||||
|         history.append(dict(role='user', content=prompt)) | ||||
|         if args.mode == 'generate': | ||||
|             history = [dict(role='user', content=prompt)] | ||||
|         print('\nInternLm2:', end='') | ||||
|         current_length = 0 | ||||
|         for status, response, _ in model.stream_chat(history): | ||||
|             print(response[current_length:], end='', flush=True) | ||||
|             current_length = len(response) | ||||
|         history.append(dict(role='assistant', content=response)) | ||||
|         print('') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
| @@ -1,36 +0,0 @@ | ||||
| from lagent.actions.action_executor import ActionExecutor | ||||
| from lagent.actions.python_interpreter import PythonInterpreter | ||||
| from lagent.agents.react import ReAct | ||||
| from lagent.llms.openai import GPTAPI | ||||
|  | ||||
|  | ||||
| def input_prompt(): | ||||
|     print('\ndouble enter to end input >>> ', end='') | ||||
|     sentinel = ''  # ends when this string is seen | ||||
|     return '\n'.join(iter(input, sentinel)) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # set OPEN_API_KEY in your environment or directly pass it with key='' | ||||
|     model = GPTAPI(model_type='gpt-3.5-turbo') | ||||
|  | ||||
|     chatbot = ReAct( | ||||
|         llm=model, | ||||
|         action_executor=ActionExecutor(actions=[PythonInterpreter()]), | ||||
|     ) | ||||
|  | ||||
|     while True: | ||||
|         try: | ||||
|             prompt = input_prompt() | ||||
|         except UnicodeDecodeError: | ||||
|             print('UnicodeDecodeError') | ||||
|             continue | ||||
|         if prompt == 'exit': | ||||
|             exit(0) | ||||
|  | ||||
|         agent_return = chatbot.chat(prompt) | ||||
|         print(agent_return.response) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
| @@ -1,215 +0,0 @@ | ||||
| import copy | ||||
| import os | ||||
|  | ||||
| import streamlit as st | ||||
| from streamlit.logger import get_logger | ||||
|  | ||||
| from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter | ||||
| from lagent.agents.react import ReAct | ||||
| from lagent.llms import GPTAPI | ||||
| from lagent.llms.huggingface import HFTransformerCasualLM | ||||
|  | ||||
|  | ||||
| class SessionState: | ||||
|  | ||||
|     def init_state(self): | ||||
|         """Initialize session state variables.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|  | ||||
|         action_list = [PythonInterpreter(), GoogleSearch()] | ||||
|         st.session_state['plugin_map'] = { | ||||
|             action.name: action | ||||
|             for action in action_list | ||||
|         } | ||||
|         st.session_state['model_map'] = {} | ||||
|         st.session_state['model_selected'] = None | ||||
|         st.session_state['plugin_actions'] = set() | ||||
|  | ||||
|     def clear_state(self): | ||||
|         """Clear the existing session state.""" | ||||
|         st.session_state['assistant'] = [] | ||||
|         st.session_state['user'] = [] | ||||
|         st.session_state['model_selected'] = None | ||||
|         if 'chatbot' in st.session_state: | ||||
|             st.session_state['chatbot']._session_history = [] | ||||
|  | ||||
|  | ||||
| class StreamlitUI: | ||||
|  | ||||
|     def __init__(self, session_state: SessionState): | ||||
|         self.init_streamlit() | ||||
|         self.session_state = session_state | ||||
|  | ||||
|     def init_streamlit(self): | ||||
|         """Initialize Streamlit's UI settings.""" | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|         st.sidebar.title('模型控制') | ||||
|  | ||||
|     def setup_sidebar(self): | ||||
|         """Setup the sidebar for model and plugin selection.""" | ||||
|         model_name = st.sidebar.selectbox( | ||||
|             '模型选择:', options=['gpt-3.5-turbo', 'internlm']) | ||||
|         if model_name != st.session_state['model_selected']: | ||||
|             model = self.init_model(model_name) | ||||
|             self.session_state.clear_state() | ||||
|             st.session_state['model_selected'] = model_name | ||||
|             if 'chatbot' in st.session_state: | ||||
|                 del st.session_state['chatbot'] | ||||
|         else: | ||||
|             model = st.session_state['model_map'][model_name] | ||||
|  | ||||
|         plugin_name = st.sidebar.multiselect( | ||||
|             '插件选择', | ||||
|             options=list(st.session_state['plugin_map'].keys()), | ||||
|             default=[list(st.session_state['plugin_map'].keys())[0]], | ||||
|         ) | ||||
|  | ||||
|         plugin_action = [ | ||||
|             st.session_state['plugin_map'][name] for name in plugin_name | ||||
|         ] | ||||
|         if 'chatbot' in st.session_state: | ||||
|             st.session_state['chatbot']._action_executor = ActionExecutor( | ||||
|                 actions=plugin_action) | ||||
|         if st.sidebar.button('清空对话', key='clear'): | ||||
|             self.session_state.clear_state() | ||||
|         uploaded_file = st.sidebar.file_uploader( | ||||
|             '上传文件', type=['png', 'jpg', 'jpeg', 'mp4', 'mp3', 'wav']) | ||||
|         return model_name, model, plugin_action, uploaded_file | ||||
|  | ||||
|     def init_model(self, option): | ||||
|         """Initialize the model based on the selected option.""" | ||||
|         if option not in st.session_state['model_map']: | ||||
|             if option.startswith('gpt'): | ||||
|                 st.session_state['model_map'][option] = GPTAPI( | ||||
|                     model_type=option) | ||||
|             else: | ||||
|                 st.session_state['model_map'][option] = HFTransformerCasualLM( | ||||
|                     'internlm/internlm-chat-7b-v1_1') | ||||
|         return st.session_state['model_map'][option] | ||||
|  | ||||
|     def initialize_chatbot(self, model, plugin_action): | ||||
|         """Initialize the chatbot with the given model and plugin actions.""" | ||||
|         return ReAct( | ||||
|             llm=model, action_executor=ActionExecutor(actions=plugin_action)) | ||||
|  | ||||
|     def render_user(self, prompt: str): | ||||
|         with st.chat_message('user'): | ||||
|             st.markdown(prompt) | ||||
|  | ||||
|     def render_assistant(self, agent_return): | ||||
|         with st.chat_message('assistant'): | ||||
|             for action in agent_return.actions: | ||||
|                 if (action): | ||||
|                     self.render_action(action) | ||||
|             st.markdown(agent_return.response) | ||||
|  | ||||
|     def render_action(self, action): | ||||
|         with st.expander(action.type, expanded=True): | ||||
|             st.markdown( | ||||
|                 "<p style='text-align: left;display:flex;'> <span style='font-size:14px;font-weight:600;width:70px;text-align-last: justify;'>插    件</span><span style='width:14px;text-align:left;display:block;'>:</span><span style='flex:1;'>"  # noqa E501 | ||||
|                 + action.type + '</span></p>', | ||||
|                 unsafe_allow_html=True) | ||||
|             st.markdown( | ||||
|                 "<p style='text-align: left;display:flex;'> <span style='font-size:14px;font-weight:600;width:70px;text-align-last: justify;'>思考步骤</span><span style='width:14px;text-align:left;display:block;'>:</span><span style='flex:1;'>"  # noqa E501 | ||||
|                 + action.thought + '</span></p>', | ||||
|                 unsafe_allow_html=True) | ||||
|             if (isinstance(action.args, dict) and 'text' in action.args): | ||||
|                 st.markdown( | ||||
|                     "<p style='text-align: left;display:flex;'><span style='font-size:14px;font-weight:600;width:70px;text-align-last: justify;'> 执行内容</span><span style='width:14px;text-align:left;display:block;'>:</span></p>",  # noqa E501 | ||||
|                     unsafe_allow_html=True) | ||||
|                 st.markdown(action.args['text']) | ||||
|             self.render_action_results(action) | ||||
|  | ||||
|     def render_action_results(self, action): | ||||
|         """Render the results of action, including text, images, videos, and | ||||
|         audios.""" | ||||
|         if (isinstance(action.result, dict)): | ||||
|             st.markdown( | ||||
|                 "<p style='text-align: left;display:flex;'><span style='font-size:14px;font-weight:600;width:70px;text-align-last: justify;'> 执行结果</span><span style='width:14px;text-align:left;display:block;'>:</span></p>",  # noqa E501 | ||||
|                 unsafe_allow_html=True) | ||||
|             if 'text' in action.result: | ||||
|                 st.markdown( | ||||
|                     "<p style='text-align: left;'>" + action.result['text'] + | ||||
|                     '</p>', | ||||
|                     unsafe_allow_html=True) | ||||
|             if 'image' in action.result: | ||||
|                 image_path = action.result['image'] | ||||
|                 image_data = open(image_path, 'rb').read() | ||||
|                 st.image(image_data, caption='Generated Image') | ||||
|             if 'video' in action.result: | ||||
|                 video_data = action.result['video'] | ||||
|                 video_data = open(video_data, 'rb').read() | ||||
|                 st.video(video_data) | ||||
|             if 'audio' in action.result: | ||||
|                 audio_data = action.result['audio'] | ||||
|                 audio_data = open(audio_data, 'rb').read() | ||||
|                 st.audio(audio_data) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     logger = get_logger(__name__) | ||||
|     # Initialize Streamlit UI and setup sidebar | ||||
|     if 'ui' not in st.session_state: | ||||
|         session_state = SessionState() | ||||
|         session_state.init_state() | ||||
|         st.session_state['ui'] = StreamlitUI(session_state) | ||||
|  | ||||
|     else: | ||||
|         st.set_page_config( | ||||
|             layout='wide', | ||||
|             page_title='lagent-web', | ||||
|             page_icon='./docs/imgs/lagent_icon.png') | ||||
|         st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') | ||||
|     model_name, model, plugin_action, uploaded_file = st.session_state[ | ||||
|         'ui'].setup_sidebar() | ||||
|  | ||||
|     # Initialize chatbot if it is not already initialized | ||||
|     # or if the model has changed | ||||
|     if 'chatbot' not in st.session_state or model != st.session_state[ | ||||
|             'chatbot']._llm: | ||||
|         st.session_state['chatbot'] = st.session_state[ | ||||
|             'ui'].initialize_chatbot(model, plugin_action) | ||||
|  | ||||
|     for prompt, agent_return in zip(st.session_state['user'], | ||||
|                                     st.session_state['assistant']): | ||||
|         st.session_state['ui'].render_user(prompt) | ||||
|         st.session_state['ui'].render_assistant(agent_return) | ||||
|     # User input form at the bottom (this part will be at the bottom) | ||||
|     # with st.form(key='my_form', clear_on_submit=True): | ||||
|  | ||||
|     if user_input := st.chat_input(''): | ||||
|         st.session_state['ui'].render_user(user_input) | ||||
|         st.session_state['user'].append(user_input) | ||||
|         # Add file uploader to sidebar | ||||
|         if uploaded_file: | ||||
|             file_bytes = uploaded_file.read() | ||||
|             file_type = uploaded_file.type | ||||
|             if 'image' in file_type: | ||||
|                 st.image(file_bytes, caption='Uploaded Image') | ||||
|             elif 'video' in file_type: | ||||
|                 st.video(file_bytes, caption='Uploaded Video') | ||||
|             elif 'audio' in file_type: | ||||
|                 st.audio(file_bytes, caption='Uploaded Audio') | ||||
|             # Save the file to a temporary location and get the path | ||||
|             file_path = os.path.join(root_dir, uploaded_file.name) | ||||
|             with open(file_path, 'wb') as tmpfile: | ||||
|                 tmpfile.write(file_bytes) | ||||
|             st.write(f'File saved at: {file_path}') | ||||
|             user_input = '我上传了一个图像,路径为: {file_path}. {user_input}'.format( | ||||
|                 file_path=file_path, user_input=user_input) | ||||
|         agent_return = st.session_state['chatbot'].chat(user_input) | ||||
|         st.session_state['assistant'].append(copy.deepcopy(agent_return)) | ||||
|         logger.info(agent_return.inner_steps) | ||||
|         st.session_state['ui'].render_assistant(agent_return) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
|     root_dir = os.path.join(root_dir, 'tmp_dir') | ||||
|     os.makedirs(root_dir, exist_ok=True) | ||||
|     main() | ||||
| @@ -1,19 +0,0 @@ | ||||
| from lagent.actions import LLMQA, ActionExecutor, GoogleSearch | ||||
| from lagent.agents import ReWOO | ||||
| from lagent.llms.openai import GPTAPI | ||||
|  | ||||
| # set OPEN_API_KEY in your environment or directly pass it with key='' | ||||
| model = GPTAPI(model_type='gpt-3.5-turbo') | ||||
| # please set the serper search API key | ||||
| search_tool = GoogleSearch(api_key='SERPER_API_KEY') | ||||
| llmqa_tool = LLMQA(model) | ||||
|  | ||||
| chatbot = ReWOO( | ||||
|     llm=model, | ||||
|     action_executor=ActionExecutor(actions=[llmqa_tool, search_tool]), | ||||
| ) | ||||
|  | ||||
| prompt = 'What profession does Nicholas Ray and Elia Kazan have in common' | ||||
|  | ||||
| agent_return = chatbot.chat(prompt) | ||||
| print(agent_return.response) | ||||
| @@ -1,47 +0,0 @@ | ||||
| import argparse | ||||
|  | ||||
| from lagent.actions.action_executor import ActionExecutor | ||||
| from lagent.actions.python_interpreter import PythonInterpreter | ||||
| from lagent.agents.react import ReAct | ||||
| from lagent.llms.lmdeploy import TurboMind | ||||
|  | ||||
| def parse_args(): | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--path', type=str, help="The path to the model") | ||||
|     args = parser.parse_args() | ||||
|     return args | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     args = parse_args() | ||||
|     model = TurboMind( | ||||
|         path=args.path, | ||||
|         meta_template=[ | ||||
|             dict(role='system', begin='<|System|>:', end='<TOKENS_UNUSED_2>\n'), | ||||
|             dict(role='user', begin='<|User|>:', end='<eoh>\n'), | ||||
|             dict(role='assistant', begin='<|Bot|>:', end='<eoa>\n', generate=True) | ||||
|         ], | ||||
|     ) | ||||
|  | ||||
|     chatbot = ReAct( | ||||
|         llm=model, | ||||
|         action_executor=ActionExecutor(actions=[PythonInterpreter()]), | ||||
|     ) | ||||
|  | ||||
|  | ||||
|     def input_prompt(): | ||||
|         print('\ndouble enter to end input >>> ', end='') | ||||
|         sentinel = ''  # ends when this string is seen | ||||
|         return '\n'.join(iter(input, sentinel)) | ||||
|  | ||||
|  | ||||
|     while True: | ||||
|         try: | ||||
|             prompt = input_prompt() | ||||
|         except UnicodeDecodeError: | ||||
|             print('UnicodeDecodeError') | ||||
|             continue | ||||
|         if prompt == 'exit': | ||||
|             exit(0) | ||||
|  | ||||
|         agent_return = chatbot.chat(prompt) | ||||
|         print(agent_return.response) | ||||
| @@ -1,11 +1,62 @@ | ||||
| from typing import Type | ||||
|  | ||||
| from .action_executor import ActionExecutor | ||||
| from .base_action import BaseAction | ||||
| from .arxiv_search import ArxivSearch | ||||
| from .base_action import TOOL_REGISTRY, BaseAction, tool_api | ||||
| from .bing_map import BINGMap | ||||
| from .builtin_actions import FinishAction, InvalidAction, NoAction | ||||
| from .google_scholar_search import GoogleScholar | ||||
| from .google_search import GoogleSearch | ||||
| from .llm_qa import LLMQA | ||||
| from .ipython_interactive import IPythonInteractive | ||||
| from .ipython_interpreter import IPythonInterpreter | ||||
| from .parser import BaseParser, JsonParser, TupleParser | ||||
| from .ppt import PPT | ||||
| from .python_interpreter import PythonInterpreter | ||||
|  | ||||
| __all__ = [ | ||||
|     'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction', | ||||
|     'FinishAction', 'GoogleSearch', 'PythonInterpreter', 'LLMQA' | ||||
|     'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction', | ||||
|     'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch', | ||||
|     'GoogleScholar', 'IPythonInterpreter', 'IPythonInteractive', | ||||
|     'PythonInterpreter', 'PPT', 'BaseParser', 'JsonParser', 'TupleParser', | ||||
|     'tool_api', 'list_tools', 'get_tool_cls', 'get_tool' | ||||
| ] | ||||
|  | ||||
|  | ||||
| def list_tools(with_class: bool = False): | ||||
|     """List available tools. | ||||
|  | ||||
|     Args: | ||||
|         with_class (bool): whether to return the action class along | ||||
|             with its name. Defaults to ``False``. | ||||
|  | ||||
|     Returns: | ||||
|         list: all action names | ||||
|     """ | ||||
|     return list(TOOL_REGISTRY.items()) if with_class else list( | ||||
|         TOOL_REGISTRY.keys()) | ||||
|  | ||||
|  | ||||
| def get_tool_cls(specifier: str) -> Type[BaseAction]: | ||||
|     """Get the action class. | ||||
|  | ||||
|     Args: | ||||
|         specifier (:class:`str`): tool name | ||||
|  | ||||
|     Returns: | ||||
|         Type[BaseAction]: action class | ||||
|     """ | ||||
|     return TOOL_REGISTRY.get_class(specifier) | ||||
|  | ||||
|  | ||||
| def get_tool(specifier: str, *args, **kwargs) -> BaseAction: | ||||
|     """Instantiate an action. | ||||
|  | ||||
|     Args: | ||||
|         specifier (str): tool name | ||||
|         args: positional arguments passed to the action's ``__init__`` method | ||||
|         kwargs: keyword arguments passed to the action's ``__init__`` method | ||||
|  | ||||
|     Returns: | ||||
|         :class:`BaseAction`: action object | ||||
|     """ | ||||
|     return TOOL_REGISTRY.get(specifier, *args, **kwargs) | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from typing import Any, Dict, List, Union | ||||
| from typing import Dict, List, Union | ||||
|  | ||||
| from lagent.schema import ActionReturn, ActionValidCode | ||||
| from .base_action import BaseAction | ||||
| @@ -39,14 +39,20 @@ class ActionExecutor: | ||||
|         self.no_action = no_action | ||||
|         self.finish_action = finish_action | ||||
|  | ||||
|     def get_actions_info(self, only_enable: bool = True) -> Dict: | ||||
|         if only_enable: | ||||
|             return { | ||||
|                 k: v.description | ||||
|                 for k, v in self.actions.items() if v.enable | ||||
|             } | ||||
|         else: | ||||
|             return {k: v.description for k, v in self.actions.items()} | ||||
|     def get_actions_info(self) -> List[Dict]: | ||||
|         actions = [] | ||||
|         for action_name, action in self.actions.items(): | ||||
|             if not action.enable: | ||||
|                 continue | ||||
|             if action.is_toolkit: | ||||
|                 for api in action.description['api_list']: | ||||
|                     api_desc = api.copy() | ||||
|                     api_desc['name'] = f"{action_name}.{api_desc['name']}" | ||||
|                     actions.append(api_desc) | ||||
|             else: | ||||
|                 action_desc = action.description.copy() | ||||
|                 actions.append(action_desc) | ||||
|         return actions | ||||
|  | ||||
|     def is_valid(self, name: str): | ||||
|         return name in self.actions and self.actions[name].enable | ||||
| @@ -66,19 +72,17 @@ class ActionExecutor: | ||||
|         if name in self.actions: | ||||
|             del self.actions[name] | ||||
|  | ||||
|     def __call__(self, name: str, command: Any) -> ActionReturn: | ||||
|         if isinstance(command, str): | ||||
|             args, kwargs = (command, ), {} | ||||
|         else: | ||||
|             args, kwargs = (), command | ||||
|         if not self.is_valid(name): | ||||
|     def __call__(self, name: str, command: str) -> ActionReturn: | ||||
|         action_name, api_name = ( | ||||
|             name.split('.') if '.' in name else (name, 'run')) | ||||
|         if not self.is_valid(action_name): | ||||
|             if name == self.no_action.name: | ||||
|                 action_return = self.no_action.run(*args, **kwargs) | ||||
|                 action_return = self.no_action(command) | ||||
|             elif name == self.finish_action.name: | ||||
|                 action_return = self.finish_action.run(*args, **kwargs) | ||||
|                 action_return = self.finish_action(command) | ||||
|             else: | ||||
|                 action_return = self.invalid_action(*args, **kwargs) | ||||
|                 action_return = self.invalid_action(command) | ||||
|         else: | ||||
|             action_return = self.actions[name].run(*args, **kwargs) | ||||
|             action_return = self.actions[action_name](command, api_name) | ||||
|             action_return.valid = ActionValidCode.OPEN | ||||
|         return action_return | ||||
|   | ||||
							
								
								
									
										56
									
								
								lagent/actions/arxiv_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								lagent/actions/arxiv_search.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| from typing import Optional, Type | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
|  | ||||
|  | ||||
| class ArxivSearch(BaseAction): | ||||
|     """Search information from Arxiv.org. \ | ||||
| Useful for when you need to answer questions about Physics, Mathematics, \ | ||||
| Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \ | ||||
| Electrical Engineering, and Economics from scientific articles on arxiv.org. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  top_k_results: int = 3, | ||||
|                  max_query_len: int = 300, | ||||
|                  doc_content_chars_max: int = 1500, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         self.top_k_results = top_k_results | ||||
|         self.max_query_len = max_query_len | ||||
|         self.doc_content_chars_max = doc_content_chars_max | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_arxiv_article_information(self, query: str) -> dict: | ||||
|         """Run Arxiv search and get the article meta information. | ||||
|  | ||||
|         Args: | ||||
|             query (:class:`str`): the content of search query | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: article information | ||||
|                 * content (str): a list of 3 arxiv search papers | ||||
|         """ | ||||
|         import arxiv | ||||
|  | ||||
|         try: | ||||
|             results = arxiv.Search(  # type: ignore | ||||
|                 query[:self.max_query_len], | ||||
|                 max_results=self.top_k_results).results() | ||||
|         except Exception as exc: | ||||
|             return ActionReturn( | ||||
|                 errmsg=f'Arxiv exception: {exc}', | ||||
|                 state=ActionStatusCode.HTTP_ERROR) | ||||
|         docs = [ | ||||
|             f'Published: {result.updated.date()}\nTitle: {result.title}\n' | ||||
|             f'Authors: {", ".join(a.name for a in result.authors)}\n' | ||||
|             f'Summary: {result.summary[:self.doc_content_chars_max]}' | ||||
|             for result in results | ||||
|         ] | ||||
|         if docs: | ||||
|             return {'content': '\n\n'.join(docs)} | ||||
|         return {'content': 'No good Arxiv Result was found'} | ||||
| @@ -1,57 +1,385 @@ | ||||
| from typing import Optional | ||||
| import inspect | ||||
| import logging | ||||
| import re | ||||
| from abc import ABCMeta | ||||
| from copy import deepcopy | ||||
| from functools import wraps | ||||
| from typing import Callable, Optional, Type, get_args, get_origin | ||||
|  | ||||
| from lagent.schema import ActionReturn | ||||
| try: | ||||
|     from typing import Annotated | ||||
| except ImportError: | ||||
|     from typing_extensions import Annotated | ||||
|  | ||||
| from class_registry import AutoRegister, ClassRegistry | ||||
| from griffe import Docstring | ||||
| from griffe.enumerations import DocstringSectionKind | ||||
|  | ||||
| from ..schema import ActionReturn, ActionStatusCode | ||||
| from .parser import BaseParser, JsonParser, ParseError | ||||
|  | ||||
| logging.getLogger('griffe').setLevel(logging.ERROR) | ||||
|  | ||||
| TOOL_REGISTRY = ClassRegistry('__tool_name__', unique=True) | ||||
|  | ||||
|  | ||||
| class BaseAction: | ||||
| def tool_api(func: Optional[Callable] = None, | ||||
|              *, | ||||
|              explode_return: bool = False, | ||||
|              returns_named_value: bool = False, | ||||
|              **kwargs): | ||||
|     """Turn functions into tools. It will parse typehints as well as docstrings | ||||
|     to build the tool description and attach it to functions via an attribute | ||||
|     ``api_description``. | ||||
|  | ||||
|     Examples: | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|             # typehints has higher priority than docstrings | ||||
|             from typing import Annotated | ||||
|  | ||||
|             @tool_api | ||||
|             def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): | ||||
|                 '''Add operation | ||||
|  | ||||
|                 Args: | ||||
|                     x (int): a | ||||
|                     y (int): b | ||||
|                 ''' | ||||
|                 return a + b | ||||
|  | ||||
|             print(add.api_description) | ||||
|  | ||||
|     Args: | ||||
|         func (Optional[Callable]): function to decorate. Defaults to ``None``. | ||||
|         explode_return (bool): whether to flatten the dictionary or tuple return | ||||
|             as the ``return_data`` field. When enabled, it is recommended to | ||||
|             annotate the member in docstrings. Defaults to ``False``. | ||||
|  | ||||
|             .. code-block:: python | ||||
|  | ||||
|                 @tool_api(explode_return=True) | ||||
|                 def foo(a, b): | ||||
|                     '''A simple function | ||||
|  | ||||
|                     Args: | ||||
|                         a (int): a | ||||
|                         b (int): b | ||||
|  | ||||
|                     Returns: | ||||
|                         dict: information of inputs | ||||
|                             * x: value of a | ||||
|                             * y: value of b | ||||
|                     ''' | ||||
|                     return {'x': a, 'y': b} | ||||
|  | ||||
|                 print(foo.api_description) | ||||
|  | ||||
|         returns_named_value (bool): whether to parse ``thing: Description`` in | ||||
|             returns sections as a name and description, rather than a type and | ||||
|             description. When true, type must be wrapped in parentheses: | ||||
|             ``(int): Description``. When false, parentheses are optional but | ||||
|             the items cannot be named: ``int: Description``. Defaults to ``False``. | ||||
|  | ||||
|     Returns: | ||||
|         Callable: wrapped function or partial decorator | ||||
|  | ||||
|     Important: | ||||
|         ``return_data`` field will be added to ``api_description`` only | ||||
|         when ``explode_return`` or ``returns_named_value`` is enabled. | ||||
|     """ | ||||
|  | ||||
|     def _detect_type(string): | ||||
|         field_type = 'STRING' | ||||
|         if 'list' in string: | ||||
|             field_type = 'Array' | ||||
|         elif 'str' not in string: | ||||
|             if 'float' in string: | ||||
|                 field_type = 'FLOAT' | ||||
|             elif 'int' in string: | ||||
|                 field_type = 'NUMBER' | ||||
|             elif 'bool' in string: | ||||
|                 field_type = 'BOOLEAN' | ||||
|         return field_type | ||||
|  | ||||
|     def _explode(desc): | ||||
|         kvs = [] | ||||
|         desc = '\nArgs:\n' + '\n'.join([ | ||||
|             '    ' + item.lstrip(' -+*#.') | ||||
|             for item in desc.split('\n')[1:] if item.strip() | ||||
|         ]) | ||||
|         docs = Docstring(desc).parse('google') | ||||
|         if not docs: | ||||
|             return kvs | ||||
|         if docs[0].kind is DocstringSectionKind.parameters: | ||||
|             for d in docs[0].value: | ||||
|                 d = d.as_dict() | ||||
|                 if not d['annotation']: | ||||
|                     d.pop('annotation') | ||||
|                 else: | ||||
|                     d['type'] = _detect_type(d.pop('annotation').lower()) | ||||
|                 kvs.append(d) | ||||
|         return kvs | ||||
|  | ||||
|     def _parse_tool(function): | ||||
|         # remove rst syntax | ||||
|         docs = Docstring( | ||||
|             re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( | ||||
|                 'google', returns_named_value=returns_named_value, **kwargs) | ||||
|         desc = dict( | ||||
|             name=function.__name__, | ||||
|             description=docs[0].value | ||||
|             if docs[0].kind is DocstringSectionKind.text else '', | ||||
|             parameters=[], | ||||
|             required=[], | ||||
|         ) | ||||
|         args_doc, returns_doc = {}, [] | ||||
|         for doc in docs: | ||||
|             if doc.kind is DocstringSectionKind.parameters: | ||||
|                 for d in doc.value: | ||||
|                     d = d.as_dict() | ||||
|                     d['type'] = _detect_type(d.pop('annotation').lower()) | ||||
|                     args_doc[d['name']] = d | ||||
|             if doc.kind is DocstringSectionKind.returns: | ||||
|                 for d in doc.value: | ||||
|                     d = d.as_dict() | ||||
|                     if not d['name']: | ||||
|                         d.pop('name') | ||||
|                     if not d['annotation']: | ||||
|                         d.pop('annotation') | ||||
|                     else: | ||||
|                         d['type'] = _detect_type(d.pop('annotation').lower()) | ||||
|                     returns_doc.append(d) | ||||
|  | ||||
|         sig = inspect.signature(function) | ||||
|         for name, param in sig.parameters.items(): | ||||
|             if name == 'self': | ||||
|                 continue | ||||
|             parameter = dict( | ||||
|                 name=param.name, | ||||
|                 type='STRING', | ||||
|                 description=args_doc.get(param.name, | ||||
|                                          {}).get('description', '')) | ||||
|             annotation = param.annotation | ||||
|             if annotation is inspect.Signature.empty: | ||||
|                 parameter['type'] = args_doc.get(param.name, | ||||
|                                                  {}).get('type', 'STRING') | ||||
|             else: | ||||
|                 if get_origin(annotation) is Annotated: | ||||
|                     annotation, info = get_args(annotation) | ||||
|                     if info: | ||||
|                         parameter['description'] = info | ||||
|                 while get_origin(annotation): | ||||
|                     annotation = get_args(annotation) | ||||
|                 parameter['type'] = _detect_type(str(annotation)) | ||||
|             desc['parameters'].append(parameter) | ||||
|             if param.default is inspect.Signature.empty: | ||||
|                 desc['required'].append(param.name) | ||||
|  | ||||
|         return_data = [] | ||||
|         if explode_return: | ||||
|             return_data = _explode(returns_doc[0]['description']) | ||||
|         elif returns_named_value: | ||||
|             return_data = returns_doc | ||||
|         if return_data: | ||||
|             desc['return_data'] = return_data | ||||
|         return desc | ||||
|  | ||||
|     if callable(func): | ||||
|  | ||||
|         @wraps(func) | ||||
|         def wrapper(self, *args, **kwargs): | ||||
|             return func(self, *args, **kwargs) | ||||
|  | ||||
|         wrapper.api_description = _parse_tool(func) | ||||
|         return wrapper | ||||
|  | ||||
|     def decorate(func): | ||||
|  | ||||
|         @wraps(func) | ||||
|         def wrapper(self, *args, **kwargs): | ||||
|             return func(self, *args, **kwargs) | ||||
|  | ||||
|         wrapper.api_description = _parse_tool(func) | ||||
|         return wrapper | ||||
|  | ||||
|     return decorate | ||||
|  | ||||
|  | ||||
| class ToolMeta(ABCMeta): | ||||
|     """Metaclass of tools.""" | ||||
|  | ||||
|     def __new__(mcs, name, base, attrs): | ||||
|         is_toolkit, tool_desc = True, dict( | ||||
|             name=attrs.setdefault('__tool_name__', name), | ||||
|             description=Docstring(attrs.get('__doc__', | ||||
|                                             '')).parse('google')[0].value) | ||||
|         for key, value in attrs.items(): | ||||
|             if callable(value) and hasattr(value, 'api_description'): | ||||
|                 api_desc = getattr(value, 'api_description') | ||||
|                 if key == 'run': | ||||
|                     tool_desc['parameters'] = api_desc['parameters'] | ||||
|                     tool_desc['required'] = api_desc['required'] | ||||
|                     if api_desc['description']: | ||||
|                         tool_desc['description'] = api_desc['description'] | ||||
|                     if api_desc.get('return_data'): | ||||
|                         tool_desc['return_data'] = api_desc['return_data'] | ||||
|                     is_toolkit = False | ||||
|                 else: | ||||
|                     tool_desc.setdefault('api_list', []).append(api_desc) | ||||
|         if not is_toolkit and 'api_list' in tool_desc: | ||||
|             raise KeyError('`run` and other tool APIs can not be implemented ' | ||||
|                            'at the same time') | ||||
|         if is_toolkit and 'api_list' not in tool_desc: | ||||
|             is_toolkit = False | ||||
|             if callable(attrs.get('run')): | ||||
|                 run_api = tool_api(attrs['run']) | ||||
|                 api_desc = run_api.api_description | ||||
|                 tool_desc['parameters'] = api_desc['parameters'] | ||||
|                 tool_desc['required'] = api_desc['required'] | ||||
|                 if api_desc['description']: | ||||
|                     tool_desc['description'] = api_desc['description'] | ||||
|                 if api_desc.get('return_data'): | ||||
|                     tool_desc['return_data'] = api_desc['return_data'] | ||||
|                 attrs['run'] = run_api | ||||
|             else: | ||||
|                 tool_desc['parameters'], tool_desc['required'] = [], [] | ||||
|         attrs['_is_toolkit'] = is_toolkit | ||||
|         attrs['__tool_description__'] = tool_desc | ||||
|         return super().__new__(mcs, name, base, attrs) | ||||
|  | ||||
|  | ||||
| class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)): | ||||
|     """Base class for all actions. | ||||
|  | ||||
|     Args: | ||||
|         description (str, optional): The description of the action. Defaults to | ||||
|             None. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class name. Defaults to None. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|         description (:class:`Optional[dict]`): The description of the action. | ||||
|             Defaults to ``None``. | ||||
|         parser (:class:`Type[BaseParser]`): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (:class:`bool`): Whether the action is enabled. Defaults to | ||||
|             ``True``. | ||||
|  | ||||
|     Examples: | ||||
|  | ||||
|         * simple tool | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|             class Bold(BaseAction): | ||||
|                 '''Make text bold''' | ||||
|  | ||||
|                 def run(self, text: str): | ||||
|                     ''' | ||||
|                     Args: | ||||
|                         text (str): input text | ||||
|  | ||||
|                     Returns: | ||||
|                         str: bold text | ||||
|                     ''' | ||||
|                     return '**' + text + '**' | ||||
|  | ||||
|             action = Bold() | ||||
|  | ||||
|         * toolkit with multiple APIs | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|             class Calculator(BaseAction): | ||||
|                 '''Calculator''' | ||||
|  | ||||
|                 @tool_api | ||||
|                 def add(self, a, b): | ||||
|                     '''Add operation | ||||
|  | ||||
|                     Args: | ||||
|                         a (int): augend | ||||
|                         b (int): addend | ||||
|  | ||||
|                     Returns: | ||||
|                         int: sum | ||||
|                     ''' | ||||
|                     return a + b | ||||
|  | ||||
|                 @tool_api | ||||
|                 def sub(self, a, b): | ||||
|                     '''Subtraction operation | ||||
|  | ||||
|                     Args: | ||||
|                         a (int): minuend | ||||
|                         b (int): subtrahend | ||||
|  | ||||
|                     Returns: | ||||
|                         int: difference | ||||
|                     ''' | ||||
|                     return a - b | ||||
|  | ||||
|             action = Calculator() | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  description: Optional[str] = None, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None) -> None: | ||||
|         if name is None: | ||||
|             name = self.__class__.__name__ | ||||
|         self._name = name | ||||
|         self._description = description | ||||
|         self._disable_description = disable_description | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         self._description = deepcopy(description or self.__tool_description__) | ||||
|         self._name = self._description['name'] | ||||
|         self._parser = parser(self) | ||||
|         self._enable = enable | ||||
|  | ||||
|     def __call__(self, *args, **kwargs) -> ActionReturn: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.name}:{self.description}' | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.__repr__() | ||||
|  | ||||
|     def run(self, *args, **kwargs) -> ActionReturn: | ||||
|         return self.__call__(*args, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def enable(self): | ||||
|         return self._enable | ||||
|     def __call__(self, inputs: str, name='run') -> ActionReturn: | ||||
|         fallback_args = {'inputs': inputs, 'name': name} | ||||
|         if not hasattr(self, name): | ||||
|             return ActionReturn( | ||||
|                 fallback_args, | ||||
|                 type=self.name, | ||||
|                 errmsg=f'invalid API: {name}', | ||||
|                 state=ActionStatusCode.API_ERROR) | ||||
|         try: | ||||
|             inputs = self._parser.parse_inputs(inputs, name) | ||||
|         except ParseError as exc: | ||||
|             return ActionReturn( | ||||
|                 fallback_args, | ||||
|                 type=self.name, | ||||
|                 errmsg=exc.err_msg, | ||||
|                 state=ActionStatusCode.ARGS_ERROR) | ||||
|         try: | ||||
|             outputs = getattr(self, name)(**inputs) | ||||
|         except Exception as exc: | ||||
|             return ActionReturn( | ||||
|                 inputs, | ||||
|                 type=self.name, | ||||
|                 errmsg=str(exc), | ||||
|                 state=ActionStatusCode.API_ERROR) | ||||
|         if isinstance(outputs, ActionReturn): | ||||
|             action_return = outputs | ||||
|             if not action_return.args: | ||||
|                 action_return.args = inputs | ||||
|             if not action_return.type: | ||||
|                 action_return.type = self.name | ||||
|         else: | ||||
|             result = self._parser.parse_outputs(outputs) | ||||
|             action_return = ActionReturn(inputs, type=self.name, result=result) | ||||
|         return action_return | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return self._name | ||||
|  | ||||
|     @property | ||||
|     def description(self): | ||||
|         if self.enable: | ||||
|             return self._description | ||||
|         else: | ||||
|             return self._disable_description | ||||
|     def enable(self): | ||||
|         return self._enable | ||||
|  | ||||
|     @property | ||||
|     def is_toolkit(self): | ||||
|         return self._is_toolkit | ||||
|  | ||||
|     @property | ||||
|     def description(self) -> dict: | ||||
|         """Description of the tool.""" | ||||
|         return self._description | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'{self.description}' | ||||
|  | ||||
|     __str__ = __repr__ | ||||
|   | ||||
							
								
								
									
										144
									
								
								lagent/actions/bing_map.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								lagent/actions/bing_map.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| # flake8: noqa: E501 | ||||
| import json | ||||
| import os | ||||
| from typing import Optional, Type | ||||
|  | ||||
| import requests | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
|  | ||||
|  | ||||
| class BINGMap(BaseAction): | ||||
|     """BING Map plugin for looking up map information.""" | ||||
|  | ||||
|     def __init__(self, | ||||
|                  key: Optional[str] = None, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True) -> None: | ||||
|         super().__init__(description, parser, enable) | ||||
|         key = os.environ.get('BING_MAP_KEY', key) | ||||
|         if key is None: | ||||
|             raise ValueError( | ||||
|                 'Please set BING Map API key either in the environment ' | ||||
|                 'as BING_MAP_KEY or pass it as `key` parameter.') | ||||
|         self.key = key | ||||
|         self.base_url = 'http://dev.virtualearth.net/REST/V1/' | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_distance(self, start: str, end: str) -> dict: | ||||
|         """Get the distance between two locations in km. | ||||
|  | ||||
|         Args: | ||||
|             start (:class:`str`): The start location | ||||
|             end (:class:`str`): The end location | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: distance information | ||||
|                 * distance (str): the distance in km. | ||||
|         """ | ||||
|         # Request URL | ||||
|         url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key | ||||
|         # GET request | ||||
|         r = requests.get(url) | ||||
|         # TODO check request status? | ||||
|         data = json.loads(r.text) | ||||
|         # Extract route information | ||||
|         route = data['resourceSets'][0]['resources'][0] | ||||
|         # Extract distance in miles | ||||
|         distance = route['travelDistance'] | ||||
|         return dict(distance=distance) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_route(self, start: str, end: str) -> dict: | ||||
|         """Get the route between two locations in km. | ||||
|  | ||||
|         Args: | ||||
|             start (:class:`str`): The start location | ||||
|             end (:class:`str`): The end location | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: route information | ||||
|                 * route (list): the route, a list of actions. | ||||
|         """ | ||||
|         # Request URL | ||||
|         url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key | ||||
|         # GET request | ||||
|         r = requests.get(url) | ||||
|         data = json.loads(r.text) | ||||
|         # Extract route information | ||||
|         route = data['resourceSets'][0]['resources'][0] | ||||
|         itinerary = route['routeLegs'][0]['itineraryItems'] | ||||
|         # Extract route text information | ||||
|         route_text = [] | ||||
|         for item in itinerary: | ||||
|             if 'instruction' in item: | ||||
|                 route_text.append(item['instruction']['text']) | ||||
|         return dict(route=route_text) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_coordinates(self, location: str) -> dict: | ||||
|         """Get the coordinates of a location. | ||||
|  | ||||
|         Args: | ||||
|             location (:class:`str`): the location need to get coordinates. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: coordinates information | ||||
|                 * latitude (float): the latitude of the location. | ||||
|                 * longitude (float): the longitude of the location. | ||||
|         """ | ||||
|         url = self.base_url + 'Locations' | ||||
|         params = {'query': location, 'key': self.key} | ||||
|         response = requests.get(url, params=params) | ||||
|         json_data = response.json() | ||||
|         coordinates = json_data['resourceSets'][0]['resources'][0]['point'][ | ||||
|             'coordinates'] | ||||
|         return dict(latitude=coordinates[0], longitude=coordinates[1]) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def search_nearby(self, | ||||
|                       search_term: str, | ||||
|                       places: str = 'unknown', | ||||
|                       latitude: float = 0.0, | ||||
|                       longitude: float = 0.0, | ||||
|                       radius: int = 5000) -> dict: | ||||
|         """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude. | ||||
|  | ||||
|         Args: | ||||
|             search_term (:class:`str`): the place name. | ||||
|             places (:class:`str`): the name of the location. Defaults to ``'unknown'``. | ||||
|             latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``. | ||||
|             longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``. | ||||
|             radius (:class:`int`): radius in meters. Defaults to ``5000``. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: places information | ||||
|                 * places (list): the list of places, each place is a dict with name and address, at most 5 places. | ||||
|         """ | ||||
|         url = self.base_url + 'LocalSearch' | ||||
|         if places != 'unknown': | ||||
|             pos = self.get_coordinates(**{'location': places}) | ||||
|             latitude, longitude = pos[1]['latitude'], pos[1]['longitude'] | ||||
|         # Build the request query string | ||||
|         params = { | ||||
|             'query': search_term, | ||||
|             'userLocation': f'{latitude},{longitude}', | ||||
|             'radius': radius, | ||||
|             'key': self.key | ||||
|         } | ||||
|         # Make the request | ||||
|         response = requests.get(url, params=params) | ||||
|         # Parse the response | ||||
|         response_data = json.loads(response.content) | ||||
|         # Get the results | ||||
|         results = response_data['resourceSets'][0]['resources'] | ||||
|         addresses = [] | ||||
|         for result in results: | ||||
|             name = result['name'] | ||||
|             address = result['Address']['formattedAddress'] | ||||
|             addresses.append(dict(name=name, address=address)) | ||||
|             if len(addresses) == 5: | ||||
|                 break | ||||
|         return dict(place=addresses) | ||||
| @@ -1,6 +1,7 @@ | ||||
| from typing import Optional | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode | ||||
|  | ||||
|  | ||||
| @@ -19,12 +20,13 @@ class InvalidAction(BaseAction): | ||||
|     def __init__(self, | ||||
|                  err_msg: | ||||
|                  str = 'The action is invalid, please check the action name.', | ||||
|                  **kwargs) -> None: | ||||
|  | ||||
|         super().__init__(enable=False, **kwargs) | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser=BaseParser) -> None: | ||||
|         super().__init__(description, parser, enable=False) | ||||
|         self._err_msg = err_msg | ||||
|  | ||||
|     def __call__(self, err_msg: Optional[str] = None): | ||||
|     @tool_api | ||||
|     def run(self, err_msg: Optional[str] = None) -> ActionReturn: | ||||
|         """Return the error message. | ||||
|  | ||||
|         Args: | ||||
| @@ -35,7 +37,7 @@ class InvalidAction(BaseAction): | ||||
|         action_return = ActionReturn( | ||||
|             url=None, | ||||
|             args=dict(text=err_msg), | ||||
|             errmsg=err_msg if err_msg else self._err_msg, | ||||
|             errmsg=err_msg or self._err_msg, | ||||
|             type=self.name, | ||||
|             valid=ActionValidCode.INVALID, | ||||
|             state=ActionStatusCode.API_ERROR) | ||||
| @@ -51,12 +53,15 @@ class NoAction(BaseAction): | ||||
|             'Please follow the format'. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, err_msg: str = 'Please follow the format', **kwargs): | ||||
|  | ||||
|         super().__init__(enable=False, **kwargs) | ||||
|     def __init__(self, | ||||
|                  err_msg: str = 'Please follow the format', | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser=BaseParser): | ||||
|         super().__init__(description, parser, enable=False) | ||||
|         self._err_msg = err_msg | ||||
|  | ||||
|     def __call__(self, err_msg: Optional[str] = None): | ||||
|     @tool_api | ||||
|     def run(self, err_msg: Optional[str] = None) -> ActionReturn: | ||||
|         """Return the error message. | ||||
|  | ||||
|         Args: | ||||
| @@ -71,7 +76,7 @@ class NoAction(BaseAction): | ||||
|             url=None, | ||||
|             args=dict(text=err_msg), | ||||
|             type=self.name, | ||||
|             errmsg=err_msg if err_msg else self._err_msg, | ||||
|             errmsg=err_msg or self._err_msg, | ||||
|             valid=ActionValidCode.INVALID, | ||||
|             state=ActionStatusCode.API_ERROR) | ||||
|         return action_return | ||||
| @@ -81,7 +86,11 @@ class FinishAction(BaseAction): | ||||
|     """This is a finish action class, which is used to return the final | ||||
|     result.""" | ||||
|  | ||||
|     def __call__(self, response: str) -> ActionReturn: | ||||
|     def __init__(self, description: Optional[dict] = None, parser=BaseParser): | ||||
|         super().__init__(description, parser, enable=True) | ||||
|  | ||||
|     @tool_api | ||||
|     def run(self, response: str) -> ActionReturn: | ||||
|         """Return the final result. | ||||
|  | ||||
|         Args: | ||||
| @@ -93,7 +102,7 @@ class FinishAction(BaseAction): | ||||
|         action_return = ActionReturn( | ||||
|             url=None, | ||||
|             args=dict(text=response), | ||||
|             result=dict(text=response), | ||||
|             result=[dict(type='text', content=response)], | ||||
|             type=self.name, | ||||
|             valid=ActionValidCode.FINISH, | ||||
|             state=ActionStatusCode.SUCCESS) | ||||
|   | ||||
							
								
								
									
										270
									
								
								lagent/actions/google_scholar_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										270
									
								
								lagent/actions/google_scholar_search.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,270 @@ | ||||
| # flake8: noqa: E501 | ||||
| import os | ||||
| from typing import Optional, Type | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
| from .parser import BaseParser, JsonParser | ||||
|  | ||||
|  | ||||
| class GoogleScholar(BaseAction): | ||||
|     """Plugin for google scholar search. | ||||
|  | ||||
|     Args: | ||||
|         api_key (str): API KEY to use serper google search API, | ||||
|             You can create a free API key at https://serper.dev. | ||||
|         description (dict): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  api_key: Optional[str] = None, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         api_key = os.environ.get('SERPER_API_KEY', api_key) | ||||
|         if api_key is None: | ||||
|             raise ValueError( | ||||
|                 'Please set Serper API key either in the environment ' | ||||
|                 'as SERPER_API_KEY or pass it as `api_key` parameter.') | ||||
|         self.api_key = api_key | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def search_google_scholar( | ||||
|         self, | ||||
|         query: str, | ||||
|         cites: Optional[str] = None, | ||||
|         as_ylo: Optional[int] = None, | ||||
|         as_yhi: Optional[int] = None, | ||||
|         scisbd: Optional[int] = None, | ||||
|         cluster: Optional[str] = None, | ||||
|         hl: Optional[str] = None, | ||||
|         lr: Optional[str] = None, | ||||
|         start: Optional[int] = None, | ||||
|         num: Optional[int] = None, | ||||
|         as_sdt: Optional[str] = None, | ||||
|         safe: Optional[str] = None, | ||||
|         filter: Optional[str] = None, | ||||
|         as_vis: Optional[str] = None, | ||||
|     ) -> dict: | ||||
|         """Search for scholarly articles based on a query according to the google scholar. | ||||
|  | ||||
|         Args: | ||||
|             query (str): The query to search for. | ||||
|             cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches. | ||||
|             as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted). | ||||
|             as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted). | ||||
|             scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything. | ||||
|             cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches. | ||||
|             hl (Optional[str]): The language to use for the Google Scholar search. | ||||
|             lr (Optional[str]): One or multiple languages to limit the search to. | ||||
|             start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.) | ||||
|             num (Optional[int]): The maximum number of results to return, limited to 20. | ||||
|             as_sdt (Optional[str]): Can be used either as a search type or a filter. | ||||
|             safe (Optional[str]): The level of filtering for adult content. | ||||
|             filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off. | ||||
|             as_vis (Optional[str]): Defines whether to include citations or not. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: article information | ||||
|                 - title: a list of the titles of the three selected papers | ||||
|                 - cited_by: a list of the citation numbers of the three selected papers | ||||
|                 - organic_id: a list of the organic results' ids of the three selected papers | ||||
|                 - pub_info: publication information of selected papers | ||||
|         """ | ||||
|         from serpapi import GoogleSearch | ||||
|         params = { | ||||
|             'q': query, | ||||
|             'engine': 'google_scholar', | ||||
|             'api_key': self.api_key, | ||||
|             'cites': cites, | ||||
|             'as_ylo': as_ylo, | ||||
|             'as_yhi': as_yhi, | ||||
|             'scisbd': scisbd, | ||||
|             'cluster': cluster, | ||||
|             'hl': hl, | ||||
|             'lr': lr, | ||||
|             'start': start, | ||||
|             'num': num, | ||||
|             'as_sdt': as_sdt, | ||||
|             'safe': safe, | ||||
|             'filter': filter, | ||||
|             'as_vis': as_vis | ||||
|         } | ||||
|         search = GoogleSearch(params) | ||||
|         try: | ||||
|             r = search.get_dict() | ||||
|             results = r['organic_results'] | ||||
|             title = [] | ||||
|             snippets = [] | ||||
|             cited_by = [] | ||||
|             organic_id = [] | ||||
|             pub_info = [] | ||||
|             for item in results[:3]: | ||||
|                 title.append(item['title']) | ||||
|                 pub_info.append(item['publication_info']['summary']) | ||||
|                 citation = item['inline_links'].get('cited_by', {'total': ''}) | ||||
|                 cited_by.append(citation['total']) | ||||
|                 snippets.append(item['snippet']) | ||||
|                 organic_id.append(item['result_id']) | ||||
|             return dict( | ||||
|                 title=title, | ||||
|                 cited_by=cited_by, | ||||
|                 organic_id=organic_id, | ||||
|                 snippets=snippets) | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_author_information(self, | ||||
|                                author_id: str, | ||||
|                                hl: Optional[str] = None, | ||||
|                                view_op: Optional[str] = None, | ||||
|                                sort: Optional[str] = None, | ||||
|                                citation_id: Optional[str] = None, | ||||
|                                start: Optional[int] = None, | ||||
|                                num: Optional[int] = None, | ||||
|                                no_cache: Optional[bool] = None, | ||||
|                                async_req: Optional[bool] = None, | ||||
|                                output: Optional[str] = None) -> dict: | ||||
|         """Search for an author's information by author's id provided by get_author_id. | ||||
|  | ||||
|         Args: | ||||
|             author_id (str): Required. The ID of an author. | ||||
|             hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'. | ||||
|             view_op (Optional[str]): Used for viewing specific parts of a page. | ||||
|             sort (Optional[str]): Used for sorting and refining articles. | ||||
|             citation_id (Optional[str]): Used for retrieving individual article citation. | ||||
|             start (Optional[int]): Defines the result offset. Default is 0. | ||||
|             num (Optional[int]): Defines the number of results to return. Default is 20. | ||||
|             no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False. | ||||
|             async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False. | ||||
|             output (Optional[str]): Defines the final output you want. Default is 'json'. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: author information | ||||
|                 * name: author's name | ||||
|                 * affliation: the affliation of the author | ||||
|                 * articles: at most 3 articles by the author | ||||
|                 * website: the author's homepage url | ||||
|         """ | ||||
|         from serpapi import GoogleSearch | ||||
|         params = { | ||||
|             'engine': 'google_scholar_author', | ||||
|             'author_id': author_id, | ||||
|             'api_key': self.api_key, | ||||
|             'hl': hl, | ||||
|             'view_op': view_op, | ||||
|             'sort': sort, | ||||
|             'citation_id': citation_id, | ||||
|             'start': start, | ||||
|             'num': num, | ||||
|             'no_cache': no_cache, | ||||
|             'async': async_req, | ||||
|             'output': output | ||||
|         } | ||||
|         try: | ||||
|             search = GoogleSearch(params) | ||||
|             results = search.get_dict() | ||||
|             author = results['author'] | ||||
|             articles = results.get('articles', []) | ||||
|             return dict( | ||||
|                 name=author['name'], | ||||
|                 affiliations=author.get('affiliations', ''), | ||||
|                 website=author.get('website', ''), | ||||
|                 articles=[ | ||||
|                     dict(title=article['title'], authors=article['authors']) | ||||
|                     for article in articles[:3] | ||||
|                 ]) | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_citation_format(self, | ||||
|                             q: str, | ||||
|                             no_cache: Optional[bool] = None, | ||||
|                             async_: Optional[bool] = None, | ||||
|                             output: Optional[str] = 'json') -> dict: | ||||
|         """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar. | ||||
|  | ||||
|         Args: | ||||
|             q (str): ID of an individual Google Scholar organic search result. | ||||
|             no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None. | ||||
|             async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None. | ||||
|             output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: citation format | ||||
|                 * authors: the authors of the article | ||||
|                 * citation: the citation format of the article | ||||
|         """ | ||||
|         from serpapi import GoogleSearch | ||||
|         params = { | ||||
|             'q': q, | ||||
|             'engine': 'google_scholar_cite', | ||||
|             'api_key': self.api_key, | ||||
|             'no_cache': no_cache, | ||||
|             'async': async_, | ||||
|             'output': output | ||||
|         } | ||||
|         try: | ||||
|             search = GoogleSearch(params) | ||||
|             results = search.get_dict() | ||||
|             citation = results['citations'] | ||||
|             citation_info = citation[0]['snippet'] | ||||
|             return citation_info | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def get_author_id(self, | ||||
|                       mauthors: str, | ||||
|                       hl: Optional[str] = 'en', | ||||
|                       after_author: Optional[str] = None, | ||||
|                       before_author: Optional[str] = None, | ||||
|                       no_cache: Optional[bool] = False, | ||||
|                       _async: Optional[bool] = False, | ||||
|                       output: Optional[str] = 'json') -> dict: | ||||
|         """The getAuthorId function is used to get the author's id by his or her name. | ||||
|  | ||||
|         Args: | ||||
|             mauthors (str): Defines the author you want to search for. | ||||
|             hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'. | ||||
|             after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None. | ||||
|             before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None. | ||||
|             no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False. | ||||
|             _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False. | ||||
|             output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: author id | ||||
|                 * author_id: the author_id of the author | ||||
|         """ | ||||
|         from serpapi import GoogleSearch | ||||
|         params = { | ||||
|             'mauthors': mauthors, | ||||
|             'engine': 'google_scholar_profiles', | ||||
|             'api_key': self.api_key, | ||||
|             'hl': hl, | ||||
|             'after_author': after_author, | ||||
|             'before_author': before_author, | ||||
|             'no_cache': no_cache, | ||||
|             'async': _async, | ||||
|             'output': output | ||||
|         } | ||||
|         try: | ||||
|             search = GoogleSearch(params) | ||||
|             results = search.get_dict() | ||||
|             profile = results['profiles'] | ||||
|             author_info = dict(author_id=profile[0]['author_id']) | ||||
|             return author_info | ||||
|         except Exception as e: | ||||
|             return ActionReturn( | ||||
|                 errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) | ||||
| @@ -1,15 +1,11 @@ | ||||
| import os | ||||
| from typing import List, Optional, Tuple, Union | ||||
| from typing import List, Optional, Tuple, Type, Union | ||||
|  | ||||
| import requests | ||||
|  | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
| from .base_action import BaseAction | ||||
|  | ||||
| DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。 | ||||
| 当你需要对于一个特定问题找到简短明了的回答时,可以使用它。 | ||||
| 输入应该是一个搜索查询。 | ||||
| """ | ||||
| from .base_action import BaseAction, tool_api | ||||
| from .parser import BaseParser, JsonParser | ||||
|  | ||||
|  | ||||
| class GoogleSearch(BaseAction): | ||||
| @@ -28,15 +24,10 @@ class GoogleSearch(BaseAction): | ||||
|         timeout (int): Upper bound of waiting time for a serper request. | ||||
|         search_type (str): Serper API support ['search', 'images', 'news', | ||||
|             'places'] types of search, currently we only support 'search'. | ||||
|         k (int): select first k results in the search results as response. | ||||
|         description (str): The description of the action. Defaults to | ||||
|             None. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class name. Defaults to None. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|         description (dict): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool): Whether the action is enabled. Defaults to ``True``. | ||||
|     """ | ||||
|     result_key_for_type = { | ||||
|         'news': 'news', | ||||
| @@ -49,43 +40,36 @@ class GoogleSearch(BaseAction): | ||||
|                  api_key: Optional[str] = None, | ||||
|                  timeout: int = 5, | ||||
|                  search_type: str = 'search', | ||||
|                  k: int = 10, | ||||
|                  description: str = DEFAULT_DESCRIPTION, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None) -> None: | ||||
|         super().__init__(description, name, enable, disable_description) | ||||
|  | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         api_key = os.environ.get('SERPER_API_KEY', api_key) | ||||
|         if api_key is None: | ||||
|             raise ValueError( | ||||
|                 'Please set Serper API key either in the environment ' | ||||
|                 ' as SERPER_API_KEY or pass it as `api_key` parameter.') | ||||
|                 'as SERPER_API_KEY or pass it as `api_key` parameter.') | ||||
|         self.api_key = api_key | ||||
|         self.timeout = timeout | ||||
|         self.search_type = search_type | ||||
|         self.k = k | ||||
|  | ||||
|     def __call__(self, query: str) -> ActionReturn: | ||||
|         """Return the search response. | ||||
|     @tool_api | ||||
|     def run(self, query: str, k: int = 10) -> ActionReturn: | ||||
|         """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。 | ||||
|  | ||||
|         Args: | ||||
|             query (str): The search content. | ||||
|  | ||||
|         Returns: | ||||
|             ActionReturn: The action return. | ||||
|             query (str): the search content | ||||
|             k (int): select first k results in the search results as response | ||||
|         """ | ||||
|  | ||||
|         tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|         status_code, response = self._search( | ||||
|             query, search_type=self.search_type, k=self.k) | ||||
|         tool_return = ActionReturn(type=self.name) | ||||
|         status_code, response = self._search(query, k=k) | ||||
|         # convert search results to ToolReturn format | ||||
|         if status_code == -1: | ||||
|             tool_return.errmsg = response | ||||
|             tool_return.state = ActionStatusCode.HTTP_ERROR | ||||
|         elif status_code == 200: | ||||
|             parsed_res = self._parse_results(response) | ||||
|             tool_return.result = dict(text=str(parsed_res)) | ||||
|             tool_return.result = [dict(type='text', content=str(parsed_res))] | ||||
|             tool_return.state = ActionStatusCode.SUCCESS | ||||
|         else: | ||||
|             tool_return.errmsg = str(status_code) | ||||
| @@ -139,7 +123,7 @@ class GoogleSearch(BaseAction): | ||||
|  | ||||
|     def _search(self, | ||||
|                 search_term: str, | ||||
|                 search_type: str = 'search', | ||||
|                 search_type: Optional[str] = None, | ||||
|                 **kwargs) -> Tuple[int, Union[dict, str]]: | ||||
|         """HTTP requests to Serper API. | ||||
|  | ||||
| @@ -166,7 +150,7 @@ class GoogleSearch(BaseAction): | ||||
|         } | ||||
|         try: | ||||
|             response = requests.post( | ||||
|                 f'https://google.serper.dev/{search_type}', | ||||
|                 f'https://google.serper.dev/{search_type or self.search_type}', | ||||
|                 headers=headers, | ||||
|                 params=params, | ||||
|                 timeout=self.timeout) | ||||
|   | ||||
							
								
								
									
										188
									
								
								lagent/actions/ipython_interactive.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										188
									
								
								lagent/actions/ipython_interactive.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,188 @@ | ||||
| import re | ||||
| from contextlib import redirect_stdout | ||||
| from dataclasses import dataclass | ||||
| from enum import Enum | ||||
| from io import StringIO | ||||
| from typing import Optional, Type | ||||
|  | ||||
| from ..schema import ActionReturn, ActionStatusCode | ||||
| from .base_action import BaseAction, tool_api | ||||
| from .parser import BaseParser, JsonParser | ||||
|  | ||||
|  | ||||
| class Status(str, Enum): | ||||
|     """Execution status.""" | ||||
|     SUCCESS = 'success' | ||||
|     FAILURE = 'failure' | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ExecutionResult: | ||||
|     """Execution result.""" | ||||
|     status: Status | ||||
|     value: Optional[str] = None | ||||
|     msg: Optional[str] = None | ||||
|  | ||||
|  | ||||
| class IPythonInteractive(BaseAction): | ||||
|     """An interactive IPython shell for code execution. | ||||
|  | ||||
|     Args: | ||||
|         timeout (int): Upper bound of waiting time for Python script execution. | ||||
|             Defaults to ``20``. | ||||
|         max_out_len (int): maximum output length. No truncation occurs if negative. | ||||
|             Defaults to ``2048``. | ||||
|         use_signals (bool): whether signals should be used for timing function out | ||||
|             or the multiprocessing. Set to ``False`` when not running in the main | ||||
|             thread, e.g. web applications. Defaults to ``True`` | ||||
|         description (dict): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool): Whether the action is enabled. Defaults to ``True``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         timeout: int = 30, | ||||
|         max_out_len: int = 2048, | ||||
|         use_signals: bool = True, | ||||
|         description: Optional[dict] = None, | ||||
|         parser: Type[BaseParser] = JsonParser, | ||||
|         enable: bool = True, | ||||
|     ): | ||||
|         super().__init__(description, parser, enable) | ||||
|         from IPython import InteractiveShell | ||||
|         self.timeout = timeout | ||||
|         self._executor = InteractiveShell() | ||||
|         self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m') | ||||
|         self._max_out_len = max_out_len if max_out_len >= 0 else None | ||||
|         self._use_signals = use_signals | ||||
|  | ||||
|     def reset(self): | ||||
|         """Clear the context.""" | ||||
|         self._executor.reset() | ||||
|  | ||||
|     @tool_api | ||||
|     def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: | ||||
|         """Launch an IPython Interactive Shell to execute code. | ||||
|  | ||||
|         Args: | ||||
|             command (:class:`str`): Python code snippet | ||||
|             timeout (:class:`Optional[int]`): timeout for execution. | ||||
|                 This argument only works in the main thread. Defaults to ``None``. | ||||
|         """ | ||||
|         from timeout_decorator import timeout as timer | ||||
|         tool_return = ActionReturn(args={'text': command}, type=self.name) | ||||
|         ret = ( | ||||
|             timer(timeout or self.timeout)(self.exec)(command) | ||||
|             if self._use_signals else self.exec(command)) | ||||
|         if ret.status is Status.SUCCESS: | ||||
|             tool_return.result = [{'type': 'text', 'content': ret.value}] | ||||
|             tool_return.state = ActionStatusCode.SUCCESS | ||||
|         else: | ||||
|             tool_return.errmsg = ret.msg | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|         return tool_return | ||||
|  | ||||
|     def exec(self, code: str) -> ExecutionResult: | ||||
|         """Run Python scripts in IPython shell. | ||||
|  | ||||
|         Args: | ||||
|             code (:class:`str`): code block | ||||
|  | ||||
|         Returns: | ||||
|             :py:class:`ExecutionResult`: execution result | ||||
|         """ | ||||
|         with StringIO() as io: | ||||
|             with redirect_stdout(io): | ||||
|                 ret = self._executor.run_cell(self.extract_code(code)) | ||||
|                 result = ret.result | ||||
|                 if result is not None: | ||||
|                     return ExecutionResult(Status.SUCCESS, | ||||
|                                            str(result)[:self._max_out_len]) | ||||
|             outs = io.getvalue().strip().split('\n') | ||||
|         if not outs: | ||||
|             return ExecutionResult(Status.SUCCESS, '') | ||||
|         for i, out in enumerate(outs): | ||||
|             if re.search('Error|Traceback', out, re.S): | ||||
|                 if 'TimeoutError' in out: | ||||
|                     return ExecutionResult( | ||||
|                         Status.FAILURE, | ||||
|                         msg=('The code interpreter encountered ' | ||||
|                              'an unexpected error.')) | ||||
|                 err_idx = i | ||||
|                 break | ||||
|         else: | ||||
|             return ExecutionResult(Status.SUCCESS, | ||||
|                                    outs[-1].strip()[:self._max_out_len]) | ||||
|         return ExecutionResult( | ||||
|             Status.FAILURE, | ||||
|             msg=self._highlighting.sub( | ||||
|                 '', '\n'.join(outs[err_idx:])[:self._max_out_len]), | ||||
|         ) | ||||
|  | ||||
|     async def async_exec(self, code: str) -> ExecutionResult: | ||||
|         """Asynchronously run Python scripts in IPython shell. | ||||
|  | ||||
|         Args: | ||||
|             code (:class:`str`): code block | ||||
|  | ||||
|         Returns: | ||||
|             :py:class:`ExecutionResult`: execution result | ||||
|         """ | ||||
|         with StringIO() as io: | ||||
|             with redirect_stdout(io): | ||||
|                 ret = await self._executor.run_cell_async( | ||||
|                     self.extract_code(code)) | ||||
|                 result = ret.result | ||||
|                 if result is not None: | ||||
|                     return ExecutionResult(Status.SUCCESS, | ||||
|                                            str(result)[:self._max_out_len]) | ||||
|             outs = io.getvalue().strip().split('\n') | ||||
|         if not outs: | ||||
|             return ExecutionResult(Status.SUCCESS, '') | ||||
|         for i, out in enumerate(outs): | ||||
|             if re.search('Error|Traceback', out, re.S): | ||||
|                 if 'TimeoutError' in out: | ||||
|                     return ExecutionResult( | ||||
|                         Status.FAILURE, | ||||
|                         msg=('The code interpreter encountered an ' | ||||
|                              'unexpected error.')) | ||||
|                 err_idx = i | ||||
|                 break | ||||
|         else: | ||||
|             return ExecutionResult(Status.SUCCESS, | ||||
|                                    outs[-1].strip()[:self._max_out_len]) | ||||
|         return ExecutionResult( | ||||
|             Status.FAILURE, | ||||
|             msg=self._highlighting.sub( | ||||
|                 '', '\n'.join(outs[err_idx:])[:self._max_out_len]), | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def extract_code(text: str) -> str: | ||||
|         """Extract Python code from markup languages. | ||||
|  | ||||
|         Args: | ||||
|             text (:class:`str`): Markdown-formatted text | ||||
|  | ||||
|         Returns: | ||||
|             :class:`str`: Python code | ||||
|         """ | ||||
|         import json5 | ||||
|  | ||||
|         # Match triple backtick blocks first | ||||
|         triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) | ||||
|         # Match single backtick blocks second | ||||
|         single_match = re.search(r'`([^`]*)`', text, re.DOTALL) | ||||
|         if triple_match: | ||||
|             text = triple_match.group(1) | ||||
|         elif single_match: | ||||
|             text = single_match.group(1) | ||||
|         else: | ||||
|             try: | ||||
|                 text = json5.loads(text)['code'] | ||||
|             except Exception: | ||||
|                 pass | ||||
|         # If no code blocks found, return original text | ||||
|         return text | ||||
							
								
								
									
										294
									
								
								lagent/actions/ipython_interpreter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								lagent/actions/ipython_interpreter.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,294 @@ | ||||
| # flake8: noqa: E501 | ||||
| import base64 | ||||
| import io | ||||
| import logging | ||||
| import os | ||||
| import queue | ||||
| import re | ||||
| import signal | ||||
| import sys | ||||
| import traceback | ||||
| import uuid | ||||
| from typing import Optional, Tuple, Type | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
|  | ||||
| START_CODE = """ | ||||
| def input(*args, **kwargs): | ||||
|     raise NotImplementedError('Python input() function is disabled.') | ||||
|  | ||||
| get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') | ||||
| {} | ||||
| """  # noqa | ||||
|  | ||||
|  | ||||
| class TimeoutError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class IPythonInterpreter(BaseAction): | ||||
|     """A IPython executor that can execute Python scripts in a jupyter manner. | ||||
|  | ||||
|     Args: | ||||
|         timeout (int): Upper bound of waiting time for Python script execution. | ||||
|             Defaults to 20. | ||||
|         user_data_dir (str, optional): Specified the user data directory for files | ||||
|             loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. | ||||
|             Defaults to `ENV`. | ||||
|         work_dir (str, optional): Specify which directory to save output images to. | ||||
|             Defaults to ``'./work_dir/tmp_dir'``. | ||||
|         description (dict): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to ``True``. | ||||
|     """ | ||||
|  | ||||
|     _KERNEL_CLIENTS = {} | ||||
|  | ||||
|     def __init__(self, | ||||
|                  timeout: int = 20, | ||||
|                  user_data_dir: str = 'ENV', | ||||
|                  work_dir='./work_dir/tmp_dir', | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|  | ||||
|         self.timeout = timeout | ||||
|         if user_data_dir == 'ENV': | ||||
|             user_data_dir = os.environ.get('USER_DATA_DIR', '') | ||||
|  | ||||
|         if user_data_dir: | ||||
|             user_data_dir = os.path.dirname(user_data_dir) | ||||
|             user_data_dir = f"import os\nos.chdir('{user_data_dir}')" | ||||
|         self.user_data_dir = user_data_dir | ||||
|         self._initialized = False | ||||
|         self.work_dir = work_dir | ||||
|         if not os.path.exists(self.work_dir): | ||||
|             os.makedirs(self.work_dir, exist_ok=True) | ||||
|  | ||||
|     @staticmethod | ||||
|     def start_kernel(): | ||||
|         from jupyter_client import KernelManager | ||||
|  | ||||
|         # start the kernel and manager | ||||
|         km = KernelManager() | ||||
|         km.start_kernel() | ||||
|         kc = km.client() | ||||
|         return km, kc | ||||
|  | ||||
|     def initialize(self): | ||||
|         if self._initialized: | ||||
|             return | ||||
|         pid = os.getpid() | ||||
|         if pid not in self._KERNEL_CLIENTS: | ||||
|             self._KERNEL_CLIENTS[pid] = self.start_kernel() | ||||
|         self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] | ||||
|         self._initialized = True | ||||
|         self._call(START_CODE.format(self.user_data_dir), None) | ||||
|  | ||||
|     def reset(self): | ||||
|         if not self._initialized: | ||||
|             self.initialize() | ||||
|         else: | ||||
|             code = "get_ipython().run_line_magic('reset', '-f')\n" + \ | ||||
|                 START_CODE.format(self.user_data_dir) | ||||
|             self._call(code, None) | ||||
|  | ||||
|     def _call(self, | ||||
|               command: str, | ||||
|               timeout: Optional[int] = None) -> Tuple[str, bool]: | ||||
|         self.initialize() | ||||
|         command = extract_code(command) | ||||
|  | ||||
|         # check previous remaining result | ||||
|         while True: | ||||
|             try: | ||||
|                 msg = self.kernel_client.get_iopub_msg(timeout=5) | ||||
|                 msg_type = msg['msg_type'] | ||||
|                 if msg_type == 'status': | ||||
|                     if msg['content'].get('execution_state') == 'idle': | ||||
|                         break | ||||
|             except queue.Empty: | ||||
|                 # assume no result | ||||
|                 break | ||||
|  | ||||
|         self.kernel_client.execute(command) | ||||
|  | ||||
|         def _inner_call(): | ||||
|             result = '' | ||||
|             images = [] | ||||
|             succeed = True | ||||
|             image_idx = 0 | ||||
|  | ||||
|             while True: | ||||
|                 text = '' | ||||
|                 image = '' | ||||
|                 finished = False | ||||
|                 msg_type = 'error' | ||||
|                 try: | ||||
|                     msg = self.kernel_client.get_iopub_msg(timeout=20) | ||||
|                     msg_type = msg['msg_type'] | ||||
|                     if msg_type == 'status': | ||||
|                         if msg['content'].get('execution_state') == 'idle': | ||||
|                             finished = True | ||||
|                     elif msg_type == 'execute_result': | ||||
|                         text = msg['content']['data'].get('text/plain', '') | ||||
|                         if 'image/png' in msg['content']['data']: | ||||
|                             image_b64 = msg['content']['data']['image/png'] | ||||
|                             image_url = publish_image_to_local( | ||||
|                                 image_b64, self.work_dir) | ||||
|                             image_idx += 1 | ||||
|                             image = '' % (image_idx, image_url) | ||||
|  | ||||
|                     elif msg_type == 'display_data': | ||||
|                         if 'image/png' in msg['content']['data']: | ||||
|                             image_b64 = msg['content']['data']['image/png'] | ||||
|                             image_url = publish_image_to_local( | ||||
|                                 image_b64, self.work_dir) | ||||
|                             image_idx += 1 | ||||
|                             image = '' % (image_idx, image_url) | ||||
|  | ||||
|                         else: | ||||
|                             text = msg['content']['data'].get('text/plain', '') | ||||
|                     elif msg_type == 'stream': | ||||
|                         msg_type = msg['content']['name']  # stdout, stderr | ||||
|                         text = msg['content']['text'] | ||||
|                     elif msg_type == 'error': | ||||
|                         succeed = False | ||||
|                         text = escape_ansi('\n'.join( | ||||
|                             msg['content']['traceback'])) | ||||
|                         if 'M6_CODE_INTERPRETER_TIMEOUT' in text: | ||||
|                             text = f'Timeout. No response after {timeout} seconds.'  # noqa | ||||
|                 except queue.Empty: | ||||
|                     # stop current task in case break next input. | ||||
|                     self.kernel_manager.interrupt_kernel() | ||||
|                     succeed = False | ||||
|                     text = f'Timeout. No response after {timeout} seconds.' | ||||
|                     finished = True | ||||
|                 except Exception: | ||||
|                     succeed = False | ||||
|                     msg = ''.join(traceback.format_exception(*sys.exc_info())) | ||||
|                     # text = 'The code interpreter encountered an unexpected error.'  # noqa | ||||
|                     text = msg | ||||
|                     logging.warning(msg) | ||||
|                     finished = True | ||||
|                 if text: | ||||
|                     # result += f'\n\n{msg_type}:\n\n```\n{text}\n```' | ||||
|                     result += f'{text}' | ||||
|  | ||||
|                 if image: | ||||
|                     images.append(image_url) | ||||
|                 if finished: | ||||
|                     return succeed, dict(text=result, image=images) | ||||
|  | ||||
|         try: | ||||
|             if timeout: | ||||
|  | ||||
|                 def handler(signum, frame): | ||||
|                     raise TimeoutError() | ||||
|  | ||||
|                 signal.signal(signal.SIGALRM, handler) | ||||
|                 signal.alarm(timeout) | ||||
|             succeed, result = _inner_call() | ||||
|         except TimeoutError: | ||||
|             succeed = False | ||||
|             text = 'The code interpreter encountered an unexpected error.' | ||||
|             result = f'\n\nerror:\n\n```\n{text}\n```' | ||||
|         finally: | ||||
|             if timeout: | ||||
|                 signal.alarm(0) | ||||
|  | ||||
|         # result = result.strip('\n') | ||||
|         return succeed, result | ||||
|  | ||||
|     @tool_api | ||||
|     def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: | ||||
|         r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. | ||||
|  | ||||
|         Args: | ||||
|             command (:class:`str`): Python code | ||||
|             timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. | ||||
|         """ | ||||
|         tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|         tool_return.args = dict(text=command) | ||||
|         succeed, result = self._call(command, timeout) | ||||
|         if succeed: | ||||
|             text = result['text'] | ||||
|             image = result.get('image', []) | ||||
|             resp = [dict(type='text', content=text)] | ||||
|             if image: | ||||
|                 resp.extend([dict(type='image', content=im) for im in image]) | ||||
|             tool_return.result = resp | ||||
|             # tool_return.result = dict( | ||||
|             #     text=result['text'], image=result.get('image', [])[0]) | ||||
|             tool_return.state = ActionStatusCode.SUCCESS | ||||
|         else: | ||||
|             tool_return.errmsg = result.get('text', '') if isinstance( | ||||
|                 result, dict) else result | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|         return tool_return | ||||
|  | ||||
|  | ||||
| def extract_code(text): | ||||
|     import json5 | ||||
|  | ||||
|     # Match triple backtick blocks first | ||||
|     triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) | ||||
|     # Match single backtick blocks second | ||||
|     single_match = re.search(r'`([^`]*)`', text, re.DOTALL) | ||||
|     if triple_match: | ||||
|         text = triple_match.group(1) | ||||
|     elif single_match: | ||||
|         text = single_match.group(1) | ||||
|     else: | ||||
|         try: | ||||
|             text = json5.loads(text)['code'] | ||||
|         except Exception: | ||||
|             pass | ||||
|     # If no code blocks found, return original text | ||||
|     return text | ||||
|  | ||||
|  | ||||
| def escape_ansi(line): | ||||
|     ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') | ||||
|     return ansi_escape.sub('', line) | ||||
|  | ||||
|  | ||||
| def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): | ||||
|     import PIL.Image | ||||
|     image_file = str(uuid.uuid4()) + '.png' | ||||
|     local_image_file = os.path.join(work_dir, image_file) | ||||
|  | ||||
|     png_bytes = base64.b64decode(image_base64) | ||||
|     assert isinstance(png_bytes, bytes) | ||||
|     bytes_io = io.BytesIO(png_bytes) | ||||
|     PIL.Image.open(bytes_io).save(local_image_file, 'png') | ||||
|  | ||||
|     return local_image_file | ||||
|  | ||||
|  | ||||
| # local test for code interpreter | ||||
| def get_multiline_input(hint): | ||||
|     print(hint) | ||||
|     print('// Press ENTER to make a new line. Press CTRL-D to end input.') | ||||
|     lines = [] | ||||
|     while True: | ||||
|         try: | ||||
|             line = input() | ||||
|         except EOFError:  # CTRL-D | ||||
|             break | ||||
|         lines.append(line) | ||||
|     print('// Input received.') | ||||
|     if lines: | ||||
|         return '\n'.join(lines) | ||||
|     else: | ||||
|         return '' | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     code_interpreter = IPythonInterpreter() | ||||
|     while True: | ||||
|         print(code_interpreter(get_multiline_input('Enter python code:'))) | ||||
| @@ -1,56 +0,0 @@ | ||||
| from typing import Optional, Union | ||||
|  | ||||
| from lagent.llms.base_api import BaseAPIModel | ||||
| from lagent.llms.base_llm import BaseModel | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
| from .base_action import BaseAction | ||||
|  | ||||
| DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。 | ||||
| 当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。 | ||||
| """ | ||||
|  | ||||
|  | ||||
| class LLMQA(BaseAction): | ||||
|     """An LLM Wrapper as BaseAction type. | ||||
|  | ||||
|     Args: | ||||
|         llm (BaseModel or BaseAPIModel): a LLM service which can chat. | ||||
|         description (str): The description of the action. Defaults to | ||||
|             None. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class name. Defaults to None. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  llm: Union[BaseModel, BaseAPIModel], | ||||
|                  description: str = DEFAULT_DESCRIPTION, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None) -> None: | ||||
|         super().__init__(description, name, enable, disable_description) | ||||
|  | ||||
|         self._llm = llm | ||||
|  | ||||
|     def __call__(self, query: str) -> ActionReturn: | ||||
|         """Return the QA response. | ||||
|  | ||||
|         Args: | ||||
|             query (str): The query content. | ||||
|  | ||||
|         Returns: | ||||
|             ActionReturn: The action return. | ||||
|         """ | ||||
|  | ||||
|         tool_return = ActionReturn(url=None, args=None) | ||||
|         try: | ||||
|             response = self._llm.generate_from_template(query, 512) | ||||
|             tool_return.result = dict(text=str(response)) | ||||
|             tool_return.state = ActionStatusCode.SUCCESS | ||||
|         except Exception as e: | ||||
|             tool_return.result = dict(text=str(e)) | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|         return tool_return | ||||
							
								
								
									
										143
									
								
								lagent/actions/parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								lagent/actions/parser.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| import json | ||||
| import re | ||||
| from ast import literal_eval | ||||
| from typing import Any, List, Union | ||||
|  | ||||
|  | ||||
| class ParseError(Exception): | ||||
|     """Parsing exception class.""" | ||||
|  | ||||
|     def __init__(self, err_msg: str): | ||||
|         self.err_msg = err_msg | ||||
|  | ||||
|  | ||||
| class BaseParser: | ||||
|     """Base parser to process inputs and outputs of actions. | ||||
|  | ||||
|     Args: | ||||
|         action (:class:`BaseAction`): action to validate | ||||
|  | ||||
|     Attributes: | ||||
|         PARAMETER_DESCRIPTION (:class:`str`): declare the input format which | ||||
|             LLMs should follow when generating arguments for decided tools. | ||||
|     """ | ||||
|  | ||||
|     PARAMETER_DESCRIPTION: str = '' | ||||
|  | ||||
|     def __init__(self, action): | ||||
|         self.action = action | ||||
|         self._api2param = {} | ||||
|         self._api2required = {} | ||||
|         # perform basic argument validation | ||||
|         if action.description: | ||||
|             for api in action.description.get('api_list', | ||||
|                                               [action.description]): | ||||
|                 name = (f'{action.name}.{api["name"]}' | ||||
|                         if self.action.is_toolkit else api['name']) | ||||
|                 required_parameters = set(api['required']) | ||||
|                 all_parameters = {j['name'] for j in api['parameters']} | ||||
|                 if not required_parameters.issubset(all_parameters): | ||||
|                     raise ValueError( | ||||
|                         f'unknown parameters for function "{name}": ' | ||||
|                         f'{required_parameters - all_parameters}') | ||||
|                 if self.PARAMETER_DESCRIPTION: | ||||
|                     api['parameter_description'] = self.PARAMETER_DESCRIPTION | ||||
|                 api_name = api['name'] if self.action.is_toolkit else 'run' | ||||
|                 self._api2param[api_name] = api['parameters'] | ||||
|                 self._api2required[api_name] = api['required'] | ||||
|  | ||||
|     def parse_inputs(self, inputs: str, name: str = 'run') -> dict: | ||||
|         """Parse inputs LLMs generate for the action. | ||||
|  | ||||
|         Args: | ||||
|             inputs (:class:`str`): input string extracted from responses | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: processed input | ||||
|         """ | ||||
|         inputs = {self._api2param[name][0]['name']: inputs} | ||||
|         return inputs | ||||
|  | ||||
|     def parse_outputs(self, outputs: Any) -> List[dict]: | ||||
|         """Parser outputs returned by the action. | ||||
|  | ||||
|         Args: | ||||
|             outputs (:class:`Any`): raw output of the action | ||||
|  | ||||
|         Returns: | ||||
|             :class:`List[dict]`: processed output of which each member is a | ||||
|                 dictionary with two keys - 'type' and 'content'. | ||||
|         """ | ||||
|         if isinstance(outputs, dict): | ||||
|             outputs = json.dumps(outputs, ensure_ascii=False) | ||||
|         elif not isinstance(outputs, str): | ||||
|             outputs = str(outputs) | ||||
|         return [{'type': 'text', 'content': outputs}] | ||||
|  | ||||
|  | ||||
| class JsonParser(BaseParser): | ||||
|     """Json parser to convert input string into a dictionary. | ||||
|  | ||||
|     Args: | ||||
|         action (:class:`BaseAction`): action to validate | ||||
|     """ | ||||
|  | ||||
|     PARAMETER_DESCRIPTION = ( | ||||
|         'If you call this tool, you must pass arguments in ' | ||||
|         'the JSON format {key: value}, where the key is the parameter name.') | ||||
|  | ||||
|     def parse_inputs(self, | ||||
|                      inputs: Union[str, dict], | ||||
|                      name: str = 'run') -> dict: | ||||
|         if not isinstance(inputs, dict): | ||||
|             try: | ||||
|                 match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs, | ||||
|                                   re.S) | ||||
|                 if match: | ||||
|                     inputs = match.group(2).strip() | ||||
|                 inputs = json.loads(inputs) | ||||
|             except json.JSONDecodeError as exc: | ||||
|                 raise ParseError(f'invalid json format: {inputs}') from exc | ||||
|         input_keys = set(inputs) | ||||
|         all_keys = {param['name'] for param in self._api2param[name]} | ||||
|         if not input_keys.issubset(all_keys): | ||||
|             raise ParseError(f'unknown arguments: {input_keys - all_keys}') | ||||
|         required_keys = set(self._api2required[name]) | ||||
|         if not input_keys.issuperset(required_keys): | ||||
|             raise ParseError( | ||||
|                 f'missing required arguments: {required_keys - input_keys}') | ||||
|         return inputs | ||||
|  | ||||
|  | ||||
| class TupleParser(BaseParser): | ||||
|     """Tuple parser to convert input string into a tuple. | ||||
|  | ||||
|     Args: | ||||
|         action (:class:`BaseAction`): action to validate | ||||
|     """ | ||||
|  | ||||
|     PARAMETER_DESCRIPTION = ( | ||||
|         'If you call this tool, you must pass arguments in the tuple format ' | ||||
|         'like (arg1, arg2, arg3), and the arguments are ordered.') | ||||
|  | ||||
|     def parse_inputs(self, | ||||
|                      inputs: Union[str, tuple], | ||||
|                      name: str = 'run') -> dict: | ||||
|         if not isinstance(inputs, tuple): | ||||
|             try: | ||||
|                 inputs = literal_eval(inputs) | ||||
|             except Exception as exc: | ||||
|                 raise ParseError(f'invalid tuple format: {inputs}') from exc | ||||
|         if len(inputs) < len(self._api2required[name]): | ||||
|             raise ParseError( | ||||
|                 f'API takes {len(self._api2required[name])} required positional ' | ||||
|                 f'arguments but {len(inputs)} were given') | ||||
|         if len(inputs) > len(self._api2param[name]): | ||||
|             raise ParseError( | ||||
|                 f'API takes {len(self._api2param[name])} positional arguments ' | ||||
|                 f'but {len(inputs)} were given') | ||||
|         inputs = { | ||||
|             self._api2param[name][i]['name']: item | ||||
|             for i, item in enumerate(inputs) | ||||
|         } | ||||
|         return inputs | ||||
							
								
								
									
										158
									
								
								lagent/actions/ppt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								lagent/actions/ppt.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,158 @@ | ||||
| from typing import Dict, Optional, Type | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
|  | ||||
| THEME_MAPPING = { | ||||
|     'Default': { | ||||
|         'template': None, | ||||
|         'title': 'Title Slide', | ||||
|         'single': 'Title and Content', | ||||
|         'two': 'Two Content', | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| class PPT(BaseAction): | ||||
|     """Plugin to create ppt slides with text, paragraph, images in good looking styles.""" | ||||
|  | ||||
|     def __init__(self, | ||||
|                  theme_mapping: Optional[Dict[str, dict]] = None, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True): | ||||
|         super().__init__(description, parser, enable) | ||||
|         self.theme_mapping = theme_mapping or THEME_MAPPING | ||||
|         self.pointer = None | ||||
|         self.location = None | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def create_file(self, theme: str, abs_location: str) -> dict: | ||||
|         """Create a pptx file with specific themes. | ||||
|  | ||||
|         Args: | ||||
|             theme (:class:`str`): the theme used. The value should be one of ['Default']. | ||||
|             abs_location (:class:`str`): the ppt file's absolute location | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         from pptx import Presentation | ||||
|         self.location = abs_location | ||||
|         try: | ||||
|             self.pointer = Presentation(self.theme_mapping[theme]['template']) | ||||
|             self.pointer.slide_master.name = theme | ||||
|             # print('created') | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|         return dict(status='created a ppt file.') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def add_first_page(self, title: str, subtitle: str) -> dict: | ||||
|         """Add the first page of ppt. | ||||
|  | ||||
|         Args: | ||||
|             title (:class:`str`): the title of ppt | ||||
|             subtitle (:class:`str`): the subtitle of ppt | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         layout_name = self.theme_mapping[ | ||||
|             self.pointer.slide_master.name]['title'] | ||||
|         layout = next(i for i in self.pointer.slide_master.slide_layouts | ||||
|                       if i.name == layout_name) | ||||
|         slide = self.pointer.slides.add_slide(layout) | ||||
|         ph_title, ph_subtitle = slide.placeholders | ||||
|         ph_title.text = title | ||||
|         if subtitle: | ||||
|             ph_subtitle.text = subtitle | ||||
|         return dict(status='added page') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def add_text_page(self, title: str, bullet_items: str) -> dict: | ||||
|         """Add text page of ppt. | ||||
|  | ||||
|         Args: | ||||
|             title (:class:`str`): the title of the page | ||||
|             bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         layout_name = self.theme_mapping[ | ||||
|             self.pointer.slide_master.name]['single'] | ||||
|         layout = next(i for i in self.pointer.slide_master.slide_layouts | ||||
|                       if i.name == layout_name) | ||||
|         slide = self.pointer.slides.add_slide(layout) | ||||
|         ph_title, ph_body = slide.placeholders | ||||
|         ph_title.text = title | ||||
|         ph = ph_body | ||||
|         tf = ph.text_frame | ||||
|         for i, item in enumerate(bullet_items.split('[SPAN]')): | ||||
|             if i == 0: | ||||
|                 p = tf.paragraphs[0] | ||||
|             else: | ||||
|                 p = tf.add_paragraph() | ||||
|             p.text = item.strip() | ||||
|             p.level = 0 | ||||
|         return dict(status='added page') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def add_text_image_page(self, title: str, bullet_items: str, | ||||
|                             image: str) -> dict: | ||||
|         """Add a text page with one image. Image should be a path. | ||||
|  | ||||
|         Args: | ||||
|             title (:class:`str`): the title of the page | ||||
|             bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. | ||||
|             image (:class:`str`): the path of the image | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         from PIL import Image | ||||
|         layout_name = self.theme_mapping[self.pointer.slide_master.name]['two'] | ||||
|         layout = next(i for i in self.pointer.slide_master.slide_layouts | ||||
|                       if i.name == layout_name) | ||||
|         slide = self.pointer.slides.add_slide(layout) | ||||
|         ph_title, ph_body1, ph_body2 = slide.placeholders | ||||
|         ph_title.text = title | ||||
|         ph = ph_body2 | ||||
|         image = Image.open(image) | ||||
|         image_pil = image.to_pil() | ||||
|         left = ph.left | ||||
|         width = ph.width | ||||
|         height = int(width / image_pil.width * image_pil.height) | ||||
|         top = (ph.top + (ph.top + ph.height)) // 2 - height // 2 | ||||
|         slide.shapes.add_picture(image.to_path(), left, top, width, height) | ||||
|  | ||||
|         ph = ph_body1 | ||||
|         tf = ph.text_frame | ||||
|         for i, item in enumerate(bullet_items.split('[SPAN]')): | ||||
|             if i == 0: | ||||
|                 p = tf.paragraphs[0] | ||||
|             else: | ||||
|                 p = tf.add_paragraph() | ||||
|             p.text = item.strip() | ||||
|             p.level = 0 | ||||
|  | ||||
|         return dict(status='added page') | ||||
|  | ||||
|     @tool_api(explode_return=True) | ||||
|     def submit_file(self) -> dict: | ||||
|         """When all steps done, YOU MUST use submit_file() to submit your work. | ||||
|  | ||||
|         Returns: | ||||
|             :class:`dict`: operation status | ||||
|                 * status: the result of the execution | ||||
|         """ | ||||
|         # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx') | ||||
|         # self.pointer.save(file_path) | ||||
|         # retreival_url = upload_file(file_path) | ||||
|         self.pointer.save(self.location) | ||||
|         return dict(status=f'submitted. view ppt at {self.location}') | ||||
| @@ -1,11 +1,11 @@ | ||||
| # flake8: noqa: E501 | ||||
| import copy | ||||
| import io | ||||
| from contextlib import redirect_stdout | ||||
| from typing import Any, Optional | ||||
| from typing import Any, Optional, Type | ||||
|  | ||||
| from func_timeout import FunctionTimedOut, func_set_timeout | ||||
|  | ||||
| from lagent.actions.base_action import BaseAction | ||||
| from lagent.actions.base_action import BaseAction, tool_api | ||||
| from lagent.actions.parser import BaseParser, JsonParser | ||||
| from lagent.schema import ActionReturn, ActionStatusCode | ||||
|  | ||||
|  | ||||
| @@ -29,72 +29,72 @@ class GenericRuntime: | ||||
|         return eval(expr, self._global_vars) | ||||
|  | ||||
|  | ||||
| DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数, | ||||
| 函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: | ||||
| ```python | ||||
| # import 依赖包 | ||||
| import xxx | ||||
| def solution(): | ||||
|     # 初始化一些变量 | ||||
|     variable_names_with_real_meaning = xxx | ||||
|     # 步骤一 | ||||
|     mid_variable = func(variable_names_with_real_meaning) | ||||
|     # 步骤 x | ||||
|     mid_variable = func(mid_variable) | ||||
|     # 最后结果 | ||||
|     final_answer =  func(mid_variable) | ||||
|     return final_answer | ||||
| ```""" | ||||
|  | ||||
|  | ||||
| class PythonInterpreter(BaseAction): | ||||
|     """A Python executor that can execute Python scripts. | ||||
|  | ||||
|     Args: | ||||
|         description (str): The description of the action. Defaults to | ||||
|             DEFAULT_DESCRIPTION. | ||||
|         answer_symbol (str, Optional): the answer symbol from LLM | ||||
|         answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``. | ||||
|         answer_expr (str, Optional): the answer function name of the Python | ||||
|             script. Default to 'solution()'. | ||||
|         answer_from_stdout (boolean): whether the execution results is from | ||||
|             stdout. | ||||
|         name (str, optional): The name of the action. If None, the name will | ||||
|             be class nameDefaults to None. | ||||
|             script. Defaults to ``'solution()'``. | ||||
|         answer_from_stdout (boolean, Optional): whether the execution results is from | ||||
|             stdout. Defaults to ``False``. | ||||
|         timeout (int, Optional): Upper bound of waiting time for Python script execution. | ||||
|             Defaults to ``20``. | ||||
|         description (dict, Optional): The description of the action. Defaults to ``None``. | ||||
|         parser (Type[BaseParser]): The parser class to process the | ||||
|             action's inputs and outputs. Defaults to :class:`JsonParser`. | ||||
|         enable (bool, optional): Whether the action is enabled. Defaults to | ||||
|             True. | ||||
|         disable_description (str, optional): The description of the action when | ||||
|             it is disabled. Defaults to None. | ||||
|         timeout (int): Upper bound of waiting time for Python script execution. | ||||
|             ``True``. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  description: str = DEFAULT_DESCRIPTION, | ||||
|                  answer_symbol: Optional[str] = None, | ||||
|                  answer_expr: Optional[str] = 'solution()', | ||||
|                  answer_from_stdout: bool = False, | ||||
|                  name: Optional[str] = None, | ||||
|                  enable: bool = True, | ||||
|                  disable_description: Optional[str] = None, | ||||
|                  timeout: int = 20) -> None: | ||||
|         super().__init__(description, name, enable, disable_description) | ||||
|  | ||||
|                  timeout: int = 20, | ||||
|                  description: Optional[dict] = None, | ||||
|                  parser: Type[BaseParser] = JsonParser, | ||||
|                  enable: bool = True) -> None: | ||||
|         super().__init__(description, parser, enable) | ||||
|         self.answer_symbol = answer_symbol | ||||
|         self.answer_expr = answer_expr | ||||
|         self.answer_from_stdout = answer_from_stdout | ||||
|         self.timeout = timeout | ||||
|  | ||||
|     def __call__(self, command: str) -> ActionReturn: | ||||
|     @tool_api | ||||
|     def run(self, command: str) -> ActionReturn: | ||||
|         """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: | ||||
|  | ||||
|         ```python | ||||
|         # import 依赖包 | ||||
|         import xxx | ||||
|         def solution(): | ||||
|             # 初始化一些变量 | ||||
|             variable_names_with_real_meaning = xxx | ||||
|             # 步骤一 | ||||
|             mid_variable = func(variable_names_with_real_meaning) | ||||
|             # 步骤 x | ||||
|             mid_variable = func(mid_variable) | ||||
|             # 最后结果 | ||||
|             final_answer =  func(mid_variable) | ||||
|             return final_answer | ||||
|         ``` | ||||
|  | ||||
|         Args: | ||||
|             command (:class:`str`): Python code snippet | ||||
|         """ | ||||
|         from func_timeout import FunctionTimedOut, func_set_timeout | ||||
|         self.runtime = GenericRuntime() | ||||
|         try: | ||||
|             tool_return = func_set_timeout(self.timeout)(self._call)(command) | ||||
|         except FunctionTimedOut as e: | ||||
|             tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|             tool_return = ActionReturn(type=self.name) | ||||
|             tool_return.errmsg = repr(e) | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|         return tool_return | ||||
|  | ||||
|     def _call(self, command: str) -> ActionReturn: | ||||
|         tool_return = ActionReturn(url=None, args=None, type=self.name) | ||||
|         tool_return = ActionReturn(type=self.name) | ||||
|         try: | ||||
|             if '```python' in command: | ||||
|                 command = command.split('```python')[1].split('```')[0] | ||||
| @@ -124,7 +124,7 @@ class PythonInterpreter(BaseAction): | ||||
|             tool_return.state = ActionStatusCode.API_ERROR | ||||
|             return tool_return | ||||
|         try: | ||||
|             tool_return.result = dict(text=str(res)) | ||||
|             tool_return.result = [dict(type='text', content=str(res))] | ||||
|             tool_return.state = ActionStatusCode.SUCCESS | ||||
|         except Exception as e: | ||||
|             tool_return.errmsg = repr(e) | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| from .autogpt import *  # noqa: F401, F403 | ||||
| from .base_agent import *  # noqa: F401, F403 | ||||
| from .internlm2_agent import *  # noqa: F401, F403 | ||||
| from .react import *  # noqa: F401, F403 | ||||
| from .rewoo import *  # noqa: F401, F403 | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| # flake8: noqa | ||||
| import ast | ||||
| import copy | ||||
| import platform | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
|  | ||||
| @@ -220,21 +219,20 @@ class AutoGPTProtocol: | ||||
|             dict(role='user', content=self.triggering_prompt)) | ||||
|         return formatted_data | ||||
|  | ||||
|     def format_response(self, action_return): | ||||
|         """format the final response at current step. | ||||
|     def format_response(self, action_return) -> dict: | ||||
|         """Format the final response at current step. | ||||
|  | ||||
|         Args: | ||||
|             action_return (ActionReturn): return value of the current action. | ||||
|  | ||||
|         Returns: | ||||
|             str: the final response at current step. | ||||
|             dict: the final response at current step. | ||||
|         """ | ||||
|         if action_return.state == ActionStatusCode.SUCCESS: | ||||
|             response = action_return.result['text'] | ||||
|             response = f'Command {action_return.type} returned: {response}' | ||||
|             response = f'Command {action_return.type} returned: {response.format_result()}' | ||||
|         else: | ||||
|             response = action_return.errmsg | ||||
|         return response | ||||
|         return dict(role='system', content=response) | ||||
|  | ||||
|  | ||||
| class AutoGPT(BaseAgent): | ||||
| @@ -261,30 +259,26 @@ class AutoGPT(BaseAgent): | ||||
|         super().__init__( | ||||
|             llm=llm, action_executor=action_executor, protocol=protocol) | ||||
|  | ||||
|     def chat(self, goal: str) -> AgentReturn: | ||||
|         self._inner_history = [] | ||||
|     def chat(self, goal: str, **kwargs) -> AgentReturn: | ||||
|         inner_history = [] | ||||
|         agent_return = AgentReturn() | ||||
|         default_response = 'Sorry that I cannot answer your question.' | ||||
|         for _ in range(self.max_turn): | ||||
|             prompt = self._protocol.format( | ||||
|                 goal=goal, | ||||
|                 inner_history=self._inner_history, | ||||
|                 inner_history=inner_history, | ||||
|                 action_executor=self._action_executor) | ||||
|             response = self._llm.generate_from_template(prompt, 512) | ||||
|             self._inner_history.append( | ||||
|                 dict(role='assistant', content=response)) | ||||
|             response = self._llm.chat(prompt, **kwargs) | ||||
|             inner_history.append(dict(role='assistant', content=response)) | ||||
|             action, action_input = self._protocol.parse( | ||||
|                 response, self._action_executor) | ||||
|             action_return: ActionReturn = self._action_executor( | ||||
|                 action, action_input) | ||||
|             agent_return.actions.append(action_return) | ||||
|             if action_return.type == self._action_executor.finish_action.name: | ||||
|                 agent_return.response = action_return.result['text'] | ||||
|                 agent_return.response = action_return.format_result() | ||||
|                 return agent_return | ||||
|             self._inner_history.append( | ||||
|                 dict( | ||||
|                     role='system', | ||||
|                     content=self._protocol.format_response(action_return))) | ||||
|         agent_return.inner_steps = copy.deepcopy(self._inner_history) | ||||
|             inner_history.append(self._protocol.format_response(action_return)) | ||||
|         agent_return.inner_steps = inner_history | ||||
|         agent_return.response = default_response | ||||
|         return agent_return | ||||
|   | ||||
| @@ -1,5 +1,3 @@ | ||||
| from typing import List | ||||
|  | ||||
| from lagent.actions import ActionExecutor | ||||
| from lagent.actions.base_action import BaseAction | ||||
| from lagent.llms.base_llm import BaseModel | ||||
| @@ -19,8 +17,6 @@ class BaseAgent: | ||||
|  | ||||
|     def __init__(self, llm: BaseModel, action_executor: ActionExecutor, | ||||
|                  protocol: object) -> None: | ||||
|  | ||||
|         self._session_history = [] | ||||
|         self._llm = llm | ||||
|         self._action_executor = action_executor | ||||
|         self._protocol = protocol | ||||
| @@ -41,9 +37,5 @@ class BaseAgent: | ||||
|         """ | ||||
|         self._action_executor.del_action(name) | ||||
|  | ||||
|     def chat(self, message: str) -> AgentReturn: | ||||
|     def chat(self, message: str, **kwargs) -> AgentReturn: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     def session_history(self) -> List: | ||||
|         return self._session_history | ||||
|   | ||||
							
								
								
									
										368
									
								
								lagent/agents/internlm2_agent.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										368
									
								
								lagent/agents/internlm2_agent.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,368 @@ | ||||
| import json | ||||
| import logging | ||||
| from copy import deepcopy | ||||
| from typing import Dict, List, Optional, Union | ||||
|  | ||||
| from lagent.actions import ActionExecutor | ||||
| from lagent.agents.base_agent import BaseAgent | ||||
| from lagent.llms import BaseAPIModel, BaseModel | ||||
| from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn, AgentStatusCode, ModelStatusCode  # noqa: E501 | ||||
|  | ||||
| API_PREFIX = ( | ||||
|     "This is the subfunction for tool '{tool_name}', you can use this tool. " | ||||
|     'The description of this function is: \n{description}') | ||||
|  | ||||
| META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用') | ||||
|  | ||||
| INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' | ||||
|                   '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' | ||||
|                   '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' | ||||
|                   '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' | ||||
|                   '文本处理和分析(比如文本解析和自然语言处理),' | ||||
|                   '机器学习和数据科学(用于展示模型训练和数据可视化),' | ||||
|                   '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。') | ||||
|  | ||||
| PLUGIN_CN = ('你可以使用如下工具:' | ||||
|              '\n{prompt}\n' | ||||
|              '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' | ||||
|              '同时注意你可以使用的工具,不要随意捏造!') | ||||
|  | ||||
|  | ||||
| class Internlm2Protocol: | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         meta_prompt: str = META_CN, | ||||
|         interpreter_prompt: str = INTERPRETER_CN, | ||||
|         plugin_prompt: str = PLUGIN_CN, | ||||
|         few_shot: Optional[List] = None, | ||||
|         language: Dict = dict( | ||||
|             begin='', | ||||
|             end='', | ||||
|             belong='assistant', | ||||
|         ), | ||||
|         tool: Dict = dict( | ||||
|             begin='{start_token}{name}\n', | ||||
|             start_token='<|action_start|>', | ||||
|             name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'), | ||||
|             belong='assistant', | ||||
|             end='<|action_end|>\n', | ||||
|         ), | ||||
|         execute: Dict = dict( | ||||
|             role='execute', begin='', end='', fallback_role='environment'), | ||||
|     ) -> None: | ||||
|         self.meta_prompt = meta_prompt | ||||
|         self.interpreter_prompt = interpreter_prompt | ||||
|         self.plugin_prompt = plugin_prompt | ||||
|         self.roles_cfg = dict(tool=tool, language=language) | ||||
|         self.language = language | ||||
|         self.execute = execute | ||||
|         self.tool = tool | ||||
|         self.few_shot = few_shot | ||||
|  | ||||
|     def format_sub_role(self, messages: List[Dict]) -> List[Dict]: | ||||
|  | ||||
|         def format_interpreter(message): | ||||
|             if isinstance(message['content'], dict): | ||||
|                 # assert message['content']['name'] == 'IPythonInterpreter' | ||||
|                 return dict( | ||||
|                     role=message['role'], | ||||
|                     name=message['name'], | ||||
|                     content=message['content']['parameters']['command']) | ||||
|             else: | ||||
|                 return message | ||||
|  | ||||
|         def format_plugin(message): | ||||
|             if isinstance(message['content'], dict): | ||||
|                 return dict( | ||||
|                     role=message['role'], | ||||
|                     name=message['name'], | ||||
|                     content=json.dumps(message['content'])) | ||||
|             else: | ||||
|                 return message | ||||
|  | ||||
|         new_message = list() | ||||
|         for message in messages: | ||||
|             if message['role'] in [ | ||||
|                     'assistant', 'user', 'system', 'environment' | ||||
|             ]: | ||||
|                 new_message.append(message) | ||||
|                 continue | ||||
|             role_cfg = self.roles_cfg[message['role']] | ||||
|             begin = role_cfg['begin'] | ||||
|             if message['role'] == 'tool': | ||||
|                 if message['name'] == 'interpreter': | ||||
|                     message = format_interpreter(message) | ||||
|                 elif message['name'] == 'plugin': | ||||
|                     message = format_plugin(message) | ||||
|                 else: | ||||
|                     raise NotImplementedError | ||||
|                 begin = role_cfg['begin'].format( | ||||
|                     start_token=role_cfg.get('start_token', ''), | ||||
|                     name=role_cfg.get('name_map', {}).get(message['name'], '')) | ||||
|             new_content = begin + message['content'] + role_cfg['end'] | ||||
|             if role_cfg.get('fallback_role'): | ||||
|                 new_message.append( | ||||
|                     dict(role=role_cfg['fallback_role'], content=new_content)) | ||||
|             elif role_cfg.get('belong'): | ||||
|                 if new_message[-1]['role'] != role_cfg.get('belong'): | ||||
|                     new_message.append( | ||||
|                         dict(role=role_cfg.get('belong'), content=new_content)) | ||||
|                 else: | ||||
|                     new_message[-1]['content'] += new_content | ||||
|             else: | ||||
|                 new_message.append( | ||||
|                     dict(role=message['role'], content=new_content)) | ||||
|  | ||||
|         return new_message | ||||
|  | ||||
|     def format(self, | ||||
|                inner_step: List[Dict], | ||||
|                plugin_executor: ActionExecutor = None, | ||||
|                interpreter_executor: ActionExecutor = None, | ||||
|                **kwargs) -> list: | ||||
|         formatted = [] | ||||
|         if self.meta_prompt: | ||||
|             formatted.append(dict(role='system', content=self.meta_prompt)) | ||||
|         if interpreter_executor and self.interpreter_prompt: | ||||
|             interpreter_info = interpreter_executor.get_actions_info()[0] | ||||
|             interpreter_prompt = self.interpreter_prompt.format( | ||||
|                 code_prompt=interpreter_info['description']) | ||||
|             formatted.append( | ||||
|                 dict( | ||||
|                     role='system', | ||||
|                     content=interpreter_prompt, | ||||
|                     name='interpreter')) | ||||
|         if plugin_executor and plugin_executor.actions and self.plugin_prompt: | ||||
|             plugin_descriptions = [] | ||||
|             for api_info in plugin_executor.get_actions_info(): | ||||
|                 plugin = deepcopy(api_info) | ||||
|                 if isinstance(api_info, dict): | ||||
|                     tool_name = api_info['name'].split('.')[0] | ||||
|                     plugin['description'] = API_PREFIX.format( | ||||
|                         tool_name=tool_name, description=plugin['description']) | ||||
|                     # only keep required parameters | ||||
|                     required_parameters = [ | ||||
|                         param for param in plugin['parameters'] | ||||
|                         if param['name'] in plugin['required'] | ||||
|                     ] | ||||
|                     plugin['parameters'] = required_parameters | ||||
|                 plugin_descriptions.append(plugin) | ||||
|             plugin_prompt = self.plugin_prompt.format( | ||||
|                 prompt=json.dumps( | ||||
|                     plugin_descriptions, ensure_ascii=False, indent=4)) | ||||
|             formatted.append( | ||||
|                 dict(role='system', content=plugin_prompt, name='plugin')) | ||||
|         if self.few_shot: | ||||
|             for few_shot in self.few_shot: | ||||
|                 formatted += self.format_sub_role(few_shot) | ||||
|         formatted += self.format_sub_role(inner_step) | ||||
|         return formatted | ||||
|  | ||||
|     def parse(self, message, plugin_executor: ActionExecutor, | ||||
|               interpreter_executor: ActionExecutor): | ||||
|         if self.language['begin']: | ||||
|             message = message.split(self.language['begin'])[-1] | ||||
|         if self.tool['name_map']['plugin'] in message: | ||||
|             message, action = message.split( | ||||
|                 f"{self.tool['start_token']}{self.tool['name_map']['plugin']}") | ||||
|             action = action.split(self.tool['end'].strip())[0] | ||||
|             return 'plugin', message, action | ||||
|         if self.tool['name_map']['interpreter'] in message: | ||||
|             message, code = message.split( | ||||
|                 f"{self.tool['start_token']}" | ||||
|                 f"{self.tool['name_map']['interpreter']}") | ||||
|             code = code.split(self.tool['end'].strip())[0].strip() | ||||
|             return 'interpreter', message, dict( | ||||
|                 name=interpreter_executor.action_names()[0], | ||||
|                 parameters=dict(command=code)) | ||||
|         return None, message.split(self.tool['start_token'])[0], None | ||||
|  | ||||
|     def format_response(self, action_return, name) -> dict: | ||||
|         if action_return.state == ActionStatusCode.SUCCESS: | ||||
|             response = action_return.format_result() | ||||
|         else: | ||||
|             response = str(action_return.errmsg) | ||||
|         content = self.execute['begin'] + response + self.execute['end'] | ||||
|         if self.execute.get('fallback_role'): | ||||
|             return dict( | ||||
|                 role=self.execute['fallback_role'], content=content, name=name) | ||||
|         elif self.execute.get('belong'): | ||||
|             return dict( | ||||
|                 role=self.execute['belong'], content=content, name=name) | ||||
|         return dict(role=self.execute['role'], content=response, name=name) | ||||
|  | ||||
|  | ||||
| class Internlm2Agent(BaseAgent): | ||||
|  | ||||
|     def __init__(self, | ||||
|                  llm: Union[BaseModel, BaseAPIModel], | ||||
|                  plugin_executor: ActionExecutor = None, | ||||
|                  interpreter_executor: ActionExecutor = None, | ||||
|                  protocol=Internlm2Protocol(), | ||||
|                  max_turn: int = 3) -> None: | ||||
|         self.max_turn = max_turn | ||||
|         self._interpreter_executor = interpreter_executor | ||||
|         super().__init__( | ||||
|             llm=llm, action_executor=plugin_executor, protocol=protocol) | ||||
|  | ||||
|     def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             message = dict(role='user', content=message) | ||||
|         if isinstance(message, dict): | ||||
|             message = [message] | ||||
|         inner_history = message[:] | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|         for _ in range(self.max_turn): | ||||
|             # list of dict | ||||
|             prompt = self._protocol.format( | ||||
|                 inner_step=inner_history, | ||||
|                 plugin_executor=self._action_executor, | ||||
|                 interpreter_executor=self._interpreter_executor, | ||||
|             ) | ||||
|             response = self._llm.chat(prompt, **kwargs) | ||||
|             name, language, action = self._protocol.parse( | ||||
|                 message=response, | ||||
|                 plugin_executor=self._action_executor, | ||||
|                 interpreter_executor=self._interpreter_executor, | ||||
|             ) | ||||
|             if name: | ||||
|                 if name == 'plugin': | ||||
|                     if self._action_executor: | ||||
|                         executor = self._action_executor | ||||
|                     else: | ||||
|                         logging.info(msg='No plugin is instantiated!') | ||||
|                         continue | ||||
|                     try: | ||||
|                         action = json.loads(action) | ||||
|                     except Exception as e: | ||||
|                         logging.info(msg=f'Invaild action {e}') | ||||
|                         continue | ||||
|                 elif name == 'interpreter': | ||||
|                     if self._interpreter_executor: | ||||
|                         executor = self._interpreter_executor | ||||
|                     else: | ||||
|                         logging.info(msg='No interpreter is instantiated!') | ||||
|                         continue | ||||
|                 else: | ||||
|                     logging.info( | ||||
|                         msg=(f"Invalid name '{name}'. Currently only 'plugin' " | ||||
|                              "and 'interpreter' are supported.")) | ||||
|                     continue | ||||
|                 action_return: ActionReturn = executor(action['name'], | ||||
|                                                        action['parameters']) | ||||
|                 action_return.thought = language | ||||
|                 agent_return.actions.append(action_return) | ||||
|             inner_history.append(dict(role='language', content=language)) | ||||
|             if not name or action_return.type == executor.finish_action.name: | ||||
|                 agent_return.response = language | ||||
|                 agent_return.state = AgentStatusCode.END | ||||
|                 break | ||||
|             else: | ||||
|                 inner_history.append( | ||||
|                     dict(role='tool', content=action, name=name)) | ||||
|                 inner_history.append( | ||||
|                     self._protocol.format_response(action_return, name=name)) | ||||
|         agent_return.inner_steps = inner_history[offset:] | ||||
|         return agent_return | ||||
|  | ||||
|     def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             message = dict(role='user', content=message) | ||||
|         if isinstance(message, dict): | ||||
|             message = [message] | ||||
|         inner_history = message[:] | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|         last_agent_state = AgentStatusCode.SESSION_READY | ||||
|         for _ in range(self.max_turn): | ||||
|             # list of dict | ||||
|             prompt = self._protocol.format( | ||||
|                 inner_step=inner_history, | ||||
|                 plugin_executor=self._action_executor, | ||||
|                 interpreter_executor=self._interpreter_executor, | ||||
|             ) | ||||
|             response = '' | ||||
|             for model_state, res, _ in self._llm.stream_chat(prompt, **kwargs): | ||||
|                 model_state: ModelStatusCode | ||||
|                 response = res | ||||
|                 if model_state.value < 0: | ||||
|                     agent_return.state = getattr(AgentStatusCode, | ||||
|                                                  model_state.name) | ||||
|                     yield deepcopy(agent_return) | ||||
|                     return | ||||
|                 else: | ||||
|                     name, language, action = self._protocol.parse( | ||||
|                         message=response, | ||||
|                         plugin_executor=self._action_executor, | ||||
|                         interpreter_executor=self._interpreter_executor, | ||||
|                     ) | ||||
|                     if name: | ||||
|                         if model_state == ModelStatusCode.END: | ||||
|                             agent_state = last_agent_state + 1 | ||||
|                             if name == 'plugin': | ||||
|                                 if self._action_executor: | ||||
|                                     executor = self._action_executor | ||||
|                                 else: | ||||
|                                     logging.info( | ||||
|                                         msg='No plugin is instantiated!') | ||||
|                                     continue | ||||
|                                 try: | ||||
|                                     action = json.loads(action) | ||||
|                                 except Exception as e: | ||||
|                                     logging.info(msg=f'Invaild action {e}') | ||||
|                                     continue | ||||
|                             elif name == 'interpreter': | ||||
|                                 if self._interpreter_executor: | ||||
|                                     executor = self._interpreter_executor | ||||
|                                 else: | ||||
|                                     logging.info( | ||||
|                                         msg='No interpreter is instantiated!') | ||||
|                                     continue | ||||
|                             agent_return.state = agent_state | ||||
|                             agent_return.response = action | ||||
|                         else: | ||||
|                             agent_state = ( | ||||
|                                 AgentStatusCode.PLUGIN_START if name | ||||
|                                 == 'plugin' else AgentStatusCode.CODING) | ||||
|                             if agent_state != last_agent_state: | ||||
|                                 # agent_return.state = agent_state | ||||
|                                 agent_return.response = language | ||||
|                                 yield deepcopy(agent_return) | ||||
|                             agent_return.state = agent_state | ||||
|                             agent_return.response = action | ||||
|                     else: | ||||
|                         agent_state = AgentStatusCode.STREAM_ING | ||||
|                         agent_return.state = agent_state | ||||
|                         agent_return.response = language | ||||
|                     last_agent_state = agent_state | ||||
|                     yield deepcopy(agent_return) | ||||
|             if name: | ||||
|                 action_return: ActionReturn = executor(action['name'], | ||||
|                                                        action['parameters']) | ||||
|                 action_return.thought = language | ||||
|                 agent_return.actions.append(action_return) | ||||
|             inner_history.append(dict(role='language', content=language)) | ||||
|             if not name: | ||||
|                 agent_return.response = language | ||||
|                 break | ||||
|             elif action_return.type == executor.finish_action.name: | ||||
|                 try: | ||||
|                     response = action_return.args['text']['response'] | ||||
|                 except Exception: | ||||
|                     logging.info(msg='Unable to parse FinishAction.') | ||||
|                     response = '' | ||||
|                 agent_return.response = response | ||||
|                 break | ||||
|             else: | ||||
|                 inner_history.append( | ||||
|                     dict(role='tool', content=action, name=name)) | ||||
|                 inner_history.append( | ||||
|                     self._protocol.format_response(action_return, name=name)) | ||||
|                 agent_state += 1 | ||||
|                 agent_return.state = agent_state | ||||
|                 yield agent_return | ||||
|         agent_return.inner_steps = deepcopy(inner_history[offset:]) | ||||
|         agent_return.state = AgentStatusCode.END | ||||
|         yield agent_return | ||||
| @@ -1,4 +1,3 @@ | ||||
| import copy | ||||
| from typing import Dict, List, Tuple, Union | ||||
|  | ||||
| from lagent.actions import ActionExecutor | ||||
| @@ -43,7 +42,7 @@ To use a tool, please use the following format: | ||||
| The response after utilizing tools should using the following format: | ||||
| ``` | ||||
| {response}the results after call the tool. | ||||
| `` | ||||
| ``` | ||||
| If you already know the answer, or you do not need to use tools, | ||||
| please using the following format to reply: | ||||
| ``` | ||||
| @@ -170,20 +169,22 @@ class ReActProtocol: | ||||
|         action_input = arg_match[-1] | ||||
|         return thought, action.strip(), action_input.strip().strip('"') | ||||
|  | ||||
|     def format_response(self, action_return: ActionReturn) -> str: | ||||
|         """format the final response at current step. | ||||
|     def format_response(self, action_return: ActionReturn) -> dict: | ||||
|         """Format the final response at current step. | ||||
|  | ||||
|         Args: | ||||
|             action_return (ActionReturn): return value of the current action. | ||||
|  | ||||
|         Returns: | ||||
|             str: the final response at current step. | ||||
|             dict: the final response at current step. | ||||
|         """ | ||||
|         if action_return.state == ActionStatusCode.SUCCESS: | ||||
|             response = action_return.result['text'] | ||||
|             response = action_return.format_result() | ||||
|         else: | ||||
|             response = action_return.errmsg | ||||
|         return self.response['begin'] + response + self.response['end'] | ||||
|         return dict( | ||||
|             role='system', | ||||
|             content=self.response['begin'] + response + self.response['end']) | ||||
|  | ||||
|  | ||||
| class ReAct(BaseAgent): | ||||
| @@ -210,20 +211,27 @@ class ReAct(BaseAgent): | ||||
|         super().__init__( | ||||
|             llm=llm, action_executor=action_executor, protocol=protocol) | ||||
|  | ||||
|     def chat(self, message: str) -> AgentReturn: | ||||
|         self._inner_history = [] | ||||
|         self._inner_history.append(dict(role='user', content=message)) | ||||
|     def chat(self, message: Union[str, dict, List[dict]], | ||||
|              **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             inner_history = [dict(role='user', content=message)] | ||||
|         elif isinstance(message, dict): | ||||
|             inner_history = [message] | ||||
|         elif isinstance(message, list): | ||||
|             inner_history = message[:] | ||||
|         else: | ||||
|             raise TypeError(f'unsupported type: {type(message)}') | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|         default_response = 'Sorry that I cannot answer your question.' | ||||
|         for turn in range(self.max_turn): | ||||
|             prompt = self._protocol.format( | ||||
|                 chat_history=self.session_history, | ||||
|                 inner_step=self._inner_history, | ||||
|                 chat_history=[], | ||||
|                 inner_step=inner_history, | ||||
|                 action_executor=self._action_executor, | ||||
|                 force_stop=(turn == self.max_turn - 1)) | ||||
|             response = self._llm.generate_from_template(prompt, 512) | ||||
|             self._inner_history.append( | ||||
|                 dict(role='assistant', content=response)) | ||||
|             response = self._llm.chat(prompt, **kwargs) | ||||
|             inner_history.append(dict(role='assistant', content=response)) | ||||
|             thought, action, action_input = self._protocol.parse( | ||||
|                 response, self._action_executor) | ||||
|             action_return: ActionReturn = self._action_executor( | ||||
| @@ -231,17 +239,10 @@ class ReAct(BaseAgent): | ||||
|             action_return.thought = thought | ||||
|             agent_return.actions.append(action_return) | ||||
|             if action_return.type == self._action_executor.finish_action.name: | ||||
|                 agent_return.response = action_return.result['text'] | ||||
|                 agent_return.response = action_return.format_result() | ||||
|                 break | ||||
|             self._inner_history.append( | ||||
|                 dict( | ||||
|                     role='system', | ||||
|                     content=self._protocol.format_response(action_return))) | ||||
|             inner_history.append(self._protocol.format_response(action_return)) | ||||
|         else: | ||||
|             agent_return.response = default_response | ||||
|         agent_return.inner_steps = copy.deepcopy(self._inner_history) | ||||
|         # only append the user and final response | ||||
|         self._session_history.append(dict(role='user', content=message)) | ||||
|         self._session_history.append( | ||||
|             dict(role='assistant', content=agent_return.response)) | ||||
|         agent_return.inner_steps = inner_history[offset:] | ||||
|         return agent_return | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| import copy | ||||
| import re | ||||
| import warnings | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
| @@ -192,7 +191,7 @@ class ReWOOProtocol: | ||||
|         worker_log = '' | ||||
|         for thought, action_return in zip(thought_list, action_return_list): | ||||
|             if action_return.state == ActionStatusCode.SUCCESS: | ||||
|                 action_resp = action_return.result['text'] | ||||
|                 action_resp = action_return.format_result() | ||||
|             else: | ||||
|                 action_resp = action_return.errmsg | ||||
|             worker_response = self.worker_prompt.format( | ||||
| @@ -227,9 +226,17 @@ class ReWOO(BaseAgent): | ||||
|  | ||||
|         self.max_turn = max_turn | ||||
|  | ||||
|     def chat(self, message: str) -> AgentReturn: | ||||
|         self._inner_history = [] | ||||
|         self._inner_history.append(dict(role='user', content=message)) | ||||
|     def chat(self, message: Union[str, dict, List[dict]], | ||||
|              **kwargs) -> AgentReturn: | ||||
|         if isinstance(message, str): | ||||
|             inner_history = [dict(role='user', content=message)] | ||||
|         elif isinstance(message, dict): | ||||
|             inner_history = [message] | ||||
|         elif isinstance(message, list): | ||||
|             inner_history = message[:] | ||||
|         else: | ||||
|             raise TypeError(f'unsupported type: {type(message)}') | ||||
|         offset = len(inner_history) | ||||
|         agent_return = AgentReturn() | ||||
|  | ||||
|         # planner | ||||
| @@ -237,13 +244,12 @@ class ReWOO(BaseAgent): | ||||
|         reformat_request = '' | ||||
|         while turn_id < self.max_turn: | ||||
|             planner_prompt = self._protocol.format_planner( | ||||
|                 chat_history=self.session_history, | ||||
|                 inner_step=self._inner_history, | ||||
|                 chat_history=[], | ||||
|                 inner_step=inner_history, | ||||
|                 action_executor=self._action_executor, | ||||
|                 reformat_request=reformat_request) | ||||
|             response = self._llm.generate_from_template(planner_prompt, 512) | ||||
|             self._inner_history.append( | ||||
|                 dict(role='assistant', content=response)) | ||||
|             response = self._llm.chat(planner_prompt, **kwargs) | ||||
|             inner_history.append(dict(role='assistant', content=response)) | ||||
|             try: | ||||
|                 thoughts, actions, actions_input = self._protocol.parse_worker( | ||||
|                     response) | ||||
| @@ -267,18 +273,17 @@ class ReWOO(BaseAgent): | ||||
|             for prev_ptr in prev_ptrs: | ||||
|                 ptr_num = int(prev_ptr.strip('#E')) - 1  # start from 0 | ||||
|                 actions_input[action_id] = actions_input[action_id].replace( | ||||
|                     prev_ptr, action_responses[ptr_num].result['text']) | ||||
|                     prev_ptr, action_responses[ptr_num].format_result()) | ||||
|             action_return: ActionReturn = self._action_executor( | ||||
|                 actions[action_id], actions_input[action_id]) | ||||
|             action_responses.append(action_return) | ||||
|  | ||||
|         solver_prompt, worker_log = self._protocol.format_solver( | ||||
|             message, thoughts, action_responses) | ||||
|         self._inner_history.append(dict(role='system', content=worker_log)) | ||||
|         inner_history.append(dict(role='system', content=worker_log)) | ||||
|  | ||||
|         final_response = self._llm.generate_from_template(solver_prompt, 512) | ||||
|         self._inner_history.append( | ||||
|             dict(role='assistant', content=final_response)) | ||||
|         agent_return.inner_steps = copy.deepcopy(self._inner_history) | ||||
|         final_response = self._llm.chat(solver_prompt, **kwargs) | ||||
|         inner_history.append(dict(role='assistant', content=final_response)) | ||||
|         agent_return.inner_steps = inner_history[offset:] | ||||
|         agent_return.response = final_response | ||||
|         return agent_return | ||||
|   | ||||
| @@ -1,14 +1,13 @@ | ||||
| from lagent.utils import is_module_exist | ||||
| from .base_api import BaseAPIModel | ||||
| from .base_llm import BaseModel | ||||
| from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat | ||||
| from .lmdepoly_wrapper import LMDeployClient, LMDeployPipeline, LMDeployServer | ||||
| from .meta_template import INTERNLM2_META | ||||
| from .openai import GPTAPI | ||||
| from .vllm_wrapper import VllmModel | ||||
|  | ||||
| __all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI'] | ||||
|  | ||||
| if is_module_exist('transformers'): | ||||
|     from .huggingface import HFTransformer, HFTransformerCasualLM  # noqa: F401 | ||||
|     __all__.extend(['HFTransformer', 'HFTransformerCasualLM']) | ||||
|  | ||||
| if is_module_exist('lmdeploy'): | ||||
|     from .lmdeploy import TritonClient, TurboMind  # noqa: F401 | ||||
|     __all__.extend(['TritonClient', 'TurboMind']) | ||||
| __all__ = [ | ||||
|     'BaseModel', 'BaseAPIModel', 'GPTAPI', 'LMDeployClient', | ||||
|     'LMDeployPipeline', 'LMDeployServer', 'HFTransformer', | ||||
|     'HFTransformerCasualLM', 'INTERNLM2_META', 'HFTransformerChat', 'VllmModel' | ||||
| ] | ||||
|   | ||||
| @@ -1,92 +1,11 @@ | ||||
| import re | ||||
| import threading | ||||
| import warnings | ||||
| from abc import abstractclassmethod | ||||
| from time import sleep | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
|  | ||||
| from .base_llm import BaseModel | ||||
|  | ||||
|  | ||||
| class BaseAPIModel(BaseModel): | ||||
|     """Base class for API model wrapper. | ||||
|  | ||||
|     Args: | ||||
|         model_type (str): The type of model. | ||||
|         query_per_second (int): The maximum queries allowed per second | ||||
|             between two consecutive calls of the API. Defaults to 1. | ||||
|         retry (int): Number of retires if the API call fails. Defaults to 2. | ||||
|         max_seq_len (int): The maximum sequence length of the model. Defaults | ||||
|             to 2048. | ||||
|         meta_template (Dict, optional): The model's meta prompt | ||||
|             template if needed, in case the requirement of injecting or | ||||
|             wrapping of any meta instructions. | ||||
|     """ | ||||
|  | ||||
|     is_api: bool = True | ||||
|  | ||||
|     def __init__(self, | ||||
|                  model_type: str, | ||||
|                  query_per_second: int = 1, | ||||
|                  retry: int = 2, | ||||
|                  max_seq_len: int = 2048, | ||||
|                  meta_template: Optional[Dict] = None): | ||||
|         self.model_type = model_type | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.meta_template = meta_template | ||||
|         self.retry = retry | ||||
|         self.query_per_second = query_per_second | ||||
|         self.token_bucket = TokenBucket(query_per_second) | ||||
|         self.template_parser = APITemplateParser(meta_template) | ||||
|  | ||||
|     @abstractclassmethod | ||||
|     def generate(self, inputs, max_out_len: int) -> List[str]: | ||||
|         """Generate results given a list of inputs. | ||||
|  | ||||
|         Args: | ||||
|             inputs (List[str or list]): A list of strings or PromptDicts. | ||||
|                 The PromptDict should be organized in OpenCompass' | ||||
|                 API format. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|  | ||||
|         Returns: | ||||
|             List[str]: A list of generated strings. | ||||
|         """ | ||||
|  | ||||
|     def get_token_len(self, prompt: str) -> int: | ||||
|         """Get lengths of the tokenized string. Only English and Chinese | ||||
|         characters are counted for now. Users are encouraged to override this | ||||
|         method if more accurate length is needed. | ||||
|  | ||||
|         Args: | ||||
|             prompt (str): Input string. | ||||
|  | ||||
|         Returns: | ||||
|             int: Length of the input tokens | ||||
|         """ | ||||
|  | ||||
|         english_parts = re.findall(r'[A-Za-z0-9]+', prompt) | ||||
|         chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) | ||||
|  | ||||
|         # Count English words | ||||
|         english_count = sum(len(part.split()) for part in english_parts) | ||||
|  | ||||
|         # Count Chinese words | ||||
|         chinese_count = sum(len(part) for part in chinese_parts) | ||||
|  | ||||
|         return english_count + chinese_count | ||||
|  | ||||
|     def wait(self): | ||||
|         """Wait till the next query can be sent. | ||||
|  | ||||
|         Applicable in both single-thread and multi-thread environments. | ||||
|         """ | ||||
|         return self.token_bucket.get_token() | ||||
|  | ||||
|     def to(self, device): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class APITemplateParser: | ||||
|     """Intermidate prompt template parser, specifically for API models. | ||||
|  | ||||
| @@ -106,7 +25,7 @@ class APITemplateParser: | ||||
|                     'role in meta prompt must be unique!' | ||||
|                 self.roles[item['role']] = item.copy() | ||||
|  | ||||
|     def parse_template(self, dialog: List[Union[str, List]]): | ||||
|     def __call__(self, dialog: List[Union[str, List]]): | ||||
|         """Parse the intermidate prompt template, and wrap it with meta | ||||
|         template if applicable. When the meta template is set and the input is | ||||
|         a list, the return value will be a list containing the full | ||||
| @@ -199,11 +118,10 @@ class APITemplateParser: | ||||
|         return res | ||||
|  | ||||
|     def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]: | ||||
|  | ||||
|         merged_prompt = self.roles.get( | ||||
|             role_prompt['role'], | ||||
|             self.roles.get( | ||||
|                 self.roles[role_prompt['role']].get('fallback_role'))) | ||||
|         merged_prompt = self.roles[role_prompt['role']] | ||||
|         if merged_prompt.get('fallback_role'): | ||||
|             merged_prompt = self.roles[self.roles[ | ||||
|                 merged_prompt['fallback_role']]] | ||||
|         res = {} | ||||
|         res['role'] = merged_prompt['api_role'] | ||||
|         res['content'] = merged_prompt.get('begin', '') | ||||
| @@ -212,6 +130,60 @@ class APITemplateParser: | ||||
|         return res | ||||
|  | ||||
|  | ||||
| class BaseAPIModel(BaseModel): | ||||
|     """Base class for API model wrapper. | ||||
|  | ||||
|     Args: | ||||
|         model_type (str): The type of model. | ||||
|         query_per_second (int): The maximum queries allowed per second | ||||
|             between two consecutive calls of the API. Defaults to 1. | ||||
|         retry (int): Number of retires if the API call fails. Defaults to 2. | ||||
|         meta_template (Dict, optional): The model's meta prompt | ||||
|             template if needed, in case the requirement of injecting or | ||||
|             wrapping of any meta instructions. | ||||
|     """ | ||||
|  | ||||
|     is_api: bool = True | ||||
|  | ||||
|     def __init__(self, | ||||
|                  model_type: str, | ||||
|                  query_per_second: int = 1, | ||||
|                  retry: int = 2, | ||||
|                  template_parser: 'APITemplateParser' = APITemplateParser, | ||||
|                  meta_template: Optional[Dict] = None, | ||||
|                  *, | ||||
|                  max_new_tokens: int = 512, | ||||
|                  top_p: float = 0.8, | ||||
|                  top_k: float = None, | ||||
|                  temperature: float = 0.8, | ||||
|                  repetition_penalty: float = 0.0, | ||||
|                  stop_words: Union[List[str], str] = None): | ||||
|         self.model_type = model_type | ||||
|         self.meta_template = meta_template | ||||
|         self.retry = retry | ||||
|         self.query_per_second = query_per_second | ||||
|         self.token_bucket = TokenBucket(query_per_second) | ||||
|         if template_parser: | ||||
|             self.template_parser = template_parser(meta_template) | ||||
|  | ||||
|         if isinstance(stop_words, str): | ||||
|             stop_words = [stop_words] | ||||
|         self.gen_params = dict( | ||||
|             max_new_tokens=max_new_tokens, | ||||
|             top_p=top_p, | ||||
|             top_k=top_k, | ||||
|             temperature=temperature, | ||||
|             repetition_penalty=repetition_penalty, | ||||
|             stop_words=stop_words) | ||||
|  | ||||
|     def _wait(self): | ||||
|         """Wait till the next query can be sent. | ||||
|  | ||||
|         Applicable in both single-thread and multi-thread environments. | ||||
|         """ | ||||
|         return self.token_bucket.get_token() | ||||
|  | ||||
|  | ||||
| class TokenBucket: | ||||
|     """A token bucket for rate limiting. | ||||
|  | ||||
|   | ||||
| @@ -1,75 +1,6 @@ | ||||
| from abc import abstractclassmethod | ||||
| from copy import copy | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
|  | ||||
|  | ||||
| class BaseModel: | ||||
|     """Base class for model wrapper. | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|         max_seq_len (int): The maximum sequence length of the model. Defaults | ||||
|             to 2048. | ||||
|         tokenizer_only (bool): If True, only the tokenizer will be initialized. | ||||
|             Defaults to False. | ||||
|         meta_template (list of dict, optional): The model's meta prompt | ||||
|             template if needed, in case the requirement of injecting or | ||||
|             wrapping of any meta instructions. | ||||
|     """ | ||||
|  | ||||
|     is_api: bool = False | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  max_seq_len: int = 2048, | ||||
|                  tokenizer_only: bool = False, | ||||
|                  meta_template: Optional[List[Dict]] = None): | ||||
|         self.path = path | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.tokenizer_only = tokenizer_only | ||||
|         # meta template | ||||
|         self.template_parser = LMTemplateParser(meta_template) | ||||
|         self.eos_token_id = None | ||||
|         if meta_template and 'eos_token_id' in meta_template: | ||||
|             self.eos_token_id = meta_template['eos_token_id'] | ||||
|  | ||||
|     @abstractclassmethod | ||||
|     def generate(self, inputs: List[str], max_out_len: int) -> List[str]: | ||||
|         """Generate results given a list of inputs. | ||||
|  | ||||
|         Args: | ||||
|             inputs (List[str]): A list of strings. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|  | ||||
|         Returns: | ||||
|             List[str]: A list of generated strings. | ||||
|         """ | ||||
|  | ||||
|     def parse_template(self, dialog) -> str: | ||||
|         """Parse a prompt template, and wrap it with meta template if | ||||
|         applicable. | ||||
|  | ||||
|         Args: | ||||
|             dialog (List[str or PromptList]): A prompt | ||||
|                 template (potentially before being wrapped by meta template). | ||||
|             mode (str): Parsing mode. Choices are 'ppl' and 'gen'. | ||||
|  | ||||
|         Returns: | ||||
|             str: The final string. | ||||
|         """ | ||||
|         return self.template_parser.parse_template(dialog) | ||||
|  | ||||
|     def generate_from_template(self, templates, max_out_len: int, **kwargs): | ||||
|         """Generate completion from a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             templates (List[PromptType]): A list of templates. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|         """ | ||||
|         inputs = self.parse_template(templates) | ||||
|         return self.generate(inputs, max_out_len=max_out_len, **kwargs) | ||||
|  | ||||
|     def to(self, device): | ||||
|         self.model.to(device) | ||||
| from warnings import warn | ||||
|  | ||||
|  | ||||
| class LMTemplateParser: | ||||
| @@ -91,7 +22,7 @@ class LMTemplateParser: | ||||
|                     'role in meta prompt must be unique!' | ||||
|                 self.roles[item['role']] = item.copy() | ||||
|  | ||||
|     def parse_template(self, dialog) -> str: | ||||
|     def __call__(self, dialog) -> str: | ||||
|         """Parse a prompt template, and wrap it with meta template if | ||||
|         applicable. | ||||
|  | ||||
| @@ -127,20 +58,171 @@ class LMTemplateParser: | ||||
|                 last_sep = '\n' | ||||
|         return prompt | ||||
|  | ||||
|     def _format_begin(self, role_cfg, message): | ||||
|         name = message.get('name', None) | ||||
|         if name is not None: | ||||
|             begin = role_cfg['begin'].get('with_name', '') | ||||
|             if name in role_cfg['begin'].get('name', {}): | ||||
|                 begin = begin.format(name=role_cfg['begin']['name'][name]) | ||||
|             else: | ||||
|                 begin = begin.format(name=name) | ||||
|         else: | ||||
|             if isinstance(role_cfg.get('begin', ''), str): | ||||
|                 begin = role_cfg.get('begin', '') | ||||
|             elif isinstance(role_cfg['begin'], dict): | ||||
|                 begin = role_cfg['begin'].get('without_name', '') | ||||
|         return begin | ||||
|  | ||||
|     def _prompt2str(self, | ||||
|                     prompt: Union[str, Dict], | ||||
|                     last: bool = False) -> Tuple[str, bool]: | ||||
|         if isinstance(prompt, str): | ||||
|             return prompt | ||||
|         merged_prompt = self.roles.get( | ||||
|             prompt['role'], | ||||
|             self.roles.get(self.roles[prompt['role']].get('fallback_role'))) | ||||
|         res = merged_prompt.get('begin', '') | ||||
|         merged_prompt = self.roles.get(prompt['role']) | ||||
|  | ||||
|         if merged_prompt.get('fallback_role'): | ||||
|             merged_prompt = self.roles.get(merged_prompt['fallback_role']) | ||||
|         begin = self._format_begin(merged_prompt, prompt) | ||||
|         res = begin | ||||
|         if last and merged_prompt.get('generate', False): | ||||
|             res += prompt.get('content', '') | ||||
|             return res | ||||
|         res += prompt.get('content', '') + merged_prompt.get('end', '') | ||||
|         if last and merged_prompt['role'] != 'assistant': | ||||
|             res += self.roles['assistant']['begin'] | ||||
|             res += self._format_begin(self.roles['assistant'], {}) | ||||
|             return res | ||||
|         return res | ||||
|  | ||||
|  | ||||
| class BaseModel: | ||||
|     """Base class for model wrapper. | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|         max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults | ||||
|             to 512. | ||||
|         tokenizer_only (bool): If True, only the tokenizer will be initialized. | ||||
|             Defaults to False. | ||||
|         meta_template (list of dict, optional): The model's meta prompt | ||||
|             template if needed, in case the requirement of injecting or | ||||
|             wrapping of any meta instructions. | ||||
|     """ | ||||
|  | ||||
|     is_api: bool = False | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  tokenizer_only: bool = False, | ||||
|                  template_parser: 'LMTemplateParser' = LMTemplateParser, | ||||
|                  meta_template: Optional[List[Dict]] = None, | ||||
|                  *, | ||||
|                  max_new_tokens: int = 512, | ||||
|                  top_p: float = 0.8, | ||||
|                  top_k: float = None, | ||||
|                  temperature: float = 0.8, | ||||
|                  repetition_penalty: float = 1.0, | ||||
|                  stop_words: Union[List[str], str] = None): | ||||
|         self.path = path | ||||
|         self.tokenizer_only = tokenizer_only | ||||
|         # meta template | ||||
|         self.template_parser = template_parser(meta_template) | ||||
|         self.eos_token_id = None | ||||
|         if meta_template and 'eos_token_id' in meta_template: | ||||
|             self.eos_token_id = meta_template['eos_token_id'] | ||||
|  | ||||
|         if isinstance(stop_words, str): | ||||
|             stop_words = [stop_words] | ||||
|         self.gen_params = dict( | ||||
|             max_new_tokens=max_new_tokens, | ||||
|             top_p=top_p, | ||||
|             top_k=top_k, | ||||
|             temperature=temperature, | ||||
|             repetition_penalty=repetition_penalty, | ||||
|             stop_words=stop_words) | ||||
|  | ||||
|     def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: | ||||
|         """Generate results given a str (or list of) inputs. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[str, List[str]]): | ||||
|             gen_params (dict): The input params for generation. | ||||
|  | ||||
|         Returns: | ||||
|             Union[str, List[str]]: A (list of) generated strings. | ||||
|  | ||||
|         eg. | ||||
|             batched = True | ||||
|             if isinstance(inputs, str): | ||||
|                 inputs = [inputs] | ||||
|                 batched = False | ||||
|             response = [''] | ||||
|             if batched: | ||||
|                 return response | ||||
|             return response[0] | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def stream_generate(self, inputs: str, **gen_params) -> List[str]: | ||||
|         """Generate results as streaming given a str inputs. | ||||
|  | ||||
|         Args: | ||||
|             inputs (str): | ||||
|             gen_params (dict): The input params for generation. | ||||
|  | ||||
|         Returns: | ||||
|             str: A generated string. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params): | ||||
|         """Generate completion from a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[List[dict], List[List[dict]]]): | ||||
|             gen_params (dict): The input params for generation. | ||||
|         Returns: | ||||
|         """ | ||||
|         if isinstance(inputs[0], list): | ||||
|             _inputs = list() | ||||
|             for msg in inputs: | ||||
|                 _inputs.append(self.template_parser(msg)) | ||||
|         else: | ||||
|             _inputs = self.template_parser(inputs) | ||||
|         return self.generate(_inputs, **gen_params) | ||||
|  | ||||
|     def generate_from_template(self, inputs: Union[List[dict], | ||||
|                                                    List[List[dict]]], | ||||
|                                **gen_params): | ||||
|         warn( | ||||
|             'This function will be deprecated after three months' | ||||
|             'and will be replaced. Please use `.chat()`', DeprecationWarning, | ||||
|             2) | ||||
|         return self.chat(inputs, **gen_params) | ||||
|  | ||||
|     def stream_chat(self, inputs: List[dict], **gen_params): | ||||
|         """Generate results as streaming given a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[List[dict]): | ||||
|             gen_params (dict): The input params for generation. | ||||
|         Returns: | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def tokenize(self, prompts: Union[str, List[str], List[dict], | ||||
|                                       List[List[dict]]]): | ||||
|         """Tokenize the input prompts. | ||||
|  | ||||
|         Args: | ||||
|             prompts(str | List[str]): user's prompt, or a batch prompts | ||||
|  | ||||
|         Returns: | ||||
|             Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token | ||||
|             ids, ids' length and requested output length | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def update_gen_params(self, **kwargs): | ||||
|         gen_params = copy(self.gen_params) | ||||
|         gen_params.update(kwargs) | ||||
|         return gen_params | ||||
|   | ||||
| @@ -1,20 +1,24 @@ | ||||
| from typing import Dict, List, Optional | ||||
|  | ||||
| import torch | ||||
| import copy | ||||
| import logging | ||||
| from typing import Dict, List, Optional, Union | ||||
|  | ||||
| from lagent.schema import ModelStatusCode | ||||
| from .base_api import APITemplateParser | ||||
| from .base_llm import BaseModel | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class HFTransformer(BaseModel): | ||||
|     """Model wrapper around HuggingFace general models. | ||||
|  | ||||
|     Adapted from OpenCompass (https://github.com/InternLM/opencompass | ||||
|     /blob/main/opencompass/models/huggingface.py) | ||||
|     Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ | ||||
|         chat/web_demo.py) | ||||
|  | ||||
|     Args: | ||||
|         path (str): The name or path to HuggingFace's model. | ||||
|         max_seq_len (int): The maximum length of the input sequence. Defaults | ||||
|             to 2048. | ||||
|         tokenizer_path (str): The path to the tokenizer. Defaults to None. | ||||
|         tokenizer_kwargs (dict): Keyword arguments for the tokenizer. | ||||
|             Defaults to {}. | ||||
| @@ -25,52 +29,50 @@ class HFTransformer(BaseModel): | ||||
|         meta_template (Dict, optional): The model's meta prompt | ||||
|             template if needed, in case the requirement of injecting or | ||||
|             wrapping of any meta instructions. | ||||
|         extract_pred_after_decode (bool): Whether to extract the prediction | ||||
|             string from the decoded output string, instead of extract the | ||||
|             prediction tokens before decoding. Defaults to False. | ||||
|         batch_padding (bool): If False, inference with be performed in for-loop | ||||
|             without batch padding. | ||||
|  | ||||
|     Note: | ||||
|         About ``extract_pred_after_decode``: Commonly, we should extract the | ||||
|         the prediction tokens before decoding. But for some tokenizers using | ||||
|         ``sentencepiece``, like LLaMA,  this behavior may change the number of | ||||
|         whitespaces, which is harmful for Python programming tasks. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|             self, | ||||
|             path: str, | ||||
|             max_seq_len: int = 2048, | ||||
|             tokenizer_path: Optional[str] = None, | ||||
|             tokenizer_kwargs: dict = dict(), | ||||
|             tokenizer_only: bool = False, | ||||
|             model_kwargs: dict = dict(device_map='auto'), | ||||
|             meta_template: Optional[Dict] = [ | ||||
|                 dict(role='system', begin='<|System|>:', end='\n'), | ||||
|                 dict(role='user', begin='<|User|>:', end='\n'), | ||||
|                 dict( | ||||
|                     role='assistant', | ||||
|                     begin='<|Bot|>:', | ||||
|                     end='<eoa>\n', | ||||
|                     generate=True) | ||||
|             ],  # default meta template for InternLM-7b | ||||
|             extract_pred_after_decode: bool = False, | ||||
|             batch_padding: bool = False): | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  tokenizer_path: Optional[str] = None, | ||||
|                  tokenizer_kwargs: dict = dict(), | ||||
|                  tokenizer_only: bool = False, | ||||
|                  model_kwargs: dict = dict(device_map='auto'), | ||||
|                  meta_template: Optional[Dict] = None, | ||||
|                  stop_words_id: Union[List[int], int] = None, | ||||
|                  **kwargs): | ||||
|         super().__init__( | ||||
|             path=path, | ||||
|             max_seq_len=max_seq_len, | ||||
|             tokenizer_only=tokenizer_only, | ||||
|             meta_template=meta_template) | ||||
|             meta_template=meta_template, | ||||
|             **kwargs) | ||||
|         if isinstance(stop_words_id, int): | ||||
|             stop_words_id = [stop_words_id] | ||||
|         self.gen_params.update(stop_words_id=stop_words_id) | ||||
|         if self.gen_params['stop_words'] is not None and \ | ||||
|                 self.gen_params['stop_words_id'] is not None: | ||||
|             logger.warning('Both stop_words and stop_words_id are specified,' | ||||
|                            'only stop_words_id will be used.') | ||||
|  | ||||
|         self._load_tokenizer( | ||||
|             path=path, | ||||
|             tokenizer_path=tokenizer_path, | ||||
|             tokenizer_kwargs=tokenizer_kwargs) | ||||
|         self.batch_padding = batch_padding | ||||
|         self.extract_pred_after_decode = extract_pred_after_decode | ||||
|         if not tokenizer_only: | ||||
|             self._load_model(path=path, model_kwargs=model_kwargs) | ||||
|  | ||||
|         from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList  # noqa: E501 | ||||
|         self.logits_processor = LogitsProcessorList() | ||||
|         self.stopping_criteria = StoppingCriteriaList() | ||||
|         self.prefix_allowed_tokens_fn = None | ||||
|  | ||||
|         stop_words_id = [] | ||||
|         if self.gen_params.get('stop_words_id'): | ||||
|             stop_words_id = self.gen_params.get('stop_words_id') | ||||
|         elif self.gen_params.get('stop_words'): | ||||
|             for sw in self.gen_params.get('stop_words'): | ||||
|                 stop_words_id.append(self.tokenizer(sw)['input_ids'][-1]) | ||||
|         self.additional_eos_token_id = stop_words_id | ||||
|  | ||||
|     def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], | ||||
|                         tokenizer_kwargs: dict): | ||||
|         from transformers import AutoTokenizer | ||||
| @@ -78,61 +80,260 @@ class HFTransformer(BaseModel): | ||||
|             tokenizer_path if tokenizer_path else path, | ||||
|             trust_remote_code=True, | ||||
|             **tokenizer_kwargs) | ||||
|  | ||||
|         if self.tokenizer.pad_token_id is None: | ||||
|             self.tokenizer.pad_token = self.tokenizer.eos_token | ||||
|             if self.tokenizer.eos_token is not None: | ||||
|                 logger.warning( | ||||
|                     f'Using eos_token_id {self.tokenizer.eos_token} ' | ||||
|                     'as pad_token_id.') | ||||
|                 self.tokenizer.pad_token = self.tokenizer.eos_token | ||||
|             else: | ||||
|                 from transformers.generation import GenerationConfig | ||||
|                 self.gcfg = GenerationConfig.from_pretrained(path) | ||||
|  | ||||
|                 if self.gcfg.pad_token_id is not None: | ||||
|                     logger.warning( | ||||
|                         f'Using pad_token_id {self.gcfg.pad_token_id} ' | ||||
|                         'as pad_token_id.') | ||||
|                     self.tokenizer.pad_token_id = self.gcfg.pad_token_id | ||||
|                 else: | ||||
|                     raise ValueError( | ||||
|                         'pad_token_id is not set for this tokenizer. Try to ' | ||||
|                         'set pad_token_id via passing ' | ||||
|                         '`pad_token_id={PAD_TOKEN_ID}` in model_cfg.') | ||||
|  | ||||
|     def _load_model(self, path: str, model_kwargs: dict): | ||||
|         import torch | ||||
|         from transformers import AutoModel | ||||
|         model_kwargs.setdefault('torch_dtype', torch.float16) | ||||
|         self.model = AutoModel.from_pretrained( | ||||
|             path, trust_remote_code=True, **model_kwargs) | ||||
|         self.model.eval() | ||||
|  | ||||
|     def generate(self, inputs: List[str], max_out_len: int, | ||||
|                  **kwargs) -> List[str]: | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|         if self.extract_pred_after_decode: | ||||
|             prompt_lens = [len(input_) for input_ in inputs] | ||||
|     def tokenize(self, inputs: str): | ||||
|         assert isinstance(inputs, str) | ||||
|         inputs = self.tokenizer( | ||||
|             inputs, return_tensors='pt', return_length=True) | ||||
|         return inputs['input_ids'].tolist() | ||||
|  | ||||
|         input_ids = self.tokenizer( | ||||
|             inputs, truncation=True, | ||||
|             max_length=self.max_seq_len - max_out_len)['input_ids'] | ||||
|         input_ids = torch.tensor(input_ids, device=self.model.device) | ||||
|         outputs = self.model.generate( | ||||
|             input_ids=input_ids, max_new_tokens=max_out_len, **kwargs) | ||||
|  | ||||
|         if not self.extract_pred_after_decode: | ||||
|             outputs = outputs[:, input_ids.shape[1]:] | ||||
|  | ||||
|         decodeds = self.tokenizer.batch_decode( | ||||
|             outputs, skip_special_tokens=True) | ||||
|         if self.extract_pred_after_decode: | ||||
|             decodeds = [ | ||||
|                 token[len_:] for token, len_ in zip(decodeds, prompt_lens) | ||||
|             ] | ||||
|  | ||||
|         return decodeds[0] | ||||
|  | ||||
|     def generate_from_template(self, templates, max_out_len: int, **kwargs): | ||||
|         """Generate completion from a list of templates. | ||||
|     def generate( | ||||
|         self, | ||||
|         inputs: Union[str, List[str]], | ||||
|         do_sample: bool = True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """Return the chat completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             templates (List[PromptType]): A list of templates. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|             inputs (Union[str, List[str]]): input texts to be completed. | ||||
|             do_sample (bool): do sampling if enabled | ||||
|         Returns: | ||||
|             (a list of/batched) text/chat completion | ||||
|         """ | ||||
|         inputs = self.parse_template(templates) | ||||
|         response = self.generate(inputs, max_out_len=max_out_len, **kwargs) | ||||
|         return response.replace( | ||||
|             self.template_parser.roles['assistant']['end'].strip(), | ||||
|             '').strip() | ||||
|         for status, chunk, _ in self.stream_generate(inputs, do_sample, | ||||
|                                                      **kwargs): | ||||
|             response = chunk | ||||
|         return response | ||||
|  | ||||
|     def stream_generate( | ||||
|         self, | ||||
|         inputs: List[str], | ||||
|         do_sample: bool = True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """Return the chat completions in stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[str, List[str]]): input texts to be completed. | ||||
|             do_sample (bool): do sampling if enabled | ||||
|         Returns: | ||||
|             tuple(Status, str, int): status, text/chat completion, | ||||
|             generated token number | ||||
|         """ | ||||
|         import torch | ||||
|         from torch import nn | ||||
|         with torch.no_grad(): | ||||
|             batched = True | ||||
|             if isinstance(inputs, str): | ||||
|                 inputs = [inputs] | ||||
|                 batched = False | ||||
|             inputs = self.tokenizer( | ||||
|                 inputs, padding=True, return_tensors='pt', return_length=True) | ||||
|             input_length = inputs['length'] | ||||
|             for k, v in inputs.items(): | ||||
|                 inputs[k] = v.cuda() | ||||
|             input_ids = inputs['input_ids'] | ||||
|             attention_mask = inputs['attention_mask'] | ||||
|             batch_size = input_ids.shape[0] | ||||
|             input_ids_seq_length = input_ids.shape[-1] | ||||
|             generation_config = self.model.generation_config | ||||
|             generation_config = copy.deepcopy(generation_config) | ||||
|             new_gen_params = self.update_gen_params(**kwargs) | ||||
|             generation_config.update(**new_gen_params) | ||||
|             generation_config.update(**kwargs) | ||||
|             model_kwargs = generation_config.to_dict() | ||||
|             model_kwargs['attention_mask'] = attention_mask | ||||
|             _, eos_token_id = (  # noqa: F841  # pylint: disable=W0612 | ||||
|                 generation_config.bos_token_id, | ||||
|                 generation_config.eos_token_id, | ||||
|             ) | ||||
|             if eos_token_id is None: | ||||
|                 if self.gcfg.eos_token_id is not None: | ||||
|                     eos_token_id = self.gcfg.eos_token_id | ||||
|                 else: | ||||
|                     eos_token_id = [] | ||||
|             if isinstance(eos_token_id, int): | ||||
|                 eos_token_id = [eos_token_id] | ||||
|             if self.additional_eos_token_id is not None: | ||||
|                 eos_token_id.extend(self.additional_eos_token_id) | ||||
|             eos_token_id_tensor = torch.tensor(eos_token_id).to( | ||||
|                 input_ids.device) if eos_token_id is not None else None | ||||
|             generation_config.max_length = ( | ||||
|                 generation_config.max_new_tokens + input_ids_seq_length) | ||||
|             # Set generation parameters if not already defined | ||||
|             logits_processor = self.logits_processor | ||||
|             stopping_criteria = self.stopping_criteria | ||||
|  | ||||
|             logits_processor = self.model._get_logits_processor( | ||||
|                 generation_config=generation_config, | ||||
|                 input_ids_seq_length=input_ids_seq_length, | ||||
|                 encoder_input_ids=input_ids, | ||||
|                 prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, | ||||
|                 logits_processor=logits_processor, | ||||
|             ) | ||||
|  | ||||
|             stopping_criteria = self.model._get_stopping_criteria( | ||||
|                 generation_config=generation_config, | ||||
|                 stopping_criteria=stopping_criteria) | ||||
|             logits_warper = self.model._get_logits_warper(generation_config) | ||||
|  | ||||
|             unfinished_sequences = input_ids.new(batch_size).fill_(1) | ||||
|             scores = None | ||||
|             while True: | ||||
|                 model_inputs = self.model.prepare_inputs_for_generation( | ||||
|                     input_ids, **model_kwargs) | ||||
|                 # forward pass to get next token | ||||
|                 outputs = self.model( | ||||
|                     **model_inputs, | ||||
|                     return_dict=True, | ||||
|                     output_attentions=False, | ||||
|                     output_hidden_states=False, | ||||
|                 ) | ||||
|  | ||||
|                 next_token_logits = outputs.logits[:, -1, :] | ||||
|  | ||||
|                 # pre-process distribution | ||||
|                 next_token_scores = logits_processor(input_ids, | ||||
|                                                      next_token_logits) | ||||
|                 next_token_scores = logits_warper(input_ids, next_token_scores) | ||||
|  | ||||
|                 # sample | ||||
|                 probs = nn.functional.softmax(next_token_scores, dim=-1) | ||||
|                 if do_sample: | ||||
|                     next_tokens = torch.multinomial( | ||||
|                         probs, num_samples=1).squeeze(1) | ||||
|                 else: | ||||
|                     next_tokens = torch.argmax(probs, dim=-1) | ||||
|  | ||||
|                 # update generated ids, model inputs, | ||||
|                 # and length for next step | ||||
|                 input_ids = torch.cat([input_ids, next_tokens[:, None]], | ||||
|                                       dim=-1) | ||||
|                 model_kwargs = self.model._update_model_kwargs_for_generation(  # noqa: E501 | ||||
|                     outputs, | ||||
|                     model_kwargs, | ||||
|                     is_encoder_decoder=False) | ||||
|                 unfinished_sequences = unfinished_sequences.mul( | ||||
|                     next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( | ||||
|                         eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) | ||||
|                 output_token_ids = input_ids.cpu().tolist() | ||||
|                 for i in range(len(output_token_ids)): | ||||
|                     output_token_ids[i] = output_token_ids[i][:][ | ||||
|                         input_length[i]:] | ||||
|                     # Find the first occurrence of | ||||
|                     # an EOS token in the sequence | ||||
|                     first_eos_idx = next( | ||||
|                         (idx | ||||
|                          for idx, token_id in enumerate(output_token_ids[i]) | ||||
|                          if token_id in eos_token_id), None) | ||||
|                     # If an EOS token is found, only the previous | ||||
|                     # part of it is retained | ||||
|                     if first_eos_idx is not None: | ||||
|                         output_token_ids[i] = output_token_ids[ | ||||
|                             i][:first_eos_idx] | ||||
|  | ||||
|                 response = self.tokenizer.batch_decode(output_token_ids) | ||||
|                 # print(response) | ||||
|                 if not batched: | ||||
|                     response = response[0] | ||||
|                 yield ModelStatusCode.STREAM_ING, response, None | ||||
|                 # stop when each sentence is finished, | ||||
|                 # or if we exceed the maximum length | ||||
|                 if (unfinished_sequences.max() == 0 | ||||
|                         or stopping_criteria(input_ids, scores)): | ||||
|                     break | ||||
|             yield ModelStatusCode.END, response, None | ||||
|  | ||||
|     def stream_chat( | ||||
|         self, | ||||
|         inputs: List[dict], | ||||
|         do_sample: bool = True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """Return the chat completions in stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (List[dict]): input messages to be completed. | ||||
|             do_sample (bool): do sampling if enabled | ||||
|         Returns: | ||||
|             the text/chat completion | ||||
|         """ | ||||
|         prompt = self.template_parser(inputs) | ||||
|         yield from self.stream_generate(prompt, do_sample, **kwargs) | ||||
|  | ||||
|  | ||||
| class HFTransformerCasualLM(HFTransformer): | ||||
|  | ||||
|     def _load_model(self, path: str, model_kwargs: dict): | ||||
|         import torch | ||||
|         from transformers import AutoModelForCausalLM | ||||
|         model_kwargs.setdefault('torch_dtype', torch.float16) | ||||
|         self.model = AutoModelForCausalLM.from_pretrained( | ||||
|             path, trust_remote_code=True, **model_kwargs) | ||||
|         self.model.eval() | ||||
|  | ||||
|  | ||||
| class HFTransformerChat(HFTransformerCasualLM): | ||||
|  | ||||
|     def __init__(self, template_parser=APITemplateParser, **kwargs): | ||||
|         super().__init__(template_parser=template_parser, **kwargs) | ||||
|  | ||||
|     def chat(self, | ||||
|              inputs: Union[List[dict], List[List[dict]]], | ||||
|              do_sample: bool = True, | ||||
|              **kwargs): | ||||
|         """Return the chat completions in stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[List[dict], List[List[dict]]]): input messages to be completed. | ||||
|             do_sample (bool): do sampling if enabled | ||||
|         Returns: | ||||
|             the text/chat completion | ||||
|         """ | ||||
|         # handle batch inference with vanilla for loop | ||||
|         if isinstance(inputs[0], list): | ||||
|             resps = [] | ||||
|             for input in inputs: | ||||
|                 resps.append(self.chat(input, do_sample, **kwargs)) | ||||
|             return resps | ||||
|         prompt = self.template_parser(inputs) | ||||
|         query = prompt[-1]['content'] | ||||
|         history = prompt[:-1] | ||||
|         try: | ||||
|             response, history = self.model.chat( | ||||
|                 self.tokenizer, query, history=history) | ||||
|         except Exception as e: | ||||
|             # handle over-length input error | ||||
|             logger.warning(str(e)) | ||||
|             response = '' | ||||
|         return response | ||||
|   | ||||
| @@ -1,137 +0,0 @@ | ||||
| import dataclasses | ||||
| import os.path as osp | ||||
| import random | ||||
|  | ||||
| import lmdeploy.turbomind.chat as tm_chat | ||||
| from lmdeploy.serve.turbomind.chatbot import Chatbot, Session, get_logger | ||||
|  | ||||
| from .base_llm import BaseModel | ||||
|  | ||||
|  | ||||
| class TritonClient(Chatbot, BaseModel): | ||||
|  | ||||
|     def __init__(self, meta_template=None, **kwargs): | ||||
|         """TritonClient is a wrapper of TritonClient for LLM. | ||||
|  | ||||
|         Args: | ||||
|             model_name (str): the name of the model | ||||
|             max_out_len (int): the expected generated token numbers | ||||
|             log_level (str): log level | ||||
|         """ | ||||
|         BaseModel.__init__(self, meta_template=meta_template, path=None) | ||||
|         Chatbot.__init__(self, **kwargs) | ||||
|  | ||||
|     def generate(self, | ||||
|                  prompt: str, | ||||
|                  session_id: int = 2967, | ||||
|                  request_id: str = '', | ||||
|                  max_out_len: int = None, | ||||
|                  sequence_start: bool = True, | ||||
|                  sequence_end: bool = True, | ||||
|                  *args, | ||||
|                  **kwargs): | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             session_id (int): the identical id of a session | ||||
|             prompt (str): user's prompt in this round conversation | ||||
|             request_id (str): the identical id of this round conversation | ||||
|             max_out_len (int): the expected generated token numbers | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|  | ||||
|         Returns: | ||||
|             tuple(Status, str, int): status, text/chat completion, | ||||
|             generated token number | ||||
|         """ | ||||
|         assert isinstance(session_id, int), \ | ||||
|             f'INT session id is required, but got {type(session_id)}' | ||||
|  | ||||
|         logger = get_logger(log_level=self.log_level) | ||||
|         logger.info(f'session {session_id}, request_id {request_id}, ' | ||||
|                     f'max_out_len {max_out_len}') | ||||
|  | ||||
|         if self._session is None: | ||||
|             sequence_start = True | ||||
|             self._session = Session(session_id=session_id) | ||||
|         elif self._session.status == 0: | ||||
|             logger.error(f'session {session_id} has been ended. Please set ' | ||||
|                          f'`sequence_start` be True if you want to restart it') | ||||
|             return '' | ||||
|  | ||||
|         self._session.status = 1 | ||||
|         self._session.request_id = request_id | ||||
|         self._session.response = '' | ||||
|  | ||||
|         status, res, _ = None, '', 0 | ||||
|         for status, res, _ in self._stream_infer(self._session, prompt, | ||||
|                                                  max_out_len, sequence_start, | ||||
|                                                  sequence_end): | ||||
|             if status.value < 0: | ||||
|                 break | ||||
|         if status.value == 0: | ||||
|             self._session.histories = \ | ||||
|                 self._session.histories + self._session.prompt + \ | ||||
|                 self._session.response | ||||
|             return res | ||||
|         else: | ||||
|             return '' | ||||
|  | ||||
|     def generate_from_template(self, templates, max_out_len: int, **kwargs): | ||||
|         """Generate completion from a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             templates (List[PromptType]): A list of templates. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|         """ | ||||
|         inputs = self.parse_template(templates) | ||||
|         response = self.generate(inputs, max_out_len=max_out_len, **kwargs) | ||||
|         # The return of tuibomind contains <eoa>, here we hard code removes it. | ||||
|         response = response.replace( | ||||
|             self.template_parser.roles['assistant']['end'].strip(), | ||||
|             '').strip() | ||||
|         return response | ||||
|  | ||||
|  | ||||
| class TurboMind(BaseModel): | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  max_seq_len: int = 2048, | ||||
|                  tokenizer_only: bool = False, | ||||
|                  meta_template=None, | ||||
|                  tp=1, | ||||
|                  **kwargs): | ||||
|         super().__init__(path, max_seq_len, tokenizer_only, meta_template) | ||||
|         tokenizer_model_path = osp.join(path, 'triton_models', 'tokenizer') | ||||
|         self.tokenizer = tm_chat.Tokenizer(tokenizer_model_path) | ||||
|         self.tm_model = tm_chat.tm.TurboMind( | ||||
|             path, eos_id=self.tokenizer.eos_token_id, tp=tp) | ||||
|         self.generator = self.tm_model.create_instance() | ||||
|  | ||||
|         model_name = self.tm_model.model_name | ||||
|         self.model = tm_chat.MODELS.get(model_name)( | ||||
|             capability='completion', **kwargs) | ||||
|  | ||||
|     def generate(self, prompt, **kwargs): | ||||
|         seed = random.getrandbits(64) | ||||
|         input_ids = self.tokenizer.encode(prompt) | ||||
|         gen_param = tm_chat.get_gen_param( | ||||
|             'completion', self.model.sampling_param, step=0, nth_round=1) | ||||
|  | ||||
|         response_size = 0 | ||||
|         for outputs in self.generator.stream_infer( | ||||
|                 session_id=1, | ||||
|                 input_ids=[input_ids], | ||||
|                 stream_output=False, | ||||
|                 **dataclasses.asdict(gen_param), | ||||
|                 ignore_eos=False, | ||||
|                 random_seed=seed): | ||||
|             res, tokens = outputs[0] | ||||
|             # decode res | ||||
|             response = self.tokenizer.decode( | ||||
|                 res.tolist(), offset=response_size) | ||||
|             response = tm_chat.valid_str(response) | ||||
|  | ||||
|         return response | ||||
							
								
								
									
										452
									
								
								lagent/llms/lmdepoly_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										452
									
								
								lagent/llms/lmdepoly_wrapper.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,452 @@ | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
| from lagent.llms.base_llm import BaseModel | ||||
| from lagent.schema import ModelStatusCode | ||||
| from lagent.utils.util import filter_suffix | ||||
|  | ||||
|  | ||||
| class TritonClient(BaseModel): | ||||
|     """TritonClient is a wrapper of TritonClient for LLM. | ||||
|  | ||||
|     Args: | ||||
|         tritonserver_addr (str): the address in format "ip:port" of | ||||
|             triton inference server | ||||
|         model_name (str): the name of the model | ||||
|         session_len (int): the context size | ||||
|         max_tokens (int): the expected generated token numbers | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  tritonserver_addr: str, | ||||
|                  model_name: str, | ||||
|                  session_len: int = 32768, | ||||
|                  log_level: str = 'WARNING', | ||||
|                  **kwargs): | ||||
|         super().__init__(path=None, **kwargs) | ||||
|         from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode | ||||
|         self.state_map = { | ||||
|             StatusCode.TRITON_STREAM_END: ModelStatusCode.END, | ||||
|             StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR, | ||||
|             StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED, | ||||
|             StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING, | ||||
|             StatusCode.TRITON_SESSION_OUT_OF_LIMIT: | ||||
|             ModelStatusCode.SESSION_OUT_OF_LIMIT, | ||||
|             StatusCode.TRITON_SESSION_INVALID_ARG: | ||||
|             ModelStatusCode.SESSION_INVALID_ARG, | ||||
|             StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY | ||||
|         } | ||||
|         self.chatbot = Chatbot( | ||||
|             tritonserver_addr=tritonserver_addr, | ||||
|             model_name=model_name, | ||||
|             session_len=session_len, | ||||
|             log_level=log_level, | ||||
|             **kwargs) | ||||
|  | ||||
|     def generate(self, | ||||
|                  inputs: Union[str, List[str]], | ||||
|                  session_id: int = 2967, | ||||
|                  request_id: str = '', | ||||
|                  sequence_start: bool = True, | ||||
|                  sequence_end: bool = True, | ||||
|                  skip_special_tokens: bool = False, | ||||
|                  **kwargs): | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (str, List[str]): user's prompt(s) in this round | ||||
|             session_id (int): the identical id of a session | ||||
|             request_id (str): the identical id of this round conversation | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|             skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|                 in the decoding. Default to be False. | ||||
|         Returns: | ||||
|             (a list of/batched) text/chat completion | ||||
|         """ | ||||
|         from lmdeploy.serve.turbomind.chatbot import Session, get_logger | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|         prompt = inputs | ||||
|  | ||||
|         assert isinstance(session_id, int), \ | ||||
|             f'INT session id is required, but got {type(session_id)}' | ||||
|  | ||||
|         self.chatbot.cfg = self._update_gen_params(**kwargs) | ||||
|         max_new_tokens = self.chatbot.cfg.max_new_tokens | ||||
|  | ||||
|         logger = get_logger('service.ft', log_level=self.chatbot.log_level) | ||||
|         logger.info(f'session {session_id}, request_id {request_id}, ' | ||||
|                     f'max_out_len {max_new_tokens}') | ||||
|  | ||||
|         if self.chatbot._session is None: | ||||
|             sequence_start = True | ||||
|             self.chatbot._session = Session(session_id=session_id) | ||||
|         elif self.chatbot._session.status == 0: | ||||
|             logger.error(f'session {session_id} has been ended. Please set ' | ||||
|                          f'`sequence_start` be True if you want to restart it') | ||||
|             return '' | ||||
|  | ||||
|         self.chatbot._session.status = 1 | ||||
|         self.chatbot._session.request_id = request_id | ||||
|         self.chatbot._session.response = '' | ||||
|  | ||||
|         status, res, _ = None, '', 0 | ||||
|         for status, res, _ in self.chatbot._stream_infer( | ||||
|                 self.chatbot._session, | ||||
|                 prompt, | ||||
|                 max_new_tokens, | ||||
|                 sequence_start, | ||||
|                 sequence_end, | ||||
|                 skip_special_tokens=skip_special_tokens): | ||||
|             status = self.state_map.get(status) | ||||
|             if status < ModelStatusCode.END: | ||||
|                 return '' | ||||
|             elif status == ModelStatusCode.END: | ||||
|                 self.chatbot._session.histories = ( | ||||
|                     self.chatbot._session.histories + | ||||
|                     self.chatbot._session.prompt + | ||||
|                     self.chatbot._session.response) | ||||
|                 # remove stop_words | ||||
|                 res = filter_suffix(res, self.gen_params.get('stop_words')) | ||||
|                 return res | ||||
|  | ||||
|     def stream_chat(self, | ||||
|                     inputs: List[dict], | ||||
|                     session_id: int = 2967, | ||||
|                     request_id: str = '', | ||||
|                     sequence_start: bool = True, | ||||
|                     sequence_end: bool = True, | ||||
|                     skip_special_tokens: bool = False, | ||||
|                     **kwargs): | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in stream mode. | ||||
|  | ||||
|         Args: | ||||
|             session_id (int): the identical id of a session | ||||
|             inputs (List[dict]): user's inputs in this round conversation | ||||
|             request_id (str): the identical id of this round conversation | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|             skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|                 in the decoding. Default to be False. | ||||
|         Returns: | ||||
|             tuple(Status, str, int): status, text/chat completion, | ||||
|             generated token number | ||||
|         """ | ||||
|         from lmdeploy.serve.turbomind.chatbot import Session, get_logger | ||||
|         assert isinstance(session_id, int), \ | ||||
|             f'INT session id is required, but got {type(session_id)}' | ||||
|  | ||||
|         self.chatbot.cfg = self._update_gen_params(**kwargs) | ||||
|         max_new_tokens = self.chatbot.cfg.max_new_tokens | ||||
|  | ||||
|         logger = get_logger('service.ft', log_level=self.chatbot.log_level) | ||||
|         logger.info(f'session {session_id}, request_id {request_id}, ' | ||||
|                     f'max_out_len {max_new_tokens}') | ||||
|  | ||||
|         if self.chatbot._session is None: | ||||
|             sequence_start = True | ||||
|             self.chatbot._session = Session(session_id=session_id) | ||||
|         elif self.chatbot._session.status == 0: | ||||
|             logger.error(f'session {session_id} has been ended. Please set ' | ||||
|                          f'`sequence_start` be True if you want to restart it') | ||||
|             return ModelStatusCode.SESSION_CLOSED, '', 0 | ||||
|  | ||||
|         self.chatbot._session.status = 1 | ||||
|         self.chatbot._session.request_id = request_id | ||||
|         self.chatbot._session.response = '' | ||||
|  | ||||
|         prompt = self.template_parser(inputs) | ||||
|         status, res, _ = None, '', 0 | ||||
|         for status, res, _ in self.chatbot._stream_infer( | ||||
|                 self.chatbot._session, | ||||
|                 prompt, | ||||
|                 max_new_tokens, | ||||
|                 sequence_start, | ||||
|                 sequence_end, | ||||
|                 skip_special_tokens=skip_special_tokens): | ||||
|             status = self.state_map.get(status) | ||||
|             # The stop symbol also appears in the output of the last STREAM_ING state. | ||||
|             res = filter_suffix(res, self.gen_params.get('stop_words')) | ||||
|             if status < ModelStatusCode.END: | ||||
|                 return status, res, _ | ||||
|             elif status == ModelStatusCode.END:  # remove stop_words | ||||
|                 self.chatbot._session.histories = ( | ||||
|                     self.chatbot._session.histories + | ||||
|                     self.chatbot._session.prompt + | ||||
|                     self.chatbot._session.response) | ||||
|                 yield status, res, _ | ||||
|                 break | ||||
|             else: | ||||
|                 yield status, res, _ | ||||
|  | ||||
|     def _update_gen_params(self, **kwargs): | ||||
|         import mmengine | ||||
|         new_gen_params = self.update_gen_params(**kwargs) | ||||
|         self.gen_params['stop_words'] = new_gen_params.pop('stop_words') | ||||
|         stop_words = self.chatbot._stop_words( | ||||
|             self.gen_params.get('stop_words')) | ||||
|         cfg = mmengine.Config( | ||||
|             dict( | ||||
|                 session_len=self.chatbot.model.session_len, | ||||
|                 stop_words=stop_words, | ||||
|                 bad_words=self.chatbot.cfg.bad_words, | ||||
|                 **new_gen_params)) | ||||
|         return cfg | ||||
|  | ||||
|  | ||||
| class LMDeployPipeline(BaseModel): | ||||
|     """ | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|             It could be one of the following options: | ||||
|                     - i) A local directory path of a turbomind model which is | ||||
|                         converted by `lmdeploy convert` command or download | ||||
|                         from ii) and iii). | ||||
|                     - ii) The model_id of a lmdeploy-quantized model hosted | ||||
|                         inside a model repo on huggingface.co, such as | ||||
|                         "InternLM/internlm-chat-20b-4bit", | ||||
|                         "lmdeploy/llama2-chat-70b-4bit", etc. | ||||
|                     - iii) The model_id of a model hosted inside a model repo | ||||
|                         on huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|                         "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||||
|                         and so on. | ||||
|         model_name (str): needed when model_path is a pytorch model on | ||||
|             huggingface.co, such as "internlm-chat-7b", | ||||
|             "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | ||||
|         tp (int): tensor parallel | ||||
|         pipeline_cfg (dict): config of pipeline | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  model_name: Optional[str] = None, | ||||
|                  tp: int = 1, | ||||
|                  pipeline_cfg=dict(), | ||||
|                  **kwargs): | ||||
|  | ||||
|         super().__init__(path=path, **kwargs) | ||||
|         from lmdeploy import pipeline | ||||
|         self.model = pipeline( | ||||
|             model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg) | ||||
|  | ||||
|     def generate(self, | ||||
|                  inputs: Union[str, List[str]], | ||||
|                  do_preprocess: bool = None, | ||||
|                  skip_special_tokens: bool = False, | ||||
|                  **kwargs): | ||||
|         """Return the chat completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[str, List[str]]): input texts to be completed. | ||||
|             do_preprocess (bool): whether pre-process the messages. Default to | ||||
|                 True, which means chat_template will be applied. | ||||
|             skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|                 in the decoding. Default to be False. | ||||
|         Returns: | ||||
|             (a list of/batched) text/chat completion | ||||
|         """ | ||||
|         from lmdeploy.messages import GenerationConfig | ||||
|  | ||||
|         batched = True | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|             batched = False | ||||
|         prompt = inputs | ||||
|         gen_params = self.update_gen_params(**kwargs) | ||||
|         gen_config = GenerationConfig( | ||||
|             skip_special_tokens=skip_special_tokens, **gen_params) | ||||
|         response = self.model.batch_infer( | ||||
|             prompt, gen_config=gen_config, do_preprocess=do_preprocess) | ||||
|         response = [resp.text for resp in response] | ||||
|         # remove stop_words | ||||
|         response = filter_suffix(response, self.gen_params.get('stop_words')) | ||||
|         if batched: | ||||
|             return response | ||||
|         return response[0] | ||||
|  | ||||
|  | ||||
| class LMDeployServer(BaseModel): | ||||
|     """ | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|             It could be one of the following options: | ||||
|                 - i) A local directory path of a turbomind model which is | ||||
|                     converted by `lmdeploy convert` command or download from | ||||
|                     ii) and iii). | ||||
|                 - ii) The model_id of a lmdeploy-quantized model hosted | ||||
|                     inside a model repo on huggingface.co, such as | ||||
|                     "InternLM/internlm-chat-20b-4bit", | ||||
|                     "lmdeploy/llama2-chat-70b-4bit", etc. | ||||
|                 - iii) The model_id of a model hosted inside a model repo | ||||
|                     on huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|                     "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||||
|                     and so on. | ||||
|         model_name (str): needed when model_path is a pytorch model on | ||||
|             huggingface.co, such as "internlm-chat-7b", | ||||
|             "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | ||||
|         server_name (str): host ip for serving | ||||
|         server_port (int): server port | ||||
|         tp (int): tensor parallel | ||||
|         log_level (str): set log level whose value among | ||||
|             [CRITICAL, ERROR, WARNING, INFO, DEBUG] | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  path: str, | ||||
|                  model_name: Optional[str] = None, | ||||
|                  server_name: str = '0.0.0.0', | ||||
|                  server_port: int = 23333, | ||||
|                  tp: int = 1, | ||||
|                  log_level: str = 'WARNING', | ||||
|                  serve_cfg=dict(), | ||||
|                  **kwargs): | ||||
|         super().__init__(path=path, **kwargs) | ||||
|         self.model_name = model_name | ||||
|         # TODO get_logger issue in multi processing | ||||
|         import lmdeploy | ||||
|         self.client = lmdeploy.serve( | ||||
|             model_path=self.path, | ||||
|             model_name=model_name, | ||||
|             server_name=server_name, | ||||
|             server_port=server_port, | ||||
|             tp=tp, | ||||
|             log_level=log_level, | ||||
|             **serve_cfg) | ||||
|  | ||||
|     def generate(self, | ||||
|                  inputs: Union[str, List[str]], | ||||
|                  session_id: int = 2967, | ||||
|                  sequence_start: bool = True, | ||||
|                  sequence_end: bool = True, | ||||
|                  ignore_eos: bool = False, | ||||
|                  skip_special_tokens: Optional[bool] = False, | ||||
|                  timeout: int = 30, | ||||
|                  **kwargs) -> List[str]: | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (str, List[str]): user's prompt(s) in this round | ||||
|             session_id (int): the identical id of a session | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|             ignore_eos (bool): indicator for ignoring eos | ||||
|             skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|                 in the decoding. Default to be False. | ||||
|             timeout (int): max time to wait for response | ||||
|         Returns: | ||||
|             (a list of/batched) text/chat completion | ||||
|         """ | ||||
|  | ||||
|         batched = True | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|             batched = False | ||||
|  | ||||
|         gen_params = self.update_gen_params(**kwargs) | ||||
|         max_new_tokens = gen_params.pop('max_new_tokens') | ||||
|         gen_params.update(max_tokens=max_new_tokens) | ||||
|  | ||||
|         resp = [''] * len(inputs) | ||||
|         for text in self.client.completions_v1( | ||||
|                 self.model_name, | ||||
|                 inputs, | ||||
|                 session_id=session_id, | ||||
|                 sequence_start=sequence_start, | ||||
|                 sequence_end=sequence_end, | ||||
|                 stream=False, | ||||
|                 ignore_eos=ignore_eos, | ||||
|                 skip_special_tokens=skip_special_tokens, | ||||
|                 timeout=timeout, | ||||
|                 **gen_params): | ||||
|             resp = [ | ||||
|                 resp[i] + item['text'] | ||||
|                 for i, item in enumerate(text['choices']) | ||||
|             ] | ||||
|         # remove stop_words | ||||
|         resp = filter_suffix(resp, self.gen_params.get('stop_words')) | ||||
|         if not batched: | ||||
|             return resp[0] | ||||
|         return resp | ||||
|  | ||||
|     def stream_chat(self, | ||||
|                     inputs: List[dict], | ||||
|                     session_id=0, | ||||
|                     sequence_start: bool = True, | ||||
|                     sequence_end: bool = True, | ||||
|                     stream: bool = True, | ||||
|                     ignore_eos: bool = False, | ||||
|                     skip_special_tokens: Optional[bool] = False, | ||||
|                     timeout: int = 30, | ||||
|                     **kwargs): | ||||
|         """Start a new round conversation of a session. Return the chat | ||||
|         completions in stream mode. | ||||
|  | ||||
|         Args: | ||||
|             session_id (int): the identical id of a session | ||||
|             inputs (List[dict]): user's inputs in this round conversation | ||||
|             sequence_start (bool): start flag of a session | ||||
|             sequence_end (bool): end flag of a session | ||||
|             stream (bool): return in a streaming format if enabled | ||||
|             ignore_eos (bool): indicator for ignoring eos | ||||
|             skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|                 in the decoding. Default to be False. | ||||
|             timeout (int): max time to wait for response | ||||
|         Returns: | ||||
|             tuple(Status, str, int): status, text/chat completion, | ||||
|             generated token number | ||||
|         """ | ||||
|         gen_params = self.update_gen_params(**kwargs) | ||||
|         max_new_tokens = gen_params.pop('max_new_tokens') | ||||
|         gen_params.update(max_tokens=max_new_tokens) | ||||
|         prompt = self.template_parser(inputs) | ||||
|  | ||||
|         resp = '' | ||||
|         finished = False | ||||
|         stop_words = self.gen_params.get('stop_words') | ||||
|         for text in self.client.completions_v1( | ||||
|                 self.model_name, | ||||
|                 prompt, | ||||
|                 session_id=session_id, | ||||
|                 sequence_start=sequence_start, | ||||
|                 sequence_end=sequence_end, | ||||
|                 stream=stream, | ||||
|                 ignore_eos=ignore_eos, | ||||
|                 skip_special_tokens=skip_special_tokens, | ||||
|                 timeout=timeout, | ||||
|                 **gen_params): | ||||
|             resp += text['choices'][0]['text'] | ||||
|             if not resp: | ||||
|                 continue | ||||
|             # remove stop_words | ||||
|             for sw in stop_words: | ||||
|                 if sw in resp: | ||||
|                     resp = filter_suffix(resp, stop_words) | ||||
|                     finished = True | ||||
|                     break | ||||
|             yield ModelStatusCode.STREAM_ING, resp, None | ||||
|             if finished: | ||||
|                 break | ||||
|         yield ModelStatusCode.END, resp, None | ||||
|  | ||||
|  | ||||
| class LMDeployClient(LMDeployServer): | ||||
|     """ | ||||
|  | ||||
|     Args: | ||||
|         url (str): communicating address 'http://<ip>:<port>' of | ||||
|             api_server | ||||
|         model_name (str): needed when model_path is a pytorch model on | ||||
|             huggingface.co, such as "internlm-chat-7b", | ||||
|             "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, url: str, model_name: str, **kwargs): | ||||
|         BaseModel.__init__(self, path=url, **kwargs) | ||||
|         from lmdeploy.serve.openai.api_client import APIClient | ||||
|         self.client = APIClient(url) | ||||
|         self.model_name = model_name | ||||
							
								
								
									
										40
									
								
								lagent/llms/meta_template.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								lagent/llms/meta_template.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| INTERNLM2_META = [ | ||||
|     dict( | ||||
|         role='system', | ||||
|         begin=dict( | ||||
|             with_name='<|im_start|>system name={name}\n', | ||||
|             without_name='<|im_start|>system\n', | ||||
|             name={ | ||||
|                 'interpreter': '<|interpreter|>', | ||||
|                 'plugin': '<|plugin|>', | ||||
|             }), | ||||
|         end='<|im_end|>\n', | ||||
|     ), | ||||
|     dict( | ||||
|         role='user', | ||||
|         begin=dict( | ||||
|             with_name='<|im_start|>user name={name}\n', | ||||
|             without_name='<|im_start|>user\n', | ||||
|         ), | ||||
|         end='<|im_end|>\n'), | ||||
|     dict( | ||||
|         role='assistant', | ||||
|         begin=dict( | ||||
|             with_name='<|im_start|>assistant name={name}\n', | ||||
|             without_name='<|im_start|>assistant\n', | ||||
|             name={ | ||||
|                 'interpreter': '<|interpreter|>', | ||||
|                 'plugin': '<|plugin|>', | ||||
|             }), | ||||
|         end='<|im_end|>\n'), | ||||
|     dict( | ||||
|         role='environment', | ||||
|         begin=dict( | ||||
|             with_name='<|im_start|>environment name={name}\n', | ||||
|             without_name='<|im_start|>environment\n', | ||||
|             name={ | ||||
|                 'interpreter': '<|interpreter|>', | ||||
|                 'plugin': '<|plugin|>', | ||||
|             }), | ||||
|         end='<|im_end|>\n'), | ||||
| ] | ||||
| @@ -1,7 +1,7 @@ | ||||
| import json | ||||
| import os | ||||
| import time | ||||
| # from concurrent.futures import ThreadPoolExecutor | ||||
| from concurrent.futures import ThreadPoolExecutor | ||||
| from logging import getLogger | ||||
| from threading import Lock | ||||
| from typing import Dict, List, Optional, Union | ||||
| @@ -18,9 +18,6 @@ class GPTAPI(BaseAPIModel): | ||||
|  | ||||
|     Args: | ||||
|         model_type (str): The name of OpenAI's model. | ||||
|         max_seq_len (int): The maximum allowed sequence length of a model. | ||||
|             Note that the length of prompt + generated tokens shall not exceed | ||||
|             this value. Defaults to 2048. | ||||
|         query_per_second (int): The maximum queries allowed per second | ||||
|             between two consecutive calls of the API. Defaults to 1. | ||||
|         retry (int): Number of retires if the API call fails. Defaults to 2. | ||||
| @@ -38,16 +35,14 @@ class GPTAPI(BaseAPIModel): | ||||
|             wrapping of any meta instructions. | ||||
|         openai_api_base (str): The base url of OpenAI's API. Defaults to | ||||
|             'https://api.openai.com/v1/chat/completions'. | ||||
|         temperature (float, optional): What sampling temperature to use. | ||||
|             If not None, will override the temperature in the `generate()` | ||||
|             call. Defaults to None. | ||||
|         gen_params: Default generation configuration which could be overridden | ||||
|             on the fly of generation. | ||||
|     """ | ||||
|  | ||||
|     is_api: bool = True | ||||
|  | ||||
|     def __init__(self, | ||||
|                  model_type: str = 'gpt-3.5-turbo', | ||||
|                  max_seq_len: int = 4096, | ||||
|                  query_per_second: int = 1, | ||||
|                  retry: int = 2, | ||||
|                  key: Union[str, List[str]] = 'ENV', | ||||
| @@ -58,16 +53,14 @@ class GPTAPI(BaseAPIModel): | ||||
|                      dict(role='assistant', api_role='assistant') | ||||
|                  ], | ||||
|                  openai_api_base: str = OPENAI_API_BASE, | ||||
|                  temperature: Optional[float] = None): | ||||
|  | ||||
|                  **gen_params): | ||||
|         super().__init__( | ||||
|             model_type=model_type, | ||||
|             max_seq_len=max_seq_len, | ||||
|             meta_template=meta_template, | ||||
|             query_per_second=query_per_second, | ||||
|             retry=retry) | ||||
|             retry=retry, | ||||
|             **gen_params) | ||||
|         self.logger = getLogger(__name__) | ||||
|         self.temperature = temperature | ||||
|  | ||||
|         if isinstance(key, str): | ||||
|             self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] | ||||
| @@ -97,67 +90,57 @@ class GPTAPI(BaseAPIModel): | ||||
|             context_window = 8192 | ||||
|         self.context_window = context_window | ||||
|  | ||||
|     def generate( | ||||
|     def chat( | ||||
|         self, | ||||
|         inputs: Union[List, str], | ||||
|         max_out_len: int = 512, | ||||
|         temperature: float = 0.7, | ||||
|     ) -> List[str]: | ||||
|         """Generate results given a list of inputs. | ||||
|         inputs: Union[List[dict], List[List[dict]]], | ||||
|         **gen_params, | ||||
|     ) -> Union[str, List[str]]: | ||||
|         """Generate responses given the contexts. | ||||
|  | ||||
|         Args: | ||||
|             inputs (List[str or List]): A list of strings or PromptDicts. | ||||
|                 The PromptDict should be organized in OpenCompass' | ||||
|                 API format. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|             temperature (float): What sampling temperature to use, | ||||
|                 between 0 and 2. Higher values like 0.8 will make the output | ||||
|                 more random, while lower values like 0.2 will make it more | ||||
|                 focused and deterministic. Defaults to 0.7. | ||||
|             inputs (Union[List[dict], List[List[dict]]]): a list of messages | ||||
|                 or list of lists of messages | ||||
|             gen_params: additional generation configuration | ||||
|  | ||||
|         Returns: | ||||
|             List[str]: A list of generated strings. | ||||
|             Union[str, List[str]]: generated string(s) | ||||
|         """ | ||||
|         if self.temperature is not None: | ||||
|             temperature = self.temperature | ||||
|         return self._generate(inputs, max_out_len, temperature) | ||||
|         assert isinstance(inputs, list) | ||||
|         if 'max_tokens' in gen_params: | ||||
|             raise NotImplementedError('unsupported parameter: max_tokens') | ||||
|         gen_params = {**self.gen_params, **gen_params} | ||||
|         with ThreadPoolExecutor(max_workers=20) as executor: | ||||
|             tasks = [ | ||||
|                 executor.submit(self._chat, messages, **gen_params) | ||||
|                 for messages in ( | ||||
|                     [inputs] if isinstance(inputs[0], dict) else inputs) | ||||
|             ] | ||||
|         ret = [task.result() for task in tasks] | ||||
|         return ret[0] if isinstance(inputs[0], dict) else ret | ||||
|  | ||||
|     def _generate(self, input: str or List, max_out_len: int, | ||||
|                   temperature: float) -> str: | ||||
|         """Generate results given a list of inputs. | ||||
|     def _chat(self, messages: List[dict], **gen_params) -> str: | ||||
|         """Generate completion from a list of templates. | ||||
|  | ||||
|         Args: | ||||
|             inputs (str or List): A string or PromptDict. | ||||
|                 The PromptDict should be organized in OpenCompass' | ||||
|                 API format. | ||||
|             max_out_len (int): The maximum length of the output. | ||||
|             temperature (float): What sampling temperature to use, | ||||
|                 between 0 and 2. Higher values like 0.8 will make the output | ||||
|                 more random, while lower values like 0.2 will make it more | ||||
|                 focused and deterministic. | ||||
|             messages (List[dict]): a list of prompt dictionaries | ||||
|             gen_params: additional generation configuration | ||||
|  | ||||
|         Returns: | ||||
|             str: The generated string. | ||||
|         """ | ||||
|         assert isinstance(input, (str, list, dict)) | ||||
|  | ||||
|         if isinstance(input, str): | ||||
|             messages = [{'role': 'user', 'content': input}] | ||||
|         elif isinstance(input, dict): | ||||
|             messages = [input] | ||||
|         else: | ||||
|             messages = input | ||||
|         assert isinstance(messages, list) | ||||
|         gen_params = gen_params.copy() | ||||
|  | ||||
|         # Hold out 100 tokens due to potential errors in tiktoken calculation | ||||
|         max_out_len = min( | ||||
|             max_out_len, | ||||
|             self.context_window - self.get_token_len(str(input)) - 100) | ||||
|         if max_out_len <= 0: | ||||
|         max_tokens = min( | ||||
|             gen_params.pop('max_new_tokens'), | ||||
|             self.context_window - len(self.tokenize(str(input))) - 100) | ||||
|         if max_tokens <= 0: | ||||
|             return '' | ||||
|  | ||||
|         max_num_retries = 0 | ||||
|         while max_num_retries < self.retry: | ||||
|             self.wait() | ||||
|             self._wait() | ||||
|  | ||||
|             with Lock(): | ||||
|                 if len(self.invalid_keys) == len(self.keys): | ||||
| @@ -190,10 +173,11 @@ class GPTAPI(BaseAPIModel): | ||||
|                 data = dict( | ||||
|                     model=self.model_type, | ||||
|                     messages=messages, | ||||
|                     max_tokens=max_out_len, | ||||
|                     max_tokens=max_tokens, | ||||
|                     n=1, | ||||
|                     stop=None, | ||||
|                     temperature=temperature, | ||||
|                     stop=gen_params.pop('stop_words'), | ||||
|                     frequency_penalty=gen_params.pop('repetition_penalty'), | ||||
|                     **gen_params, | ||||
|                 ) | ||||
|                 raw_response = requests.post( | ||||
|                     self.url, headers=header, data=json.dumps(data)) | ||||
| @@ -225,18 +209,16 @@ class GPTAPI(BaseAPIModel): | ||||
|                            f'{max_num_retries} times. Check the logs for ' | ||||
|                            'details.') | ||||
|  | ||||
|     def get_token_len(self, prompt: str) -> int: | ||||
|         """Get lengths of the tokenized string. Only English and Chinese | ||||
|         characters are counted for now. Users are encouraged to override this | ||||
|         method if more accurate length is needed. | ||||
|     def tokenize(self, prompt: str) -> list: | ||||
|         """Tokenize the input prompt. | ||||
|  | ||||
|         Args: | ||||
|             prompt (str): Input string. | ||||
|  | ||||
|         Returns: | ||||
|             int: Length of the input tokens | ||||
|             list: token ids | ||||
|         """ | ||||
|         import tiktoken | ||||
|         self.tiktoken = tiktoken | ||||
|         enc = self.tiktoken.encoding_for_model(self.model_type) | ||||
|         return len(enc.encode(prompt)) | ||||
|         return enc.encode(prompt) | ||||
|   | ||||
							
								
								
									
										71
									
								
								lagent/llms/vllm_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								lagent/llms/vllm_wrapper.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| from typing import List, Union | ||||
|  | ||||
| from lagent.llms.base_llm import BaseModel | ||||
| from lagent.utils.util import filter_suffix | ||||
|  | ||||
|  | ||||
| class VllmModel(BaseModel): | ||||
|     """ | ||||
|     A wrapper of vLLM model. | ||||
|  | ||||
|     Args: | ||||
|         path (str): The path to the model. | ||||
|             It could be one of the following options: | ||||
|                     - i) A local directory path of a huggingface model. | ||||
|                     - ii) The model_id of a model hosted inside a model repo | ||||
|                         on huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|                         "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||||
|                         and so on. | ||||
|         tp (int): tensor parallel | ||||
|         vllm_cfg (dict): Other kwargs for vllm model initialization. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): | ||||
|  | ||||
|         super().__init__(path=path, **kwargs) | ||||
|         from vllm import LLM | ||||
|         self.model = LLM( | ||||
|             model=self.path, | ||||
|             trust_remote_code=True, | ||||
|             tensor_parallel_size=tp, | ||||
|             **vllm_cfg) | ||||
|  | ||||
|     def generate(self, | ||||
|                  inputs: Union[str, List[str]], | ||||
|                  do_preprocess: bool = None, | ||||
|                  skip_special_tokens: bool = False, | ||||
|                  **kwargs): | ||||
|         """Return the chat completions in non-stream mode. | ||||
|  | ||||
|         Args: | ||||
|             inputs (Union[str, List[str]]): input texts to be completed. | ||||
|             do_preprocess (bool): whether pre-process the messages. Default to | ||||
|                 True, which means chat_template will be applied. | ||||
|             skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|                 in the decoding. Default to be False. | ||||
|         Returns: | ||||
|             (a list of/batched) text/chat completion | ||||
|         """ | ||||
|         from vllm import SamplingParams | ||||
|  | ||||
|         batched = True | ||||
|         if isinstance(inputs, str): | ||||
|             inputs = [inputs] | ||||
|             batched = False | ||||
|         prompt = inputs | ||||
|         gen_params = self.update_gen_params(**kwargs) | ||||
|         max_new_tokens = gen_params.pop('max_new_tokens') | ||||
|         stop_words = gen_params.pop('stop_words') | ||||
|  | ||||
|         sampling_config = SamplingParams( | ||||
|             skip_special_tokens=skip_special_tokens, | ||||
|             max_tokens=max_new_tokens, | ||||
|             stop=stop_words, | ||||
|             **gen_params) | ||||
|         response = self.model.generate(prompt, sampling_params=sampling_config) | ||||
|         response = [resp.outputs[0].text for resp in response] | ||||
|         # remove stop_words | ||||
|         response = filter_suffix(response, self.gen_params.get('stop_words')) | ||||
|         if batched: | ||||
|             return response | ||||
|         return response[0] | ||||
| @@ -1,12 +1,10 @@ | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from enum import Enum | ||||
| from typing import Dict, List, Optional, Union | ||||
|  | ||||
| from lagent.utils import is_module_exist | ||||
| from enum import IntEnum | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
|  | ||||
| def enum_dict_factory(inputs): | ||||
|     inputs = [(i[0], i[-1].value) if isinstance(i[-1], Enum) else i | ||||
|     inputs = [(i[0], i[-1].value) if isinstance(i[-1], IntEnum) else i | ||||
|               for i in inputs] | ||||
|     return dict(inputs) | ||||
|  | ||||
| @@ -15,7 +13,7 @@ def dataclass2dict(data): | ||||
|     return asdict(data, dict_factory=enum_dict_factory) | ||||
|  | ||||
|  | ||||
| class ActionStatusCode(int, Enum): | ||||
| class ActionStatusCode(IntEnum): | ||||
|     ING = 1 | ||||
|     SUCCESS = 0 | ||||
|     HTTP_ERROR = -1000  # http error | ||||
| @@ -23,7 +21,7 @@ class ActionStatusCode(int, Enum): | ||||
|     API_ERROR = -1002  # 不知道的API错误 | ||||
|  | ||||
|  | ||||
| class ActionValidCode(int, Enum): | ||||
| class ActionValidCode(IntEnum): | ||||
|     FINISH = 1 | ||||
|     OPEN = 0 | ||||
|     CLOSED = -1 | ||||
| @@ -33,47 +31,58 @@ class ActionValidCode(int, Enum): | ||||
|  | ||||
| @dataclass | ||||
| class ActionReturn: | ||||
|     args: Dict | ||||
|     args: Optional[dict] = None | ||||
|     url: Optional[str] = None | ||||
|     type: Optional[str] = None | ||||
|     result: Optional[str] = None | ||||
|     result: Optional[List[dict]] = None | ||||
|     errmsg: Optional[str] = None | ||||
|     state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS | ||||
|     thought: Optional[str] = None | ||||
|     valid: Optional[ActionValidCode] = ActionValidCode.OPEN | ||||
|  | ||||
|     def format_result(self) -> str: | ||||
|         """Concatenate items in result.""" | ||||
|         result = [] | ||||
|         for item in self.result or []: | ||||
|             if item['type'] == 'text': | ||||
|                 result.append(item['content']) | ||||
|             else: | ||||
|                 result.append(f"[{item['type']}]({item['content']})") | ||||
|         result = '\n'.join(result) | ||||
|         return result | ||||
|  | ||||
| class AgentStatusCode(Enum): | ||||
|     END = 0  # end of streaming | ||||
|  | ||||
| # 需要集成int,如此asdict可以把AgentStatusCode 转换成 int | ||||
| class ModelStatusCode(IntEnum): | ||||
|     END = 0  # end of streaming 返回本次history | ||||
|     STREAM_ING = 1  # response is in streaming | ||||
|     SERVER_ERR = -1  # triton server's error | ||||
|     SESSION_CLOSED = -2  # session has been closed | ||||
|     SESSION_OUT_OF_LIMIT = -3  # request length out of limit | ||||
|     CMD = 2  # return command | ||||
|     SESSION_INVALID_ARG = -4  # invalid argument | ||||
|     SESSION_READY = 3  # session is ready for inference | ||||
|     SESSION_READY = 2  # session is ready for inference | ||||
|  | ||||
|  | ||||
| class AgentStatusCode(IntEnum): | ||||
|     END = 0  # end of streaming 返回本次history | ||||
|     STREAM_ING = 1  # response is in streaming | ||||
|     SERVER_ERR = -1  # triton server's error | ||||
|     SESSION_CLOSED = -2  # session has been closed | ||||
|     SESSION_OUT_OF_LIMIT = -3  # request length out of limit | ||||
|     SESSION_INVALID_ARG = -4  # invalid argument | ||||
|     SESSION_READY = 2  # session is ready for inference | ||||
|     PLUGIN_START = 3  # start tool | ||||
|     PLUGIN_END = 4  # finish tool | ||||
|     PLUGIN_RETURN = 5  # finish tool | ||||
|     CODING = 6  # start python | ||||
|     CODE_END = 7  # end python | ||||
|     CODE_RETURN = 8  # python return | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class AgentReturn: | ||||
|     state: Union[AgentStatusCode, int] = AgentStatusCode.END | ||||
|     actions: List[ActionReturn] = field(default_factory=list) | ||||
|     response: str = '' | ||||
|     inner_steps: List = field(default_factory=list) | ||||
|     errmsg: Optional[str] = None | ||||
|  | ||||
|  | ||||
| if is_module_exist('lmdeploy'): | ||||
|     from lmdeploy.serve.turbomind.chatbot import StatusCode | ||||
|     STATE_MAP = { | ||||
|         StatusCode.TRITON_STREAM_END: AgentStatusCode.END, | ||||
|         StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR, | ||||
|         StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED, | ||||
|         StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING, | ||||
|         StatusCode.TRITON_SESSION_OUT_OF_LIMIT: | ||||
|         AgentStatusCode.SESSION_OUT_OF_LIMIT, | ||||
|         StatusCode.TRITON_SESSION_INVALID_ARG: | ||||
|         AgentStatusCode.SESSION_INVALID_ARG, | ||||
|         StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY | ||||
|     } | ||||
| else: | ||||
|     STATE_MAP = {} | ||||
|   | ||||
							
								
								
									
										31
									
								
								lagent/utils/util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								lagent/utils/util.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
|  | ||||
| def filter_suffix(response: Union[str, List[str]], | ||||
|                   suffixes: Optional[List[str]] = None) -> str: | ||||
|     """Filter response with suffixes. | ||||
|  | ||||
|     Args: | ||||
|         response (Union[str, List[str]]): generated responses by LLMs. | ||||
|         suffixes (str): a list of suffixes to be deleted. | ||||
|  | ||||
|     Return: | ||||
|         str: a clean response. | ||||
|     """ | ||||
|     if suffixes is None: | ||||
|         return response | ||||
|     batched = True | ||||
|     if isinstance(response, str): | ||||
|         response = [response] | ||||
|         batched = False | ||||
|     processed = [] | ||||
|     for resp in response: | ||||
|         for item in suffixes: | ||||
|             # if response.endswith(item): | ||||
|             #     response = response[:len(response) - len(item)] | ||||
|             if item in resp: | ||||
|                 resp = resp.split(item)[0] | ||||
|         processed.append(resp) | ||||
|     if not batched: | ||||
|         return processed[0] | ||||
|     return processed | ||||
| @@ -1,5 +1,5 @@ | ||||
| # Copyright (c) OpenMMLab. All rights reserved. | ||||
| __version__ = '0.1.2' | ||||
| __version__ = '0.2.2' | ||||
|  | ||||
|  | ||||
| def parse_version_info(version_str: str, length: int = 4) -> tuple: | ||||
|   | ||||
| @@ -1,8 +1,12 @@ | ||||
| docutils==0.16.0 | ||||
| astroid<3.0.0 | ||||
| docutils==0.18.1 | ||||
| markdown>=3.4.0 | ||||
| myst-parser | ||||
| -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme | ||||
| sphinx==4.0.2 | ||||
| myst-nb | ||||
| # -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme | ||||
| # sphinx==4.0.2 | ||||
| sphinx==6.1.0 | ||||
| sphinx-autoapi | ||||
| sphinx-rtd-theme==1.3.0 | ||||
| sphinx-tabs | ||||
| sphinx_copybutton | ||||
| sphinx_markdown_tables>=0.0.16 | ||||
|   | ||||
| @@ -1,4 +1,8 @@ | ||||
| lmdeploy | ||||
| streamlit | ||||
| google-search-results | ||||
| lmdeploy>=0.2.3 | ||||
| pillow | ||||
| python-pptx | ||||
| timeout_decorator | ||||
| torch | ||||
| transformers | ||||
| transformers>=4.34 | ||||
| vllm>=0.3.3 | ||||
|   | ||||
| @@ -1,5 +1,13 @@ | ||||
| arxiv | ||||
| distro | ||||
| func_timeout | ||||
| griffe | ||||
| json5 | ||||
| jsonschema | ||||
| jupyter | ||||
| jupyter_client | ||||
| phx-class-registry | ||||
| requests | ||||
| streamlit | ||||
| tiktoken | ||||
| typing-extensions | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| [isort] | ||||
| line_length = 79 | ||||
| line_length = 119 | ||||
| multi_line_output = 0 | ||||
| extra_standard_library = setuptools | ||||
| known_first_party = mmdet | ||||
| @@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true | ||||
| [codespell] | ||||
| skip = *.ipynb | ||||
| quiet-level = 3 | ||||
| ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer | ||||
| ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer,astroid | ||||
|  | ||||
| [flake8] | ||||
| per-file-ignores = ftdp/configs/*: F401,F403,F405 | ||||
| max-line-length = 200 | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| from unittest import TestCase | ||||
|  | ||||
| from lagent.actions.builtin_actions import (FinishAction, InvalidAction, | ||||
|                                             NoAction) | ||||
| from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction | ||||
| from lagent.schema import ActionStatusCode | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user