mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Add cache implementation using llama state
This commit is contained in:
@@ -12,12 +12,22 @@ from .llama_types import *
|
|||||||
|
|
||||||
|
|
||||||
class LlamaCache:
|
class LlamaCache:
|
||||||
"""Cache for a llama.cpp model.
|
"""Cache for a llama.cpp model."""
|
||||||
|
|
||||||
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
|
def __init__(self):
|
||||||
completion. It does not actually cache the results."""
|
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
|
||||||
|
|
||||||
pass
|
def __getitem__(
|
||||||
|
self, key: Sequence[llama_cpp.llama_token]
|
||||||
|
) -> Optional["LlamaState"]:
|
||||||
|
return self.cache_state.get(tuple(key), None)
|
||||||
|
|
||||||
|
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
||||||
|
return tuple(key) in self.cache_state
|
||||||
|
|
||||||
|
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
|
||||||
|
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
|
||||||
|
self.cache_state[tuple(key)] = value
|
||||||
|
|
||||||
|
|
||||||
class LlamaState:
|
class LlamaState:
|
||||||
@@ -100,13 +110,7 @@ class Llama:
|
|||||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
||||||
|
|
||||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
self.cache: Optional[LlamaCache] = None
|
||||||
### saving and restoring state, this allows us to continue a completion if the last
|
|
||||||
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
|
|
||||||
### because it does not take into account stop tokens which have been processed by the model.
|
|
||||||
self._completion_bytes: List[bytes] = []
|
|
||||||
self._cache: Optional[LlamaCache] = None
|
|
||||||
###
|
|
||||||
|
|
||||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||||
|
|
||||||
@@ -182,7 +186,7 @@ class Llama:
|
|||||||
Args:
|
Args:
|
||||||
cache: The cache to set.
|
cache: The cache to set.
|
||||||
"""
|
"""
|
||||||
self._cache = cache
|
self.cache = cache
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the model state."""
|
"""Reset the model state."""
|
||||||
@@ -287,10 +291,9 @@ class Llama:
|
|||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
### HACK
|
|
||||||
if (
|
if (
|
||||||
reset
|
reset
|
||||||
and self._cache
|
|
||||||
and len(self.eval_tokens) > 0
|
and len(self.eval_tokens) > 0
|
||||||
and self.eval_tokens == tokens[: len(self.eval_tokens)]
|
and self.eval_tokens == tokens[: len(self.eval_tokens)]
|
||||||
):
|
):
|
||||||
@@ -298,7 +301,7 @@ class Llama:
|
|||||||
print("generate cache hit", file=sys.stderr)
|
print("generate cache hit", file=sys.stderr)
|
||||||
reset = False
|
reset = False
|
||||||
tokens = tokens[len(self.eval_tokens) :]
|
tokens = tokens[len(self.eval_tokens) :]
|
||||||
###
|
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
while True:
|
while True:
|
||||||
@@ -415,20 +418,10 @@ class Llama:
|
|||||||
"logprobs is not supported for models created with logits_all=False"
|
"logprobs is not supported for models created with logits_all=False"
|
||||||
)
|
)
|
||||||
|
|
||||||
### HACK
|
if self.cache and prompt_tokens in self.cache:
|
||||||
reset: bool = True
|
|
||||||
_prompt: bytes = prompt.encode("utf-8")
|
|
||||||
_completion: bytes = b"".join(self._completion_bytes)
|
|
||||||
if len(_completion) and self._cache and _prompt.startswith(_completion):
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("completion cache hit", file=sys.stderr)
|
print("cache hit", file=sys.stderr)
|
||||||
reset = False
|
self.load_state(self.cache[prompt_tokens])
|
||||||
_prompt = _prompt[len(_completion) :]
|
|
||||||
prompt_tokens = self.tokenize(b" " + _prompt)
|
|
||||||
self._completion_bytes.append(_prompt)
|
|
||||||
else:
|
|
||||||
self._completion_bytes = [prompt.encode("utf-8")]
|
|
||||||
###
|
|
||||||
|
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
for token in self.generate(
|
for token in self.generate(
|
||||||
@@ -437,12 +430,16 @@ class Llama:
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
reset=reset,
|
|
||||||
):
|
):
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.cache and len(completion_tokens) == 0:
|
||||||
|
if prompt_tokens not in self.cache:
|
||||||
|
self.cache[prompt_tokens] = self.save_state()
|
||||||
|
|
||||||
completion_tokens.append(token)
|
completion_tokens.append(token)
|
||||||
|
|
||||||
all_text = self.detokenize(completion_tokens)
|
all_text = self.detokenize(completion_tokens)
|
||||||
@@ -467,9 +464,6 @@ class Llama:
|
|||||||
break
|
break
|
||||||
text = all_text[: len(all_text) - longest]
|
text = all_text[: len(all_text) - longest]
|
||||||
returned_characters += len(text[start:])
|
returned_characters += len(text[start:])
|
||||||
### HACK
|
|
||||||
self._completion_bytes.append(text[start:])
|
|
||||||
###
|
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
@@ -491,9 +485,6 @@ class Llama:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
### HACK
|
|
||||||
self._completion_bytes.append(text[returned_characters:])
|
|
||||||
###
|
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
@@ -510,9 +501,6 @@ class Llama:
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
### HACK
|
|
||||||
self._completion_bytes.append(text)
|
|
||||||
###
|
|
||||||
text_str = text.decode("utf-8")
|
text_str = text.decode("utf-8")
|
||||||
|
|
||||||
if echo:
|
if echo:
|
||||||
|
|||||||
Reference in New Issue
Block a user