From c784d83131f4c695b57fe2c5a4143432f4106cc1 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 24 Mar 2023 14:58:42 -0400 Subject: [PATCH] Update llama.cpp and re-organize low-level api --- llama_cpp/llama_cpp.py | 179 ++++++++++++++++++++++++----------------- vendor/llama.cpp | 2 +- 2 files changed, 107 insertions(+), 74 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 0947187..6ae8aa4 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -19,6 +19,9 @@ lib = ctypes.CDLL(str(libfile)) # C types +llama_context_p = c_void_p + + llama_token = c_int llama_token_p = POINTER(llama_token) @@ -45,98 +48,60 @@ class llama_context_params(Structure): c_bool, ), # the llama_eval() call computes all logits, not just the last one ("vocab_only", c_bool), # only load the vocabulary, no weights + ("use_mlock", c_bool), # force system to keep model in RAM + ("embedding", c_bool), # embedding mode only ] llama_context_params_p = POINTER(llama_context_params) -llama_context_p = c_void_p -# C functions -lib.llama_context_default_params.argtypes = [] -lib.llama_context_default_params.restype = llama_context_params - -lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params] -lib.llama_init_from_file.restype = llama_context_p - -lib.llama_free.argtypes = [llama_context_p] -lib.llama_free.restype = None - -lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int] -lib.llama_model_quantize.restype = c_int - -lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_int] -lib.llama_eval.restype = c_int - -lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool] -lib.llama_tokenize.restype = c_int - -lib.llama_n_vocab.argtypes = [llama_context_p] -lib.llama_n_vocab.restype = c_int - -lib.llama_n_ctx.argtypes = [llama_context_p] -lib.llama_n_ctx.restype = c_int - -lib.llama_get_logits.argtypes = [llama_context_p] -lib.llama_get_logits.restype = POINTER(c_float) - -lib.llama_token_to_str.argtypes = [llama_context_p, llama_token] -lib.llama_token_to_str.restype = c_char_p - -lib.llama_token_bos.argtypes = [] -lib.llama_token_bos.restype = llama_token - -lib.llama_token_eos.argtypes = [] -lib.llama_token_eos.restype = llama_token - -lib.llama_sample_top_p_top_k.argtypes = [ - llama_context_p, - llama_token_p, - c_int, - c_int, - c_double, - c_double, - c_double, -] -lib.llama_sample_top_p_top_k.restype = llama_token - -lib.llama_print_timings.argtypes = [llama_context_p] -lib.llama_print_timings.restype = None - -lib.llama_reset_timings.argtypes = [llama_context_p] -lib.llama_reset_timings.restype = None - -lib.llama_print_system_info.argtypes = [] -lib.llama_print_system_info.restype = c_char_p +# Functions -# Python functions def llama_context_default_params() -> llama_context_params: params = lib.llama_context_default_params() return params +lib.llama_context_default_params.argtypes = [] +lib.llama_context_default_params.restype = llama_context_params + +# Various functions for loading a ggml llama model. +# Allocate (almost) all memory needed for the model. +# Return NULL on failure def llama_init_from_file( path_model: bytes, params: llama_context_params ) -> llama_context_p: - """Various functions for loading a ggml llama model. - Allocate (almost) all memory needed for the model. - Return NULL on failure""" return lib.llama_init_from_file(path_model, params) +lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params] +lib.llama_init_from_file.restype = llama_context_p + +# Frees all allocated memory def llama_free(ctx: llama_context_p): - """Free all allocated memory""" lib.llama_free(ctx) +lib.llama_free.argtypes = [llama_context_p] +lib.llama_free.restype = None + +# TODO: not great API - very likely to change +# Returns 0 on success def llama_model_quantize( fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int ) -> c_int: - """Returns 0 on success""" return lib.llama_model_quantize(fname_inp, fname_out, itype, qk) +lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int] +lib.llama_model_quantize.restype = c_int + +# Run the llama inference to obtain the logits and probabilities for the next token. +# tokens + n_tokens is the provided batch of new tokens to process +# n_past is the number of tokens to use from previous eval calls +# Returns 0 on success def llama_eval( ctx: llama_context_p, tokens: llama_token_p, @@ -144,13 +109,18 @@ def llama_eval( n_past: c_int, n_threads: c_int, ) -> c_int: - """Run the llama inference to obtain the logits and probabilities for the next token. - tokens + n_tokens is the provided batch of new tokens to process - n_past is the number of tokens to use from previous eval calls - Returns 0 on success""" return lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads) +lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_int] +lib.llama_eval.restype = c_int + + +# Convert the provided text into tokens. +# The tokens pointer must be large enough to hold the resulting tokens. +# Returns the number of tokens on success, no more than n_max_tokens +# Returns a negative number on failure - the number of tokens that would have been returned +# TODO: not sure if correct def llama_tokenize( ctx: llama_context_p, text: bytes, @@ -166,36 +136,72 @@ def llama_tokenize( return lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos) +lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int, c_bool] +lib.llama_tokenize.restype = c_int + + def llama_n_vocab(ctx: llama_context_p) -> c_int: return lib.llama_n_vocab(ctx) +lib.llama_n_vocab.argtypes = [llama_context_p] +lib.llama_n_vocab.restype = c_int + + def llama_n_ctx(ctx: llama_context_p) -> c_int: return lib.llama_n_ctx(ctx) +lib.llama_n_ctx.argtypes = [llama_context_p] +lib.llama_n_ctx.restype = c_int + +# Token logits obtained from the last call to llama_eval() +# The logits for the last token are stored in the last row +# Can be mutated in order to change the probabilities of the next token +# Rows: n_tokens +# Cols: n_vocab def llama_get_logits(ctx: llama_context_p): - """Token logits obtained from the last call to llama_eval() - The logits for the last token are stored in the last row - Can be mutated in order to change the probabilities of the next token - Rows: n_tokens - Cols: n_vocab""" return lib.llama_get_logits(ctx) +lib.llama_get_logits.argtypes = [llama_context_p] +lib.llama_get_logits.restype = POINTER(c_float) + +# Get the embeddings for the input +# shape: [n_embd] (1-dimensional) +def llama_get_embeddings(ctx: llama_context_p): + return lib.llama_get_embeddings(ctx) + +lib.llama_get_embeddings.argtypes = [llama_context_p] +lib.llama_get_embeddings.restype = POINTER(c_float) + +# Token Id -> String. Uses the vocabulary in the provided context def llama_token_to_str(ctx: llama_context_p, token: int) -> bytes: - """Token Id -> String. Uses the vocabulary in the provided context""" return lib.llama_token_to_str(ctx, token) +lib.llama_token_to_str.argtypes = [llama_context_p, llama_token] +lib.llama_token_to_str.restype = c_char_p + +# Special tokens + def llama_token_bos() -> llama_token: return lib.llama_token_bos() +lib.llama_token_bos.argtypes = [] +lib.llama_token_bos.restype = llama_token + + def llama_token_eos() -> llama_token: return lib.llama_token_eos() +lib.llama_token_eos.argtypes = [] +lib.llama_token_eos.restype = llama_token + + +# TODO: improve the last_n_tokens interface ? def llama_sample_top_p_top_k( ctx: llama_context_p, last_n_tokens_data: llama_token_p, @@ -210,14 +216,41 @@ def llama_sample_top_p_top_k( ) +lib.llama_sample_top_p_top_k.argtypes = [ + llama_context_p, + llama_token_p, + c_int, + c_int, + c_double, + c_double, + c_double, +] +lib.llama_sample_top_p_top_k.restype = llama_token + + +# Performance information + def llama_print_timings(ctx: llama_context_p): lib.llama_print_timings(ctx) +lib.llama_print_timings.argtypes = [llama_context_p] +lib.llama_print_timings.restype = None + + def llama_reset_timings(ctx: llama_context_p): lib.llama_reset_timings(ctx) +lib.llama_reset_timings.argtypes = [llama_context_p] +lib.llama_reset_timings.restype = None + + +# Print system information def llama_print_system_info() -> bytes: """Print system informaiton""" return lib.llama_print_system_info() + + +lib.llama_print_system_info.argtypes = [] +lib.llama_print_system_info.restype = c_char_p diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4870e45..31572d9 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4870e455b3653f7d7769fa5772b2c90ffad088df +Subproject commit 31572d966531f7d768eb773322016ab78eb6e835