22 Commits
v0.2.0 ... main

Author SHA1 Message Date
liujiangning30
581d9fb898 support demo with hf (#179) 2024-03-26 17:59:01 +08:00
RangiLyu
26672731e9 feat: support vllm (#177)
* feat: support vllm

* update requirements
2024-03-21 17:20:38 +08:00
liujiangning30
466d7aa24d Fix errmsg: cast dict to str (#172) 2024-03-20 14:27:04 +08:00
tackhwa
5f55e12736 fix typo "ablility " in overview.md (#175)
fix type "ablility " in overview.md
2024-03-20 14:26:33 +08:00
liujiangning30
e16a6bfc3a Fix bug of ppt and googlescholar (#167)
* fix bug of ppt and googlescholar

* Format required parameters
2024-03-04 13:52:06 +08:00
BraisedPork
605a921878 Fix chat return of GPTAPI (#166)
fix chat return data

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
2024-02-29 13:56:10 +08:00
liujiangning30
4fd014bebf update version (#161) 2024-02-23 19:50:49 +08:00
liujiangning30
432ffaae8a fix bug caused by static model_name (#156) 2024-02-21 14:47:40 +08:00
liukuikun
a64aa599ce fix batch generate (#158) 2024-02-20 15:11:42 +08:00
liujiangning30
662011783e Feat: no_skip_speicial_token (#148)
* Feat: no_skip_speicial_token

* fix: get_logger of lmdeploy

* update lmdeploy requirement
2024-02-19 16:33:13 +08:00
loveSnowBest
3cf20f5011 support inference for pad_token & chatglm chat (#157)
* update code for chatglm

* update code

* handle batch infer for chat

* update warning for cases
2024-02-19 15:54:48 +08:00
liujiangning30
7b71988d09 max_tokens to max_new_tokens (#149)
* Fix: max_new_tokens to max_tokens

* change `max_tokens` to `max_new_tokens` in API models

* max_tokens to max_new_tokens

* inject parameter 'max_new_tokens' for examples

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
2024-02-06 12:08:53 +08:00
liujiangning30
90ef5215b6 Fix: gen_config in lmdeploypipeline updated by input gen_params (#151) 2024-02-05 15:53:08 +08:00
liujiangning30
6a5447663a Fix: filter_suffix in TritonClient (#150) 2024-02-04 17:36:52 +08:00
liujiangning30
a2c23ef9dd Fix: skip start_token (#145) 2024-02-02 16:09:46 +08:00
liukuikun
3be9ec042c [Enchance] lazy import for actions (#146) 2024-02-02 15:44:27 +08:00
BraisedPork
aa5a357a34 Fix type annotation (#144)
fix type annotation

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
2024-02-02 11:19:45 +08:00
liukuikun
5650a75f3e update readme demo (#143) 2024-02-01 23:09:47 +08:00
liujiangning30
eea6e1cb56 fix bug of TritonClient (#141) 2024-02-01 20:40:37 +08:00
liujiangning30
42c6d265e1 Fix bug of LMDeployClient (#140)
* Fix bug of LMDeployClient

* fix bug of web_demo
2024-02-01 17:58:10 +08:00
BraisedPork
e20a768066 [Version] Bump v0.2.1 (#139)
bump v0.2.1

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
2024-02-01 13:14:26 +08:00
BraisedPork
60244a253a Fix docstring format of GoogleScholar (#138)
* fix docstrings

* update pre-commit-hooks

* chores

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
2024-02-01 13:11:07 +08:00
33 changed files with 686 additions and 218 deletions

1
.gitignore vendored
View File

@@ -160,3 +160,4 @@ cython_debug/
#.idea/
.vscode/
docs/*/_build/
tmp_dir/

View File

@@ -38,11 +38,6 @@ repos:
rev: v2.2.6
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.7.5
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:

View File

@@ -22,7 +22,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
<div align="center">
[![Alt text](https://img.youtube.com/vi/YAelRLi0Zak/0.jpg)](https://www.youtube.com/watch?v=YAelRLi0Zak)
https://github.com/InternLM/lagent/assets/24622904/3242f9bf-32d2-4907-8815-e16a75a4ac0e
</div>

View File

@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
<div align="center">
[![Alt text](https://img.youtube.com/vi/YAelRLi0Zak/0.jpg)](https://www.youtube.com/watch?v=YAelRLi0Zak)
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
</div>

View File

@@ -14,7 +14,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
<div align="center">
[![Alt text](https://img.youtube.com/vi/YAelRLi0Zak/0.jpg)](https://www.youtube.com/watch?v=YAelRLi0Zak)
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
</div>

View File

@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
<div align="center">
[![Alt text](https://img.youtube.com/vi/YAelRLi0Zak/0.jpg)](https://www.youtube.com/watch?v=YAelRLi0Zak)
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
</div>

View File

@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
<div align="center">
[![Alt text](https://img.youtube.com/vi/YAelRLi0Zak/0.jpg)](https://www.youtube.com/watch?v=YAelRLi0Zak)
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
</div>

View File

@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
<div align="center">
[![Alt text](https://img.youtube.com/vi/YAelRLi0Zak/0.jpg)](https://www.youtube.com/watch?v=YAelRLi0Zak)
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
</div>

View File

@@ -4,7 +4,7 @@ This chapter introduces you to the framework of Lagent, and provides links to de
## What is Lagent
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below:
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ability of LLM, and the whole framework is shown below:
![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6)

View File

@@ -24,6 +24,7 @@ def main():
model = HFTransformer(
path=args.path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
@@ -69,7 +70,7 @@ def main():
print('\nInternLm2', end='')
current_length = 0
last_status = None
for agent_return in chatbot.stream_chat(history, max_new_tokens=512):
for agent_return in chatbot.stream_chat(history):
status = agent_return.state
if status not in [
AgentStatusCode.STREAM_ING, AgentStatusCode.CODING,

View File

@@ -62,7 +62,8 @@ class StreamlitUI:
def setup_sidebar(self):
"""Setup the sidebar for model and plugin selection."""
model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
@@ -113,19 +114,20 @@ class StreamlitUI:
return model_name, model, plugin_action, uploaded_file, model_ip
def init_model(self, option, ip=None):
"""Initialize the model based on the selected option."""
def init_model(self, model_name, ip=None):
"""Initialize the model based on the input model name."""
model_url = f'http://{ip}'
st.session_state['model_map'][option] = LMDeployClient(
path='internlm2-chat-20b',
st.session_state['model_map'][model_name] = LMDeployClient(
model_name=model_name,
url=model_url,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=100,
temperature=0,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
return st.session_state['model_map'][option]
return st.session_state['model_map'][model_name]
def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""
@@ -140,7 +142,7 @@ class StreamlitUI:
belong='assistant',
end='<|action_end|>\n',
), ),
)
max_turn=7)
def render_user(self, prompt: str):
with st.chat_message('user'):
@@ -294,8 +296,7 @@ def main():
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif (agent_return.state == AgentStatusCode.STREAM_ING
or agent_return.state == AgentStatusCode.CODING
or agent_return.state == AgentStatusCode.END):
or agent_return.state == AgentStatusCode.CODING):
# st.markdown(agent_return.response)
# 清除占位符的当前内容,并显示新内容
with st.container():

View File

@@ -0,0 +1,332 @@
import copy
import hashlib
import json
import os
import streamlit as st
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms import HFTransformer
from lagent.llms.meta_template import INTERNLM2_META as META
from lagent.schema import AgentStatusCode
# from streamlit.logger import get_logger
class SessionState:
def init_state(self):
"""Initialize session state variables."""
st.session_state['assistant'] = []
st.session_state['user'] = []
action_list = [
ArxivSearch(),
]
st.session_state['plugin_map'] = {
action.name: action
for action in action_list
}
st.session_state['model_map'] = {}
st.session_state['model_selected'] = None
st.session_state['plugin_actions'] = set()
st.session_state['history'] = []
def clear_state(self):
"""Clear the existing session state."""
st.session_state['assistant'] = []
st.session_state['user'] = []
st.session_state['model_selected'] = None
st.session_state['file'] = set()
if 'chatbot' in st.session_state:
st.session_state['chatbot']._session_history = []
class StreamlitUI:
def __init__(self, session_state: SessionState):
self.init_streamlit()
self.session_state = session_state
def init_streamlit(self):
"""Initialize Streamlit's UI settings."""
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
st.sidebar.title('模型控制')
st.session_state['file'] = set()
st.session_state['model_path'] = None
def setup_sidebar(self):
"""Setup the sidebar for model and plugin selection."""
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
model_path = st.sidebar.text_input(
'模型路径:', value='internlm/internlm2-chat-20b')
if model_name != st.session_state['model_selected'] or st.session_state[
'model_path'] != model_path:
st.session_state['model_path'] = model_path
model = self.init_model(model_name, model_path)
self.session_state.clear_state()
st.session_state['model_selected'] = model_name
if 'chatbot' in st.session_state:
del st.session_state['chatbot']
else:
model = st.session_state['model_map'][model_name]
plugin_name = st.sidebar.multiselect(
'插件选择',
options=list(st.session_state['plugin_map'].keys()),
default=[],
)
da_flag = st.sidebar.checkbox(
'数据分析',
value=False,
)
plugin_action = [
st.session_state['plugin_map'][name] for name in plugin_name
]
if 'chatbot' in st.session_state:
if len(plugin_action) > 0:
st.session_state['chatbot']._action_executor = ActionExecutor(
actions=plugin_action)
else:
st.session_state['chatbot']._action_executor = None
if da_flag:
st.session_state[
'chatbot']._interpreter_executor = ActionExecutor(
actions=[IPythonInterpreter()])
else:
st.session_state['chatbot']._interpreter_executor = None
st.session_state['chatbot']._protocol._meta_template = meta_prompt
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
st.session_state[
'chatbot']._protocol.interpreter_prompt = da_prompt
if st.sidebar.button('清空对话', key='clear'):
self.session_state.clear_state()
uploaded_file = st.sidebar.file_uploader('上传文件')
return model_name, model, plugin_action, uploaded_file, model_path
def init_model(self, model_name, path):
"""Initialize the model based on the input model name."""
st.session_state['model_map'][model_name] = HFTransformer(
path=path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
return st.session_state['model_map'][model_name]
def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""
return Internlm2Agent(
llm=model,
protocol=Internlm2Protocol(
tool=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(
plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
), ),
max_turn=7)
def render_user(self, prompt: str):
with st.chat_message('user'):
st.markdown(prompt)
def render_assistant(self, agent_return):
with st.chat_message('assistant'):
for action in agent_return.actions:
if (action) and (action.type != 'FinishAction'):
self.render_action(action)
st.markdown(agent_return.response)
def render_plugin_args(self, action):
action_name = action.type
args = action.args
import json
parameter_dict = dict(name=action_name, parameters=args)
parameter_str = '```json\n' + json.dumps(
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
st.markdown(parameter_str)
def render_interpreter_args(self, action):
st.info(action.type)
st.markdown(action.args['text'])
def render_action(self, action):
st.markdown(action.thought)
if action.type == 'IPythonInterpreter':
self.render_interpreter_args(action)
elif action.type == 'FinishAction':
pass
else:
self.render_plugin_args(action)
self.render_action_results(action)
def render_action_results(self, action):
"""Render the results of action, including text, images, videos, and
audios."""
if (isinstance(action.result, dict)):
if 'text' in action.result:
st.markdown('```\n' + action.result['text'] + '\n```')
if 'image' in action.result:
# image_path = action.result['image']
for image_path in action.result['image']:
image_data = open(image_path, 'rb').read()
st.image(image_data, caption='Generated Image')
if 'video' in action.result:
video_data = action.result['video']
video_data = open(video_data, 'rb').read()
st.video(video_data)
if 'audio' in action.result:
audio_data = action.result['audio']
audio_data = open(audio_data, 'rb').read()
st.audio(audio_data)
elif isinstance(action.result, list):
for item in action.result:
if item['type'] == 'text':
st.markdown('```\n' + item['content'] + '\n```')
elif item['type'] == 'image':
image_data = open(item['content'], 'rb').read()
st.image(image_data, caption='Generated Image')
elif item['type'] == 'video':
video_data = open(item['content'], 'rb').read()
st.video(video_data)
elif item['type'] == 'audio':
audio_data = open(item['content'], 'rb').read()
st.audio(audio_data)
if action.errmsg:
st.error(action.errmsg)
def main():
# logger = get_logger(__name__)
# Initialize Streamlit UI and setup sidebar
if 'ui' not in st.session_state:
session_state = SessionState()
session_state.init_state()
st.session_state['ui'] = StreamlitUI(session_state)
else:
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
_, model, plugin_action, uploaded_file, _ = st.session_state[
'ui'].setup_sidebar()
# Initialize chatbot if it is not already initialized
# or if the model has changed
if 'chatbot' not in st.session_state or model != st.session_state[
'chatbot']._llm:
st.session_state['chatbot'] = st.session_state[
'ui'].initialize_chatbot(model, plugin_action)
st.session_state['session_history'] = []
for prompt, agent_return in zip(st.session_state['user'],
st.session_state['assistant']):
st.session_state['ui'].render_user(prompt)
st.session_state['ui'].render_assistant(agent_return)
if user_input := st.chat_input(''):
with st.container():
st.session_state['ui'].render_user(user_input)
st.session_state['user'].append(user_input)
# Add file uploader to sidebar
if (uploaded_file
and uploaded_file.name not in st.session_state['file']):
st.session_state['file'].add(uploaded_file.name)
file_bytes = uploaded_file.read()
file_type = uploaded_file.type
if 'image' in file_type:
st.image(file_bytes, caption='Uploaded Image')
elif 'video' in file_type:
st.video(file_bytes, caption='Uploaded Video')
elif 'audio' in file_type:
st.audio(file_bytes, caption='Uploaded Audio')
# Save the file to a temporary location and get the path
postfix = uploaded_file.name.split('.')[-1]
# prefix = str(uuid.uuid4())
prefix = hashlib.md5(file_bytes).hexdigest()
filename = f'{prefix}.{postfix}'
file_path = os.path.join(root_dir, filename)
with open(file_path, 'wb') as tmpfile:
tmpfile.write(file_bytes)
file_size = os.stat(file_path).st_size / 1024 / 1024
file_size = f'{round(file_size, 2)} MB'
# st.write(f'File saved at: {file_path}')
user_input = [
dict(role='user', content=user_input),
dict(
role='user',
content=json.dumps(dict(path=file_path, size=file_size)),
name='file')
]
if isinstance(user_input, str):
user_input = [dict(role='user', content=user_input)]
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
for agent_return in st.session_state['chatbot'].stream_chat(
st.session_state['session_history'] + user_input):
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
with st.container():
st.session_state['ui'].render_plugin_args(
agent_return.actions[-1])
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif agent_return.state == AgentStatusCode.CODE_RETURN:
with st.container():
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif (agent_return.state == AgentStatusCode.STREAM_ING
or agent_return.state == AgentStatusCode.CODING):
# st.markdown(agent_return.response)
# 清除占位符的当前内容,并显示新内容
with st.container():
if agent_return.state != st.session_state['last_status']:
st.session_state['temp'] = ''
placeholder = st.empty()
st.session_state['placeholder'] = placeholder
if isinstance(agent_return.response, dict):
action = f"\n\n {agent_return.response['name']}: \n\n"
action_input = agent_return.response['parameters']
if agent_return.response[
'name'] == 'IPythonInterpreter':
action_input = action_input['command']
response = action + action_input
else:
response = agent_return.response
st.session_state['temp'] = response
st.session_state['placeholder'].markdown(
st.session_state['temp'])
elif agent_return.state == AgentStatusCode.END:
st.session_state['session_history'] += (
user_input + agent_return.inner_steps)
agent_return = copy.deepcopy(agent_return)
agent_return.response = st.session_state['temp']
st.session_state['assistant'].append(
copy.deepcopy(agent_return))
st.session_state['last_status'] = agent_return.state
if __name__ == '__main__':
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.path.join(root_dir, 'tmp_dir')
os.makedirs(root_dir, exist_ok=True)
main()

View File

@@ -26,6 +26,7 @@ def main():
model = HFTransformer(
path=args.path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
@@ -51,8 +52,7 @@ def main():
history = [dict(role='user', content=prompt)]
print('\nInternLm2', end='')
current_length = 0
for status, response, _ in model.stream_chat(
history, max_new_tokens=512):
for status, response, _ in model.stream_chat(history):
print(response[current_length:], end='', flush=True)
current_length = len(response)
history.append(dict(role='assistant', content=response))

View File

@@ -1,7 +1,5 @@
from typing import Optional, Type
import arxiv
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
from lagent.schema import ActionReturn, ActionStatusCode
@@ -37,6 +35,8 @@ Electrical Engineering, and Economics from scientific articles on arxiv.org.
:class:`dict`: article information
* content (str): a list of 3 arxiv search papers
"""
import arxiv
try:
results = arxiv.Search( # type: ignore
query[:self.max_query_len],

View File

@@ -1,3 +1,4 @@
# flake8: noqa: E501
import json
import os
from typing import Optional, Type
@@ -103,9 +104,7 @@ class BINGMap(BaseAction):
latitude: float = 0.0,
longitude: float = 0.0,
radius: int = 5000) -> dict:
"""Search for places nearby a location, within a given radius, and \
return the results into a list. You can use either the places name or
the \\ latitude and longitude.
"""Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
Args:
search_term (:class:`str`): the place name.

View File

@@ -1,8 +1,7 @@
# flake8: noqa: E501
import os
from typing import Optional, Type
from serpapi import GoogleSearch
from lagent.actions.base_action import BaseAction, tool_api
from lagent.schema import ActionReturn, ActionStatusCode
from .parser import BaseParser, JsonParser
@@ -52,8 +51,7 @@ class GoogleScholar(BaseAction):
filter: Optional[str] = None,
as_vis: Optional[str] = None,
) -> dict:
"""Search for scholarly articles based on a query according to the
google scholar.
"""Search for scholarly articles based on a query according to the google scholar.
Args:
query (str): The query to search for.
@@ -78,6 +76,7 @@ class GoogleScholar(BaseAction):
- organic_id: a list of the organic results' ids of the three selected papers
- pub_info: publication information of selected papers
"""
from serpapi import GoogleSearch
params = {
'q': query,
'engine': 'google_scholar',
@@ -133,8 +132,7 @@ class GoogleScholar(BaseAction):
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None) -> dict:
"""Search for an author's information by author's id provided by
get_author_id.
"""Search for an author's information by author's id provided by get_author_id.
Args:
author_id (str): Required. The ID of an author.
@@ -155,6 +153,7 @@ class GoogleScholar(BaseAction):
* articles: at most 3 articles by the author
* website: the author's homepage url
"""
from serpapi import GoogleSearch
params = {
'engine': 'google_scholar_author',
'author_id': author_id,
@@ -192,8 +191,7 @@ class GoogleScholar(BaseAction):
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json') -> dict:
"""Function to get MLA citation format by an identification of
organic_result's id provided by search_google_scholar.
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
Args:
q (str): ID of an individual Google Scholar organic search result.
@@ -206,6 +204,7 @@ class GoogleScholar(BaseAction):
* authors: the authors of the article
* citation: the citation format of the article
"""
from serpapi import GoogleSearch
params = {
'q': q,
'engine': 'google_scholar_cite',
@@ -233,27 +232,22 @@ class GoogleScholar(BaseAction):
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json') -> dict:
"""The getAuthorId function is used to get the author's id by his or
her name.
"""The getAuthorId function is used to get the author's id by his or her name.
Args:
mauthors (str): Defines the author you want to search for.
hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter \
language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page \
results. The parameter has the precedence over before_author parameter. Defaults to None.
before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous \
page results. Defaults to None.
no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a \
cached version is already present. Defaults to False.
_async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a \
structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
Args:
mauthors (str): Defines the author you want to search for.
hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
_async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
Returns:
:class:`dict`: author id
* author_id: the author_id of the author
Returns:
:class:`dict`: author id
* author_id: the author_id of the author
"""
from serpapi import GoogleSearch
params = {
'mauthors': mauthors,
'engine': 'google_scholar_profiles',

View File

@@ -3,11 +3,7 @@ from contextlib import redirect_stdout
from dataclasses import dataclass
from enum import Enum
from io import StringIO
from typing import Optional
import json5
from IPython import InteractiveShell
from timeout_decorator import timeout as timer
from typing import Optional, Type
from ..schema import ActionReturn, ActionStatusCode
from .base_action import BaseAction, tool_api
@@ -51,10 +47,11 @@ class IPythonInteractive(BaseAction):
max_out_len: int = 2048,
use_signals: bool = True,
description: Optional[dict] = None,
parser: type[BaseParser] = JsonParser,
parser: Type[BaseParser] = JsonParser,
enable: bool = True,
):
super().__init__(description, parser, enable)
from IPython import InteractiveShell
self.timeout = timeout
self._executor = InteractiveShell()
self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m')
@@ -71,9 +68,10 @@ class IPythonInteractive(BaseAction):
Args:
command (:class:`str`): Python code snippet
timeout (:class:`Optional[int]`): timeout for execution. This
argument only works in the main thread. Defaults to ``None``.
timeout (:class:`Optional[int]`): timeout for execution.
This argument only works in the main thread. Defaults to ``None``.
"""
from timeout_decorator import timeout as timer
tool_return = ActionReturn(args={'text': command}, type=self.name)
ret = (
timer(timeout or self.timeout)(self.exec)(command)
@@ -171,6 +169,8 @@ class IPythonInteractive(BaseAction):
Returns:
:class:`str`: Python code
"""
import json5
# Match triple backtick blocks first
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
# Match single backtick blocks second

View File

@@ -1,3 +1,4 @@
# flake8: noqa: E501
import base64
import io
import logging
@@ -10,10 +11,6 @@ import traceback
import uuid
from typing import Optional, Tuple, Type
import json5
import PIL.Image
from jupyter_client import KernelManager
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
from lagent.schema import ActionReturn, ActionStatusCode
@@ -74,6 +71,8 @@ class IPythonInterpreter(BaseAction):
@staticmethod
def start_kernel():
from jupyter_client import KernelManager
# start the kernel and manager
km = KernelManager()
km.start_kernel()
@@ -207,12 +206,7 @@ class IPythonInterpreter(BaseAction):
@tool_api
def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
r"""When you send a message containing Python code to python, it will be
\ executed in a stateful Jupyter notebook environment. python will
respond with \ the output of the execution or time out after 60.0
seconds. The drive at '/mnt/data' \ can be used to save and persist
user files. Internet access for this session is \ disabled. Do not make
external web requests or API calls as they will fail.
r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
Args:
command (:class:`str`): Python code
@@ -239,6 +233,8 @@ class IPythonInterpreter(BaseAction):
def extract_code(text):
import json5
# Match triple backtick blocks first
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
# Match single backtick blocks second
@@ -262,6 +258,7 @@ def escape_ansi(line):
def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
import PIL.Image
image_file = str(uuid.uuid4()) + '.png'
local_image_file = os.path.join(work_dir, image_file)

View File

@@ -1,7 +1,5 @@
from typing import Dict, Optional, Type
from pptx import Presentation
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
@@ -10,14 +8,13 @@ THEME_MAPPING = {
'template': None,
'title': 'Title Slide',
'single': 'Title and Content',
'two': 'Tow content',
'two': 'Two Content',
}
}
class PPT(BaseAction):
"""Plugin to create ppt slides with text, paragraph, images in good looking
styles."""
"""Plugin to create ppt slides with text, paragraph, images in good looking styles."""
def __init__(self,
theme_mapping: Optional[Dict[str, dict]] = None,
@@ -34,13 +31,14 @@ class PPT(BaseAction):
"""Create a pptx file with specific themes.
Args:
theme (:class:`str`): the theme used
theme (:class:`str`): the theme used. The value should be one of ['Default'].
abs_location (:class:`str`): the ppt file's absolute location
Returns:
:class:`dict`: operation status
* status: the result of the execution
"""
from pptx import Presentation
self.location = abs_location
try:
self.pointer = Presentation(self.theme_mapping[theme]['template'])
@@ -117,6 +115,7 @@ class PPT(BaseAction):
:class:`dict`: operation status
* status: the result of the execution
"""
from PIL import Image
layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
layout = next(i for i in self.pointer.slide_master.slide_layouts
if i.name == layout_name)
@@ -124,6 +123,7 @@ class PPT(BaseAction):
ph_title, ph_body1, ph_body2 = slide.placeholders
ph_title.text = title
ph = ph_body2
image = Image.open(image)
image_pil = image.to_pil()
left = ph.left
width = ph.width

View File

@@ -1,10 +1,9 @@
# flake8: noqa: E501
import copy
import io
from contextlib import redirect_stdout
from typing import Any, Optional, Type
from func_timeout import FunctionTimedOut, func_set_timeout
from lagent.actions.base_action import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
from lagent.schema import ActionReturn, ActionStatusCode
@@ -65,15 +64,26 @@ class PythonInterpreter(BaseAction):
@tool_api
def run(self, command: str) -> ActionReturn:
"""用来执行Python代码。代码必须是一个函数函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
```python # import 依赖包 import xxx def solution(): # 初始化一些变量
variable_names_with_real_meaning = xxx # 步骤一 mid_variable =
func(variable_names_with_real_meaning) # 步骤 x mid_variable =
func(mid_variable) # 最后结果 final_answer = func(mid_variable) return
final_answer ```
```python
# import 依赖包
import xxx
def solution():
# 初始化一些变量
variable_names_with_real_meaning = xxx
# 步骤一
mid_variable = func(variable_names_with_real_meaning)
# 步骤 x
mid_variable = func(mid_variable)
# 最后结果
final_answer = func(mid_variable)
return final_answer
```
Args:
command (:class:`str`): Python code snippet
"""
from func_timeout import FunctionTimedOut, func_set_timeout
self.runtime = GenericRuntime()
try:
tool_return = func_set_timeout(self.timeout)(self._call)(command)

View File

@@ -1,4 +1,5 @@
from .autogpt import * # noqa: F401, F403
from .base_agent import * # noqa: F401, F403
from .internlm2_agent import * # noqa: F401, F403
from .react import * # noqa: F401, F403
from .rewoo import * # noqa: F401, F403

View File

@@ -3,8 +3,8 @@ import logging
from copy import deepcopy
from typing import Dict, List, Optional, Union
from lagent import BaseAgent
from lagent.actions import ActionExecutor
from lagent.agents.base_agent import BaseAgent
from lagent.llms import BaseAPIModel, BaseModel
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn, AgentStatusCode, ModelStatusCode # noqa: E501
@@ -141,6 +141,12 @@ class Internlm2Protocol:
tool_name = api_info['name'].split('.')[0]
plugin['description'] = API_PREFIX.format(
tool_name=tool_name, description=plugin['description'])
# only keep required parameters
required_parameters = [
param for param in plugin['parameters']
if param['name'] in plugin['required']
]
plugin['parameters'] = required_parameters
plugin_descriptions.append(plugin)
plugin_prompt = self.plugin_prompt.format(
prompt=json.dumps(
@@ -170,13 +176,13 @@ class Internlm2Protocol:
return 'interpreter', message, dict(
name=interpreter_executor.action_names()[0],
parameters=dict(command=code))
return None, message, None
return None, message.split(self.tool['start_token'])[0], None
def format_response(self, action_return, name) -> dict:
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.format_result()
else:
response = action_return.errmsg
response = str(action_return.errmsg)
content = self.execute['begin'] + response + self.execute['end']
if self.execute.get('fallback_role'):
return dict(

View File

@@ -1,10 +1,13 @@
from lagent.utils import is_module_exist
from .base_api import BaseAPIModel
from .base_llm import BaseModel
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
from .lmdepoly_wrapper import LMDeployClient, LMDeployPipeline, LMDeployServer
from .meta_template import INTERNLM2_META
from .openai import GPTAPI
from .vllm_wrapper import VllmModel
__all__ = ['BaseModel', 'BaseAPIModel', 'GPTAPI']
if is_module_exist('transformers'):
from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401
__all__.extend(['HFTransformer', 'HFTransformerCasualLM'])
__all__ = [
'BaseModel', 'BaseAPIModel', 'GPTAPI', 'LMDeployClient',
'LMDeployPipeline', 'LMDeployServer', 'HFTransformer',
'HFTransformerCasualLM', 'INTERNLM2_META', 'HFTransformerChat', 'VllmModel'
]

View File

@@ -118,7 +118,7 @@ class APITemplateParser:
return res
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
merged_prompt = self.roles[self.roles[role_prompt['role']]]
merged_prompt = self.roles[role_prompt['role']]
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles[self.roles[
merged_prompt['fallback_role']]]
@@ -152,7 +152,7 @@ class BaseAPIModel(BaseModel):
template_parser: 'APITemplateParser' = APITemplateParser,
meta_template: Optional[Dict] = None,
*,
max_tokens: int = 512,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
@@ -169,7 +169,7 @@ class BaseAPIModel(BaseModel):
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,

View File

@@ -99,8 +99,8 @@ class BaseModel:
Args:
path (str): The path to the model.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
to 512.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (list of dict, optional): The model's meta prompt
@@ -116,7 +116,7 @@ class BaseModel:
template_parser: 'LMTemplateParser' = LMTemplateParser,
meta_template: Optional[List[Dict]] = None,
*,
max_tokens: int = 512,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
@@ -133,7 +133,7 @@ class BaseModel:
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
@@ -183,12 +183,12 @@ class BaseModel:
Returns:
"""
if isinstance(inputs[0], list):
inputs = list()
_inputs = list()
for msg in inputs:
inputs.append(self.template_parser(msg))
_inputs.append(self.template_parser(msg))
else:
inputs = self.template_parser(inputs)
return self.generate(inputs, **gen_params)
_inputs = self.template_parser(inputs)
return self.generate(_inputs, **gen_params)
def generate_from_template(self, inputs: Union[List[dict],
List[List[dict]]],

View File

@@ -1,9 +1,9 @@
import copy
import logging
import warnings
from typing import Dict, List, Optional, Union
from lagent.schema import ModelStatusCode
from .base_api import APITemplateParser
from .base_llm import BaseModel
logger = logging.getLogger(__name__)
@@ -19,8 +19,6 @@ class HFTransformer(BaseModel):
Args:
path (str): The name or path to HuggingFace's model.
max_seq_len (int): The maximum length of the input sequence. Defaults
to 2048.
tokenizer_path (str): The path to the tokenizer. Defaults to None.
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
Defaults to {}.
@@ -40,12 +38,20 @@ class HFTransformer(BaseModel):
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = None,
stop_words_id: Union[List[int], int] = None,
**kwargs):
super().__init__(
path=path,
tokenizer_only=tokenizer_only,
meta_template=meta_template,
**kwargs)
if isinstance(stop_words_id, int):
stop_words_id = [stop_words_id]
self.gen_params.update(stop_words_id=stop_words_id)
if self.gen_params['stop_words'] is not None and \
self.gen_params['stop_words_id'] is not None:
logger.warning('Both stop_words and stop_words_id are specified,'
'only stop_words_id will be used.')
self._load_tokenizer(
path=path,
@@ -60,7 +66,9 @@ class HFTransformer(BaseModel):
self.prefix_allowed_tokens_fn = None
stop_words_id = []
if self.gen_params.get('stop_words'):
if self.gen_params.get('stop_words_id'):
stop_words_id = self.gen_params.get('stop_words_id')
elif self.gen_params.get('stop_words'):
for sw in self.gen_params.get('stop_words'):
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
self.additional_eos_token_id = stop_words_id
@@ -72,8 +80,27 @@ class HFTransformer(BaseModel):
tokenizer_path if tokenizer_path else path,
trust_remote_code=True,
**tokenizer_kwargs)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.eos_token is not None:
logger.warning(
f'Using eos_token_id {self.tokenizer.eos_token} '
'as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
from transformers.generation import GenerationConfig
self.gcfg = GenerationConfig.from_pretrained(path)
if self.gcfg.pad_token_id is not None:
logger.warning(
f'Using pad_token_id {self.gcfg.pad_token_id} '
'as pad_token_id.')
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to '
'set pad_token_id via passing '
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
def _load_model(self, path: str, model_kwargs: dict):
import torch
@@ -130,7 +157,6 @@ class HFTransformer(BaseModel):
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs, padding=True, return_tensors='pt', return_length=True)
input_length = inputs['length']
@@ -151,55 +177,20 @@ class HFTransformer(BaseModel):
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if eos_token_id is None:
if self.gcfg.eos_token_id is not None:
eos_token_id = self.gcfg.eos_token_id
else:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if (has_default_max_length
and generation_config.max_new_tokens is None):
warnings.warn(
"Using `max_length`'s default"
f'({generation_config.max_length})'
'to control the generation length. '
'This behaviour is deprecated and will be removed'
' from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the'
' maximum length of the generation.',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
'Both `max_new_tokens`'
f'(={generation_config.max_new_tokens})'
'and `max_length`'
f'(={generation_config.max_length})'
' seem to have been set.`max_new_tokens`'
' will take precedence. Please refer to'
' the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en'
'/main_classes/text_generation)',
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string}'
f' is {input_ids_seq_length},'
' but `max_length` is set to'
f' {generation_config.max_length}.'
'This can lead to unexpected behavior.'
' You should consider increasing `max_new_tokens`.')
# 2. Set generation parameters if not already defined
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
# Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria
@@ -310,3 +301,39 @@ class HFTransformerCasualLM(HFTransformer):
self.model = AutoModelForCausalLM.from_pretrained(
path, trust_remote_code=True, **model_kwargs)
self.model.eval()
class HFTransformerChat(HFTransformerCasualLM):
def __init__(self, template_parser=APITemplateParser, **kwargs):
super().__init__(template_parser=template_parser, **kwargs)
def chat(self,
inputs: Union[List[dict], List[List[dict]]],
do_sample: bool = True,
**kwargs):
"""Return the chat completions in stream mode.
Args:
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
do_sample (bool): do sampling if enabled
Returns:
the text/chat completion
"""
# handle batch inference with vanilla for loop
if isinstance(inputs[0], list):
resps = []
for input in inputs:
resps.append(self.chat(input, do_sample, **kwargs))
return resps
prompt = self.template_parser(inputs)
query = prompt[-1]['content']
history = prompt[:-1]
try:
response, history = self.model.chat(
self.tokenizer, query, history=history)
except Exception as e:
# handle over-length input error
logger.warning(str(e))
response = ''
return response

View File

@@ -46,9 +46,9 @@ class TritonClient(BaseModel):
inputs: Union[str, List[str]],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
@@ -57,10 +57,10 @@ class TritonClient(BaseModel):
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
@@ -72,9 +72,12 @@ class TritonClient(BaseModel):
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.chatbot.log_level)
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
f'max_out_len {max_new_tokens}')
if self.chatbot._session is None:
sequence_start = True
@@ -88,32 +91,33 @@ class TritonClient(BaseModel):
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''
self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
sequence_end):
if status.value < 0:
break
if status.value == 0:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt + self.chatbot._session.response)
# remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
return res
else:
return ''
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
if status < ModelStatusCode.END:
return ''
elif status == ModelStatusCode.END:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt +
self.chatbot._session.response)
# remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
return res
def stream_chat(self,
inputs: List[dict],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
@@ -122,21 +126,24 @@ class TritonClient(BaseModel):
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
from lmdeploy.serve.turbomind.chatbot import Session, StatusCode, get_logger
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.chatbot.log_level)
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
f'max_out_len {max_new_tokens}')
if self.chatbot._session is None:
sequence_start = True
@@ -150,27 +157,29 @@ class TritonClient(BaseModel):
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''
self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)
prompt = self.template_parser(inputs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
sequence_end):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
if status.value < 0:
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
# The stop symbol also appears in the output of the last STREAM_ING state.
res = filter_suffix(res, self.gen_params.get('stop_words'))
if status < ModelStatusCode.END:
return status, res, _
elif status == ModelStatusCode.END: # remove stop_words
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt +
self.chatbot._session.response)
yield status, res, _
break
else:
yield self.state_map.get(status), res, _
if status.value == 0:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt + self.chatbot._session.response)
yield self.state_map.get(status), res, _
else:
return self.state_map.get(status), res, _
yield status, res, _
def _update_gen_params(self, **kwargs):
import mmengine
@@ -226,6 +235,7 @@ class LMDeployPipeline(BaseModel):
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
@@ -233,18 +243,23 @@ class LMDeployPipeline(BaseModel):
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.messages import GenerationConfig
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
response = self.model.batch_infer(
prompt, do_preprocess=do_preprocess, **gen_params)
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
# remove stop_words
response = filter_suffix(response, self.gen_params.get('stop_words'))
@@ -308,6 +323,7 @@ class LMDeployServer(BaseModel):
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
@@ -319,6 +335,8 @@ class LMDeployServer(BaseModel):
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
@@ -330,6 +348,8 @@ class LMDeployServer(BaseModel):
batched = False
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
resp = [''] * len(inputs)
for text in self.client.completions_v1(
@@ -340,6 +360,7 @@ class LMDeployServer(BaseModel):
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp = [
@@ -359,6 +380,7 @@ class LMDeployServer(BaseModel):
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
@@ -371,12 +393,16 @@ class LMDeployServer(BaseModel):
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
prompt = self.template_parser(inputs)
resp = ''
@@ -390,6 +416,7 @@ class LMDeployServer(BaseModel):
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']
@@ -411,12 +438,15 @@ class LMDeployClient(LMDeployServer):
"""
Args:
path (str): The path to the model.
url (str): communicating address 'http://<ip>:<port>' of
api_server
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
"""
def __init__(self, path: str, url: str, **kwargs):
BaseModel.__init__(self, path=path, **kwargs)
def __init__(self, url: str, model_name: str, **kwargs):
BaseModel.__init__(self, path=url, **kwargs)
from lmdeploy.serve.openai.api_client import APIClient
self.client = APIClient(url)
self.model_name = model_name

View File

@@ -1,7 +1,7 @@
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, wait
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
@@ -106,15 +106,15 @@ class GPTAPI(BaseAPIModel):
Union[str, List[str]]: generated string(s)
"""
assert isinstance(inputs, list)
if isinstance(inputs[0], dict):
inputs = [inputs]
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
gen_params = {**self.gen_params, **gen_params}
with ThreadPoolExecutor(max_workers=20) as executor:
tasks = [
executor.submit(self._chat, messages, **gen_params)
for messages in inputs
for messages in (
[inputs] if isinstance(inputs[0], dict) else inputs)
]
wait(tasks)
ret = [task.result() for task in tasks]
return ret[0] if isinstance(inputs[0], dict) else ret
@@ -133,7 +133,7 @@ class GPTAPI(BaseAPIModel):
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(
gen_params.pop('max_tokens'),
gen_params.pop('max_new_tokens'),
self.context_window - len(self.tokenize(str(input))) - 100)
if max_tokens <= 0:
return ''

View File

@@ -0,0 +1,71 @@
from typing import List, Union
from lagent.llms.base_llm import BaseModel
from lagent.utils.util import filter_suffix
class VllmModel(BaseModel):
"""
A wrapper of vLLM model.
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a huggingface model.
- ii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
tp (int): tensor parallel
vllm_cfg (dict): Other kwargs for vllm model initialization.
"""
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
super().__init__(path=path, **kwargs)
from vllm import LLM
self.model = LLM(
model=self.path,
trust_remote_code=True,
tensor_parallel_size=tp,
**vllm_cfg)
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from vllm import SamplingParams
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
stop_words = gen_params.pop('stop_words')
sampling_config = SamplingParams(
skip_special_tokens=skip_special_tokens,
max_tokens=max_new_tokens,
stop=stop_words,
**gen_params)
response = self.model.generate(prompt, sampling_params=sampling_config)
response = [resp.outputs[0].text for resp in response]
# remove stop_words
response = filter_suffix(response, self.gen_params.get('stop_words'))
if batched:
return response
return response[0]

View File

@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
__version__ = '0.2.0'
__version__ = '0.2.2'
def parse_version_info(version_str: str, length: int = 4) -> tuple:

View File

@@ -1,4 +1,8 @@
lmdeploy
streamlit
google-search-results
lmdeploy>=0.2.3
pillow
python-pptx
timeout_decorator
torch
transformers
transformers>=4.34
vllm>=0.3.3

View File

@@ -1,16 +1,13 @@
arxiv
distro
func_timeout
google-search-results
griffe
json5
jsonschema
jupyter
jupyter_client
phx-class-registry
pillow
python-pptx
requests
streamlit
tiktoken
timeout_decorator
typing-extensions

View File

@@ -21,5 +21,4 @@ quiet-level = 3
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer,astroid
[flake8]
per-file-ignores = ftdp/configs/*: F401,F403,F405
max-line-length = 200