From f0ec6e615ecae3d7b5343d78c9063003dfb71e6a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 May 2023 11:35:59 -0400 Subject: [PATCH] Stream tokens instead of text chunks --- llama_cpp/llama.py | 112 +++++++++++++++++++++++++++++++-------------- 1 file changed, 78 insertions(+), 34 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f47f4a4..bf4caf7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -623,7 +623,7 @@ class Llama: b" " + prompt.encode("utf-8") ) text: bytes = b"" - returned_characters: int = 0 + returned_tokens: int = 0 stop = stop if stop is not None else [] model_name: str = model if model is not None else self.model_path @@ -707,33 +707,42 @@ class Llama: break if stream: - start = returned_characters - longest = 0 # We want to avoid yielding any characters from # the generated text if they are part of a stop # sequence. + longest = 0 for s in stop_sequences: for i in range(len(s), 0, -1): if all_text.endswith(s[:i]): if i > longest: longest = i break - text = all_text[: len(all_text) - longest] - returned_characters += len(text[start:]) - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": text[start:].decode("utf-8", errors="ignore"), - "index": 0, - "logprobs": None, - "finish_reason": None, - } - ], - } + + offset = 0 + remaining_tokens = completion_tokens[returned_tokens:] + remaining_length = len(self.detokenize(remaining_tokens)) + for token in remaining_tokens: + offset += len(self.detokenize([token])) + # Check if stop sequence is not in the token + if offset >= (remaining_length - longest - 1): + break + returned_tokens += 1 + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": self.detokenize([token]).decode( + "utf-8", errors="ignore" + ), + "index": 0, + "logprobs": None, + "finish_reason": None, + } + ], + } if len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens) @@ -749,22 +758,57 @@ class Llama: llama_cpp.llama_print_timings(self.ctx) if stream: - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": text[returned_characters:].decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": None, - "finish_reason": finish_reason, + remaining_tokens = completion_tokens[returned_tokens:] + all_text = self.detokenize(remaining_tokens) + any_stop = [s for s in stop_sequences if s in all_text] + if len(any_stop) > 0: + end = min(all_text.index(stop) for stop in any_stop) + else: + end = len(all_text) + + offset = 0 + for token in remaining_tokens: + offset += len(self.detokenize([token])) + if offset >= end: + last_text = self.detokenize([token]) + if offset == end - 1: + break + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": last_text[ + : len(last_text) - (offset - end) + ].decode("utf-8", errors="ignore"), + "index": 0, + "logprobs": None, + "finish_reason": finish_reason, + } + ], } - ], - } + break + returned_tokens += 1 + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": self.detokenize([token]).decode( + "utf-8", errors="ignore" + ), + "index": 0, + "logprobs": None, + "finish_reason": finish_reason + if returned_tokens == len(completion_tokens) + else None, + } + ], + } return text_str = text.decode("utf-8", errors="ignore")