support demo with hf (#179)
This commit is contained in:
332
examples/internlm2_agent_web_demo_hf.py
Normal file
332
examples/internlm2_agent_web_demo_hf.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
|
||||
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
|
||||
from lagent.llms import HFTransformer
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
from lagent.schema import AgentStatusCode
|
||||
|
||||
# from streamlit.logger import get_logger
|
||||
|
||||
|
||||
class SessionState:
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize session state variables."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
|
||||
action_list = [
|
||||
ArxivSearch(),
|
||||
]
|
||||
st.session_state['plugin_map'] = {
|
||||
action.name: action
|
||||
for action in action_list
|
||||
}
|
||||
st.session_state['model_map'] = {}
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['plugin_actions'] = set()
|
||||
st.session_state['history'] = []
|
||||
|
||||
def clear_state(self):
|
||||
"""Clear the existing session state."""
|
||||
st.session_state['assistant'] = []
|
||||
st.session_state['user'] = []
|
||||
st.session_state['model_selected'] = None
|
||||
st.session_state['file'] = set()
|
||||
if 'chatbot' in st.session_state:
|
||||
st.session_state['chatbot']._session_history = []
|
||||
|
||||
|
||||
class StreamlitUI:
|
||||
|
||||
def __init__(self, session_state: SessionState):
|
||||
self.init_streamlit()
|
||||
self.session_state = session_state
|
||||
|
||||
def init_streamlit(self):
|
||||
"""Initialize Streamlit's UI settings."""
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
st.sidebar.title('模型控制')
|
||||
st.session_state['file'] = set()
|
||||
st.session_state['model_path'] = None
|
||||
|
||||
def setup_sidebar(self):
|
||||
"""Setup the sidebar for model and plugin selection."""
|
||||
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
||||
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
|
||||
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
|
||||
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
|
||||
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
|
||||
model_path = st.sidebar.text_input(
|
||||
'模型路径:', value='internlm/internlm2-chat-20b')
|
||||
if model_name != st.session_state['model_selected'] or st.session_state[
|
||||
'model_path'] != model_path:
|
||||
st.session_state['model_path'] = model_path
|
||||
model = self.init_model(model_name, model_path)
|
||||
self.session_state.clear_state()
|
||||
st.session_state['model_selected'] = model_name
|
||||
if 'chatbot' in st.session_state:
|
||||
del st.session_state['chatbot']
|
||||
else:
|
||||
model = st.session_state['model_map'][model_name]
|
||||
|
||||
plugin_name = st.sidebar.multiselect(
|
||||
'插件选择',
|
||||
options=list(st.session_state['plugin_map'].keys()),
|
||||
default=[],
|
||||
)
|
||||
da_flag = st.sidebar.checkbox(
|
||||
'数据分析',
|
||||
value=False,
|
||||
)
|
||||
plugin_action = [
|
||||
st.session_state['plugin_map'][name] for name in plugin_name
|
||||
]
|
||||
|
||||
if 'chatbot' in st.session_state:
|
||||
if len(plugin_action) > 0:
|
||||
st.session_state['chatbot']._action_executor = ActionExecutor(
|
||||
actions=plugin_action)
|
||||
else:
|
||||
st.session_state['chatbot']._action_executor = None
|
||||
if da_flag:
|
||||
st.session_state[
|
||||
'chatbot']._interpreter_executor = ActionExecutor(
|
||||
actions=[IPythonInterpreter()])
|
||||
else:
|
||||
st.session_state['chatbot']._interpreter_executor = None
|
||||
st.session_state['chatbot']._protocol._meta_template = meta_prompt
|
||||
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
|
||||
st.session_state[
|
||||
'chatbot']._protocol.interpreter_prompt = da_prompt
|
||||
if st.sidebar.button('清空对话', key='clear'):
|
||||
self.session_state.clear_state()
|
||||
uploaded_file = st.sidebar.file_uploader('上传文件')
|
||||
|
||||
return model_name, model, plugin_action, uploaded_file, model_path
|
||||
|
||||
def init_model(self, model_name, path):
|
||||
"""Initialize the model based on the input model name."""
|
||||
st.session_state['model_map'][model_name] = HFTransformer(
|
||||
path=path,
|
||||
meta_template=META,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.8,
|
||||
top_k=None,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
stop_words=['<|im_end|>'])
|
||||
return st.session_state['model_map'][model_name]
|
||||
|
||||
def initialize_chatbot(self, model, plugin_action):
|
||||
"""Initialize the chatbot with the given model and plugin actions."""
|
||||
return Internlm2Agent(
|
||||
llm=model,
|
||||
protocol=Internlm2Protocol(
|
||||
tool=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(
|
||||
plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
belong='assistant',
|
||||
end='<|action_end|>\n',
|
||||
), ),
|
||||
max_turn=7)
|
||||
|
||||
def render_user(self, prompt: str):
|
||||
with st.chat_message('user'):
|
||||
st.markdown(prompt)
|
||||
|
||||
def render_assistant(self, agent_return):
|
||||
with st.chat_message('assistant'):
|
||||
for action in agent_return.actions:
|
||||
if (action) and (action.type != 'FinishAction'):
|
||||
self.render_action(action)
|
||||
st.markdown(agent_return.response)
|
||||
|
||||
def render_plugin_args(self, action):
|
||||
action_name = action.type
|
||||
args = action.args
|
||||
import json
|
||||
parameter_dict = dict(name=action_name, parameters=args)
|
||||
parameter_str = '```json\n' + json.dumps(
|
||||
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
|
||||
st.markdown(parameter_str)
|
||||
|
||||
def render_interpreter_args(self, action):
|
||||
st.info(action.type)
|
||||
st.markdown(action.args['text'])
|
||||
|
||||
def render_action(self, action):
|
||||
st.markdown(action.thought)
|
||||
if action.type == 'IPythonInterpreter':
|
||||
self.render_interpreter_args(action)
|
||||
elif action.type == 'FinishAction':
|
||||
pass
|
||||
else:
|
||||
self.render_plugin_args(action)
|
||||
self.render_action_results(action)
|
||||
|
||||
def render_action_results(self, action):
|
||||
"""Render the results of action, including text, images, videos, and
|
||||
audios."""
|
||||
if (isinstance(action.result, dict)):
|
||||
if 'text' in action.result:
|
||||
st.markdown('```\n' + action.result['text'] + '\n```')
|
||||
if 'image' in action.result:
|
||||
# image_path = action.result['image']
|
||||
for image_path in action.result['image']:
|
||||
image_data = open(image_path, 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
if 'video' in action.result:
|
||||
video_data = action.result['video']
|
||||
video_data = open(video_data, 'rb').read()
|
||||
st.video(video_data)
|
||||
if 'audio' in action.result:
|
||||
audio_data = action.result['audio']
|
||||
audio_data = open(audio_data, 'rb').read()
|
||||
st.audio(audio_data)
|
||||
elif isinstance(action.result, list):
|
||||
for item in action.result:
|
||||
if item['type'] == 'text':
|
||||
st.markdown('```\n' + item['content'] + '\n```')
|
||||
elif item['type'] == 'image':
|
||||
image_data = open(item['content'], 'rb').read()
|
||||
st.image(image_data, caption='Generated Image')
|
||||
elif item['type'] == 'video':
|
||||
video_data = open(item['content'], 'rb').read()
|
||||
st.video(video_data)
|
||||
elif item['type'] == 'audio':
|
||||
audio_data = open(item['content'], 'rb').read()
|
||||
st.audio(audio_data)
|
||||
if action.errmsg:
|
||||
st.error(action.errmsg)
|
||||
|
||||
|
||||
def main():
|
||||
# logger = get_logger(__name__)
|
||||
# Initialize Streamlit UI and setup sidebar
|
||||
if 'ui' not in st.session_state:
|
||||
session_state = SessionState()
|
||||
session_state.init_state()
|
||||
st.session_state['ui'] = StreamlitUI(session_state)
|
||||
|
||||
else:
|
||||
st.set_page_config(
|
||||
layout='wide',
|
||||
page_title='lagent-web',
|
||||
page_icon='./docs/imgs/lagent_icon.png')
|
||||
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
||||
_, model, plugin_action, uploaded_file, _ = st.session_state[
|
||||
'ui'].setup_sidebar()
|
||||
|
||||
# Initialize chatbot if it is not already initialized
|
||||
# or if the model has changed
|
||||
if 'chatbot' not in st.session_state or model != st.session_state[
|
||||
'chatbot']._llm:
|
||||
st.session_state['chatbot'] = st.session_state[
|
||||
'ui'].initialize_chatbot(model, plugin_action)
|
||||
st.session_state['session_history'] = []
|
||||
|
||||
for prompt, agent_return in zip(st.session_state['user'],
|
||||
st.session_state['assistant']):
|
||||
st.session_state['ui'].render_user(prompt)
|
||||
st.session_state['ui'].render_assistant(agent_return)
|
||||
|
||||
if user_input := st.chat_input(''):
|
||||
with st.container():
|
||||
st.session_state['ui'].render_user(user_input)
|
||||
st.session_state['user'].append(user_input)
|
||||
# Add file uploader to sidebar
|
||||
if (uploaded_file
|
||||
and uploaded_file.name not in st.session_state['file']):
|
||||
|
||||
st.session_state['file'].add(uploaded_file.name)
|
||||
file_bytes = uploaded_file.read()
|
||||
file_type = uploaded_file.type
|
||||
if 'image' in file_type:
|
||||
st.image(file_bytes, caption='Uploaded Image')
|
||||
elif 'video' in file_type:
|
||||
st.video(file_bytes, caption='Uploaded Video')
|
||||
elif 'audio' in file_type:
|
||||
st.audio(file_bytes, caption='Uploaded Audio')
|
||||
# Save the file to a temporary location and get the path
|
||||
|
||||
postfix = uploaded_file.name.split('.')[-1]
|
||||
# prefix = str(uuid.uuid4())
|
||||
prefix = hashlib.md5(file_bytes).hexdigest()
|
||||
filename = f'{prefix}.{postfix}'
|
||||
file_path = os.path.join(root_dir, filename)
|
||||
with open(file_path, 'wb') as tmpfile:
|
||||
tmpfile.write(file_bytes)
|
||||
file_size = os.stat(file_path).st_size / 1024 / 1024
|
||||
file_size = f'{round(file_size, 2)} MB'
|
||||
# st.write(f'File saved at: {file_path}')
|
||||
user_input = [
|
||||
dict(role='user', content=user_input),
|
||||
dict(
|
||||
role='user',
|
||||
content=json.dumps(dict(path=file_path, size=file_size)),
|
||||
name='file')
|
||||
]
|
||||
if isinstance(user_input, str):
|
||||
user_input = [dict(role='user', content=user_input)]
|
||||
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
|
||||
for agent_return in st.session_state['chatbot'].stream_chat(
|
||||
st.session_state['session_history'] + user_input):
|
||||
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_plugin_args(
|
||||
agent_return.actions[-1])
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif agent_return.state == AgentStatusCode.CODE_RETURN:
|
||||
with st.container():
|
||||
st.session_state['ui'].render_action_results(
|
||||
agent_return.actions[-1])
|
||||
elif (agent_return.state == AgentStatusCode.STREAM_ING
|
||||
or agent_return.state == AgentStatusCode.CODING):
|
||||
# st.markdown(agent_return.response)
|
||||
# 清除占位符的当前内容,并显示新内容
|
||||
with st.container():
|
||||
if agent_return.state != st.session_state['last_status']:
|
||||
st.session_state['temp'] = ''
|
||||
placeholder = st.empty()
|
||||
st.session_state['placeholder'] = placeholder
|
||||
if isinstance(agent_return.response, dict):
|
||||
action = f"\n\n {agent_return.response['name']}: \n\n"
|
||||
action_input = agent_return.response['parameters']
|
||||
if agent_return.response[
|
||||
'name'] == 'IPythonInterpreter':
|
||||
action_input = action_input['command']
|
||||
response = action + action_input
|
||||
else:
|
||||
response = agent_return.response
|
||||
st.session_state['temp'] = response
|
||||
st.session_state['placeholder'].markdown(
|
||||
st.session_state['temp'])
|
||||
elif agent_return.state == AgentStatusCode.END:
|
||||
st.session_state['session_history'] += (
|
||||
user_input + agent_return.inner_steps)
|
||||
agent_return = copy.deepcopy(agent_return)
|
||||
agent_return.response = st.session_state['temp']
|
||||
st.session_state['assistant'].append(
|
||||
copy.deepcopy(agent_return))
|
||||
st.session_state['last_status'] = agent_return.state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
root_dir = os.path.join(root_dir, 'tmp_dir')
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
main()
|
||||
Reference in New Issue
Block a user