Feat: no_skip_speicial_token (#148)
* Feat: no_skip_speicial_token * fix: get_logger of lmdeploy * update lmdeploy requirement
This commit is contained in:
@@ -48,6 +48,7 @@ class TritonClient(BaseModel):
|
||||
request_id: str = '',
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in non-stream mode.
|
||||
@@ -58,7 +59,8 @@ class TritonClient(BaseModel):
|
||||
request_id (str): the identical id of this round conversation
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
@@ -73,7 +75,7 @@ class TritonClient(BaseModel):
|
||||
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
||||
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
||||
|
||||
logger = get_logger(log_level=self.chatbot.log_level)
|
||||
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_new_tokens}')
|
||||
|
||||
@@ -91,8 +93,12 @@ class TritonClient(BaseModel):
|
||||
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session, prompt, max_new_tokens, sequence_start,
|
||||
sequence_end):
|
||||
self.chatbot._session,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
sequence_start,
|
||||
sequence_end,
|
||||
skip_special_tokens=skip_special_tokens):
|
||||
status = self.state_map.get(status)
|
||||
if status < ModelStatusCode.END:
|
||||
return ''
|
||||
@@ -111,6 +117,7 @@ class TritonClient(BaseModel):
|
||||
request_id: str = '',
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
completions in stream mode.
|
||||
@@ -121,7 +128,8 @@ class TritonClient(BaseModel):
|
||||
request_id (str): the identical id of this round conversation
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
generated token number
|
||||
@@ -133,7 +141,7 @@ class TritonClient(BaseModel):
|
||||
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
||||
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
||||
|
||||
logger = get_logger(log_level=self.chatbot.log_level)
|
||||
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
||||
logger.info(f'session {session_id}, request_id {request_id}, '
|
||||
f'max_out_len {max_new_tokens}')
|
||||
|
||||
@@ -152,8 +160,12 @@ class TritonClient(BaseModel):
|
||||
prompt = self.template_parser(inputs)
|
||||
status, res, _ = None, '', 0
|
||||
for status, res, _ in self.chatbot._stream_infer(
|
||||
self.chatbot._session, prompt, max_new_tokens, sequence_start,
|
||||
sequence_end):
|
||||
self.chatbot._session,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
sequence_start,
|
||||
sequence_end,
|
||||
skip_special_tokens=skip_special_tokens):
|
||||
status = self.state_map.get(status)
|
||||
# The stop symbol also appears in the output of the last STREAM_ING state.
|
||||
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
||||
@@ -223,6 +235,7 @@ class LMDeployPipeline(BaseModel):
|
||||
def generate(self,
|
||||
inputs: Union[str, List[str]],
|
||||
do_preprocess: bool = None,
|
||||
skip_special_tokens: bool = False,
|
||||
**kwargs):
|
||||
"""Return the chat completions in non-stream mode.
|
||||
|
||||
@@ -230,7 +243,8 @@ class LMDeployPipeline(BaseModel):
|
||||
inputs (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
@@ -242,7 +256,8 @@ class LMDeployPipeline(BaseModel):
|
||||
batched = False
|
||||
prompt = inputs
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
gen_config = GenerationConfig(**gen_params)
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens, **gen_params)
|
||||
response = self.model.batch_infer(
|
||||
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
|
||||
response = [resp.text for resp in response]
|
||||
@@ -308,6 +323,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_start: bool = True,
|
||||
sequence_end: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: Optional[bool] = False,
|
||||
timeout: int = 30,
|
||||
**kwargs) -> List[str]:
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
@@ -319,6 +335,8 @@ class LMDeployServer(BaseModel):
|
||||
sequence_start (bool): start flag of a session
|
||||
sequence_end (bool): end flag of a session
|
||||
ignore_eos (bool): indicator for ignoring eos
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
timeout (int): max time to wait for response
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
@@ -342,6 +360,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end=sequence_end,
|
||||
stream=False,
|
||||
ignore_eos=ignore_eos,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp = [
|
||||
@@ -361,6 +380,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end: bool = True,
|
||||
stream: bool = True,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: Optional[bool] = False,
|
||||
timeout: int = 30,
|
||||
**kwargs):
|
||||
"""Start a new round conversation of a session. Return the chat
|
||||
@@ -373,6 +393,8 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end (bool): end flag of a session
|
||||
stream (bool): return in a streaming format if enabled
|
||||
ignore_eos (bool): indicator for ignoring eos
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be False.
|
||||
timeout (int): max time to wait for response
|
||||
Returns:
|
||||
tuple(Status, str, int): status, text/chat completion,
|
||||
@@ -394,6 +416,7 @@ class LMDeployServer(BaseModel):
|
||||
sequence_end=sequence_end,
|
||||
stream=stream,
|
||||
ignore_eos=ignore_eos,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
timeout=timeout,
|
||||
**gen_params):
|
||||
resp += text['choices'][0]['text']
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
google-search-results
|
||||
lmdeploy>=0.2.2
|
||||
lmdeploy>=0.2.3
|
||||
pillow
|
||||
python-pptx
|
||||
timeout_decorator
|
||||
|
||||
Reference in New Issue
Block a user