diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 1d5a5f4..5bcfad8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -286,6 +286,7 @@ class Llama: # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8")) text = b"" + returned_characters = 0 if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): raise ValueError( @@ -293,9 +294,9 @@ class Llama: ) if stop != []: - stop_bytes = [s.encode("utf-8") for s in stop] + stop_sequences = [s.encode("utf-8") for s in stop] else: - stop_bytes = [] + stop_sequences = [] finish_reason = None for token in self.generate( @@ -306,28 +307,33 @@ class Llama: repeat_penalty=repeat_penalty, ): if token == llama_cpp.llama_token_eos(): + text = self.detokenize(completion_tokens) finish_reason = "stop" break completion_tokens.append(token) - text = self.detokenize(completion_tokens) - any_stop = [s for s in stop_bytes if s in text] + all_text = self.detokenize(completion_tokens) + any_stop = [s for s in stop_sequences if s in all_text] if len(any_stop) > 0: first_stop = any_stop[0] - text = text[: text.index(first_stop)] + text = all_text[: all_text.index(first_stop)] finish_reason = "stop" break if stream: - start = len(self.detokenize(completion_tokens[:-1])) + start = returned_characters longest = 0 - # TODO: Clean up this mess - for s in stop_bytes: + # We want to avoid yielding any characters from + # the generated text if they are part of a stop + # sequence. + for s in stop_sequences: for i in range(len(s), 0, -1): - if s[-i:] == text[-i:]: + if all_text.endswith(s[:i]): if i > longest: longest = i break + text = all_text[: len(all_text) - longest] + returned_characters += len(text[start:]) yield { "id": completion_id, "object": "text_completion", @@ -335,23 +341,22 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text[start : len(text) - longest].decode("utf-8"), + "text": text[start :].decode("utf-8"), "index": 0, "logprobs": None, "finish_reason": None, } ], } + if len(completion_tokens) >= max_tokens: + text = self.detokenize(completion_tokens) + finish_reason = "length" + break if finish_reason is None: finish_reason = "length" if stream: - if finish_reason == "stop": - start = len(self.detokenize(completion_tokens[:-1])) - text = text[start:].decode("utf-8") - else: - text = "" yield { "id": completion_id, "object": "text_completion", @@ -359,7 +364,7 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text, + "text": text[returned_characters:].decode("utf-8"), "index": 0, "logprobs": None, "finish_reason": finish_reason,