From b5f3e746275bf231df544c60f30b80f537195af7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 5 May 2023 14:22:55 -0400 Subject: [PATCH] Add return type annotations for embeddings and logits --- llama_cpp/llama_cpp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 3b1ac1e..ccec12c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -387,7 +387,7 @@ _lib.llama_n_embd.restype = c_int # Can be mutated in order to change the probabilities of the next token # Rows: n_tokens # Cols: n_vocab -def llama_get_logits(ctx: llama_context_p): +def llama_get_logits(ctx: llama_context_p): # type: (...) -> Array[float] # type: ignore return _lib.llama_get_logits(ctx) @@ -397,7 +397,7 @@ _lib.llama_get_logits.restype = c_float_p # Get the embeddings for the input # shape: [n_embd] (1-dimensional) -def llama_get_embeddings(ctx: llama_context_p): +def llama_get_embeddings(ctx: llama_context_p): # type: (...) -> Array[float] # type: ignore return _lib.llama_get_embeddings(ctx)