Fix: gen_config in lmdeploypipeline updated by input gen_params (#151)
This commit is contained in:
		@@ -238,14 +238,19 @@ class LMDeployPipeline(BaseModel):
 | 
			
		||||
        Returns:
 | 
			
		||||
            (a list of/batched) text/chat completion
 | 
			
		||||
        """
 | 
			
		||||
        from lmdeploy.messages import GenerationConfig
 | 
			
		||||
 | 
			
		||||
        batched = True
 | 
			
		||||
        if isinstance(inputs, str):
 | 
			
		||||
            inputs = [inputs]
 | 
			
		||||
            batched = False
 | 
			
		||||
        prompt = inputs
 | 
			
		||||
        gen_params = self.update_gen_params(**kwargs)
 | 
			
		||||
        max_tokens = gen_params.pop('max_tokens')
 | 
			
		||||
        gen_config = GenerationConfig(**gen_params)
 | 
			
		||||
        gen_config.max_new_tokens = max_tokens
 | 
			
		||||
        response = self.model.batch_infer(
 | 
			
		||||
            prompt, do_preprocess=do_preprocess, **gen_params)
 | 
			
		||||
            prompt, gen_config=gen_config, do_preprocess=do_preprocess)
 | 
			
		||||
        response = [resp.text for resp in response]
 | 
			
		||||
        # remove stop_words
 | 
			
		||||
        response = filter_suffix(response, self.gen_params.get('stop_words'))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user