Fix: outputs with same format (#112)

model outputs with the same format
This commit is contained in:
liujiangning30
2024-01-30 16:36:08 +08:00
committed by GitHub
parent 5581fad8ce
commit d4a71f40b5
6 changed files with 203 additions and 173 deletions

View File

@@ -1,11 +1,11 @@
exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/
repos:
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.0.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
@@ -13,7 +13,7 @@ repos:
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@@ -26,7 +26,7 @@ repos:
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.9
rev: 0.7.17
hooks:
- id: mdformat
args: ["--number"]
@@ -35,16 +35,16 @@ repos:
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
rev: v2.2.6
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
rev: v1.7.5
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/asottile/pyupgrade
rev: v3.0.0
rev: v3.15.0
hooks:
- id: pyupgrade
args: ["--py36-plus"]

View File

@@ -5,9 +5,10 @@ import os
import streamlit as st
from lagent.actions import ActionExecutor, ArxivSearch, GoogleScholar, IPythonInterpreter
from lagent.actions import (ActionExecutor, ArxivSearch, GoogleScholar,
IPythonInterpreter)
from lagent.agents.internlm2_agent import (INTERPRETER_CN, META_INS, PLUGIN_CN,
Internlm2Agent, Interlm2Protocol)
Internlm2Agent, Internlm2Protocol)
from lagent.llms.lmdepoly_wrapper import LMDeployClient
from lagent.llms.meta_template import INTERNLM2_META as META
from lagent.schema import AgentStatusCode
@@ -135,7 +136,7 @@ class StreamlitUI:
"""Initialize the chatbot with the given model and plugin actions."""
return Internlm2Agent(
llm=model,
protocol=Interlm2Protocol(
protocol=Internlm2Protocol(
tool=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
@@ -309,7 +310,8 @@ def main():
if isinstance(agent_return.response, dict):
action = f"\n\n {agent_return.response['name']}: \n\n"
action_input = agent_return.response['parameters']
if agent_return.response['name'] == 'IPythonInterpreter':
if agent_return.response[
'name'] == 'IPythonInterpreter':
action_input = action_input['command']
response = action + action_input
else:
@@ -318,7 +320,8 @@ def main():
st.session_state['placeholder'].markdown(
st.session_state['temp'])
elif agent_return.state == AgentStatusCode.END:
st.session_state['session_history'] += (user_input + agent_return.inner_steps)
st.session_state['session_history'] += (
user_input + agent_return.inner_steps)
agent_return = copy.deepcopy(agent_return)
agent_return.response = st.session_state['temp']
st.session_state['assistant'].append(

View File

@@ -1,7 +1,7 @@
import json
import logging
from copy import deepcopy
from typing import Dict, List, Union, Optional
from typing import Dict, List, Optional, Union
from ilagent.schema import AgentReturn, AgentStatusCode
@@ -34,20 +34,20 @@ PLUGIN_CN = (
'同时注意你可以使用的工具,不要随意捏造!')
class Interlm2Protocol:
class Internlm2Protocol:
def __init__(
self,
meta_prompt: str=META_INS,
interpreter_prompt: str=INTERPRETER_CN,
plugin_prompt: str=PLUGIN_CN,
few_shot: Optional[List]=None,
language: Dict=dict(
meta_prompt: str = META_INS,
interpreter_prompt: str = INTERPRETER_CN,
plugin_prompt: str = PLUGIN_CN,
few_shot: Optional[List] = None,
language: Dict = dict(
begin='',
end='',
belong='assistant',
),
tool: Dict=dict(
tool: Dict = dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
@@ -199,7 +199,7 @@ class Internlm2Agent(BaseAgent):
llm: Union[BaseModel, BaseAPIModel],
plugin_executor: ActionExecutor = None,
interpreter_executor: ActionExecutor = None,
protocol=Interlm2Protocol(),
protocol=Internlm2Protocol(),
max_turn: int = 3) -> None:
self.max_turn = max_turn
self._interpreter_executor = interpreter_executor
@@ -237,8 +237,7 @@ class Internlm2Agent(BaseAgent):
try:
action = json.loads(action)
except Exception as e:
logging.info(
msg=f'Invaild action {e}')
logging.info(msg=f'Invaild action {e}')
continue
elif name == 'interpreter':
if self._interpreter_executor:
@@ -265,9 +264,8 @@ class Internlm2Agent(BaseAgent):
dict(role='tool', content=action, name=name))
inner_history.append(
self._protocol.format_response(action_return, name=name))
yield agent_return
agent_return.inner_steps = inner_history[offset:]
yield agent_return
return agent_return
def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
if isinstance(message, str):
@@ -286,8 +284,7 @@ class Internlm2Agent(BaseAgent):
interpreter_executor=self._interpreter_executor,
)
response = ''
for model_state, res, _ in self._llm.stream_chat(
prompt, **kwargs):
for model_state, res, _ in self._llm.stream_chat(prompt, **kwargs):
response = res
if model_state.value < 0:
agent_return.state = model_state
@@ -312,8 +309,7 @@ class Internlm2Agent(BaseAgent):
try:
action = json.loads(action)
except Exception as e:
logging.info(
msg=f'Invaild action {e}')
logging.info(msg=f'Invaild action {e}')
continue
elif name == 'interpreter':
if self._interpreter_executor:

View File

@@ -1,4 +1,3 @@
from abc import abstractclassmethod
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
from warnings import warn
@@ -189,15 +188,13 @@ class BaseModel:
inputs = self.template_parser(inputs)
return self.generate(inputs, **gen_params)
def generate_from_template(
self,
inputs: Union[List[dict], List[List[dict]]],
**gen_params
):
def generate_from_template(self, inputs: Union[List[dict],
List[List[dict]]],
**gen_params):
warn(
"This function will be deprecated after three months and will be replaced."
"Please use `.chat()`",
DeprecationWarning, 2)
'This function will be deprecated after three months'
'and will be replaced. Please use `.chat()`', DeprecationWarning,
2)
return self.chat(inputs, **gen_params)
def stream_chat(self, inputs: List[dict], **gen_params):

View File

@@ -1,13 +1,17 @@
import copy
import warnings
import logging
from typing import Dict, List, Optional, Union
from dataclasses import asdict
import sys
import traceback
import warnings
from typing import Dict, List, Optional
from lagent.schema import AgentStatusCode
from .base_llm import BaseModel
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class HFTransformer(BaseModel):
"""Model wrapper around HuggingFace general models.
@@ -103,147 +107,177 @@ class HFTransformer(BaseModel):
do_sample=True,
**kwargs,
):
import torch
from torch import nn
with torch.no_grad():
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs, padding=True, return_tensors='pt', return_length=True)
input_length = inputs['length']
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
batch_size, input_ids_seq_length = input_ids.shape[
0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
generation_config = self.model.generation_config
generation_config = copy.deepcopy(generation_config)
new_gen_params = self.update_gen_params(**kwargs)
generation_config.update(**new_gen_params)
generation_config.update(**kwargs)
model_kwargs = generation_config.to_dict()
model_kwargs['attention_mask'] = attention_mask
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the maximum length of the generation.',
UserWarning,
try:
import torch
from torch import nn
with torch.no_grad():
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs,
padding=True,
return_tensors='pt',
return_length=True)
input_length = inputs['length']
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
batch_size = input_ids.shape[0]
input_ids_seq_length = input_ids.shape[-1]
generation_config = self.model.generation_config
generation_config = copy.deepcopy(generation_config)
new_gen_params = self.update_gen_params(**kwargs)
generation_config.update(**new_gen_params)
generation_config.update(**kwargs)
model_kwargs = generation_config.to_dict()
model_kwargs['attention_mask'] = attention_mask
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
'Please refer to the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if (has_default_max_length
and generation_config.max_new_tokens is None):
warnings.warn(
"Using `max_length`'s default"
f'({generation_config.max_length})'
'to control the generation length. '
'This behaviour is deprecated and will be removed'
' from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the'
' maximum length of the generation.',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens +
input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
'Both `max_new_tokens`'
f'(={generation_config.max_new_tokens})'
'and `max_length`'
f'(={generation_config.max_length})'
' seem to have been set.`max_new_tokens`'
' will take precedence. Please refer to'
' the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en'
'/main_classes/text_generation)',
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
' increasing `max_new_tokens`.')
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string}'
f' is {input_ids_seq_length},'
' but `max_length` is set to'
f' {generation_config.max_length}.'
'This can lead to unexpected behavior.'
' You should consider increasing `max_new_tokens`.')
# 2. Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria
# 2. Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria
logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = self.model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(batch_size).fill_(1)
scores = None
while True:
model_inputs = self.model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = self.model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
next_token_logits = outputs.logits[:, -1, :]
stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = self.model._get_logits_warper(
generation_config)
# pre-process distribution
next_token_scores = logits_processor(input_ids,
next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
unfinished_sequences = input_ids.new(batch_size).fill_(1)
scores = None
while True:
model_inputs = self.model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = self.model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if do_sample:
next_tokens = torch.multinomial(
probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
next_token_logits = outputs.logits[:, -1, :]
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]],
dim=-1)
model_kwargs = self.model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
# output_token_ids = input_ids.cpu()[:, input_length:].tolist()
output_token_ids = input_ids.cpu().tolist()
for i in range(len(output_token_ids)):
output_token_ids[i] = output_token_ids[i][:][
input_length[i]:]
# Find the first occurrence of an EOS token in the sequence
first_eos_idx = next(
(idx
for idx, token_id in enumerate(output_token_ids[i])
if token_id in eos_token_id), None)
# If an EOS token is found, only the previous part of it is retained
if first_eos_idx is not None:
output_token_ids[i] = output_token_ids[
i][:first_eos_idx]
# pre-process distribution
next_token_scores = logits_processor(
input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids,
next_token_scores)
response = self.tokenizer.batch_decode(output_token_ids)
# print(response)
if not batched:
yield response[0]
else:
yield response
# stop when each sentence is finished, or if we exceed the maximum length
if (unfinished_sequences.max() == 0
or stopping_criteria(input_ids, scores)):
break
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if do_sample:
next_tokens = torch.multinomial(
probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs,
# and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]],
dim=-1)
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
outputs,
model_kwargs,
is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
output_token_ids = input_ids.cpu().tolist()
for i in range(len(output_token_ids)):
output_token_ids[i] = output_token_ids[i][:][
input_length[i]:]
# Find the first occurrence of
# an EOS token in the sequence
first_eos_idx = next((
idx
for idx, token_id in enumerate(output_token_ids[i])
if token_id in eos_token_id), None)
# If an EOS token is found, only the previous
# part of it is retained
if first_eos_idx is not None:
output_token_ids[i] = output_token_ids[
i][:first_eos_idx]
response = self.tokenizer.batch_decode(output_token_ids)
# print(response)
if not batched:
response = response[0]
yield AgentStatusCode.STREAM_ING, response, None
# stop when each sentence is finished,
# or if we exceed the maximum length
if (unfinished_sequences.max() == 0
or stopping_criteria(input_ids, scores)):
break
yield AgentStatusCode.END, response, None
except Exception:
response = ''.join(traceback.format_exception(*sys.exc_info()))
if batched:
response = [response]
yield AgentStatusCode.SERVER_ERR, response, None
class HFTransformerCasualLM(HFTransformer):

View File

@@ -145,7 +145,7 @@ class TritonClient(BaseModel):
elif self.chatbot._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return ''
return AgentStatusCode.SESSION_CLOSED, '', 0
self.chatbot._session.status = 1
self.chatbot._session.request_id = request_id
@@ -171,7 +171,7 @@ class TritonClient(BaseModel):
self.chatbot._session.prompt + self.chatbot._session.response)
yield self.state_map.get(status), res, _
else:
return ''
return self.state_map.get(status), res, _
def _update_gen_params(self, **kwargs):
import mmengine
@@ -195,7 +195,7 @@ class LMDeployPipeline(BaseModel):
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as