From 670d3900014730ff80fe1a23284158fb981da543 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 31 Mar 2023 03:20:15 -0400 Subject: [PATCH] Fix ctypes typing issue for Arrays --- llama_cpp/llama_cpp.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 5980430..fdad289 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1,14 +1,6 @@ import ctypes -from ctypes import ( - c_int, - c_float, - c_char_p, - c_void_p, - c_bool, - POINTER, - Structure, -) +from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array import pathlib from itertools import chain @@ -116,7 +108,7 @@ _lib.llama_model_quantize.restype = c_int # Returns 0 on success def llama_eval( ctx: llama_context_p, - tokens: ctypes.Array[llama_token], + tokens, # type: Array[llama_token] n_tokens: c_int, n_past: c_int, n_threads: c_int, @@ -136,7 +128,7 @@ _lib.llama_eval.restype = c_int def llama_tokenize( ctx: llama_context_p, text: bytes, - tokens: ctypes.Array[llama_token], + tokens, # type: Array[llama_token] n_max_tokens: c_int, add_bos: c_bool, ) -> c_int: @@ -176,7 +168,7 @@ _lib.llama_n_embd.restype = c_int # Can be mutated in order to change the probabilities of the next token # Rows: n_tokens # Cols: n_vocab -def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]: +def llama_get_logits(ctx: llama_context_p): return _lib.llama_get_logits(ctx) @@ -186,7 +178,7 @@ _lib.llama_get_logits.restype = POINTER(c_float) # Get the embeddings for the input # shape: [n_embd] (1-dimensional) -def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]: +def llama_get_embeddings(ctx: llama_context_p): return _lib.llama_get_embeddings(ctx) @@ -224,7 +216,7 @@ _lib.llama_token_eos.restype = llama_token # TODO: improve the last_n_tokens interface ? def llama_sample_top_p_top_k( ctx: llama_context_p, - last_n_tokens_data: ctypes.Array[llama_token], + last_n_tokens_data, # type: Array[llama_token] last_n_tokens_size: c_int, top_k: c_int, top_p: c_float,