From cbe95bbb75ba72cbb39308ee645d3bf1e5507a86 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 24 Apr 2023 19:54:41 -0400 Subject: [PATCH] Add cache implementation using llama state --- llama_cpp/llama.py | 64 +++++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c2d9d10..0a69b2c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -12,12 +12,22 @@ from .llama_types import * class LlamaCache: - """Cache for a llama.cpp model. + """Cache for a llama.cpp model.""" - NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last - completion. It does not actually cache the results.""" + def __init__(self): + self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict() - pass + def __getitem__( + self, key: Sequence[llama_cpp.llama_token] + ) -> Optional["LlamaState"]: + return self.cache_state.get(tuple(key), None) + + def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool: + return tuple(key) in self.cache_state + + def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"): + self.cache_state = dict() # NOTE: Currently limit to one cache entry. + self.cache_state[tuple(key)] = value class LlamaState: @@ -100,13 +110,7 @@ class Llama: self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx) - ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support - ### saving and restoring state, this allows us to continue a completion if the last - ### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect - ### because it does not take into account stop tokens which have been processed by the model. - self._completion_bytes: List[bytes] = [] - self._cache: Optional[LlamaCache] = None - ### + self.cache: Optional[LlamaCache] = None self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) @@ -182,7 +186,7 @@ class Llama: Args: cache: The cache to set. """ - self._cache = cache + self.cache = cache def reset(self): """Reset the model state.""" @@ -287,10 +291,9 @@ class Llama: The generated tokens. """ assert self.ctx is not None - ### HACK + if ( reset - and self._cache and len(self.eval_tokens) > 0 and self.eval_tokens == tokens[: len(self.eval_tokens)] ): @@ -298,7 +301,7 @@ class Llama: print("generate cache hit", file=sys.stderr) reset = False tokens = tokens[len(self.eval_tokens) :] - ### + if reset: self.reset() while True: @@ -415,20 +418,10 @@ class Llama: "logprobs is not supported for models created with logits_all=False" ) - ### HACK - reset: bool = True - _prompt: bytes = prompt.encode("utf-8") - _completion: bytes = b"".join(self._completion_bytes) - if len(_completion) and self._cache and _prompt.startswith(_completion): + if self.cache and prompt_tokens in self.cache: if self.verbose: - print("completion cache hit", file=sys.stderr) - reset = False - _prompt = _prompt[len(_completion) :] - prompt_tokens = self.tokenize(b" " + _prompt) - self._completion_bytes.append(_prompt) - else: - self._completion_bytes = [prompt.encode("utf-8")] - ### + print("cache hit", file=sys.stderr) + self.load_state(self.cache[prompt_tokens]) finish_reason = "length" for token in self.generate( @@ -437,12 +430,16 @@ class Llama: top_p=top_p, temp=temperature, repeat_penalty=repeat_penalty, - reset=reset, ): if token == llama_cpp.llama_token_eos(): text = self.detokenize(completion_tokens) finish_reason = "stop" break + + if self.cache and len(completion_tokens) == 0: + if prompt_tokens not in self.cache: + self.cache[prompt_tokens] = self.save_state() + completion_tokens.append(token) all_text = self.detokenize(completion_tokens) @@ -467,9 +464,6 @@ class Llama: break text = all_text[: len(all_text) - longest] returned_characters += len(text[start:]) - ### HACK - self._completion_bytes.append(text[start:]) - ### yield { "id": completion_id, "object": "text_completion", @@ -491,9 +485,6 @@ class Llama: break if stream: - ### HACK - self._completion_bytes.append(text[returned_characters:]) - ### yield { "id": completion_id, "object": "text_completion", @@ -510,9 +501,6 @@ class Llama: } return - ### HACK - self._completion_bytes.append(text) - ### text_str = text.decode("utf-8") if echo: