11 Commits

Author SHA1 Message Date
liujiangning30
581d9fb898 support demo with hf (#179) 2024-03-26 17:59:01 +08:00
RangiLyu
26672731e9 feat: support vllm (#177)
* feat: support vllm

* update requirements
2024-03-21 17:20:38 +08:00
liujiangning30
466d7aa24d Fix errmsg: cast dict to str (#172) 2024-03-20 14:27:04 +08:00
tackhwa
5f55e12736 fix typo "ablility " in overview.md (#175)
fix type "ablility " in overview.md
2024-03-20 14:26:33 +08:00
liujiangning30
e16a6bfc3a Fix bug of ppt and googlescholar (#167)
* fix bug of ppt and googlescholar

* Format required parameters
2024-03-04 13:52:06 +08:00
BraisedPork
605a921878 Fix chat return of GPTAPI (#166)
fix chat return data

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
2024-02-29 13:56:10 +08:00
liujiangning30
4fd014bebf update version (#161) 2024-02-23 19:50:49 +08:00
liujiangning30
432ffaae8a fix bug caused by static model_name (#156) 2024-02-21 14:47:40 +08:00
liukuikun
a64aa599ce fix batch generate (#158) 2024-02-20 15:11:42 +08:00
liujiangning30
662011783e Feat: no_skip_speicial_token (#148)
* Feat: no_skip_speicial_token

* fix: get_logger of lmdeploy

* update lmdeploy requirement
2024-02-19 16:33:13 +08:00
loveSnowBest
3cf20f5011 support inference for pad_token & chatglm chat (#157)
* update code for chatglm

* update code

* handle batch infer for chat

* update warning for cases
2024-02-19 15:54:48 +08:00
15 changed files with 544 additions and 38 deletions

1
.gitignore vendored
View File

@@ -160,3 +160,4 @@ cython_debug/
#.idea/
.vscode/
docs/*/_build/
tmp_dir/

View File

@@ -4,7 +4,7 @@ This chapter introduces you to the framework of Lagent, and provides links to de
## What is Lagent
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ablility of LLM, and the whole framework is shown below:
Lagent is an open source LLM agent framework, which enables people to efficiently turn a large language model to agent. It also provides some typical tools to enlighten the ability of LLM, and the whole framework is shown below:
![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6)

View File

@@ -62,7 +62,8 @@ class StreamlitUI:
def setup_sidebar(self):
"""Setup the sidebar for model and plugin selection."""
model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
@@ -113,11 +114,11 @@ class StreamlitUI:
return model_name, model, plugin_action, uploaded_file, model_ip
def init_model(self, option, ip=None):
"""Initialize the model based on the selected option."""
def init_model(self, model_name, ip=None):
"""Initialize the model based on the input model name."""
model_url = f'http://{ip}'
st.session_state['model_map'][option] = LMDeployClient(
model_name='internlm2-chat-20b',
st.session_state['model_map'][model_name] = LMDeployClient(
model_name=model_name,
url=model_url,
meta_template=META,
max_new_tokens=1024,
@@ -126,7 +127,7 @@ class StreamlitUI:
temperature=0,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
return st.session_state['model_map'][option]
return st.session_state['model_map'][model_name]
def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""
@@ -141,7 +142,7 @@ class StreamlitUI:
belong='assistant',
end='<|action_end|>\n',
), ),
)
max_turn=7)
def render_user(self, prompt: str):
with st.chat_message('user'):

View File

@@ -0,0 +1,332 @@
import copy
import hashlib
import json
import os
import streamlit as st
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms import HFTransformer
from lagent.llms.meta_template import INTERNLM2_META as META
from lagent.schema import AgentStatusCode
# from streamlit.logger import get_logger
class SessionState:
def init_state(self):
"""Initialize session state variables."""
st.session_state['assistant'] = []
st.session_state['user'] = []
action_list = [
ArxivSearch(),
]
st.session_state['plugin_map'] = {
action.name: action
for action in action_list
}
st.session_state['model_map'] = {}
st.session_state['model_selected'] = None
st.session_state['plugin_actions'] = set()
st.session_state['history'] = []
def clear_state(self):
"""Clear the existing session state."""
st.session_state['assistant'] = []
st.session_state['user'] = []
st.session_state['model_selected'] = None
st.session_state['file'] = set()
if 'chatbot' in st.session_state:
st.session_state['chatbot']._session_history = []
class StreamlitUI:
def __init__(self, session_state: SessionState):
self.init_streamlit()
self.session_state = session_state
def init_streamlit(self):
"""Initialize Streamlit's UI settings."""
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
st.sidebar.title('模型控制')
st.session_state['file'] = set()
st.session_state['model_path'] = None
def setup_sidebar(self):
"""Setup the sidebar for model and plugin selection."""
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
model_path = st.sidebar.text_input(
'模型路径:', value='internlm/internlm2-chat-20b')
if model_name != st.session_state['model_selected'] or st.session_state[
'model_path'] != model_path:
st.session_state['model_path'] = model_path
model = self.init_model(model_name, model_path)
self.session_state.clear_state()
st.session_state['model_selected'] = model_name
if 'chatbot' in st.session_state:
del st.session_state['chatbot']
else:
model = st.session_state['model_map'][model_name]
plugin_name = st.sidebar.multiselect(
'插件选择',
options=list(st.session_state['plugin_map'].keys()),
default=[],
)
da_flag = st.sidebar.checkbox(
'数据分析',
value=False,
)
plugin_action = [
st.session_state['plugin_map'][name] for name in plugin_name
]
if 'chatbot' in st.session_state:
if len(plugin_action) > 0:
st.session_state['chatbot']._action_executor = ActionExecutor(
actions=plugin_action)
else:
st.session_state['chatbot']._action_executor = None
if da_flag:
st.session_state[
'chatbot']._interpreter_executor = ActionExecutor(
actions=[IPythonInterpreter()])
else:
st.session_state['chatbot']._interpreter_executor = None
st.session_state['chatbot']._protocol._meta_template = meta_prompt
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
st.session_state[
'chatbot']._protocol.interpreter_prompt = da_prompt
if st.sidebar.button('清空对话', key='clear'):
self.session_state.clear_state()
uploaded_file = st.sidebar.file_uploader('上传文件')
return model_name, model, plugin_action, uploaded_file, model_path
def init_model(self, model_name, path):
"""Initialize the model based on the input model name."""
st.session_state['model_map'][model_name] = HFTransformer(
path=path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
return st.session_state['model_map'][model_name]
def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""
return Internlm2Agent(
llm=model,
protocol=Internlm2Protocol(
tool=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(
plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
), ),
max_turn=7)
def render_user(self, prompt: str):
with st.chat_message('user'):
st.markdown(prompt)
def render_assistant(self, agent_return):
with st.chat_message('assistant'):
for action in agent_return.actions:
if (action) and (action.type != 'FinishAction'):
self.render_action(action)
st.markdown(agent_return.response)
def render_plugin_args(self, action):
action_name = action.type
args = action.args
import json
parameter_dict = dict(name=action_name, parameters=args)
parameter_str = '```json\n' + json.dumps(
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
st.markdown(parameter_str)
def render_interpreter_args(self, action):
st.info(action.type)
st.markdown(action.args['text'])
def render_action(self, action):
st.markdown(action.thought)
if action.type == 'IPythonInterpreter':
self.render_interpreter_args(action)
elif action.type == 'FinishAction':
pass
else:
self.render_plugin_args(action)
self.render_action_results(action)
def render_action_results(self, action):
"""Render the results of action, including text, images, videos, and
audios."""
if (isinstance(action.result, dict)):
if 'text' in action.result:
st.markdown('```\n' + action.result['text'] + '\n```')
if 'image' in action.result:
# image_path = action.result['image']
for image_path in action.result['image']:
image_data = open(image_path, 'rb').read()
st.image(image_data, caption='Generated Image')
if 'video' in action.result:
video_data = action.result['video']
video_data = open(video_data, 'rb').read()
st.video(video_data)
if 'audio' in action.result:
audio_data = action.result['audio']
audio_data = open(audio_data, 'rb').read()
st.audio(audio_data)
elif isinstance(action.result, list):
for item in action.result:
if item['type'] == 'text':
st.markdown('```\n' + item['content'] + '\n```')
elif item['type'] == 'image':
image_data = open(item['content'], 'rb').read()
st.image(image_data, caption='Generated Image')
elif item['type'] == 'video':
video_data = open(item['content'], 'rb').read()
st.video(video_data)
elif item['type'] == 'audio':
audio_data = open(item['content'], 'rb').read()
st.audio(audio_data)
if action.errmsg:
st.error(action.errmsg)
def main():
# logger = get_logger(__name__)
# Initialize Streamlit UI and setup sidebar
if 'ui' not in st.session_state:
session_state = SessionState()
session_state.init_state()
st.session_state['ui'] = StreamlitUI(session_state)
else:
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
_, model, plugin_action, uploaded_file, _ = st.session_state[
'ui'].setup_sidebar()
# Initialize chatbot if it is not already initialized
# or if the model has changed
if 'chatbot' not in st.session_state or model != st.session_state[
'chatbot']._llm:
st.session_state['chatbot'] = st.session_state[
'ui'].initialize_chatbot(model, plugin_action)
st.session_state['session_history'] = []
for prompt, agent_return in zip(st.session_state['user'],
st.session_state['assistant']):
st.session_state['ui'].render_user(prompt)
st.session_state['ui'].render_assistant(agent_return)
if user_input := st.chat_input(''):
with st.container():
st.session_state['ui'].render_user(user_input)
st.session_state['user'].append(user_input)
# Add file uploader to sidebar
if (uploaded_file
and uploaded_file.name not in st.session_state['file']):
st.session_state['file'].add(uploaded_file.name)
file_bytes = uploaded_file.read()
file_type = uploaded_file.type
if 'image' in file_type:
st.image(file_bytes, caption='Uploaded Image')
elif 'video' in file_type:
st.video(file_bytes, caption='Uploaded Video')
elif 'audio' in file_type:
st.audio(file_bytes, caption='Uploaded Audio')
# Save the file to a temporary location and get the path
postfix = uploaded_file.name.split('.')[-1]
# prefix = str(uuid.uuid4())
prefix = hashlib.md5(file_bytes).hexdigest()
filename = f'{prefix}.{postfix}'
file_path = os.path.join(root_dir, filename)
with open(file_path, 'wb') as tmpfile:
tmpfile.write(file_bytes)
file_size = os.stat(file_path).st_size / 1024 / 1024
file_size = f'{round(file_size, 2)} MB'
# st.write(f'File saved at: {file_path}')
user_input = [
dict(role='user', content=user_input),
dict(
role='user',
content=json.dumps(dict(path=file_path, size=file_size)),
name='file')
]
if isinstance(user_input, str):
user_input = [dict(role='user', content=user_input)]
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
for agent_return in st.session_state['chatbot'].stream_chat(
st.session_state['session_history'] + user_input):
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
with st.container():
st.session_state['ui'].render_plugin_args(
agent_return.actions[-1])
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif agent_return.state == AgentStatusCode.CODE_RETURN:
with st.container():
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif (agent_return.state == AgentStatusCode.STREAM_ING
or agent_return.state == AgentStatusCode.CODING):
# st.markdown(agent_return.response)
# 清除占位符的当前内容,并显示新内容
with st.container():
if agent_return.state != st.session_state['last_status']:
st.session_state['temp'] = ''
placeholder = st.empty()
st.session_state['placeholder'] = placeholder
if isinstance(agent_return.response, dict):
action = f"\n\n {agent_return.response['name']}: \n\n"
action_input = agent_return.response['parameters']
if agent_return.response[
'name'] == 'IPythonInterpreter':
action_input = action_input['command']
response = action + action_input
else:
response = agent_return.response
st.session_state['temp'] = response
st.session_state['placeholder'].markdown(
st.session_state['temp'])
elif agent_return.state == AgentStatusCode.END:
st.session_state['session_history'] += (
user_input + agent_return.inner_steps)
agent_return = copy.deepcopy(agent_return)
agent_return.response = st.session_state['temp']
st.session_state['assistant'].append(
copy.deepcopy(agent_return))
st.session_state['last_status'] = agent_return.state
if __name__ == '__main__':
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.path.join(root_dir, 'tmp_dir')
os.makedirs(root_dir, exist_ok=True)
main()

View File

@@ -8,7 +8,7 @@ THEME_MAPPING = {
'template': None,
'title': 'Title Slide',
'single': 'Title and Content',
'two': 'Tow content',
'two': 'Two Content',
}
}
@@ -31,7 +31,7 @@ class PPT(BaseAction):
"""Create a pptx file with specific themes.
Args:
theme (:class:`str`): the theme used
theme (:class:`str`): the theme used. The value should be one of ['Default'].
abs_location (:class:`str`): the ppt file's absolute location
Returns:
@@ -115,6 +115,7 @@ class PPT(BaseAction):
:class:`dict`: operation status
* status: the result of the execution
"""
from PIL import Image
layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
layout = next(i for i in self.pointer.slide_master.slide_layouts
if i.name == layout_name)
@@ -122,6 +123,7 @@ class PPT(BaseAction):
ph_title, ph_body1, ph_body2 = slide.placeholders
ph_title.text = title
ph = ph_body2
image = Image.open(image)
image_pil = image.to_pil()
left = ph.left
width = ph.width

View File

@@ -141,6 +141,12 @@ class Internlm2Protocol:
tool_name = api_info['name'].split('.')[0]
plugin['description'] = API_PREFIX.format(
tool_name=tool_name, description=plugin['description'])
# only keep required parameters
required_parameters = [
param for param in plugin['parameters']
if param['name'] in plugin['required']
]
plugin['parameters'] = required_parameters
plugin_descriptions.append(plugin)
plugin_prompt = self.plugin_prompt.format(
prompt=json.dumps(
@@ -176,7 +182,7 @@ class Internlm2Protocol:
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.format_result()
else:
response = action_return.errmsg
response = str(action_return.errmsg)
content = self.execute['begin'] + response + self.execute['end']
if self.execute.get('fallback_role'):
return dict(

View File

@@ -1,12 +1,13 @@
from .base_api import BaseAPIModel
from .base_llm import BaseModel
from .huggingface import HFTransformer, HFTransformerCasualLM
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
from .lmdepoly_wrapper import LMDeployClient, LMDeployPipeline, LMDeployServer
from .meta_template import INTERNLM2_META
from .openai import GPTAPI
from .vllm_wrapper import VllmModel
__all__ = [
'BaseModel', 'BaseAPIModel', 'GPTAPI', 'LMDeployClient',
'LMDeployPipeline', 'LMDeployServer', 'HFTransformer',
'HFTransformerCasualLM', 'INTERNLM2_META'
'HFTransformerCasualLM', 'INTERNLM2_META', 'HFTransformerChat', 'VllmModel'
]

View File

@@ -118,7 +118,7 @@ class APITemplateParser:
return res
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
merged_prompt = self.roles[self.roles[role_prompt['role']]]
merged_prompt = self.roles[role_prompt['role']]
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles[self.roles[
merged_prompt['fallback_role']]]

View File

@@ -183,12 +183,12 @@ class BaseModel:
Returns:
"""
if isinstance(inputs[0], list):
inputs = list()
_inputs = list()
for msg in inputs:
inputs.append(self.template_parser(msg))
_inputs.append(self.template_parser(msg))
else:
inputs = self.template_parser(inputs)
return self.generate(inputs, **gen_params)
_inputs = self.template_parser(inputs)
return self.generate(_inputs, **gen_params)
def generate_from_template(self, inputs: Union[List[dict],
List[List[dict]]],

View File

@@ -3,6 +3,7 @@ import logging
from typing import Dict, List, Optional, Union
from lagent.schema import ModelStatusCode
from .base_api import APITemplateParser
from .base_llm import BaseModel
logger = logging.getLogger(__name__)
@@ -37,12 +38,20 @@ class HFTransformer(BaseModel):
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = None,
stop_words_id: Union[List[int], int] = None,
**kwargs):
super().__init__(
path=path,
tokenizer_only=tokenizer_only,
meta_template=meta_template,
**kwargs)
if isinstance(stop_words_id, int):
stop_words_id = [stop_words_id]
self.gen_params.update(stop_words_id=stop_words_id)
if self.gen_params['stop_words'] is not None and \
self.gen_params['stop_words_id'] is not None:
logger.warning('Both stop_words and stop_words_id are specified,'
'only stop_words_id will be used.')
self._load_tokenizer(
path=path,
@@ -57,7 +66,9 @@ class HFTransformer(BaseModel):
self.prefix_allowed_tokens_fn = None
stop_words_id = []
if self.gen_params.get('stop_words'):
if self.gen_params.get('stop_words_id'):
stop_words_id = self.gen_params.get('stop_words_id')
elif self.gen_params.get('stop_words'):
for sw in self.gen_params.get('stop_words'):
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
self.additional_eos_token_id = stop_words_id
@@ -69,8 +80,27 @@ class HFTransformer(BaseModel):
tokenizer_path if tokenizer_path else path,
trust_remote_code=True,
**tokenizer_kwargs)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.eos_token is not None:
logger.warning(
f'Using eos_token_id {self.tokenizer.eos_token} '
'as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
from transformers.generation import GenerationConfig
self.gcfg = GenerationConfig.from_pretrained(path)
if self.gcfg.pad_token_id is not None:
logger.warning(
f'Using pad_token_id {self.gcfg.pad_token_id} '
'as pad_token_id.')
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to '
'set pad_token_id via passing '
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
def _load_model(self, path: str, model_kwargs: dict):
import torch
@@ -127,7 +157,6 @@ class HFTransformer(BaseModel):
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs, padding=True, return_tensors='pt', return_length=True)
input_length = inputs['length']
@@ -148,6 +177,11 @@ class HFTransformer(BaseModel):
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if eos_token_id is None:
if self.gcfg.eos_token_id is not None:
eos_token_id = self.gcfg.eos_token_id
else:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
@@ -267,3 +301,39 @@ class HFTransformerCasualLM(HFTransformer):
self.model = AutoModelForCausalLM.from_pretrained(
path, trust_remote_code=True, **model_kwargs)
self.model.eval()
class HFTransformerChat(HFTransformerCasualLM):
def __init__(self, template_parser=APITemplateParser, **kwargs):
super().__init__(template_parser=template_parser, **kwargs)
def chat(self,
inputs: Union[List[dict], List[List[dict]]],
do_sample: bool = True,
**kwargs):
"""Return the chat completions in stream mode.
Args:
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
do_sample (bool): do sampling if enabled
Returns:
the text/chat completion
"""
# handle batch inference with vanilla for loop
if isinstance(inputs[0], list):
resps = []
for input in inputs:
resps.append(self.chat(input, do_sample, **kwargs))
return resps
prompt = self.template_parser(inputs)
query = prompt[-1]['content']
history = prompt[:-1]
try:
response, history = self.model.chat(
self.tokenizer, query, history=history)
except Exception as e:
# handle over-length input error
logger.warning(str(e))
response = ''
return response

View File

@@ -48,6 +48,7 @@ class TritonClient(BaseModel):
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
@@ -58,7 +59,8 @@ class TritonClient(BaseModel):
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
@@ -73,7 +75,7 @@ class TritonClient(BaseModel):
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger(log_level=self.chatbot.log_level)
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')
@@ -91,8 +93,12 @@ class TritonClient(BaseModel):
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
if status < ModelStatusCode.END:
return ''
@@ -111,6 +117,7 @@ class TritonClient(BaseModel):
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
@@ -121,7 +128,8 @@ class TritonClient(BaseModel):
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
@@ -133,7 +141,7 @@ class TritonClient(BaseModel):
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger(log_level=self.chatbot.log_level)
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')
@@ -152,8 +160,12 @@ class TritonClient(BaseModel):
prompt = self.template_parser(inputs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
# The stop symbol also appears in the output of the last STREAM_ING state.
res = filter_suffix(res, self.gen_params.get('stop_words'))
@@ -223,6 +235,7 @@ class LMDeployPipeline(BaseModel):
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
@@ -230,7 +243,8 @@ class LMDeployPipeline(BaseModel):
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
@@ -242,7 +256,8 @@ class LMDeployPipeline(BaseModel):
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
gen_config = GenerationConfig(**gen_params)
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
response = self.model.batch_infer(
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
@@ -308,6 +323,7 @@ class LMDeployServer(BaseModel):
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
@@ -319,6 +335,8 @@ class LMDeployServer(BaseModel):
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
@@ -342,6 +360,7 @@ class LMDeployServer(BaseModel):
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp = [
@@ -361,6 +380,7 @@ class LMDeployServer(BaseModel):
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
@@ -373,6 +393,8 @@ class LMDeployServer(BaseModel):
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
@@ -394,6 +416,7 @@ class LMDeployServer(BaseModel):
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']

View File

@@ -1,7 +1,7 @@
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, wait
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
@@ -106,17 +106,15 @@ class GPTAPI(BaseAPIModel):
Union[str, List[str]]: generated string(s)
"""
assert isinstance(inputs, list)
if isinstance(inputs[0], dict):
inputs = [inputs]
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
gen_params = {**self.gen_params, **gen_params}
with ThreadPoolExecutor(max_workers=20) as executor:
tasks = [
executor.submit(self._chat, messages, **gen_params)
for messages in inputs
for messages in (
[inputs] if isinstance(inputs[0], dict) else inputs)
]
wait(tasks)
ret = [task.result() for task in tasks]
return ret[0] if isinstance(inputs[0], dict) else ret

View File

@@ -0,0 +1,71 @@
from typing import List, Union
from lagent.llms.base_llm import BaseModel
from lagent.utils.util import filter_suffix
class VllmModel(BaseModel):
"""
A wrapper of vLLM model.
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a huggingface model.
- ii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
tp (int): tensor parallel
vllm_cfg (dict): Other kwargs for vllm model initialization.
"""
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
super().__init__(path=path, **kwargs)
from vllm import LLM
self.model = LLM(
model=self.path,
trust_remote_code=True,
tensor_parallel_size=tp,
**vllm_cfg)
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from vllm import SamplingParams
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
stop_words = gen_params.pop('stop_words')
sampling_config = SamplingParams(
skip_special_tokens=skip_special_tokens,
max_tokens=max_new_tokens,
stop=stop_words,
**gen_params)
response = self.model.generate(prompt, sampling_params=sampling_config)
response = [resp.outputs[0].text for resp in response]
# remove stop_words
response = filter_suffix(response, self.gen_params.get('stop_words'))
if batched:
return response
return response[0]

View File

@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
__version__ = '0.2.1'
__version__ = '0.2.2'
def parse_version_info(version_str: str, length: int = 4) -> tuple:

View File

@@ -1,7 +1,8 @@
google-search-results
lmdeploy>=0.2.2
lmdeploy>=0.2.3
pillow
python-pptx
timeout_decorator
torch
transformers>=4.34
vllm>=0.3.3