mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Add experimental cache
This commit is contained in:
@@ -11,6 +11,15 @@ from . import llama_cpp
|
||||
from .llama_types import *
|
||||
|
||||
|
||||
class LlamaCache:
|
||||
"""Cache for a llama.cpp model.
|
||||
|
||||
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
|
||||
completion. It does not actually cache the results."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Llama:
|
||||
"""High-level Python wrapper for a llama.cpp model."""
|
||||
|
||||
@@ -82,6 +91,14 @@ class Llama:
|
||||
self.n_past = 0
|
||||
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
||||
|
||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||
### saving and restoring state, this allows us to continue a completion if the last
|
||||
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
|
||||
### because it does not take into account stop tokens which have been processed by the model.
|
||||
self._completion_bytes: List[bytes] = []
|
||||
self._cache: Optional[LlamaCache] = None
|
||||
###
|
||||
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
@@ -135,6 +152,14 @@ class Llama:
|
||||
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
||||
return output
|
||||
|
||||
def set_cache(self, cache: Optional[LlamaCache]):
|
||||
"""Set the cache.
|
||||
|
||||
Args:
|
||||
cache: The cache to set.
|
||||
"""
|
||||
self._cache = cache
|
||||
|
||||
def reset(self):
|
||||
"""Reset the model state."""
|
||||
self.last_n_tokens_data.extend(
|
||||
@@ -245,6 +270,17 @@ class Llama:
|
||||
The generated tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
### HACK
|
||||
if (
|
||||
reset
|
||||
and self._cache
|
||||
and len(self.tokens) > 0
|
||||
and self.tokens == tokens[: len(self.tokens)]
|
||||
):
|
||||
if self.verbose:
|
||||
print("generate cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
###
|
||||
if reset:
|
||||
self.reset()
|
||||
while True:
|
||||
@@ -361,6 +397,21 @@ class Llama:
|
||||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
||||
### HACK
|
||||
reset: bool = True
|
||||
_prompt: bytes = prompt.encode("utf-8")
|
||||
_completion: bytes = b"".join(self._completion_bytes)
|
||||
if len(_completion) and self._cache and _prompt.startswith(_completion):
|
||||
if self.verbose:
|
||||
print("completion cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
_prompt = _prompt[len(_completion) :]
|
||||
prompt_tokens = self.tokenize(b" " + _prompt)
|
||||
self._completion_bytes.append(_prompt)
|
||||
else:
|
||||
self._completion_bytes = [prompt.encode("utf-8")]
|
||||
###
|
||||
|
||||
finish_reason = "length"
|
||||
for token in self.generate(
|
||||
prompt_tokens,
|
||||
@@ -368,6 +419,7 @@ class Llama:
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
reset=reset,
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
text = self.detokenize(completion_tokens)
|
||||
@@ -397,6 +449,9 @@ class Llama:
|
||||
break
|
||||
text = all_text[: len(all_text) - longest]
|
||||
returned_characters += len(text[start:])
|
||||
### HACK
|
||||
self._completion_bytes.append(text[start:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
@@ -418,6 +473,9 @@ class Llama:
|
||||
break
|
||||
|
||||
if stream:
|
||||
### HACK
|
||||
self._completion_bytes.append(text[returned_characters:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
@@ -434,13 +492,16 @@ class Llama:
|
||||
}
|
||||
return
|
||||
|
||||
text = text.decode("utf-8")
|
||||
### HACK
|
||||
self._completion_bytes.append(text)
|
||||
###
|
||||
text_str = text.decode("utf-8")
|
||||
|
||||
if echo:
|
||||
text = prompt + text
|
||||
text_str = prompt + text_str
|
||||
|
||||
if suffix is not None:
|
||||
text = text + suffix
|
||||
text_str = text_str + suffix
|
||||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
@@ -493,7 +554,7 @@ class Llama:
|
||||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": text,
|
||||
"text": text_str,
|
||||
"index": 0,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": finish_reason,
|
||||
|
||||
Reference in New Issue
Block a user