mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Add verbose flag. Closes #19
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
@@ -27,6 +28,7 @@ class Llama:
|
|||||||
n_threads: Optional[int] = None,
|
n_threads: Optional[int] = None,
|
||||||
n_batch: int = 8,
|
n_batch: int = 8,
|
||||||
last_n_tokens_size: int = 64,
|
last_n_tokens_size: int = 64,
|
||||||
|
verbose: bool = True,
|
||||||
):
|
):
|
||||||
"""Load a llama.cpp model from `model_path`.
|
"""Load a llama.cpp model from `model_path`.
|
||||||
|
|
||||||
@@ -43,6 +45,7 @@ class Llama:
|
|||||||
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
||||||
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||||
|
verbose: Print verbose output to stderr.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the model path does not exist.
|
ValueError: If the model path does not exist.
|
||||||
@@ -50,6 +53,7 @@ class Llama:
|
|||||||
Returns:
|
Returns:
|
||||||
A Llama instance.
|
A Llama instance.
|
||||||
"""
|
"""
|
||||||
|
self.verbose = verbose
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
|
||||||
self.params = llama_cpp.llama_context_default_params()
|
self.params = llama_cpp.llama_context_default_params()
|
||||||
@@ -79,6 +83,9 @@ class Llama:
|
|||||||
self.model_path.encode("utf-8"), self.params
|
self.model_path.encode("utf-8"), self.params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
|
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
||||||
@@ -239,6 +246,10 @@ class Llama:
|
|||||||
An embedding object.
|
An embedding object.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
|
|
||||||
tokens = self.tokenize(input.encode("utf-8"))
|
tokens = self.tokenize(input.encode("utf-8"))
|
||||||
self.reset()
|
self.reset()
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
@@ -246,6 +257,10 @@ class Llama:
|
|||||||
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
|
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
|
||||||
: llama_cpp.llama_n_embd(self.ctx)
|
: llama_cpp.llama_n_embd(self.ctx)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
llama_cpp.llama_print_timings(self.ctx)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": [
|
"data": [
|
||||||
@@ -296,6 +311,9 @@ class Llama:
|
|||||||
text = b""
|
text = b""
|
||||||
returned_characters = 0
|
returned_characters = 0
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
|
|
||||||
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
|
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||||
@@ -392,6 +410,9 @@ class Llama:
|
|||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
raise NotImplementedError("logprobs not implemented")
|
raise NotImplementedError("logprobs not implemented")
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
llama_cpp.llama_print_timings(self.ctx)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
|||||||
Reference in New Issue
Block a user