Feat: no_skip_speicial_token (#148)

* Feat: no_skip_speicial_token

* fix: get_logger of lmdeploy

* update lmdeploy requirement
This commit is contained in:
liujiangning30
2024-02-19 16:33:13 +08:00
committed by GitHub
parent 3cf20f5011
commit 662011783e
2 changed files with 34 additions and 11 deletions

View File

@@ -48,6 +48,7 @@ class TritonClient(BaseModel):
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
@@ -58,7 +59,8 @@ class TritonClient(BaseModel):
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
@@ -73,7 +75,7 @@ class TritonClient(BaseModel):
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger(log_level=self.chatbot.log_level)
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')
@@ -91,8 +93,12 @@ class TritonClient(BaseModel):
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
if status < ModelStatusCode.END:
return ''
@@ -111,6 +117,7 @@ class TritonClient(BaseModel):
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
@@ -121,7 +128,8 @@ class TritonClient(BaseModel):
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
@@ -133,7 +141,7 @@ class TritonClient(BaseModel):
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger(log_level=self.chatbot.log_level)
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')
@@ -152,8 +160,12 @@ class TritonClient(BaseModel):
prompt = self.template_parser(inputs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
# The stop symbol also appears in the output of the last STREAM_ING state.
res = filter_suffix(res, self.gen_params.get('stop_words'))
@@ -223,6 +235,7 @@ class LMDeployPipeline(BaseModel):
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
@@ -230,7 +243,8 @@ class LMDeployPipeline(BaseModel):
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
@@ -242,7 +256,8 @@ class LMDeployPipeline(BaseModel):
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
gen_config = GenerationConfig(**gen_params)
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
response = self.model.batch_infer(
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
@@ -308,6 +323,7 @@ class LMDeployServer(BaseModel):
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
@@ -319,6 +335,8 @@ class LMDeployServer(BaseModel):
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
@@ -342,6 +360,7 @@ class LMDeployServer(BaseModel):
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp = [
@@ -361,6 +380,7 @@ class LMDeployServer(BaseModel):
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
@@ -373,6 +393,8 @@ class LMDeployServer(BaseModel):
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
@@ -394,6 +416,7 @@ class LMDeployServer(BaseModel):
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']

View File

@@ -1,5 +1,5 @@
google-search-results
lmdeploy>=0.2.2
lmdeploy>=0.2.3
pillow
python-pptx
timeout_decorator