mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Detect multi-byte responses and wait
This commit is contained in:
@@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr)
|
|||||||
|
|
||||||
print(file=sys.stderr)
|
print(file=sys.stderr)
|
||||||
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
|
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
|
||||||
| {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr)
|
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
|
||||||
|
|
||||||
# determine the required inference memory per token:
|
# determine the required inference memory per token:
|
||||||
if (self.params.mem_test):
|
if (self.params.mem_test):
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ class Llama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr)
|
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
|
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
@@ -446,6 +446,7 @@ class Llama:
|
|||||||
self.load_state(self.cache[prompt_tokens])
|
self.load_state(self.cache[prompt_tokens])
|
||||||
|
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
|
multibyte_fix = 0
|
||||||
for token in self.generate(
|
for token in self.generate(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -458,6 +459,12 @@ class Llama:
|
|||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Contains multi-byte UTF8
|
||||||
|
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||||
|
# Bitwise AND check
|
||||||
|
if (pattern & token == pattern):
|
||||||
|
multibyte_fix = num
|
||||||
|
|
||||||
if self.cache and len(completion_tokens) == 0:
|
if self.cache and len(completion_tokens) == 0:
|
||||||
if prompt_tokens not in self.cache:
|
if prompt_tokens not in self.cache:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
@@ -466,6 +473,11 @@ class Llama:
|
|||||||
|
|
||||||
completion_tokens.append(token)
|
completion_tokens.append(token)
|
||||||
|
|
||||||
|
# Stop incomplete bytes from passing
|
||||||
|
if (multibyte_fix > 0):
|
||||||
|
multibyte_fix -= 1
|
||||||
|
continue
|
||||||
|
|
||||||
all_text = self.detokenize(completion_tokens)
|
all_text = self.detokenize(completion_tokens)
|
||||||
any_stop = [s for s in stop_sequences if s in all_text]
|
any_stop = [s for s in stop_sequences if s in all_text]
|
||||||
if len(any_stop) > 0:
|
if len(any_stop) > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user