From 26cc4ee029704976db08a5c67ab812200fcf2c9e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 14 Apr 2023 09:59:08 -0400 Subject: [PATCH] Fix signature for stop parameter --- llama_cpp/llama.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index db9a337..ae25137 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -323,7 +323,7 @@ class Llama: top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, - stop: List[str] = [], + stop: Optional[List[str]] = [], repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, @@ -336,6 +336,7 @@ class Llama: prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8")) text = b"" returned_characters = 0 + stop = stop if not None else [] if self.verbose: llama_cpp.llama_reset_timings(self.ctx) @@ -537,7 +538,7 @@ class Llama: top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, - stop: List[str] = [], + stop: Optional[List[str]] = [], repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, @@ -592,7 +593,7 @@ class Llama: top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, - stop: List[str] = [], + stop: Optional[List[str]] = [], repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, @@ -698,7 +699,7 @@ class Llama: top_p: float = 0.95, top_k: int = 40, stream: bool = False, - stop: List[str] = [], + stop: Optional[List[str]] = [], max_tokens: int = 128, repeat_penalty: float = 1.1, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: @@ -717,6 +718,7 @@ class Llama: Returns: Generated chat completion or a stream of chat completion chunks. """ + stop = stop if not None else [] instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions.""" chat_history = "\n".join( f'{message["role"]} {message.get("user", "")}: {message["content"]}'