diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index a6da424..51d237b 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1,6 +1,6 @@ import ctypes -from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array +from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t import pathlib from itertools import chain @@ -101,6 +101,36 @@ def llama_model_quantize( _lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int] _lib.llama_model_quantize.restype = c_int +# Returns the KV cache that will contain the context for the +# ongoing prediction with the model. +def llama_get_kv_cache(ctx: llama_context_p): + return _lib.llama_get_kv_cache(ctx) + +_lib.llama_get_kv_cache.argtypes = [llama_context_p] +_lib.llama_get_kv_cache.restype = POINTER(c_uint8) + +# Returns the size of the KV cache +def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t: + return _lib.llama_get_kv_cache_size(ctx) + +_lib.llama_get_kv_cache_size.argtypes = [llama_context_p] +_lib.llama_get_kv_cache_size.restype = c_size_t + +# Returns the number of tokens in the KV cache +def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int: + return _lib.llama_get_kv_cache_token_count(ctx) + +_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p] +_lib.llama_get_kv_cache_token_count.restype = c_int + + +# Sets the KV cache containing the current context for the model +def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int): + return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count) + +_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int] +_lib.llama_set_kv_cache.restype = None + # Run the llama inference to obtain the logits and probabilities for the next token. # tokens + n_tokens is the provided batch of new tokens to process diff --git a/vendor/llama.cpp b/vendor/llama.cpp index d0a7f74..d8d4e86 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit d0a7f742e76bb48c0bd852f0b3bf09ec0b75b200 +Subproject commit d8d4e865cd481b18f10508ffee35db903767ef5c