From 7a536e86c260872c0551e52df37ba8b45317068e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 12 May 2023 14:28:22 -0400 Subject: [PATCH] Allow model to tokenize strings longer than context length and set add_bos. Closes #92 --- llama_cpp/llama.py | 20 +++++++++++++++++--- llama_cpp/llama_cpp.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 47fa543..4295ba7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -174,7 +174,9 @@ class Llama: if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) - def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: + def tokenize( + self, text: bytes, add_bos: bool = True + ) -> List[llama_cpp.llama_token]: """Tokenize a string. Args: @@ -194,10 +196,22 @@ class Llama: text, tokens, n_ctx, - llama_cpp.c_bool(True), + llama_cpp.c_bool(add_bos), ) if int(n_tokens) < 0: - raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') + n_tokens = abs(n_tokens) + tokens = (llama_cpp.llama_token * int(n_tokens))() + n_tokens = llama_cpp.llama_tokenize( + self.ctx, + text, + tokens, + llama_cpp.c_int(n_tokens), + llama_cpp.c_bool(add_bos), + ) + if n_tokens < 0: + raise RuntimeError( + f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' + ) return list(tokens[:n_tokens]) def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index e60558c..870eced 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -350,7 +350,7 @@ def llama_tokenize( tokens, # type: Array[llama_token] n_max_tokens: c_int, add_bos: c_bool, -) -> c_int: +) -> int: return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)