333 lines
14 KiB
Python
333 lines
14 KiB
Python
import copy
|
|
import hashlib
|
|
import json
|
|
import os
|
|
|
|
import streamlit as st
|
|
|
|
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
|
|
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
|
|
from lagent.llms import HFTransformer
|
|
from lagent.llms.meta_template import INTERNLM2_META as META
|
|
from lagent.schema import AgentStatusCode
|
|
|
|
# from streamlit.logger import get_logger
|
|
|
|
|
|
class SessionState:
|
|
|
|
def init_state(self):
|
|
"""Initialize session state variables."""
|
|
st.session_state['assistant'] = []
|
|
st.session_state['user'] = []
|
|
|
|
action_list = [
|
|
ArxivSearch(),
|
|
]
|
|
st.session_state['plugin_map'] = {
|
|
action.name: action
|
|
for action in action_list
|
|
}
|
|
st.session_state['model_map'] = {}
|
|
st.session_state['model_selected'] = None
|
|
st.session_state['plugin_actions'] = set()
|
|
st.session_state['history'] = []
|
|
|
|
def clear_state(self):
|
|
"""Clear the existing session state."""
|
|
st.session_state['assistant'] = []
|
|
st.session_state['user'] = []
|
|
st.session_state['model_selected'] = None
|
|
st.session_state['file'] = set()
|
|
if 'chatbot' in st.session_state:
|
|
st.session_state['chatbot']._session_history = []
|
|
|
|
|
|
class StreamlitUI:
|
|
|
|
def __init__(self, session_state: SessionState):
|
|
self.init_streamlit()
|
|
self.session_state = session_state
|
|
|
|
def init_streamlit(self):
|
|
"""Initialize Streamlit's UI settings."""
|
|
st.set_page_config(
|
|
layout='wide',
|
|
page_title='lagent-web',
|
|
page_icon='./docs/imgs/lagent_icon.png')
|
|
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
|
st.sidebar.title('模型控制')
|
|
st.session_state['file'] = set()
|
|
st.session_state['model_path'] = None
|
|
|
|
def setup_sidebar(self):
|
|
"""Setup the sidebar for model and plugin selection."""
|
|
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
|
|
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
|
|
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
|
|
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
|
|
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
|
|
model_path = st.sidebar.text_input(
|
|
'模型路径:', value='internlm/internlm2-chat-20b')
|
|
if model_name != st.session_state['model_selected'] or st.session_state[
|
|
'model_path'] != model_path:
|
|
st.session_state['model_path'] = model_path
|
|
model = self.init_model(model_name, model_path)
|
|
self.session_state.clear_state()
|
|
st.session_state['model_selected'] = model_name
|
|
if 'chatbot' in st.session_state:
|
|
del st.session_state['chatbot']
|
|
else:
|
|
model = st.session_state['model_map'][model_name]
|
|
|
|
plugin_name = st.sidebar.multiselect(
|
|
'插件选择',
|
|
options=list(st.session_state['plugin_map'].keys()),
|
|
default=[],
|
|
)
|
|
da_flag = st.sidebar.checkbox(
|
|
'数据分析',
|
|
value=False,
|
|
)
|
|
plugin_action = [
|
|
st.session_state['plugin_map'][name] for name in plugin_name
|
|
]
|
|
|
|
if 'chatbot' in st.session_state:
|
|
if len(plugin_action) > 0:
|
|
st.session_state['chatbot']._action_executor = ActionExecutor(
|
|
actions=plugin_action)
|
|
else:
|
|
st.session_state['chatbot']._action_executor = None
|
|
if da_flag:
|
|
st.session_state[
|
|
'chatbot']._interpreter_executor = ActionExecutor(
|
|
actions=[IPythonInterpreter()])
|
|
else:
|
|
st.session_state['chatbot']._interpreter_executor = None
|
|
st.session_state['chatbot']._protocol._meta_template = meta_prompt
|
|
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
|
|
st.session_state[
|
|
'chatbot']._protocol.interpreter_prompt = da_prompt
|
|
if st.sidebar.button('清空对话', key='clear'):
|
|
self.session_state.clear_state()
|
|
uploaded_file = st.sidebar.file_uploader('上传文件')
|
|
|
|
return model_name, model, plugin_action, uploaded_file, model_path
|
|
|
|
def init_model(self, model_name, path):
|
|
"""Initialize the model based on the input model name."""
|
|
st.session_state['model_map'][model_name] = HFTransformer(
|
|
path=path,
|
|
meta_template=META,
|
|
max_new_tokens=1024,
|
|
top_p=0.8,
|
|
top_k=None,
|
|
temperature=0.1,
|
|
repetition_penalty=1.0,
|
|
stop_words=['<|im_end|>'])
|
|
return st.session_state['model_map'][model_name]
|
|
|
|
def initialize_chatbot(self, model, plugin_action):
|
|
"""Initialize the chatbot with the given model and plugin actions."""
|
|
return Internlm2Agent(
|
|
llm=model,
|
|
protocol=Internlm2Protocol(
|
|
tool=dict(
|
|
begin='{start_token}{name}\n',
|
|
start_token='<|action_start|>',
|
|
name_map=dict(
|
|
plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
|
belong='assistant',
|
|
end='<|action_end|>\n',
|
|
), ),
|
|
max_turn=7)
|
|
|
|
def render_user(self, prompt: str):
|
|
with st.chat_message('user'):
|
|
st.markdown(prompt)
|
|
|
|
def render_assistant(self, agent_return):
|
|
with st.chat_message('assistant'):
|
|
for action in agent_return.actions:
|
|
if (action) and (action.type != 'FinishAction'):
|
|
self.render_action(action)
|
|
st.markdown(agent_return.response)
|
|
|
|
def render_plugin_args(self, action):
|
|
action_name = action.type
|
|
args = action.args
|
|
import json
|
|
parameter_dict = dict(name=action_name, parameters=args)
|
|
parameter_str = '```json\n' + json.dumps(
|
|
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
|
|
st.markdown(parameter_str)
|
|
|
|
def render_interpreter_args(self, action):
|
|
st.info(action.type)
|
|
st.markdown(action.args['text'])
|
|
|
|
def render_action(self, action):
|
|
st.markdown(action.thought)
|
|
if action.type == 'IPythonInterpreter':
|
|
self.render_interpreter_args(action)
|
|
elif action.type == 'FinishAction':
|
|
pass
|
|
else:
|
|
self.render_plugin_args(action)
|
|
self.render_action_results(action)
|
|
|
|
def render_action_results(self, action):
|
|
"""Render the results of action, including text, images, videos, and
|
|
audios."""
|
|
if (isinstance(action.result, dict)):
|
|
if 'text' in action.result:
|
|
st.markdown('```\n' + action.result['text'] + '\n```')
|
|
if 'image' in action.result:
|
|
# image_path = action.result['image']
|
|
for image_path in action.result['image']:
|
|
image_data = open(image_path, 'rb').read()
|
|
st.image(image_data, caption='Generated Image')
|
|
if 'video' in action.result:
|
|
video_data = action.result['video']
|
|
video_data = open(video_data, 'rb').read()
|
|
st.video(video_data)
|
|
if 'audio' in action.result:
|
|
audio_data = action.result['audio']
|
|
audio_data = open(audio_data, 'rb').read()
|
|
st.audio(audio_data)
|
|
elif isinstance(action.result, list):
|
|
for item in action.result:
|
|
if item['type'] == 'text':
|
|
st.markdown('```\n' + item['content'] + '\n```')
|
|
elif item['type'] == 'image':
|
|
image_data = open(item['content'], 'rb').read()
|
|
st.image(image_data, caption='Generated Image')
|
|
elif item['type'] == 'video':
|
|
video_data = open(item['content'], 'rb').read()
|
|
st.video(video_data)
|
|
elif item['type'] == 'audio':
|
|
audio_data = open(item['content'], 'rb').read()
|
|
st.audio(audio_data)
|
|
if action.errmsg:
|
|
st.error(action.errmsg)
|
|
|
|
|
|
def main():
|
|
# logger = get_logger(__name__)
|
|
# Initialize Streamlit UI and setup sidebar
|
|
if 'ui' not in st.session_state:
|
|
session_state = SessionState()
|
|
session_state.init_state()
|
|
st.session_state['ui'] = StreamlitUI(session_state)
|
|
|
|
else:
|
|
st.set_page_config(
|
|
layout='wide',
|
|
page_title='lagent-web',
|
|
page_icon='./docs/imgs/lagent_icon.png')
|
|
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
|
_, model, plugin_action, uploaded_file, _ = st.session_state[
|
|
'ui'].setup_sidebar()
|
|
|
|
# Initialize chatbot if it is not already initialized
|
|
# or if the model has changed
|
|
if 'chatbot' not in st.session_state or model != st.session_state[
|
|
'chatbot']._llm:
|
|
st.session_state['chatbot'] = st.session_state[
|
|
'ui'].initialize_chatbot(model, plugin_action)
|
|
st.session_state['session_history'] = []
|
|
|
|
for prompt, agent_return in zip(st.session_state['user'],
|
|
st.session_state['assistant']):
|
|
st.session_state['ui'].render_user(prompt)
|
|
st.session_state['ui'].render_assistant(agent_return)
|
|
|
|
if user_input := st.chat_input(''):
|
|
with st.container():
|
|
st.session_state['ui'].render_user(user_input)
|
|
st.session_state['user'].append(user_input)
|
|
# Add file uploader to sidebar
|
|
if (uploaded_file
|
|
and uploaded_file.name not in st.session_state['file']):
|
|
|
|
st.session_state['file'].add(uploaded_file.name)
|
|
file_bytes = uploaded_file.read()
|
|
file_type = uploaded_file.type
|
|
if 'image' in file_type:
|
|
st.image(file_bytes, caption='Uploaded Image')
|
|
elif 'video' in file_type:
|
|
st.video(file_bytes, caption='Uploaded Video')
|
|
elif 'audio' in file_type:
|
|
st.audio(file_bytes, caption='Uploaded Audio')
|
|
# Save the file to a temporary location and get the path
|
|
|
|
postfix = uploaded_file.name.split('.')[-1]
|
|
# prefix = str(uuid.uuid4())
|
|
prefix = hashlib.md5(file_bytes).hexdigest()
|
|
filename = f'{prefix}.{postfix}'
|
|
file_path = os.path.join(root_dir, filename)
|
|
with open(file_path, 'wb') as tmpfile:
|
|
tmpfile.write(file_bytes)
|
|
file_size = os.stat(file_path).st_size / 1024 / 1024
|
|
file_size = f'{round(file_size, 2)} MB'
|
|
# st.write(f'File saved at: {file_path}')
|
|
user_input = [
|
|
dict(role='user', content=user_input),
|
|
dict(
|
|
role='user',
|
|
content=json.dumps(dict(path=file_path, size=file_size)),
|
|
name='file')
|
|
]
|
|
if isinstance(user_input, str):
|
|
user_input = [dict(role='user', content=user_input)]
|
|
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
|
|
for agent_return in st.session_state['chatbot'].stream_chat(
|
|
st.session_state['session_history'] + user_input):
|
|
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
|
|
with st.container():
|
|
st.session_state['ui'].render_plugin_args(
|
|
agent_return.actions[-1])
|
|
st.session_state['ui'].render_action_results(
|
|
agent_return.actions[-1])
|
|
elif agent_return.state == AgentStatusCode.CODE_RETURN:
|
|
with st.container():
|
|
st.session_state['ui'].render_action_results(
|
|
agent_return.actions[-1])
|
|
elif (agent_return.state == AgentStatusCode.STREAM_ING
|
|
or agent_return.state == AgentStatusCode.CODING):
|
|
# st.markdown(agent_return.response)
|
|
# 清除占位符的当前内容,并显示新内容
|
|
with st.container():
|
|
if agent_return.state != st.session_state['last_status']:
|
|
st.session_state['temp'] = ''
|
|
placeholder = st.empty()
|
|
st.session_state['placeholder'] = placeholder
|
|
if isinstance(agent_return.response, dict):
|
|
action = f"\n\n {agent_return.response['name']}: \n\n"
|
|
action_input = agent_return.response['parameters']
|
|
if agent_return.response[
|
|
'name'] == 'IPythonInterpreter':
|
|
action_input = action_input['command']
|
|
response = action + action_input
|
|
else:
|
|
response = agent_return.response
|
|
st.session_state['temp'] = response
|
|
st.session_state['placeholder'].markdown(
|
|
st.session_state['temp'])
|
|
elif agent_return.state == AgentStatusCode.END:
|
|
st.session_state['session_history'] += (
|
|
user_input + agent_return.inner_steps)
|
|
agent_return = copy.deepcopy(agent_return)
|
|
agent_return.response = st.session_state['temp']
|
|
st.session_state['assistant'].append(
|
|
copy.deepcopy(agent_return))
|
|
st.session_state['last_status'] = agent_return.state
|
|
|
|
|
|
if __name__ == '__main__':
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
root_dir = os.path.join(root_dir, 'tmp_dir')
|
|
os.makedirs(root_dir, exist_ok=True)
|
|
main()
|