diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index dc0f38b..9ae2a30 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,3 +1,4 @@ +import os import uuid import time import multiprocessing @@ -35,6 +36,9 @@ class Llama: self.tokens = (llama_cpp.llama_token * self.params.n_ctx)() + if not os.path.exists(model_path): + raise ValueError(f"Model path does not exist: {model_path}") + self.ctx = llama_cpp.llama_init_from_file( self.model_path.encode("utf-8"), self.params ) @@ -66,6 +70,8 @@ class Llama: llama_cpp.llama_n_ctx(self.ctx), True, ) + if prompt_tokens < 0: + raise RuntimeError(f"Failed to tokenize prompt: {prompt_tokens}") if prompt_tokens + max_tokens > self.params.n_ctx: raise ValueError( @@ -115,13 +121,15 @@ class Llama: finish_reason = "stop" break - llama_cpp.llama_eval( + rc = llama_cpp.llama_eval( self.ctx, (llama_cpp.llama_token * 1)(self.tokens[prompt_tokens + i]), 1, prompt_tokens + completion_tokens, self.n_threads, ) + if rc != 0: + raise RuntimeError(f"Failed to evaluate next token: {rc}") text = text.decode("utf-8")