Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c04da092ef | ||
|
|
352fbe8dc8 | ||
|
|
5087e59e81 | ||
|
|
4713c9938d |
@@ -1,6 +1,6 @@
|
||||
from .base_api import BaseAPIModel
|
||||
from .base_llm import BaseModel
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM
|
||||
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
|
||||
from .lmdepoly_wrapper import LMDeployClient, LMDeployPipeline, LMDeployServer
|
||||
from .meta_template import INTERNLM2_META
|
||||
from .openai import GPTAPI
|
||||
@@ -8,5 +8,5 @@ from .openai import GPTAPI
|
||||
__all__ = [
|
||||
'BaseModel', 'BaseAPIModel', 'GPTAPI', 'LMDeployClient',
|
||||
'LMDeployPipeline', 'LMDeployServer', 'HFTransformer',
|
||||
'HFTransformerCasualLM', 'INTERNLM2_META'
|
||||
'HFTransformerCasualLM', 'INTERNLM2_META', 'HFTransformerChat'
|
||||
]
|
||||
|
||||
@@ -118,7 +118,7 @@ class APITemplateParser:
|
||||
return res
|
||||
|
||||
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
|
||||
merged_prompt = self.roles[self.roles[role_prompt['role']]]
|
||||
merged_prompt = self.roles[role_prompt['role']]
|
||||
if merged_prompt.get('fallback_role'):
|
||||
merged_prompt = self.roles[self.roles[
|
||||
merged_prompt['fallback_role']]]
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from lagent.schema import ModelStatusCode
|
||||
from .base_api import APITemplateParser
|
||||
from .base_llm import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,12 +38,20 @@ class HFTransformer(BaseModel):
|
||||
tokenizer_only: bool = False,
|
||||
model_kwargs: dict = dict(device_map='auto'),
|
||||
meta_template: Optional[Dict] = None,
|
||||
stop_words_id: Union[List[int], int] = None,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
path=path,
|
||||
tokenizer_only=tokenizer_only,
|
||||
meta_template=meta_template,
|
||||
**kwargs)
|
||||
if isinstance(stop_words_id, int):
|
||||
stop_words_id = [stop_words_id]
|
||||
self.gen_params.update(stop_words_id=stop_words_id)
|
||||
if self.gen_params['stop_words'] is not None and \
|
||||
self.gen_params['stop_words_id'] is not None:
|
||||
logger.warning("Both stop_words and stop_words_id are specified,"
|
||||
"only stop_words_id will be used.")
|
||||
|
||||
self._load_tokenizer(
|
||||
path=path,
|
||||
@@ -57,7 +66,9 @@ class HFTransformer(BaseModel):
|
||||
self.prefix_allowed_tokens_fn = None
|
||||
|
||||
stop_words_id = []
|
||||
if self.gen_params.get('stop_words'):
|
||||
if self.gen_params.get('stop_words_id'):
|
||||
stop_words_id = self.gen_params.get('stop_words_id')
|
||||
elif self.gen_params.get('stop_words'):
|
||||
for sw in self.gen_params.get('stop_words'):
|
||||
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
|
||||
self.additional_eos_token_id = stop_words_id
|
||||
@@ -69,9 +80,28 @@ class HFTransformer(BaseModel):
|
||||
tokenizer_path if tokenizer_path else path,
|
||||
trust_remote_code=True,
|
||||
**tokenizer_kwargs)
|
||||
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
if self.tokenizer.eos_token is not None:
|
||||
logger.warning(
|
||||
f'Using eos_token_id {self.tokenizer.eos_token} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
else:
|
||||
from transformers.generation import GenerationConfig
|
||||
self.gcfg = GenerationConfig.from_pretrained(path)
|
||||
|
||||
if self.gcfg.pad_token_id is not None:
|
||||
logger.warning(
|
||||
f'Using pad_token_id {self.gcfg.pad_token_id} '
|
||||
'as pad_token_id.')
|
||||
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
'pad_token_id is not set for this tokenizer. Try to '
|
||||
'set pad_token_id via passing '
|
||||
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict):
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
@@ -127,7 +157,6 @@ class HFTransformer(BaseModel):
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
# import pdb; pdb.set_trace()
|
||||
inputs = self.tokenizer(
|
||||
inputs, padding=True, return_tensors='pt', return_length=True)
|
||||
input_length = inputs['length']
|
||||
@@ -148,6 +177,11 @@ class HFTransformer(BaseModel):
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
if eos_token_id is None:
|
||||
if self.gcfg.eos_token_id is not None:
|
||||
eos_token_id = self.gcfg.eos_token_id
|
||||
else:
|
||||
eos_token_id = []
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if self.additional_eos_token_id is not None:
|
||||
@@ -267,3 +301,38 @@ class HFTransformerCasualLM(HFTransformer):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path, trust_remote_code=True, **model_kwargs)
|
||||
self.model.eval()
|
||||
|
||||
class HFTransformerChat(HFTransformerCasualLM):
|
||||
def __init__(self,
|
||||
template_parser=APITemplateParser,
|
||||
**kwargs):
|
||||
super().__init__(template_parser=template_parser, **kwargs)
|
||||
|
||||
def chat(self, inputs: Union[List[dict], List[List[dict]]], do_sample: bool = True, **kwargs):
|
||||
"""Return the chat completions in stream mode.
|
||||
|
||||
Args:
|
||||
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
|
||||
do_sample (bool): do sampling if enabled
|
||||
Returns:
|
||||
the text/chat completion
|
||||
"""
|
||||
# handle batch inference with vanilla for loop
|
||||
if isinstance(inputs[0], list):
|
||||
resps = []
|
||||
for input in inputs:
|
||||
resps.append(self.chat(input, do_sample, **kwargs))
|
||||
return resps
|
||||
prompt = self.template_parser(inputs)
|
||||
query = prompt[-1]['content']
|
||||
history = prompt[:-1]
|
||||
try:
|
||||
response, history = self.model.chat(self.tokenizer,
|
||||
query,
|
||||
history=history)
|
||||
except Exception as e:
|
||||
# handle over-length input error
|
||||
logger.warning(str(e))
|
||||
response = ""
|
||||
return response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user