mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Bugfix: Stop sequences and missing max_tokens check
This commit is contained in:
@@ -286,6 +286,7 @@ class Llama:
|
|||||||
# Add blank space to start of prompt to match OG llama tokenizer
|
# Add blank space to start of prompt to match OG llama tokenizer
|
||||||
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8"))
|
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8"))
|
||||||
text = b""
|
text = b""
|
||||||
|
returned_characters = 0
|
||||||
|
|
||||||
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
|
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -293,9 +294,9 @@ class Llama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stop != []:
|
if stop != []:
|
||||||
stop_bytes = [s.encode("utf-8") for s in stop]
|
stop_sequences = [s.encode("utf-8") for s in stop]
|
||||||
else:
|
else:
|
||||||
stop_bytes = []
|
stop_sequences = []
|
||||||
|
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
for token in self.generate(
|
for token in self.generate(
|
||||||
@@ -306,28 +307,33 @@ class Llama:
|
|||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
):
|
):
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
completion_tokens.append(token)
|
completion_tokens.append(token)
|
||||||
|
|
||||||
text = self.detokenize(completion_tokens)
|
all_text = self.detokenize(completion_tokens)
|
||||||
any_stop = [s for s in stop_bytes if s in text]
|
any_stop = [s for s in stop_sequences if s in all_text]
|
||||||
if len(any_stop) > 0:
|
if len(any_stop) > 0:
|
||||||
first_stop = any_stop[0]
|
first_stop = any_stop[0]
|
||||||
text = text[: text.index(first_stop)]
|
text = all_text[: all_text.index(first_stop)]
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
start = len(self.detokenize(completion_tokens[:-1]))
|
start = returned_characters
|
||||||
longest = 0
|
longest = 0
|
||||||
# TODO: Clean up this mess
|
# We want to avoid yielding any characters from
|
||||||
for s in stop_bytes:
|
# the generated text if they are part of a stop
|
||||||
|
# sequence.
|
||||||
|
for s in stop_sequences:
|
||||||
for i in range(len(s), 0, -1):
|
for i in range(len(s), 0, -1):
|
||||||
if s[-i:] == text[-i:]:
|
if all_text.endswith(s[:i]):
|
||||||
if i > longest:
|
if i > longest:
|
||||||
longest = i
|
longest = i
|
||||||
break
|
break
|
||||||
|
text = all_text[: len(all_text) - longest]
|
||||||
|
returned_characters += len(text[start:])
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
@@ -335,23 +341,22 @@ class Llama:
|
|||||||
"model": self.model_path,
|
"model": self.model_path,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[start : len(text) - longest].decode("utf-8"),
|
"text": text[start :].decode("utf-8"),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
if len(completion_tokens) >= max_tokens:
|
||||||
|
text = self.detokenize(completion_tokens)
|
||||||
|
finish_reason = "length"
|
||||||
|
break
|
||||||
|
|
||||||
if finish_reason is None:
|
if finish_reason is None:
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
if finish_reason == "stop":
|
|
||||||
start = len(self.detokenize(completion_tokens[:-1]))
|
|
||||||
text = text[start:].decode("utf-8")
|
|
||||||
else:
|
|
||||||
text = ""
|
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
@@ -359,7 +364,7 @@ class Llama:
|
|||||||
"model": self.model_path,
|
"model": self.model_path,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text,
|
"text": text[returned_characters:].decode("utf-8"),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
|
|||||||
Reference in New Issue
Block a user