Fix: gen_config in lmdeploypipeline updated by input gen_params (#151)
This commit is contained in:
@@ -238,14 +238,19 @@ class LMDeployPipeline(BaseModel):
|
||||
Returns:
|
||||
(a list of/batched) text/chat completion
|
||||
"""
|
||||
from lmdeploy.messages import GenerationConfig
|
||||
|
||||
batched = True
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batched = False
|
||||
prompt = inputs
|
||||
gen_params = self.update_gen_params(**kwargs)
|
||||
max_tokens = gen_params.pop('max_tokens')
|
||||
gen_config = GenerationConfig(**gen_params)
|
||||
gen_config.max_new_tokens = max_tokens
|
||||
response = self.model.batch_infer(
|
||||
prompt, do_preprocess=do_preprocess, **gen_params)
|
||||
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
|
||||
response = [resp.text for resp in response]
|
||||
# remove stop_words
|
||||
response = filter_suffix(response, self.gen_params.get('stop_words'))
|
||||
|
||||
Reference in New Issue
Block a user