Merge pull request #118 from SagsMug/main

Fix UnicodeDecodeError permanently
This commit is contained in:
Andrei
2023-04-29 07:19:01 -04:00
committed by GitHub
4 changed files with 59 additions and 10 deletions

View File

@@ -446,6 +446,7 @@ class Llama:
self.load_state(self.cache[prompt_tokens])
finish_reason = "length"
multibyte_fix = 0
for token in self.generate(
prompt_tokens,
top_k=top_k,
@@ -467,6 +468,20 @@ class Llama:
completion_tokens.append(token)
all_text = self.detokenize(completion_tokens)
# Contains multi-byte UTF8
for k,char in enumerate(all_text[-3:]):
k = 3 - k
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if (num > k and pattern & char == pattern):
multibyte_fix = num - k
# Stop incomplete bytes from passing
if (multibyte_fix > 0):
multibyte_fix -= 1
continue
any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0:
first_stop = any_stop[0]
@@ -495,7 +510,7 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text[start:].decode("utf-8"),
"text": text[start:].decode("utf-8", errors="ignore"),
"index": 0,
"logprobs": None,
"finish_reason": None,
@@ -516,7 +531,7 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text[returned_characters:].decode("utf-8"),
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,
@@ -525,7 +540,7 @@ class Llama:
}
return
text_str = text.decode("utf-8")
text_str = text.decode("utf-8", errors="ignore")
if echo:
text_str = prompt + text_str
@@ -543,7 +558,7 @@ class Llama:
all_tokens = prompt_tokens + completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8") for token in all_tokens
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
]
all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row]
@@ -562,7 +577,7 @@ class Llama:
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})