Update high-level api

This commit is contained in:
Andrei Betlen
2023-04-01 13:01:27 -04:00
parent 3af274cbd4
commit 318eae237e

View File

@@ -2,10 +2,11 @@ import os
import uuid import uuid
import time import time
import multiprocessing import multiprocessing
from typing import List, Optional from typing import List, Optional, Union, Generator, Sequence
from collections import deque from collections import deque
from . import llama_cpp from . import llama_cpp
from .llama_types import *
class Llama: class Llama:
@@ -14,7 +15,7 @@ class Llama:
def __init__( def __init__(
self, self,
model_path: str, model_path: str,
# NOTE: The following parameters are likely to change in the future. # NOTE: These parameters are likely to change in the future.
n_ctx: int = 512, n_ctx: int = 512,
n_parts: int = -1, n_parts: int = -1,
seed: int = 1337, seed: int = 1337,
@@ -24,7 +25,9 @@ class Llama:
use_mlock: bool = False, use_mlock: bool = False,
embedding: bool = False, embedding: bool = False,
n_threads: Optional[int] = None, n_threads: Optional[int] = None,
) -> "Llama": n_batch: int = 8,
last_n_tokens_size: int = 64,
):
"""Load a llama.cpp model from `model_path`. """Load a llama.cpp model from `model_path`.
Args: Args:
@@ -38,6 +41,8 @@ class Llama:
use_mlock: Force the system to keep the model in RAM. use_mlock: Force the system to keep the model in RAM.
embedding: Embedding mode only. embedding: Embedding mode only.
n_threads: Number of threads to use. If None, the number of threads is automatically determined. n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
Raises: Raises:
ValueError: If the model path does not exist. ValueError: If the model path does not exist.
@@ -57,8 +62,8 @@ class Llama:
self.params.use_mlock = use_mlock self.params.use_mlock = use_mlock
self.params.embedding = embedding self.params.embedding = embedding
self.last_n = 64 self.last_n_tokens_size = last_n_tokens_size
self.max_chunk_size = n_ctx self.n_batch = n_batch
self.n_threads = n_threads or multiprocessing.cpu_count() self.n_threads = n_threads or multiprocessing.cpu_count()
@@ -69,29 +74,33 @@ class Llama:
self.model_path.encode("utf-8"), self.params self.model_path.encode("utf-8"), self.params
) )
def tokenize(self, text: bytes) -> List[int]: def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
"""Tokenize a string. """Tokenize a string.
Args: Args:
text: The utf-8 encoded string to tokenize. text: The utf-8 encoded string to tokenize.
Raises:
RuntimeError: If the tokenization failed.
Returns: Returns:
A list of tokens. A list of tokens.
""" """
assert self.ctx is not None
n_ctx = llama_cpp.llama_n_ctx(self.ctx) n_ctx = llama_cpp.llama_n_ctx(self.ctx)
tokens = (llama_cpp.llama_token * n_ctx)() tokens = (llama_cpp.llama_token * int(n_ctx))()
n_tokens = llama_cpp.llama_tokenize( n_tokens = llama_cpp.llama_tokenize(
self.ctx, self.ctx,
text, text,
tokens, tokens,
n_ctx, n_ctx,
True, llama_cpp.c_bool(True),
) )
if n_tokens < 0: if int(n_tokens) < 0:
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
return list(tokens[:n_tokens]) return list(tokens[:n_tokens])
def detokenize(self, tokens: List[int]) -> bytes: def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
"""Detokenize a list of tokens. """Detokenize a list of tokens.
Args: Args:
@@ -100,62 +109,98 @@ class Llama:
Returns: Returns:
The detokenized string. The detokenized string.
""" """
assert self.ctx is not None
output = b"" output = b""
for token in tokens: for token in tokens:
output += llama_cpp.llama_token_to_str(self.ctx, token) output += llama_cpp.llama_token_to_str(self.ctx, token)
return output return output
def embed(self, text: str): def generate(
self,
tokens: Sequence[llama_cpp.llama_token],
top_k: int,
top_p: float,
temp: float,
repeat_penalty: float,
) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]:
# Temporary workaround for https://github.com/ggerganov/llama.cpp/issues/684
if temp == 0.0:
temp = 1.0
top_p = 0.0
top_k = 1
assert self.ctx is not None
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
n_tokens = 0
last_n_tokens = deque(
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
maxlen=self.last_n_tokens_size,
)
while True:
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), n_tokens)
return_code = llama_cpp.llama_eval(
ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(len(batch)),
n_past=llama_cpp.c_int(n_past),
n_threads=llama_cpp.c_int(self.n_threads),
)
if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}")
last_n_tokens.extend(batch)
n_tokens += len(batch)
token = llama_cpp.llama_sample_top_p_top_k(
ctx=self.ctx,
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens
),
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
top_k=llama_cpp.c_int(top_k),
top_p=llama_cpp.c_float(top_p),
temp=llama_cpp.c_float(temp),
repeat_penalty=llama_cpp.c_float(repeat_penalty),
)
tokens_or_none = yield token
tokens = [token]
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
def create_embedding(self, input: str) -> Embedding:
"""Embed a string. """Embed a string.
Args: Args:
text: The utf-8 encoded string to embed. input: The utf-8 encoded string to embed.
Returns: Returns:
A list of embeddings. An embedding object.
""" """
tokens = self.tokenize(text.encode("utf-8")) assert self.ctx is not None
self._eval(tokens, 0) tokens = self.tokenize(input.encode("utf-8"))
embeddings = llama_cpp.llama_get_embeddings(self.ctx) next(self.generate(tokens, top_k=0, top_p=0.0, temp=1.0, repeat_penalty=1.0))
return embeddings[:llama_cpp.llama_n_embd(self.ctx)] n_tokens = len(tokens)
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": embedding,
"index": 0,
}
],
"model": self.model_path,
"usage": {
"prompt_tokens": n_tokens,
"total_tokens": n_tokens,
},
}
def _eval(self, tokens: List[int], n_past): def _create_completion(
rc = llama_cpp.llama_eval(
self.ctx,
(llama_cpp.llama_token * len(tokens))(*tokens),
len(tokens),
n_past,
self.n_threads,
)
if rc != 0:
raise RuntimeError(f"Failed to evaluate: {rc}")
def _sample(self, last_n_tokens, top_p, top_k, temp, repeat_penalty):
return llama_cpp.llama_sample_top_p_top_k(
self.ctx,
(llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens),
len(last_n_tokens),
top_k=top_k,
top_p=top_p,
temp=temp,
repeat_penalty=repeat_penalty,
)
def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty):
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
last_n_tokens.extend(past_tokens)
for i in range(max_tokens):
token = self._sample(
last_n_tokens,
top_p=top_p,
top_k=top_k,
temp=temp,
repeat_penalty=repeat_penalty,
)
yield token
self._eval([token], len(past_tokens) + i)
def _call(
self, self,
prompt: str, prompt: str,
suffix: Optional[str] = None, suffix: Optional[str] = None,
@@ -168,28 +213,35 @@ class Llama:
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
): ) -> Union[
Generator[Completion, None, None],
Generator[CompletionChunk, None, None],
]:
assert self.ctx is not None
completion_id = f"cmpl-{str(uuid.uuid4())}" completion_id = f"cmpl-{str(uuid.uuid4())}"
created = int(time.time()) created = int(time.time())
completion_tokens = [] completion_tokens: List[llama_cpp.llama_token] = []
prompt_tokens = self.tokenize(prompt.encode("utf-8")) # Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8"))
text = b""
if len(prompt_tokens) + max_tokens > llama_cpp.llama_n_ctx(self.ctx): if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
raise ValueError( raise ValueError(
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
) )
# Process prompt in chunks to avoid running out of memory if stop != []:
for i in range(0, len(prompt_tokens), self.max_chunk_size): stop_bytes = [s.encode("utf-8") for s in stop]
chunk = prompt_tokens[i : min(len(prompt_tokens), i + self.max_chunk_size)] else:
self._eval(chunk, n_past=i) stop_bytes = []
if stop is not None:
stop = [s.encode("utf-8") for s in stop]
finish_reason = None finish_reason = None
for token in self._generate( for token in self.generate(
prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty prompt_tokens,
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
): ):
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
finish_reason = "stop" finish_reason = "stop"
@@ -197,7 +249,7 @@ class Llama:
completion_tokens.append(token) completion_tokens.append(token)
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
any_stop = [s for s in stop if s in text] any_stop = [s for s in stop_bytes if s in text]
if len(any_stop) > 0: if len(any_stop) > 0:
first_stop = any_stop[0] first_stop = any_stop[0]
text = text[: text.index(first_stop)] text = text[: text.index(first_stop)]
@@ -207,7 +259,8 @@ class Llama:
if stream: if stream:
start = len(self.detokenize(completion_tokens[:-1])) start = len(self.detokenize(completion_tokens[:-1]))
longest = 0 longest = 0
for s in stop: # TODO: Clean up this mess
for s in stop_bytes:
for i in range(len(s), 0, -1): for i in range(len(s), 0, -1):
if s[-i:] == text[-i:]: if s[-i:] == text[-i:]:
if i > longest: if i > longest:
@@ -262,9 +315,7 @@ class Llama:
text = text + suffix text = text + suffix
if logprobs is not None: if logprobs is not None:
logprobs = llama_cpp.llama_get_logits( raise NotImplementedError("logprobs not implemented")
self.ctx,
)[:logprobs]
yield { yield {
"id": completion_id, "id": completion_id,
@@ -275,7 +326,7 @@ class Llama:
{ {
"text": text, "text": text,
"index": 0, "index": 0,
"logprobs": logprobs, "logprobs": None,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
], ],
@@ -286,11 +337,66 @@ class Llama:
}, },
} }
def create_completion(
self,
prompt: str,
suffix: Optional[str] = None,
max_tokens: int = 128,
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: List[str] = [],
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
) -> Union[Completion, Generator[CompletionChunk, None, None]]:
"""Generate text from a prompt.
Args:
prompt: The prompt to generate text from.
suffix: A suffix to append to the generated text. If None, no suffix is appended.
max_tokens: The maximum number of tokens to generate.
temperature: The temperature to use for sampling.
top_p: The top-p value to use for sampling.
logprobs: The number of logprobs to return. If None, no logprobs are returned.
echo: Whether to echo the prompt.
stop: A list of strings to stop generation when encountered.
repeat_penalty: The penalty to apply to repeated tokens.
top_k: The top-k value to use for sampling.
stream: Whether to stream the results.
Raises:
ValueError: If the requested tokens exceed the context window.
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
Returns:
Response object containing the generated text.
"""
completion_or_chunks = self._create_completion(
prompt=prompt,
suffix=suffix,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
stop=stop,
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=stream,
)
if stream:
chunks: Generator[CompletionChunk, None, None] = completion_or_chunks
return chunks
completion: Completion = next(completion_or_chunks) # type: ignore
return completion
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
suffix: Optional[str] = None, suffix: Optional[str] = None,
max_tokens: int = 16, max_tokens: int = 128,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
@@ -322,7 +428,7 @@ class Llama:
Returns: Returns:
Response object containing the generated text. Response object containing the generated text.
""" """
call = self._call( return self.create_completion(
prompt=prompt, prompt=prompt,
suffix=suffix, suffix=suffix,
max_tokens=max_tokens, max_tokens=max_tokens,
@@ -335,9 +441,8 @@ class Llama:
top_k=top_k, top_k=top_k,
stream=stream, stream=stream,
) )
if stream:
return call
return next(call)
def __del__(self): def __del__(self):
if self.ctx is not None:
llama_cpp.llama_free(self.ctx) llama_cpp.llama_free(self.ctx)
self.ctx = None