diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4840caf..e84d457 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,4 +1,5 @@ import os +import sys import uuid import time import multiprocessing @@ -27,6 +28,7 @@ class Llama: n_threads: Optional[int] = None, n_batch: int = 8, last_n_tokens_size: int = 64, + verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -43,6 +45,7 @@ class Llama: n_threads: Number of threads to use. If None, the number of threads is automatically determined. n_batch: Maximum number of prompt tokens to batch together when calling llama_eval. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. + verbose: Print verbose output to stderr. Raises: ValueError: If the model path does not exist. @@ -50,6 +53,7 @@ class Llama: Returns: A Llama instance. """ + self.verbose = verbose self.model_path = model_path self.params = llama_cpp.llama_context_default_params() @@ -79,6 +83,9 @@ class Llama: self.model_path.encode("utf-8"), self.params ) + if self.verbose: + print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) + def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: """Tokenize a string. @@ -239,6 +246,10 @@ class Llama: An embedding object. """ assert self.ctx is not None + + if self.verbose: + llama_cpp.llama_reset_timings(self.ctx) + tokens = self.tokenize(input.encode("utf-8")) self.reset() self.eval(tokens) @@ -246,6 +257,10 @@ class Llama: embedding = llama_cpp.llama_get_embeddings(self.ctx)[ : llama_cpp.llama_n_embd(self.ctx) ] + + if self.verbose: + llama_cpp.llama_print_timings(self.ctx) + return { "object": "list", "data": [ @@ -296,6 +311,9 @@ class Llama: text = b"" returned_characters = 0 + if self.verbose: + llama_cpp.llama_reset_timings(self.ctx) + if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): raise ValueError( f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" @@ -392,6 +410,9 @@ class Llama: if logprobs is not None: raise NotImplementedError("logprobs not implemented") + if self.verbose: + llama_cpp.llama_print_timings(self.ctx) + yield { "id": completion_id, "object": "text_completion",