remove exception catch in hf (#135)

This commit is contained in:
liujiangning30
2024-01-31 21:15:25 +08:00
committed by GitHub
parent 94959c16da
commit baeed6ed88

View File

@@ -1,7 +1,5 @@
import copy
import logging
import sys
import traceback
import warnings
from typing import Dict, List, Optional, Union
@@ -125,177 +123,165 @@ class HFTransformer(BaseModel):
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
try:
import torch
from torch import nn
with torch.no_grad():
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs,
padding=True,
return_tensors='pt',
return_length=True)
input_length = inputs['length']
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
batch_size = input_ids.shape[0]
input_ids_seq_length = input_ids.shape[-1]
generation_config = self.model.generation_config
generation_config = copy.deepcopy(generation_config)
new_gen_params = self.update_gen_params(**kwargs)
generation_config.update(**new_gen_params)
generation_config.update(**kwargs)
model_kwargs = generation_config.to_dict()
model_kwargs['attention_mask'] = attention_mask
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
import torch
from torch import nn
with torch.no_grad():
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs, padding=True, return_tensors='pt', return_length=True)
input_length = inputs['length']
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
batch_size = input_ids.shape[0]
input_ids_seq_length = input_ids.shape[-1]
generation_config = self.model.generation_config
generation_config = copy.deepcopy(generation_config)
new_gen_params = self.update_gen_params(**kwargs)
generation_config.update(**new_gen_params)
generation_config.update(**kwargs)
model_kwargs = generation_config.to_dict()
model_kwargs['attention_mask'] = attention_mask
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if (has_default_max_length
and generation_config.max_new_tokens is None):
warnings.warn(
"Using `max_length`'s default"
f'({generation_config.max_length})'
'to control the generation length. '
'This behaviour is deprecated and will be removed'
' from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the'
' maximum length of the generation.',
UserWarning,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if (has_default_max_length
and generation_config.max_new_tokens is None):
warnings.warn(
"Using `max_length`'s default"
f'({generation_config.max_length})'
'to control the generation length. '
'This behaviour is deprecated and will be removed'
' from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the'
' maximum length of the generation.',
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
'Both `max_new_tokens`'
f'(={generation_config.max_new_tokens})'
'and `max_length`'
f'(={generation_config.max_length})'
' seem to have been set.`max_new_tokens`'
' will take precedence. Please refer to'
' the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en'
'/main_classes/text_generation)',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens +
input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
'Both `max_new_tokens`'
f'(={generation_config.max_new_tokens})'
'and `max_length`'
f'(={generation_config.max_length})'
' seem to have been set.`max_new_tokens`'
' will take precedence. Please refer to'
' the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en'
'/main_classes/text_generation)',
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string}'
f' is {input_ids_seq_length},'
' but `max_length` is set to'
f' {generation_config.max_length}.'
'This can lead to unexpected behavior.'
' You should consider increasing `max_new_tokens`.')
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string}'
f' is {input_ids_seq_length},'
' but `max_length` is set to'
f' {generation_config.max_length}.'
'This can lead to unexpected behavior.'
' You should consider increasing `max_new_tokens`.')
# 2. Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria
# 2. Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria
logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
logits_processor=logits_processor,
logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = self.model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(batch_size).fill_(1)
scores = None
while True:
model_inputs = self.model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = self.model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = self.model._get_logits_warper(
generation_config)
next_token_logits = outputs.logits[:, -1, :]
unfinished_sequences = input_ids.new(batch_size).fill_(1)
scores = None
while True:
model_inputs = self.model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = self.model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
# pre-process distribution
next_token_scores = logits_processor(input_ids,
next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
next_token_logits = outputs.logits[:, -1, :]
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if do_sample:
next_tokens = torch.multinomial(
probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# pre-process distribution
next_token_scores = logits_processor(
input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids,
next_token_scores)
# update generated ids, model inputs,
# and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]],
dim=-1)
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
outputs,
model_kwargs,
is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
output_token_ids = input_ids.cpu().tolist()
for i in range(len(output_token_ids)):
output_token_ids[i] = output_token_ids[i][:][
input_length[i]:]
# Find the first occurrence of
# an EOS token in the sequence
first_eos_idx = next(
(idx
for idx, token_id in enumerate(output_token_ids[i])
if token_id in eos_token_id), None)
# If an EOS token is found, only the previous
# part of it is retained
if first_eos_idx is not None:
output_token_ids[i] = output_token_ids[
i][:first_eos_idx]
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if do_sample:
next_tokens = torch.multinomial(
probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs,
# and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]],
dim=-1)
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
outputs,
model_kwargs,
is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
output_token_ids = input_ids.cpu().tolist()
for i in range(len(output_token_ids)):
output_token_ids[i] = output_token_ids[i][:][
input_length[i]:]
# Find the first occurrence of
# an EOS token in the sequence
first_eos_idx = next((
idx
for idx, token_id in enumerate(output_token_ids[i])
if token_id in eos_token_id), None)
# If an EOS token is found, only the previous
# part of it is retained
if first_eos_idx is not None:
output_token_ids[i] = output_token_ids[
i][:first_eos_idx]
response = self.tokenizer.batch_decode(output_token_ids)
# print(response)
if not batched:
response = response[0]
yield ModelStatusCode.STREAM_ING, response, None
# stop when each sentence is finished,
# or if we exceed the maximum length
if (unfinished_sequences.max() == 0
or stopping_criteria(input_ids, scores)):
break
yield ModelStatusCode.END, response, None
except Exception:
response = ''.join(traceback.format_exception(*sys.exc_info()))
if batched:
response = [response]
yield ModelStatusCode.SERVER_ERR, response, None
response = self.tokenizer.batch_decode(output_token_ids)
# print(response)
if not batched:
response = response[0]
yield ModelStatusCode.STREAM_ING, response, None
# stop when each sentence is finished,
# or if we exceed the maximum length
if (unfinished_sequences.max() == 0
or stopping_criteria(input_ids, scores)):
break
yield ModelStatusCode.END, response, None
def stream_chat(
self,