Add verbose flag. Closes #19

This commit is contained in:
Andrei Betlen
2023-04-04 13:09:24 -04:00
parent 5075c16fcc
commit c137789143

View File

@@ -1,4 +1,5 @@
import os
import sys
import uuid
import time
import multiprocessing
@@ -27,6 +28,7 @@ class Llama:
n_threads: Optional[int] = None,
n_batch: int = 8,
last_n_tokens_size: int = 64,
verbose: bool = True,
):
"""Load a llama.cpp model from `model_path`.
@@ -43,6 +45,7 @@ class Llama:
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
verbose: Print verbose output to stderr.
Raises:
ValueError: If the model path does not exist.
@@ -50,6 +53,7 @@ class Llama:
Returns:
A Llama instance.
"""
self.verbose = verbose
self.model_path = model_path
self.params = llama_cpp.llama_context_default_params()
@@ -79,6 +83,9 @@ class Llama:
self.model_path.encode("utf-8"), self.params
)
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
"""Tokenize a string.
@@ -239,6 +246,10 @@ class Llama:
An embedding object.
"""
assert self.ctx is not None
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
tokens = self.tokenize(input.encode("utf-8"))
self.reset()
self.eval(tokens)
@@ -246,6 +257,10 @@ class Llama:
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
return {
"object": "list",
"data": [
@@ -296,6 +311,9 @@ class Llama:
text = b""
returned_characters = 0
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
raise ValueError(
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
@@ -392,6 +410,9 @@ class Llama:
if logprobs is not None:
raise NotImplementedError("logprobs not implemented")
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
yield {
"id": completion_id,
"object": "text_completion",