Fix: outputs with same format (#112)
model outputs with the same format
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 5.0.4
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.11.5
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
@@ -13,7 +13,7 @@ repos:
|
||||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
- id: mixed-line-ending
|
||||
args: ["--fix=lf"]
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.9
|
||||
rev: 0.7.17
|
||||
hooks:
|
||||
- id: mdformat
|
||||
args: ["--number"]
|
||||
@@ -35,16 +35,16 @@ repos:
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.1
|
||||
rev: v2.2.6
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
rev: v1.7.5
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.0.0
|
||||
rev: v3.15.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: ["--py36-plus"]
|
||||
|
||||
@@ -5,9 +5,10 @@ import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from lagent.actions import ActionExecutor, ArxivSearch, GoogleScholar, IPythonInterpreter
|
||||
from lagent.actions import (ActionExecutor, ArxivSearch, GoogleScholar,
|
||||
IPythonInterpreter)
|
||||
from lagent.agents.internlm2_agent import (INTERPRETER_CN, META_INS, PLUGIN_CN,
|
||||
Internlm2Agent, Interlm2Protocol)
|
||||
Internlm2Agent, Internlm2Protocol)
|
||||
from lagent.llms.lmdepoly_wrapper import LMDeployClient
|
||||
from lagent.llms.meta_template import INTERNLM2_META as META
|
||||
from lagent.schema import AgentStatusCode
|
||||
@@ -135,7 +136,7 @@ class StreamlitUI:
|
||||
"""Initialize the chatbot with the given model and plugin actions."""
|
||||
return Internlm2Agent(
|
||||
llm=model,
|
||||
protocol=Interlm2Protocol(
|
||||
protocol=Internlm2Protocol(
|
||||
tool=dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
@@ -309,7 +310,8 @@ def main():
|
||||
if isinstance(agent_return.response, dict):
|
||||
action = f"\n\n {agent_return.response['name']}: \n\n"
|
||||
action_input = agent_return.response['parameters']
|
||||
if agent_return.response['name'] == 'IPythonInterpreter':
|
||||
if agent_return.response[
|
||||
'name'] == 'IPythonInterpreter':
|
||||
action_input = action_input['command']
|
||||
response = action + action_input
|
||||
else:
|
||||
@@ -318,7 +320,8 @@ def main():
|
||||
st.session_state['placeholder'].markdown(
|
||||
st.session_state['temp'])
|
||||
elif agent_return.state == AgentStatusCode.END:
|
||||
st.session_state['session_history'] += (user_input + agent_return.inner_steps)
|
||||
st.session_state['session_history'] += (
|
||||
user_input + agent_return.inner_steps)
|
||||
agent_return = copy.deepcopy(agent_return)
|
||||
agent_return.response = st.session_state['temp']
|
||||
st.session_state['assistant'].append(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Union, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ilagent.schema import AgentReturn, AgentStatusCode
|
||||
|
||||
@@ -34,20 +34,20 @@ PLUGIN_CN = (
|
||||
'同时注意你可以使用的工具,不要随意捏造!')
|
||||
|
||||
|
||||
class Interlm2Protocol:
|
||||
class Internlm2Protocol:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta_prompt: str=META_INS,
|
||||
interpreter_prompt: str=INTERPRETER_CN,
|
||||
plugin_prompt: str=PLUGIN_CN,
|
||||
few_shot: Optional[List]=None,
|
||||
language: Dict=dict(
|
||||
meta_prompt: str = META_INS,
|
||||
interpreter_prompt: str = INTERPRETER_CN,
|
||||
plugin_prompt: str = PLUGIN_CN,
|
||||
few_shot: Optional[List] = None,
|
||||
language: Dict = dict(
|
||||
begin='',
|
||||
end='',
|
||||
belong='assistant',
|
||||
),
|
||||
tool: Dict=dict(
|
||||
tool: Dict = dict(
|
||||
begin='{start_token}{name}\n',
|
||||
start_token='<|action_start|>',
|
||||
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
||||
@@ -199,7 +199,7 @@ class Internlm2Agent(BaseAgent):
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
plugin_executor: ActionExecutor = None,
|
||||
interpreter_executor: ActionExecutor = None,
|
||||
protocol=Interlm2Protocol(),
|
||||
protocol=Internlm2Protocol(),
|
||||
max_turn: int = 3) -> None:
|
||||
self.max_turn = max_turn
|
||||
self._interpreter_executor = interpreter_executor
|
||||
@@ -237,8 +237,7 @@ class Internlm2Agent(BaseAgent):
|
||||
try:
|
||||
action = json.loads(action)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
msg=f'Invaild action {e}')
|
||||
logging.info(msg=f'Invaild action {e}')
|
||||
continue
|
||||
elif name == 'interpreter':
|
||||
if self._interpreter_executor:
|
||||
@@ -265,9 +264,8 @@ class Internlm2Agent(BaseAgent):
|
||||
dict(role='tool', content=action, name=name))
|
||||
inner_history.append(
|
||||
self._protocol.format_response(action_return, name=name))
|
||||
yield agent_return
|
||||
agent_return.inner_steps = inner_history[offset:]
|
||||
yield agent_return
|
||||
return agent_return
|
||||
|
||||
def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
|
||||
if isinstance(message, str):
|
||||
@@ -286,8 +284,7 @@ class Internlm2Agent(BaseAgent):
|
||||
interpreter_executor=self._interpreter_executor,
|
||||
)
|
||||
response = ''
|
||||
for model_state, res, _ in self._llm.stream_chat(
|
||||
prompt, **kwargs):
|
||||
for model_state, res, _ in self._llm.stream_chat(prompt, **kwargs):
|
||||
response = res
|
||||
if model_state.value < 0:
|
||||
agent_return.state = model_state
|
||||
@@ -312,8 +309,7 @@ class Internlm2Agent(BaseAgent):
|
||||
try:
|
||||
action = json.loads(action)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
msg=f'Invaild action {e}')
|
||||
logging.info(msg=f'Invaild action {e}')
|
||||
continue
|
||||
elif name == 'interpreter':
|
||||
if self._interpreter_executor:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from abc import abstractclassmethod
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from warnings import warn
|
||||
@@ -189,15 +188,13 @@ class BaseModel:
|
||||
inputs = self.template_parser(inputs)
|
||||
return self.generate(inputs, **gen_params)
|
||||
|
||||
def generate_from_template(
|
||||
self,
|
||||
inputs: Union[List[dict], List[List[dict]]],
|
||||
**gen_params
|
||||
):
|
||||
def generate_from_template(self, inputs: Union[List[dict],
|
||||
List[List[dict]]],
|
||||
**gen_params):
|
||||
warn(
|
||||
"This function will be deprecated after three months and will be replaced."
|
||||
"Please use `.chat()`",
|
||||
DeprecationWarning, 2)
|
||||
'This function will be deprecated after three months'
|
||||
'and will be replaced. Please use `.chat()`', DeprecationWarning,
|
||||
2)
|
||||
return self.chat(inputs, **gen_params)
|
||||
|
||||
def stream_chat(self, inputs: List[dict], **gen_params):
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import copy
|
||||
import warnings
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
from dataclasses import asdict
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from lagent.schema import AgentStatusCode
|
||||
from .base_llm import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFTransformer(BaseModel):
|
||||
"""Model wrapper around HuggingFace general models.
|
||||
@@ -103,147 +107,177 @@ class HFTransformer(BaseModel):
|
||||
do_sample=True,
|
||||
**kwargs,
|
||||
):
|
||||
import torch
|
||||
from torch import nn
|
||||
with torch.no_grad():
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
# import pdb; pdb.set_trace()
|
||||
inputs = self.tokenizer(
|
||||
inputs, padding=True, return_tensors='pt', return_length=True)
|
||||
input_length = inputs['length']
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
input_ids = inputs['input_ids']
|
||||
attention_mask = inputs['attention_mask']
|
||||
batch_size, input_ids_seq_length = input_ids.shape[
|
||||
0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
|
||||
generation_config = self.model.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
new_gen_params = self.update_gen_params(**kwargs)
|
||||
generation_config.update(**new_gen_params)
|
||||
generation_config.update(**kwargs)
|
||||
model_kwargs = generation_config.to_dict()
|
||||
model_kwargs['attention_mask'] = attention_mask
|
||||
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if self.additional_eos_token_id is not None:
|
||||
eos_token_id.extend(self.additional_eos_token_id)
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(
|
||||
input_ids.device) if eos_token_id is not None else None
|
||||
has_default_max_length = (
|
||||
kwargs.get('max_length') is None
|
||||
and generation_config.max_length is not None)
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||
'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we'
|
||||
' recommend using `max_new_tokens` to control the maximum length of the generation.',
|
||||
UserWarning,
|
||||
try:
|
||||
import torch
|
||||
from torch import nn
|
||||
with torch.no_grad():
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
# import pdb; pdb.set_trace()
|
||||
inputs = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
return_length=True)
|
||||
input_length = inputs['length']
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
input_ids = inputs['input_ids']
|
||||
attention_mask = inputs['attention_mask']
|
||||
batch_size = input_ids.shape[0]
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
generation_config = self.model.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
new_gen_params = self.update_gen_params(**kwargs)
|
||||
generation_config.update(**new_gen_params)
|
||||
generation_config.update(**kwargs)
|
||||
model_kwargs = generation_config.to_dict()
|
||||
model_kwargs['attention_mask'] = attention_mask
|
||||
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length)
|
||||
if not has_default_max_length:
|
||||
logger.warn( # pylint: disable=W4902
|
||||
f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
|
||||
f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
|
||||
'Please refer to the documentation for more information. '
|
||||
'(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if self.additional_eos_token_id is not None:
|
||||
eos_token_id.extend(self.additional_eos_token_id)
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(
|
||||
input_ids.device) if eos_token_id is not None else None
|
||||
has_default_max_length = (
|
||||
kwargs.get('max_length') is None
|
||||
and generation_config.max_length is not None)
|
||||
if (has_default_max_length
|
||||
and generation_config.max_new_tokens is None):
|
||||
warnings.warn(
|
||||
"Using `max_length`'s default"
|
||||
f'({generation_config.max_length})'
|
||||
'to control the generation length. '
|
||||
'This behaviour is deprecated and will be removed'
|
||||
' from the config in v5 of Transformers -- we'
|
||||
' recommend using `max_new_tokens` to control the'
|
||||
' maximum length of the generation.',
|
||||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens +
|
||||
input_ids_seq_length)
|
||||
if not has_default_max_length:
|
||||
logger.warn( # pylint: disable=W4902
|
||||
'Both `max_new_tokens`'
|
||||
f'(={generation_config.max_new_tokens})'
|
||||
'and `max_length`'
|
||||
f'(={generation_config.max_length})'
|
||||
' seem to have been set.`max_new_tokens`'
|
||||
' will take precedence. Please refer to'
|
||||
' the documentation for more information. '
|
||||
'(https://huggingface.co/docs/transformers/main/en'
|
||||
'/main_classes/text_generation)',
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = 'input_ids'
|
||||
logger.warning(
|
||||
f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
|
||||
f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
|
||||
' increasing `max_new_tokens`.')
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = 'input_ids'
|
||||
logger.warning(
|
||||
f'Input length of {input_ids_string}'
|
||||
f' is {input_ids_seq_length},'
|
||||
' but `max_length` is set to'
|
||||
f' {generation_config.max_length}.'
|
||||
'This can lead to unexpected behavior.'
|
||||
' You should consider increasing `max_new_tokens`.')
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = self.logits_processor
|
||||
stopping_criteria = self.stopping_criteria
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = self.logits_processor
|
||||
stopping_criteria = self.stopping_criteria
|
||||
|
||||
logits_processor = self.model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = self.model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria)
|
||||
logits_warper = self.model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(batch_size).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = self.model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
logits_processor = self.model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
stopping_criteria = self.model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria)
|
||||
logits_warper = self.model._get_logits_warper(
|
||||
generation_config)
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids,
|
||||
next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
unfinished_sequences = input_ids.new(batch_size).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = self.model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if do_sample:
|
||||
next_tokens = torch.multinomial(
|
||||
probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]],
|
||||
dim=-1)
|
||||
model_kwargs = self.model._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
|
||||
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
|
||||
# output_token_ids = input_ids.cpu()[:, input_length:].tolist()
|
||||
output_token_ids = input_ids.cpu().tolist()
|
||||
for i in range(len(output_token_ids)):
|
||||
output_token_ids[i] = output_token_ids[i][:][
|
||||
input_length[i]:]
|
||||
# Find the first occurrence of an EOS token in the sequence
|
||||
first_eos_idx = next(
|
||||
(idx
|
||||
for idx, token_id in enumerate(output_token_ids[i])
|
||||
if token_id in eos_token_id), None)
|
||||
# If an EOS token is found, only the previous part of it is retained
|
||||
if first_eos_idx is not None:
|
||||
output_token_ids[i] = output_token_ids[
|
||||
i][:first_eos_idx]
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(
|
||||
input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids,
|
||||
next_token_scores)
|
||||
|
||||
response = self.tokenizer.batch_decode(output_token_ids)
|
||||
# print(response)
|
||||
if not batched:
|
||||
yield response[0]
|
||||
else:
|
||||
yield response
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if (unfinished_sequences.max() == 0
|
||||
or stopping_criteria(input_ids, scores)):
|
||||
break
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if do_sample:
|
||||
next_tokens = torch.multinomial(
|
||||
probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
|
||||
# update generated ids, model inputs,
|
||||
# and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]],
|
||||
dim=-1)
|
||||
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
|
||||
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
|
||||
output_token_ids = input_ids.cpu().tolist()
|
||||
for i in range(len(output_token_ids)):
|
||||
output_token_ids[i] = output_token_ids[i][:][
|
||||
input_length[i]:]
|
||||
# Find the first occurrence of
|
||||
# an EOS token in the sequence
|
||||
first_eos_idx = next((
|
||||
idx
|
||||
for idx, token_id in enumerate(output_token_ids[i])
|
||||
if token_id in eos_token_id), None)
|
||||
# If an EOS token is found, only the previous
|
||||
# part of it is retained
|
||||
if first_eos_idx is not None:
|
||||
output_token_ids[i] = output_token_ids[
|
||||
i][:first_eos_idx]
|
||||
|
||||
response = self.tokenizer.batch_decode(output_token_ids)
|
||||
# print(response)
|
||||
if not batched:
|
||||
response = response[0]
|
||||
yield AgentStatusCode.STREAM_ING, response, None
|
||||
# stop when each sentence is finished,
|
||||
# or if we exceed the maximum length
|
||||
if (unfinished_sequences.max() == 0
|
||||
or stopping_criteria(input_ids, scores)):
|
||||
break
|
||||
yield AgentStatusCode.END, response, None
|
||||
except Exception:
|
||||
response = ''.join(traceback.format_exception(*sys.exc_info()))
|
||||
if batched:
|
||||
response = [response]
|
||||
yield AgentStatusCode.SERVER_ERR, response, None
|
||||
|
||||
|
||||
class HFTransformerCasualLM(HFTransformer):
|
||||
|
||||
@@ -145,7 +145,7 @@ class TritonClient(BaseModel):
|
||||
elif self.chatbot._session.status == 0:
|
||||
logger.error(f'session {session_id} has been ended. Please set '
|
||||
f'`sequence_start` be True if you want to restart it')
|
||||
return ''
|
||||
return AgentStatusCode.SESSION_CLOSED, '', 0
|
||||
|
||||
self.chatbot._session.status = 1
|
||||
self.chatbot._session.request_id = request_id
|
||||
@@ -171,7 +171,7 @@ class TritonClient(BaseModel):
|
||||
self.chatbot._session.prompt + self.chatbot._session.response)
|
||||
yield self.state_map.get(status), res, _
|
||||
else:
|
||||
return ''
|
||||
return self.state_map.get(status), res, _
|
||||
|
||||
def _update_gen_params(self, **kwargs):
|
||||
import mmengine
|
||||
@@ -195,7 +195,7 @@ class LMDeployPipeline(BaseModel):
|
||||
path (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
|
||||
Reference in New Issue
Block a user