Fix decode errors permanently

This commit is contained in:
Mug
2023-04-26 14:37:06 +02:00
parent 1b73a15e62
commit c4a8491d42
3 changed files with 13 additions and 10 deletions

View File

@@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr)
print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
| {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr)
# determine the required inference memory per token:
if (self.params.mem_test):
@@ -342,7 +342,7 @@ n_keep = {self.params.n_keep}
# return past text
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
# write input
def input(self, prompt: str):
@@ -356,7 +356,10 @@ n_keep = {self.params.n_keep}
def output(self):
self.remaining_tokens = self.params.n_predict
for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
try:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
except UnicodeDecodeError:
pass
# read user input
def read_input(self):

View File

@@ -70,7 +70,7 @@ while remaining_tokens > 0:
if not input_noecho:
for id in embd:
print(
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"),
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
end="",
flush=True,
)