From 98bbd1c6a8ea1f86c010583f6b1ab74996a1c751 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 5 May 2023 14:23:14 -0400 Subject: [PATCH] Fix eval logits type --- llama_cpp/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6cd65a4..a643f51 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -127,7 +127,7 @@ class Llama: self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) - self.eval_logits: Deque[List[llama_cpp.c_float]] = deque( + self.eval_logits: Deque[List[float]] = deque( maxlen=n_ctx if logits_all else 1 ) @@ -245,7 +245,7 @@ class Llama: n_vocab = llama_cpp.llama_n_vocab(self.ctx) cols = int(n_vocab) logits_view = llama_cpp.llama_get_logits(self.ctx) - logits: List[List[llama_cpp.c_float]] = [ + logits: List[List[float]] = [ [logits_view[i * cols + j] for j in range(cols)] for i in range(rows) ] self.eval_logits.extend(logits) @@ -287,7 +287,7 @@ class Llama: candidates=llama_cpp.ctypes.pointer(candidates), penalty=repeat_penalty, ) - if float(temp) == 0.0: + if float(temp.value) == 0.0: return llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.pointer(candidates),