diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b38f2bb..bec5be7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -638,7 +638,7 @@ class Llama: for token in all_tokens ] all_logprobs = [ - [Llama.logit_to_logprob(logit) for logit in row] + Llama._logits_to_logprobs(row) for row in self.eval_logits ] for token, token_str, logprobs_token in zip( @@ -980,5 +980,7 @@ class Llama: return llama_cpp.llama_token_bos() @staticmethod - def logit_to_logprob(x: float) -> float: - return math.log(1.0 + math.exp(x)) + def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]: + exps = [math.exp(float(x)) for x in logits] + sum_exps = sum(exps) + return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]