mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Merge branch 'main' of github.com:abetlen/llama-cpp-python into better-server-params-and-fields
This commit is contained in:
@@ -53,12 +53,14 @@ class LlamaState:
|
||||
def __init__(
|
||||
self,
|
||||
eval_tokens: Deque[llama_cpp.llama_token],
|
||||
eval_logits: Deque[List[float]],
|
||||
eval_logits: Deque[List[llama_cpp.c_float]],
|
||||
llama_state,
|
||||
llama_state_size: llama_cpp.c_size_t,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
self.eval_logits = eval_logits
|
||||
self.llama_state = llama_state
|
||||
self.llama_state_size = llama_state_size
|
||||
|
||||
|
||||
class Llama:
|
||||
@@ -394,7 +396,7 @@ class Llama:
|
||||
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
|
||||
):
|
||||
if self.verbose:
|
||||
print("generate cache hit", file=sys.stderr)
|
||||
print("Llama.generate: cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
tokens = tokens[len(self.eval_tokens) :]
|
||||
|
||||
@@ -516,7 +518,7 @@ class Llama:
|
||||
|
||||
if self.cache and prompt_tokens in self.cache:
|
||||
if self.verbose:
|
||||
print("cache hit", file=sys.stderr)
|
||||
print("Llama._create_completion: cache hit", file=sys.stderr)
|
||||
self.load_state(self.cache[prompt_tokens])
|
||||
|
||||
finish_reason = "length"
|
||||
@@ -536,7 +538,7 @@ class Llama:
|
||||
if self.cache and len(completion_tokens) == 0:
|
||||
if prompt_tokens not in self.cache:
|
||||
if self.verbose:
|
||||
print("cache miss", file=sys.stderr)
|
||||
print("Llama._create_completion: cache miss", file=sys.stderr)
|
||||
self.cache[prompt_tokens] = self.save_state()
|
||||
|
||||
completion_tokens.append(token)
|
||||
@@ -950,19 +952,25 @@ class Llama:
|
||||
assert self.ctx is not None
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
||||
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
|
||||
n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
|
||||
if int(n_bytes) > int(state_size):
|
||||
raise RuntimeError("Failed to copy llama state data")
|
||||
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
|
||||
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
|
||||
if self.verbose:
|
||||
print(f"Llama.save_state: saving {n_bytes} bytes of llama state", file=sys.stderr)
|
||||
return LlamaState(
|
||||
eval_tokens=self.eval_tokens.copy(),
|
||||
eval_logits=self.eval_logits.copy(),
|
||||
llama_state=llama_state,
|
||||
llama_state=llama_state_compact,
|
||||
llama_state_size=n_bytes,
|
||||
)
|
||||
|
||||
def load_state(self, state: LlamaState) -> None:
|
||||
assert self.ctx is not None
|
||||
self.eval_tokens = state.eval_tokens.copy()
|
||||
self.eval_logits = state.eval_logits.copy()
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
state_size = state.llama_state_size
|
||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||
raise RuntimeError("Failed to set llama state data")
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ LLAMA_FILE_VERSION = ctypes.c_int(1)
|
||||
LLAMA_FILE_MAGIC = b"ggjt"
|
||||
LLAMA_FILE_MAGIC_UNVERSIONED = b"ggml"
|
||||
LLAMA_SESSION_MAGIC = b"ggsn"
|
||||
LLAMA_SESSION_VERSION = ctypes.c_int(0)
|
||||
LLAMA_SESSION_VERSION = ctypes.c_int(1)
|
||||
|
||||
llama_context_p = c_void_p
|
||||
|
||||
@@ -136,9 +136,9 @@ LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(
|
||||
) # tok_embeddings.weight and output.weight are F16
|
||||
LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes.c_int(5) # except 1d tensors
|
||||
# LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTYL_Q8_0 = ctypes.c_int(7) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTYL_Q5_0 = ctypes.c_int(8) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTYL_Q5_1 = ctypes.c_int(9) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = ctypes.c_int(7) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors
|
||||
|
||||
# Functions
|
||||
|
||||
@@ -239,7 +239,8 @@ _lib.llama_set_rng_seed.argtypes = [llama_context_p, c_int]
|
||||
_lib.llama_set_rng_seed.restype = None
|
||||
|
||||
|
||||
# Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
||||
# Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||
# and kv_cache) - will often be smaller after compacting tokens
|
||||
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
|
||||
return _lib.llama_get_state_size(ctx)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user