fix bug caused by static model_name (#156)

This commit is contained in:
liujiangning30
2024-02-21 14:47:40 +08:00
committed by GitHub
parent a64aa599ce
commit 432ffaae8a

View File

@@ -62,7 +62,8 @@ class StreamlitUI:
def setup_sidebar(self):
"""Setup the sidebar for model and plugin selection."""
model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
@@ -113,11 +114,11 @@ class StreamlitUI:
return model_name, model, plugin_action, uploaded_file, model_ip
def init_model(self, option, ip=None):
"""Initialize the model based on the selected option."""
def init_model(self, model_name, ip=None):
"""Initialize the model based on the input model name."""
model_url = f'http://{ip}'
st.session_state['model_map'][option] = LMDeployClient(
model_name='internlm2-chat-20b',
st.session_state['model_map'][model_name] = LMDeployClient(
model_name=model_name,
url=model_url,
meta_template=META,
max_new_tokens=1024,
@@ -126,7 +127,7 @@ class StreamlitUI:
temperature=0,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
return st.session_state['model_map'][option]
return st.session_state['model_map'][model_name]
def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""