Fix: gen_config in lmdeploypipeline updated by input gen_params (#151)

This commit is contained in:
liujiangning30
2024-02-05 15:53:08 +08:00
committed by GitHub
parent 6a5447663a
commit 90ef5215b6

View File

@@ -238,14 +238,19 @@ class LMDeployPipeline(BaseModel):
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.messages import GenerationConfig
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_tokens = gen_params.pop('max_tokens')
gen_config = GenerationConfig(**gen_params)
gen_config.max_new_tokens = max_tokens
response = self.model.batch_infer(
prompt, do_preprocess=do_preprocess, **gen_params)
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
# remove stop_words
response = filter_suffix(response, self.gen_params.get('stop_words'))