Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
581d9fb898 | ||
|
|
26672731e9 | ||
|
|
466d7aa24d | ||
|
|
5f55e12736 | ||
|
|
e16a6bfc3a | ||
|
|
605a921878 | ||
|
|
4fd014bebf | ||
|
|
432ffaae8a | ||
|
|
a64aa599ce | ||
|
|
662011783e | ||
|
|
3cf20f5011 | ||
|
|
7b71988d09 | ||
|
|
90ef5215b6 | ||
|
|
6a5447663a | ||
|
|
a2c23ef9dd | ||
|
|
3be9ec042c | ||
|
|
aa5a357a34 | ||
|
|
5650a75f3e | ||
|
|
eea6e1cb56 | ||
|
|
42c6d265e1 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -160,3 +160,4 @@ cython_debug/
|
||||
#.idea/
|
||||
.vscode/
|
||||
docs/*/_build/
|
||||
tmp_dir/
|
||||
|
||||
@@ -22,7 +22,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://www.youtube.com/watch?v=YAelRLi0Zak)
|
||||
https://github.com/InternLM/lagent/assets/24622904/3242f9bf-32d2-4907-8815-e16a75a4ac0e
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://www.youtube.com/watch?v=YAelRLi0Zak)
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://www.youtube.com/watch?v=YAelRLi0Zak)
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://www.youtube.com/watch?v=YAelRLi0Zak)
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://www.youtube.com/watch?v=YAelRLi0Zak)
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ English | [简体中文](README_zh-CN.md) | [日本語](README_ja_JP.md) | [ह
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://www.youtube.com/watch?v=YAelRLi0Zak)
|
||||
https://github.com/InternLM/lagent/assets/24622904/cb851b31-6932-422e-a776-b1aa68f2a64f
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ This chapter introduces you to the framework of Lagent, and provides links to de
|
||||
|
||||
## What is Lagent
|
||||
|
||||
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below:
|
||||
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ability of LLM, and the whole framework is shown below:
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ def main():
|
||||
model = HFTransformer(
|
||||
path=args.path,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=None,
|
||||
temperature=0.1,
|
||||
@@ -69,7 +70,7 @@ def main():
|
||||
print('\nInternLm2:', end='')
|
||||
current_length = 0
|
||||
last_status = None
|
||||
for agent_return in chatbot.stream_chat(history, max_new_tokens=512):
|
||||
for agent_return in chatbot.stream_chat(history):
|
||||
status = agent_return.state
|
||||
if status not in [
|
||||
AgentStatusCode.STREAM_ING, AgentStatusCode.CODING,
|
||||
|
||||
@@ -62,7 +62,8 @@ class StreamlitUI:
|
||||
|
||||
def setup_sidebar(self):
|
||||
"""Setup the sidebar for model and plugin selection."""
|
||||
model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
||||
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
||||
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
|
||||
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
|
||||
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
|
||||
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
|
||||
@@ -113,19 +114,20 @@ class StreamlitUI:
|
||||
|
||||
return model_name, model, plugin_action, uploaded_file, model_ip
|
||||
|
||||
def init_model(self, option, ip=None):
|
||||
"""Initialize the model based on the selected option."""
|
||||
def init_model(self, model_name, ip=None):
|
||||
"""Initialize the model based on the input model name."""
|
||||
model_url = f'http://{ip}'
|
||||
st.session_state['model_map'][option] = LMDeployClient(
|
||||
path='internlm2-chat-20b',
|
||||
st.session_state['model_map'][model_name] = LMDeployClient(
|
||||
model_name=model_name,
|
||||
url=model_url,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=100,
|
||||
temperature=0,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
return st.session_state['model_map'][option]
|
||||
return st.session_state['model_map'][model_name]
|
||||
|
||||
def initialize_chatbot(self, model, plugin_action):
|
||||
"""Initialize the chatbot with the given model and plugin actions."""
|
||||
@@ -140,7 +142,7 @@ class StreamlitUI:
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
), ),
|
||||
)
|
||||
max_turn=7)
|
||||
|
||||
def render_user(self, prompt: str):
|
||||
with st.chat_message('user'):
|
||||
@@ -294,8 +296,7 @@ def main():
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif (agent_return.state == AgentStatusCode.STREAM_ING
|
||||
or agent_return.state == AgentStatusCode.CODING
|
||||
or agent_return.state == AgentStatusCode.END):
|
||||
or agent_return.state == AgentStatusCode.CODING):
|
||||
# st.markdown(agent_return.response)
|
||||
# 清除占位符的当前内容,并显示新内容
|
||||
with st.container():
|
||||
|
||||
332
examples/internlm2_agent_web_demo_hf.py
Normal file
332
examples/internlm2_agent_web_demo_hf.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
|
||||
from lagent.llms import HFTransformer
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
from lagent.schema import AgentStatusCode
|
||||
|
||||
# from streamlit.logger import get_logger
|
||||
|
||||
|
||||
class SessionState:
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize session state variables."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
|
||||
action_list = [
|
||||
ArxivSearch(),
|
||||
]
|
||||
st.session_state['plugin_map'] = {
|
||||
action.name: action
|
||||
for action in action_list
|
||||
}
|
||||
st.session_state['model_map'] = {}
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['plugin_actions'] = set()
|
||||
st.session_state['history'] = []
|
||||
|
||||
def clear_state(self):
|
||||
"""Clear the existing session state."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['file'] = set()
|
||||
if 'chatbot' in st.session_state:
|
||||
st.session_state['chatbot']._session_history = []
|
||||
|
||||
|
||||
class StreamlitUI:
|
||||
|
||||
def __init__(self, session_state: SessionState):
|
||||
self.init_streamlit()
|
||||
self.session_state = session_state
|
||||
|
||||
def init_streamlit(self):
|
||||
"""Initialize Streamlit's UI settings."""
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
st.sidebar.title('模型控制')
|
||||
st.session_state['file'] = set()
|
||||
st.session_state['model_path'] = None
|
||||
|
||||
def setup_sidebar(self):
|
||||
"""Setup the sidebar for model and plugin selection."""
|
||||
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
||||
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
|
||||
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
|
||||
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
|
||||
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
|
||||
model_path = st.sidebar.text_input(
|
||||
'模型路径:', value='internlm/internlm2-chat-20b')
|
||||
if model_name != st.session_state['model_selected'] or st.session_state[
|
||||
'model_path'] != model_path:
|
||||
st.session_state['model_path'] = model_path
|
||||
model = self.init_model(model_name, model_path)
|
||||
self.session_state.clear_state()
|
||||
st.session_state['model_selected'] = model_name
|
||||
if 'chatbot' in st.session_state:
|
||||
del st.session_state['chatbot']
|
||||
else:
|
||||
model = st.session_state['model_map'][model_name]
|
||||
|
||||
plugin_name = st.sidebar.multiselect(
|
||||
'插件选择',
|
||||
options=list(st.session_state['plugin_map'].keys()),
|
||||
default=[],
|
||||
)
|
||||
da_flag = st.sidebar.checkbox(
|
||||
'数据分析',
|
||||
value=False,
|
||||
)
|
||||
plugin_action = [
|
||||
st.session_state['plugin_map'][name] for name in plugin_name
|
||||
]
|
||||
|
||||
if 'chatbot' in st.session_state:
|
||||
if len(plugin_action) > 0:
|
||||
st.session_state['chatbot']._action_executor = ActionExecutor(
|
||||
actions=plugin_action)
|
||||
else:
|
||||
st.session_state['chatbot']._action_executor = None
|
||||
if da_flag:
|
||||
st.session_state[
|
||||
'chatbot']._interpreter_executor = ActionExecutor(
|
||||
actions=[IPythonInterpreter()])
|
||||
else:
|
||||
st.session_state['chatbot']._interpreter_executor = None
|
||||
st.session_state['chatbot']._protocol._meta_template = meta_prompt
|
||||
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
|
||||
st.session_state[
|
||||
'chatbot']._protocol.interpreter_prompt = da_prompt
|
||||
if st.sidebar.button('清空对话', key='clear'):
|
||||
self.session_state.clear_state()
|
||||
uploaded_file = st.sidebar.file_uploader('上传文件')
|
||||
|
||||
return model_name, model, plugin_action, uploaded_file, model_path
|
||||
|
||||
def init_model(self, model_name, path):
|
||||
"""Initialize the model based on the input model name."""
|
||||
st.session_state['model_map'][model_name] = HFTransformer(
|
||||
path=path,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=None,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
return st.session_state['model_map'][model_name]
|
||||
|
||||
def initialize_chatbot(self, model, plugin_action):
|
||||
"""Initialize the chatbot with the given model and plugin actions."""
|
||||
return Internlm2Agent(
|
||||
llm=model,
|
||||
protocol=Internlm2Protocol(
|
||||
tool=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(
|
||||
plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
), ),
|
||||
max_turn=7)
|
||||
|
||||
def render_user(self, prompt: str):
|
||||
with st.chat_message('user'):
|
||||
st.markdown(prompt)
|
||||
|
||||
def render_assistant(self, agent_return):
|
||||
with st.chat_message('assistant'):
|
||||
for action in agent_return.actions:
|
||||
if (action) and (action.type != 'FinishAction'):
|
||||
self.render_action(action)
|
||||
st.markdown(agent_return.response)
|
||||
|
||||
def render_plugin_args(self, action):
|
||||
action_name = action.type
|
||||
args = action.args
|
||||
import json
|
||||
parameter_dict = dict(name=action_name, parameters=args)
|
||||
parameter_str = '```json\n' + json.dumps(
|
||||
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
|
||||
st.markdown(parameter_str)
|
||||
|
||||
def render_interpreter_args(self, action):
|
||||
st.info(action.type)
|
||||
st.markdown(action.args['text'])
|
||||
|
||||
def render_action(self, action):
|
||||
st.markdown(action.thought)
|
||||
if action.type == 'IPythonInterpreter':
|
||||
self.render_interpreter_args(action)
|
||||
elif action.type == 'FinishAction':
|
||||
pass
|
||||
else:
|
||||
self.render_plugin_args(action)
|
||||
self.render_action_results(action)
|
||||
|
||||
def render_action_results(self, action):
|
||||
"""Render the results of action, including text, images, videos, and
|
||||
audios."""
|
||||
if (isinstance(action.result, dict)):
|
||||
if 'text' in action.result:
|
||||
st.markdown('```\n' + action.result['text'] + '\n```')
|
||||
if 'image' in action.result:
|
||||
# image_path = action.result['image']
|
||||
for image_path in action.result['image']:
|
||||
image_data = open(image_path, 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
if 'video' in action.result:
|
||||
video_data = action.result['video']
|
||||
video_data = open(video_data, 'rb').read()
|
||||
st.video(video_data)
|
||||
if 'audio' in action.result:
|
||||
audio_data = action.result['audio']
|
||||
audio_data = open(audio_data, 'rb').read()
|
||||
st.audio(audio_data)
|
||||
elif isinstance(action.result, list):
|
||||
for item in action.result:
|
||||
if item['type'] == 'text':
|
||||
st.markdown('```\n' + item['content'] + '\n```')
|
||||
elif item['type'] == 'image':
|
||||
image_data = open(item['content'], 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
elif item['type'] == 'video':
|
||||
video_data = open(item['content'], 'rb').read()
|
||||
st.video(video_data)
|
||||
elif item['type'] == 'audio':
|
||||
audio_data = open(item['content'], 'rb').read()
|
||||
st.audio(audio_data)
|
||||
if action.errmsg:
|
||||
st.error(action.errmsg)
|
||||
|
||||
|
||||
def main():
|
||||
# logger = get_logger(__name__)
|
||||
# Initialize Streamlit UI and setup sidebar
|
||||
if 'ui' not in st.session_state:
|
||||
session_state = SessionState()
|
||||
session_state.init_state()
|
||||
st.session_state['ui'] = StreamlitUI(session_state)
|
||||
|
||||
else:
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
_, model, plugin_action, uploaded_file, _ = st.session_state[
|
||||
'ui'].setup_sidebar()
|
||||
|
||||
# Initialize chatbot if it is not already initialized
|
||||
# or if the model has changed
|
||||
if 'chatbot' not in st.session_state or model != st.session_state[
|
||||
'chatbot']._llm:
|
||||
st.session_state['chatbot'] = st.session_state[
|
||||
'ui'].initialize_chatbot(model, plugin_action)
|
||||
st.session_state['session_history'] = []
|
||||
|
||||
for prompt, agent_return in zip(st.session_state['user'],
|
||||
st.session_state['assistant']):
|
||||
st.session_state['ui'].render_user(prompt)
|
||||
st.session_state['ui'].render_assistant(agent_return)
|
||||
|
||||
if user_input := st.chat_input(''):
|
||||
with st.container():
|
||||
st.session_state['ui'].render_user(user_input)
|
||||
st.session_state['user'].append(user_input)
|
||||
# Add file uploader to sidebar
|
||||
if (uploaded_file
|
||||
and uploaded_file.name not in st.session_state['file']):
|
||||
|
||||
st.session_state['file'].add(uploaded_file.name)
|
||||
file_bytes = uploaded_file.read()
|
||||
file_type = uploaded_file.type
|
||||
if 'image' in file_type:
|
||||
st.image(file_bytes, caption='Uploaded Image')
|
||||
elif 'video' in file_type:
|
||||
st.video(file_bytes, caption='Uploaded Video')
|
||||
elif 'audio' in file_type:
|
||||
st.audio(file_bytes, caption='Uploaded Audio')
|
||||
# Save the file to a temporary location and get the path
|
||||
|
||||
postfix = uploaded_file.name.split('.')[-1]
|
||||
# prefix = str(uuid.uuid4())
|
||||
prefix = hashlib.md5(file_bytes).hexdigest()
|
||||
filename = f'{prefix}.{postfix}'
|
||||
file_path = os.path.join(root_dir, filename)
|
||||
with open(file_path, 'wb') as tmpfile:
|
||||
tmpfile.write(file_bytes)
|
||||
file_size = os.stat(file_path).st_size / 1024 / 1024
|
||||
file_size = f'{round(file_size, 2)} MB'
|
||||
# st.write(f'File saved at: {file_path}')
|
||||
user_input = [
|
||||
dict(role='user', content=user_input),
|
||||
dict(
|
||||
role='user',
|
||||
content=json.dumps(dict(path=file_path, size=file_size)),
|
||||
name='file')
|
||||
]
|
||||
if isinstance(user_input, str):
|
||||
user_input = [dict(role='user', content=user_input)]
|
||||
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
|
||||
for agent_return in st.session_state['chatbot'].stream_chat(
|
||||
st.session_state['session_history'] + user_input):
|
||||
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_plugin_args(
|
||||
agent_return.actions[-1])
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif agent_return.state == AgentStatusCode.CODE_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif (agent_return.state == AgentStatusCode.STREAM_ING
|
||||
or agent_return.state == AgentStatusCode.CODING):
|
||||
# st.markdown(agent_return.response)
|
||||
# 清除占位符的当前内容,并显示新内容
|
||||
with st.container():
|
||||
if agent_return.state != st.session_state['last_status']:
|
||||
st.session_state['temp'] = ''
|
||||
placeholder = st.empty()
|
||||
st.session_state['placeholder'] = placeholder
|
||||
if isinstance(agent_return.response, dict):
|
||||
action = f"\n\n {agent_return.response['name']}: \n\n"
|
||||
action_input = agent_return.response['parameters']
|
||||
if agent_return.response[
|
||||
'name'] == 'IPythonInterpreter':
|
||||
action_input = action_input['command']
|
||||
response = action + action_input
|
||||
else:
|
||||
response = agent_return.response
|
||||
st.session_state['temp'] = response
|
||||
st.session_state['placeholder'].markdown(
|
||||
st.session_state['temp'])
|
||||
elif agent_return.state == AgentStatusCode.END:
|
||||
st.session_state['session_history'] += (
|
||||
user_input + agent_return.inner_steps)
|
||||
agent_return = copy.deepcopy(agent_return)
|
||||
agent_return.response = st.session_state['temp']
|
||||
st.session_state['assistant'].append(
|
||||
copy.deepcopy(agent_return))
|
||||
st.session_state['last_status'] = agent_return.state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
root_dir = os.path.join(root_dir, 'tmp_dir')
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
main()
|
||||
@@ -26,6 +26,7 @@ def main():
|
||||
model = HFTransformer(
|
||||
path=args.path,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=None,
|
||||
temperature=0.1,
|
||||
@@ -51,8 +52,7 @@ def main():
|
||||
history = [dict(role='user', content=prompt)]
|
||||
print('\nInternLm2:', end='')
|
||||
current_length = 0
|
||||
for status, response, _ in model.stream_chat(
|
||||
history, max_new_tokens=512):
|
||||
for status, response, _ in model.stream_chat(history):
|
||||
print(response[current_length:], end='', flush=True)
|
||||
current_length = len(response)
|
||||
history.append(dict(role='assistant', content=response))
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
import arxiv
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
@@ -37,6 +35,8 @@ Electrical Engineering, and Economics from scientific articles on arxiv.org.
|
||||
:class:`dict`: article information
|
||||
* content (str): a list of 3 arxiv search papers
|
||||
"""
|
||||
import arxiv
|
||||
|
||||
try:
|
||||
results = arxiv.Search( # type: ignore
|
||||
query[:self.max_query_len],
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
import os
|
||||
from typing import Optional, Type
|
||||
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
from .parser import BaseParser, JsonParser
|
||||
@@ -78,6 +76,7 @@ class GoogleScholar(BaseAction):
|
||||
- organic_id: a list of the organic results' ids of the three selected papers
|
||||
- pub_info: publication information of selected papers
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'q': query,
|
||||
'engine': 'google_scholar',
|
||||
@@ -154,6 +153,7 @@ class GoogleScholar(BaseAction):
|
||||
* articles: at most 3 articles by the author
|
||||
* website: the author's homepage url
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'engine': 'google_scholar_author',
|
||||
'author_id': author_id,
|
||||
@@ -204,6 +204,7 @@ class GoogleScholar(BaseAction):
|
||||
* authors: the authors of the article
|
||||
* citation: the citation format of the article
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'q': q,
|
||||
'engine': 'google_scholar_cite',
|
||||
@@ -246,6 +247,7 @@ class GoogleScholar(BaseAction):
|
||||
:class:`dict`: author id
|
||||
* author_id: the author_id of the author
|
||||
"""
|
||||
from serpapi import GoogleSearch
|
||||
params = {
|
||||
'mauthors': mauthors,
|
||||
'engine': 'google_scholar_profiles',
|
||||
|
||||
@@ -3,11 +3,7 @@ from contextlib import redirect_stdout
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
import json5
|
||||
from IPython import InteractiveShell
|
||||
from timeout_decorator import timeout as timer
|
||||
from typing import Optional, Type
|
||||
|
||||
from ..schema import ActionReturn, ActionStatusCode
|
||||
from .base_action import BaseAction, tool_api
|
||||
@@ -51,10 +47,11 @@ class IPythonInteractive(BaseAction):
|
||||
max_out_len: int = 2048,
|
||||
use_signals: bool = True,
|
||||
description: Optional[dict] = None,
|
||||
parser: type[BaseParser] = JsonParser,
|
||||
parser: Type[BaseParser] = JsonParser,
|
||||
enable: bool = True,
|
||||
):
|
||||
super().__init__(description, parser, enable)
|
||||
from IPython import InteractiveShell
|
||||
self.timeout = timeout
|
||||
self._executor = InteractiveShell()
|
||||
self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m')
|
||||
@@ -74,6 +71,7 @@ class IPythonInteractive(BaseAction):
|
||||
timeout (:class:`Optional[int]`): timeout for execution.
|
||||
This argument only works in the main thread. Defaults to ``None``.
|
||||
"""
|
||||
from timeout_decorator import timeout as timer
|
||||
tool_return = ActionReturn(args={'text': command}, type=self.name)
|
||||
ret = (
|
||||
timer(timeout or self.timeout)(self.exec)(command)
|
||||
@@ -171,6 +169,8 @@ class IPythonInteractive(BaseAction):
|
||||
Returns:
|
||||
:class:`str`: Python code
|
||||
"""
|
||||
import json5
|
||||
|
||||
# Match triple backtick blocks first
|
||||
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
||||
# Match single backtick blocks second
|
||||
|
||||
@@ -11,10 +11,6 @@ import traceback
|
||||
import uuid
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import json5
|
||||
import PIL.Image
|
||||
from jupyter_client import KernelManager
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
@@ -75,6 +71,8 @@ class IPythonInterpreter(BaseAction):
|
||||
|
||||
@staticmethod
|
||||
def start_kernel():
|
||||
from jupyter_client import KernelManager
|
||||
|
||||
# start the kernel and manager
|
||||
km = KernelManager()
|
||||
km.start_kernel()
|
||||
@@ -235,6 +233,8 @@ class IPythonInterpreter(BaseAction):
|
||||
|
||||
|
||||
def extract_code(text):
|
||||
import json5
|
||||
|
||||
# Match triple backtick blocks first
|
||||
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
||||
# Match single backtick blocks second
|
||||
@@ -258,6 +258,7 @@ def escape_ansi(line):
|
||||
|
||||
|
||||
def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
|
||||
import PIL.Image
|
||||
image_file = str(uuid.uuid4()) + '.png'
|
||||
local_image_file = os.path.join(work_dir, image_file)
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from pptx import Presentation
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
|
||||
@@ -10,7 +8,7 @@ THEME_MAPPING = {
|
||||
'template': None,
|
||||
'title': 'Title Slide',
|
||||
'single': 'Title and Content',
|
||||
'two': 'Tow content',
|
||||
'two': 'Two Content',
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,13 +31,14 @@ class PPT(BaseAction):
|
||||
"""Create a pptx file with specific themes.
|
||||
|
||||
Args:
|
||||
theme (:class:`str`): the theme used
|
||||
theme (:class:`str`): the theme used. The value should be one of ['Default'].
|
||||
abs_location (:class:`str`): the ppt file's absolute location
|
||||
|
||||
Returns:
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
from pptx import Presentation
|
||||
self.location = abs_location
|
||||
try:
|
||||
self.pointer = Presentation(self.theme_mapping[theme]['template'])
|
||||
@@ -116,6 +115,7 @@ class PPT(BaseAction):
|
||||
:class:`dict`: operation status
|
||||
* status: the result of the execution
|
||||
"""
|
||||
from PIL import Image
|
||||
layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
|
||||
layout = next(i for i in self.pointer.slide_master.slide_layouts
|
||||
if i.name == layout_name)
|
||||
@@ -123,6 +123,7 @@ class PPT(BaseAction):
|
||||
ph_title, ph_body1, ph_body2 = slide.placeholders
|
||||
ph_title.text = title
|
||||
ph = ph_body2
|
||||
image = Image.open(image)
|
||||
image_pil = image.to_pil()
|
||||
left = ph.left
|
||||
width = ph.width
|
||||
|
||||
@@ -4,8 +4,6 @@ import io
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from func_timeout import FunctionTimedOut, func_set_timeout
|
||||
|
||||
from lagent.actions.base_action import BaseAction, tool_api
|
||||
from lagent.actions.parser import BaseParser, JsonParser
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
@@ -85,6 +83,7 @@ class PythonInterpreter(BaseAction):
|
||||
Args:
|
||||
command (:class:`str`): Python code snippet
|
||||
"""
|
||||
from func_timeout import FunctionTimedOut, func_set_timeout
|
||||
self.runtime = GenericRuntime()
|
||||
try:
|
||||
tool_return = func_set_timeout(self.timeout)(self._call)(command)
|
||||
|
||||
@@ -141,6 +141,12 @@ class Internlm2Protocol:
|
||||
tool_name = api_info['name'].split('.')[0]
|
||||
plugin['description'] = API_PREFIX.format(
|
||||
tool_name=tool_name, description=plugin['description'])
|
||||
# only keep required parameters
|
||||
required_parameters = [
|
||||
param for param in plugin['parameters']
|
||||
if param['name'] in plugin['required']
|
||||
]
|
||||
plugin['parameters'] = required_parameters
|
||||
plugin_descriptions.append(plugin)
|
||||
plugin_prompt = self.plugin_prompt.format(
|
||||
prompt=json.dumps(
|
||||
@@ -170,13 +176,13 @@ class Internlm2Protocol:
|
||||
return 'interpreter', message, dict(
|
||||
name=interpreter_executor.action_names()[0],
|
||||
parameters=dict(command=code))
|
||||
return None, message, None
|
||||
return None, message.split(self.tool['start_token'])[0], None
|
||||
|
||||
def format_response(self, action_return, name) -> dict:
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
response = action_return.format_result()
|
||||
else:
|
||||
response = action_return.errmsg
|
||||
response = str(action_return.errmsg)
|
||||
content = self.execute['begin'] + response + self.execute['end']
|
||||
if self.execute.get('fallback_role'):
|
||||
return dict(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from .base_api import BaseAPIModel
|
||||
from .base_llm import BaseModel
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
|
||||
from .lmdepoly_wrapper import LMDeployClient, LMDeployPipeline, LMDeployServer
|
||||
from .meta_template import INTERNLM2_META
|
||||
from .openai import GPTAPI
|
||||
from .vllm_wrapper import VllmModel
|
||||
|
||||
__all__ = [
|
||||
'BaseModel', 'BaseAPIModel', 'GPTAPI', 'LMDeployClient',
|
||||
'LMDeployPipeline', 'LMDeployServer', 'HFTransformer',
|
||||
'HFTransformerCasualLM', 'INTERNLM2_META'
|
||||
'HFTransformerCasualLM', 'INTERNLM2_META', 'HFTransformerChat', 'VllmModel'
|
||||
]
|
||||
|
||||
@@ -118,7 +118,7 @@ class APITemplateParser:
|
||||
return res
|
||||
|
||||
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
|
||||
merged_prompt = self.roles[self.roles[role_prompt['role']]]
|
||||
merged_prompt = self.roles[role_prompt['role']]
|
||||
if merged_prompt.get('fallback_role'):
|
||||
merged_prompt = self.roles[self.roles[
|
||||
merged_prompt['fallback_role']]]
|
||||
@@ -152,7 +152,7 @@ class BaseAPIModel(BaseModel):
|
||||
template_parser: 'APITemplateParser' = APITemplateParser,
|
||||
meta_template: Optional[Dict] = None,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
max_new_tokens: int = 512,
|
||||
top_p: float = 0.8,
|
||||
top_k: float = None,
|
||||
temperature: float = 0.8,
|
||||
@@ -169,7 +169,7 @@ class BaseAPIModel(BaseModel):
|
||||
if isinstance(stop_words, str):
|
||||
stop_words = [stop_words]
|
||||
self.gen_params = dict(
|
||||
max_tokens=max_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -99,8 +99,8 @@ class BaseModel:
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
max_seq_len (int): The maximum sequence length of the model. Defaults
|
||||
to 2048.
|
||||
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
|
||||
to 512.
|
||||
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||
Defaults to False.
|
||||
meta_template (list of dict, optional): The model's meta prompt
|
||||
@@ -116,7 +116,7 @@ class BaseModel:
|
||||
template_parser: 'LMTemplateParser' = LMTemplateParser,
|
||||
meta_template: Optional[List[Dict]] = None,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
max_new_tokens: int = 512,
|
||||
top_p: float = 0.8,
|
||||
top_k: float = None,
|
||||
temperature: float = 0.8,
|
||||
@@ -133,7 +133,7 @@ class BaseModel:
|
||||
if isinstance(stop_words, str):
|
||||
stop_words = [stop_words]
|
||||
self.gen_params = dict(
|
||||
max_tokens=max_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
@@ -183,12 +183,12 @@ class BaseModel:
|
||||
Returns:
|
||||
"""
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = list()
|
||||
_inputs = list()
|
||||
for msg in inputs:
|
||||
inputs.append(self.template_parser(msg))
|
||||
_inputs.append(self.template_parser(msg))
|
||||
else:
|
||||
inputs = self.template_parser(inputs)
|
||||
return self.generate(inputs, **gen_params)
|
||||
_inputs = self.template_parser(inputs)
|
||||
return self.generate(_inputs, **gen_params)
|
||||
|
||||
def generate_from_template(self, inputs: Union[List[dict],
|
||||
List[List[dict]]],
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from lagent.schema import ModelStatusCode
|
||||
from .base_api import APITemplateParser
|
||||
from .base_llm import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -19,8 +19,6 @@ class HFTransformer(BaseModel):
|
||||
|
||||
Args:
|
||||
path (str): The name or path to HuggingFace's model.
|
||||
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||
to 2048.
|
||||
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||
Defaults to {}.
|
||||
@@ -40,12 +38,20 @@ class HFTransformer(BaseModel):
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = None,
|
||||
stop_words_id: Union[List[int], int] = None,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
path=path,
|
||||
tokenizer_only=tokenizer_only,
|
||||
meta_template=meta_template,
|
||||
**kwargs)
|
||||
if isinstance(stop_words_id, int):
|
||||
stop_words_id = [stop_words_id]
|
||||
self.gen_params.update(stop_words_id=stop_words_id)
|
||||
if self.gen_params['stop_words'] is not None and \
|
||||
self.gen_params['stop_words_id'] is not None:
|
||||
logger.warning('Both stop_words and stop_words_id are specified,'
|
||||
'only stop_words_id will be used.')
|
||||
|
||||
self._load_tokenizer(
|
||||
path=path,
|
||||
@@ -60,7 +66,9 @@ class HFTransformer(BaseModel):
|
||||
self.prefix_allowed_tokens_fn = None
|
||||
|
||||
stop_words_id = []
|
||||
if self.gen_params.get('stop_words'):
|
||||
if self.gen_params.get('stop_words_id'):
|
||||
stop_words_id = self.gen_params.get('stop_words_id')
|
||||
elif self.gen_params.get('stop_words'):
|
||||
for sw in self.gen_params.get('stop_words'):
|
||||
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
|
||||
self.additional_eos_token_id = stop_words_id
|
||||
@@ -72,8 +80,27 @@ class HFTransformer(BaseModel):
|
||||
tokenizer_path if tokenizer_path else path,
|
||||
trust_remote_code=True,
|
||||
**tokenizer_kwargs)
|
||||
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
if self.tokenizer.eos_token is not None:
|
||||
logger.warning(
|
||||
f'Using eos_token_id {self.tokenizer.eos_token} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
else:
|
||||
from transformers.generation import GenerationConfig
|
||||
self.gcfg = GenerationConfig.from_pretrained(path)
|
||||
|
||||
if self.gcfg.pad_token_id is not None:
|
||||
logger.warning(
|
||||
f'Using pad_token_id {self.gcfg.pad_token_id} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
'pad_token_id is not set for this tokenizer. Try to '
|
||||
'set pad_token_id via passing '
|
||||
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
import torch
|
||||
@@ -130,7 +157,6 @@ class HFTransformer(BaseModel):
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
# import pdb; pdb.set_trace()
|
||||
inputs = self.tokenizer(
|
||||
inputs, padding=True, return_tensors='pt', return_length=True)
|
||||
input_length = inputs['length']
|
||||
@@ -151,55 +177,20 @@ class HFTransformer(BaseModel):
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
if eos_token_id is None:
|
||||
if self.gcfg.eos_token_id is not None:
|
||||
eos_token_id = self.gcfg.eos_token_id
|
||||
else:
|
||||
eos_token_id = []
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if self.additional_eos_token_id is not None:
|
||||
eos_token_id.extend(self.additional_eos_token_id)
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(
|
||||
input_ids.device) if eos_token_id is not None else None
|
||||
has_default_max_length = (
|
||||
kwargs.get('max_length') is None
|
||||
and generation_config.max_length is not None)
|
||||
if (has_default_max_length
|
||||
and generation_config.max_new_tokens is None):
|
||||
warnings.warn(
|
||||
"Using `max_length`'s default"
|
||||
f'({generation_config.max_length})'
|
||||
'to control the generation length. '
|
||||
'This behaviour is deprecated and will be removed'
|
||||
' from the config in v5 of Transformers -- we'
|
||||
' recommend using `max_new_tokens` to control the'
|
||||
' maximum length of the generation.',
|
||||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length)
|
||||
if not has_default_max_length:
|
||||
logger.warn( # pylint: disable=W4902
|
||||
'Both `max_new_tokens`'
|
||||
f'(={generation_config.max_new_tokens})'
|
||||
'and `max_length`'
|
||||
f'(={generation_config.max_length})'
|
||||
' seem to have been set.`max_new_tokens`'
|
||||
' will take precedence. Please refer to'
|
||||
' the documentation for more information. '
|
||||
'(https://huggingface.co/docs/transformers/main/en'
|
||||
'/main_classes/text_generation)',
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = 'input_ids'
|
||||
logger.warning(
|
||||
f'Input length of {input_ids_string}'
|
||||
f' is {input_ids_seq_length},'
|
||||
' but `max_length` is set to'
|
||||
f' {generation_config.max_length}.'
|
||||
'This can lead to unexpected behavior.'
|
||||
' You should consider increasing `max_new_tokens`.')
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length)
|
||||
# Set generation parameters if not already defined
|
||||
logits_processor = self.logits_processor
|
||||
stopping_criteria = self.stopping_criteria
|
||||
|
||||
@@ -310,3 +301,39 @@ class HFTransformerCasualLM(HFTransformer):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path, trust_remote_code=True, **model_kwargs)
|
||||
self.model.eval()
|
||||
|
||||
|
||||
class HFTransformerChat(HFTransformerCasualLM):
|
||||
|
||||
def __init__(self, template_parser=APITemplateParser, **kwargs):
|
||||
super().__init__(template_parser=template_parser, **kwargs)
|
||||
|
||||
def chat(self,
|
||||
inputs: Union[List[dict], List[List[dict]]],
|
||||
do_sample: bool = True,
|
||||
**kwargs):
|
||||
"""Return the chat completions in stream mode.
|
||||
|
||||
Args:
|
||||
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
|
||||
do_sample (bool): do sampling if enabled
|
||||
Returns:
|
||||
the text/chat completion
|
||||
"""
|
||||
# handle batch inference with vanilla for loop
|
||||
if isinstance(inputs[0], list):
|
||||
resps = []
|
||||
for input in inputs:
|
||||
resps.append(self.chat(input, do_sample, **kwargs))
|
||||
return resps
|
||||
prompt = self.template_parser(inputs)
|
||||
query = prompt[-1]['content']
|
||||
history = prompt[:-1]
|
||||
try:
|
||||
response, history = self.model.chat(
|
||||
self.tokenizer, query, history=history)
|
||||
except Exception as e:
|
||||
# handle over-length input error
|
||||
logger.warning(str(e))
|
||||
response = ''
|
||||
return response
|
||||
|
||||
@@ -46,9 +46,9 @@ class TritonClient(BaseModel):
|
||||
inputs: Union[str, List[str]],
|
||||
session_id: int = 2967,
|
||||
request_id: str = '',
|
||||
max_tokens: int = 512,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in non-stream mode.
|
||||
@@ -57,10 +57,10 @@ class TritonClient(BaseModel):
|
||||
inputs (str, List[str]): user's prompt(s) in this round
|
||||
session_id (int): the identical id of a session
|
||||
request_id (str): the identical id of this round conversation
|
||||
max_tokens (int): the expected generated token numbers
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
@@ -72,9 +72,12 @@ class TritonClient(BaseModel):
|
||||
assert isinstance(session_id, int), \
|
||||
f'INT session id is required, but got {type(session_id)}'
|
||||
|
||||
logger = get_logger(log_level=self.chatbot.log_level)
|
||||
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
||||
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
||||
|
||||
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_tokens}')
|
||||
f'max_out_len {max_new_tokens}')
|
||||
|
||||
if self.chatbot._session is None:
|
||||
sequence_start = True
|
||||
@@ -88,32 +91,33 @@ class TritonClient(BaseModel):
|
||||
self.chatbot._session.request_id = request_id
|
||||
self.chatbot._session.response = ''
|
||||
|
||||
self.chatbot.cfg = self._update_gen_params(
|
||||
max_tokens=max_tokens, **kwargs)
|
||||
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session, prompt, max_tokens, sequence_start,
|
||||
sequence_end):
|
||||
if status.value < 0:
|
||||
break
|
||||
if status.value == 0:
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt + self.chatbot._session.response)
|
||||
# remove stop_words
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
return res
|
||||
else:
|
||||
return ''
|
||||
self.chatbot._session,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
sequence_start,
|
||||
sequence_end,
|
||||
skip_special_tokens=skip_special_tokens):
|
||||
status = self.state_map.get(status)
|
||||
if status < ModelStatusCode.END:
|
||||
return ''
|
||||
elif status == ModelStatusCode.END:
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt +
|
||||
self.chatbot._session.response)
|
||||
# remove stop_words
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
return res
|
||||
|
||||
def stream_chat(self,
|
||||
inputs: List[dict],
|
||||
session_id: int = 2967,
|
||||
request_id: str = '',
|
||||
max_tokens: int = 512,
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in stream mode.
|
||||
@@ -122,21 +126,24 @@ class TritonClient(BaseModel):
|
||||
session_id (int): the identical id of a session
|
||||
inputs (List[dict]): user's inputs in this round conversation
|
||||
request_id (str): the identical id of this round conversation
|
||||
max_tokens (int): the expected generated token numbers
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
"""
|
||||
from lmdeploy.serve.turbomind.chatbot import Session, StatusCode, get_logger
|
||||
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
|
||||
assert isinstance(session_id, int), \
|
||||
f'INT session id is required, but got {type(session_id)}'
|
||||
|
||||
logger = get_logger(log_level=self.chatbot.log_level)
|
||||
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
||||
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
||||
|
||||
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_tokens}')
|
||||
f'max_out_len {max_new_tokens}')
|
||||
|
||||
if self.chatbot._session is None:
|
||||
sequence_start = True
|
||||
@@ -150,27 +157,29 @@ class TritonClient(BaseModel):
|
||||
self.chatbot._session.request_id = request_id
|
||||
self.chatbot._session.response = ''
|
||||
|
||||
self.chatbot.cfg = self._update_gen_params(
|
||||
max_tokens=max_tokens, **kwargs)
|
||||
prompt = self.template_parser(inputs)
|
||||
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session, prompt, max_tokens, sequence_start,
|
||||
sequence_end):
|
||||
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
if status.value < 0:
|
||||
self.chatbot._session,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
sequence_start,
|
||||
sequence_end,
|
||||
skip_special_tokens=skip_special_tokens):
|
||||
status = self.state_map.get(status)
|
||||
# The stop symbol also appears in the output of the last STREAM_ING state.
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
if status < ModelStatusCode.END:
|
||||
return status, res, _
|
||||
elif status == ModelStatusCode.END: # remove stop_words
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt +
|
||||
self.chatbot._session.response)
|
||||
yield status, res, _
|
||||
break
|
||||
else:
|
||||
yield self.state_map.get(status), res, _
|
||||
if status.value == 0:
|
||||
self.chatbot._session.histories = (
|
||||
self.chatbot._session.histories +
|
||||
self.chatbot._session.prompt + self.chatbot._session.response)
|
||||
yield self.state_map.get(status), res, _
|
||||
else:
|
||||
return self.state_map.get(status), res, _
|
||||
yield status, res, _
|
||||
|
||||
def _update_gen_params(self, **kwargs):
|
||||
import mmengine
|
||||
@@ -226,6 +235,7 @@ class LMDeployPipeline(BaseModel):
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
do_preprocess: bool = None,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Return the chat completions in non-stream mode.
|
||||
|
||||
@@ -233,18 +243,23 @@ class LMDeployPipeline(BaseModel):
|
||||
inputs (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
from lmdeploy.messages import GenerationConfig
|
||||
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
prompt = inputs
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens, **gen_params)
|
||||
response = self.model.batch_infer(
|
||||
prompt, do_preprocess=do_preprocess, **gen_params)
|
||||
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
|
||||
response = [resp.text for resp in response]
|
||||
# remove stop_words
|
||||
response = filter_suffix(response, self.gen_params.get('stop_words'))
|
||||
@@ -308,6 +323,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: Optional[bool] = False,
|
||||
timeout: int = 30,
|
||||
**kwargs) -> List[str]:
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
@@ -319,6 +335,8 @@ class LMDeployServer(BaseModel):
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
ignore_eos (bool): indicator for ignoring eos
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
timeout (int): max time to wait for response
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
@@ -330,6 +348,8 @@ class LMDeployServer(BaseModel):
|
||||
batched = False
|
||||
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
max_new_tokens = gen_params.pop('max_new_tokens')
|
||||
gen_params.update(max_tokens=max_new_tokens)
|
||||
|
||||
resp = [''] * len(inputs)
|
||||
for text in self.client.completions_v1(
|
||||
@@ -340,6 +360,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end=sequence_end,
|
||||
stream=False,
|
||||
ignore_eos=ignore_eos,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp = [
|
||||
@@ -359,6 +380,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end: bool = True,
|
||||
stream: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: Optional[bool] = False,
|
||||
timeout: int = 30,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
@@ -371,12 +393,16 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end (bool): end flag of a session
|
||||
stream (bool): return in a streaming format if enabled
|
||||
ignore_eos (bool): indicator for ignoring eos
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
timeout (int): max time to wait for response
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
"""
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
max_new_tokens = gen_params.pop('max_new_tokens')
|
||||
gen_params.update(max_tokens=max_new_tokens)
|
||||
prompt = self.template_parser(inputs)
|
||||
|
||||
resp = ''
|
||||
@@ -390,6 +416,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end=sequence_end,
|
||||
stream=stream,
|
||||
ignore_eos=ignore_eos,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp += text['choices'][0]['text']
|
||||
@@ -411,12 +438,15 @@ class LMDeployClient(LMDeployServer):
|
||||
"""
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
url (str): communicating address 'http://<ip>:<port>' of
|
||||
api_server
|
||||
model_name (str): needed when model_path is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, url: str, **kwargs):
|
||||
BaseModel.__init__(self, path=path, **kwargs)
|
||||
def __init__(self, url: str, model_name: str, **kwargs):
|
||||
BaseModel.__init__(self, path=url, **kwargs)
|
||||
from lmdeploy.serve.openai.api_client import APIClient
|
||||
self.client = APIClient(url)
|
||||
self.model_name = model_name
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from logging import getLogger
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -106,15 +106,15 @@ class GPTAPI(BaseAPIModel):
|
||||
Union[str, List[str]]: generated string(s)
|
||||
"""
|
||||
assert isinstance(inputs, list)
|
||||
if isinstance(inputs[0], dict):
|
||||
inputs = [inputs]
|
||||
if 'max_tokens' in gen_params:
|
||||
raise NotImplementedError('unsupported parameter: max_tokens')
|
||||
gen_params = {**self.gen_params, **gen_params}
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
tasks = [
|
||||
executor.submit(self._chat, messages, **gen_params)
|
||||
for messages in inputs
|
||||
for messages in (
|
||||
[inputs] if isinstance(inputs[0], dict) else inputs)
|
||||
]
|
||||
wait(tasks)
|
||||
ret = [task.result() for task in tasks]
|
||||
return ret[0] if isinstance(inputs[0], dict) else ret
|
||||
|
||||
@@ -133,7 +133,7 @@ class GPTAPI(BaseAPIModel):
|
||||
|
||||
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
||||
max_tokens = min(
|
||||
gen_params.pop('max_tokens'),
|
||||
gen_params.pop('max_new_tokens'),
|
||||
self.context_window - len(self.tokenize(str(input))) - 100)
|
||||
if max_tokens <= 0:
|
||||
return ''
|
||||
|
||||
71
lagent/llms/vllm_wrapper.py
Normal file
71
lagent/llms/vllm_wrapper.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import List, Union
|
||||
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.utils.util import filter_suffix
|
||||
|
||||
|
||||
class VllmModel(BaseModel):
|
||||
"""
|
||||
A wrapper of vLLM model.
|
||||
|
||||
Args:
|
||||
path (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a huggingface model.
|
||||
- ii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
tp (int): tensor parallel
|
||||
vllm_cfg (dict): Other kwargs for vllm model initialization.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
|
||||
|
||||
super().__init__(path=path, **kwargs)
|
||||
from vllm import LLM
|
||||
self.model = LLM(
|
||||
model=self.path,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp,
|
||||
**vllm_cfg)
|
||||
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
do_preprocess: bool = None,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Return the chat completions in non-stream mode.
|
||||
|
||||
Args:
|
||||
inputs (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
from vllm import SamplingParams
|
||||
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
prompt = inputs
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
max_new_tokens = gen_params.pop('max_new_tokens')
|
||||
stop_words = gen_params.pop('stop_words')
|
||||
|
||||
sampling_config = SamplingParams(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_tokens=max_new_tokens,
|
||||
stop=stop_words,
|
||||
**gen_params)
|
||||
response = self.model.generate(prompt, sampling_params=sampling_config)
|
||||
response = [resp.outputs[0].text for resp in response]
|
||||
# remove stop_words
|
||||
response = filter_suffix(response, self.gen_params.get('stop_words'))
|
||||
if batched:
|
||||
return response
|
||||
return response[0]
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
__version__ = '0.2.1'
|
||||
__version__ = '0.2.2'
|
||||
|
||||
|
||||
def parse_version_info(version_str: str, length: int = 4) -> tuple:
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
lmdeploy>=0.2.2
|
||||
streamlit
|
||||
google-search-results
|
||||
lmdeploy>=0.2.3
|
||||
pillow
|
||||
python-pptx
|
||||
timeout_decorator
|
||||
torch
|
||||
transformers>=4.34
|
||||
vllm>=0.3.3
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
arxiv
|
||||
distro
|
||||
func_timeout
|
||||
google-search-results
|
||||
griffe
|
||||
json5
|
||||
jsonschema
|
||||
jupyter
|
||||
jupyter_client
|
||||
phx-class-registry
|
||||
pillow
|
||||
python-pptx
|
||||
requests
|
||||
streamlit
|
||||
tiktoken
|
||||
timeout_decorator
|
||||
typing-extensions
|
||||
|
||||
Reference in New Issue
Block a user