diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index cd737c5..3ff94a6 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -76,6 +76,7 @@ class Llama: maxlen=self.last_n_tokens_size, ) self.tokens_consumed = 0 + self.tokens: List[llama_cpp.llama_token] = [] self.n_batch = min(n_ctx, n_batch) self.n_tokens = 0 self.n_past = 0 @@ -140,6 +141,7 @@ class Llama: [llama_cpp.llama_token(0)] * self.last_n_tokens_size ) self.tokens_consumed = 0 + self.tokens.clear() self.n_tokens = 0 self.n_past = 0 self.all_logits = [] @@ -165,6 +167,7 @@ class Llama: ) if int(return_code) != 0: raise RuntimeError(f"llama_eval returned {return_code}") + self.tokens.extend(batch) self.last_n_tokens_data.extend(batch) self.tokens_consumed += len(batch) if self.params.logits_all: