diff --git a/examples/low_level_api_llama_cpp.py b/examples/low_level_api_llama_cpp.py index 4a888c3..2a639aa 100644 --- a/examples/low_level_api_llama_cpp.py +++ b/examples/low_level_api_llama_cpp.py @@ -35,7 +35,7 @@ remaining_tokens = n_predict embd = [] last_n_size = 64 -last_n_tokens = [0] * last_n_size +last_n_tokens_data = [0] * last_n_size n_batch = 24 while remaining_tokens > 0: @@ -49,21 +49,21 @@ while remaining_tokens > 0: if len(embd_inp) <= input_consumed: id = llama_cpp.llama_sample_top_p_top_k( ctx, - (llama_cpp.c_int * len(last_n_tokens))(*last_n_tokens), - len(last_n_tokens), + (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data), + len(last_n_tokens_data), 40, 0.8, 0.2, 1.0 / 0.85, ) - last_n_tokens = last_n_tokens[1:] + [id] + last_n_tokens_data = last_n_tokens_data[1:] + [id] embd.append(id) input_noecho = False remaining_tokens -= 1 else: while len(embd_inp) > input_consumed: embd.append(embd_inp[input_consumed]) - last_n_tokens = last_n_tokens[1:] + [embd_inp[input_consumed]] + last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]] input_consumed += 1 if len(embd) >= n_batch: break