mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Merge branch 'main' of github.com:abetlen/llama_cpp_python into better-server-params-and-fields
This commit is contained in:
@@ -33,12 +33,10 @@ class LlamaCache:
|
||||
return k
|
||||
return None
|
||||
|
||||
def __getitem__(
|
||||
self, key: Sequence[llama_cpp.llama_token]
|
||||
) -> Optional["LlamaState"]:
|
||||
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
|
||||
_key = self._find_key(tuple(key))
|
||||
if _key is None:
|
||||
return None
|
||||
raise KeyError(f"Key not found: {key}")
|
||||
return self.cache_state[_key]
|
||||
|
||||
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
||||
@@ -53,8 +51,8 @@ class LlamaState:
|
||||
def __init__(
|
||||
self,
|
||||
eval_tokens: Deque[llama_cpp.llama_token],
|
||||
eval_logits: Deque[List[llama_cpp.c_float]],
|
||||
llama_state,
|
||||
eval_logits: Deque[List[float]],
|
||||
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
||||
llama_state_size: llama_cpp.c_size_t,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
@@ -129,7 +127,7 @@ class Llama:
|
||||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
|
||||
self.eval_logits: Deque[List[float]] = deque(
|
||||
maxlen=n_ctx if logits_all else 1
|
||||
)
|
||||
|
||||
@@ -247,7 +245,7 @@ class Llama:
|
||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits: List[List[llama_cpp.c_float]] = [
|
||||
logits: List[List[float]] = [
|
||||
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
|
||||
]
|
||||
self.eval_logits.extend(logits)
|
||||
@@ -289,7 +287,7 @@ class Llama:
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
penalty=repeat_penalty,
|
||||
)
|
||||
if temp == 0.0:
|
||||
if float(temp.value) == 0.0:
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
@@ -299,21 +297,25 @@ class Llama:
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
k=top_k,
|
||||
min_keep=llama_cpp.c_size_t(1),
|
||||
)
|
||||
llama_cpp.llama_sample_tail_free(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
z=llama_cpp.c_float(1.0),
|
||||
min_keep=llama_cpp.c_size_t(1),
|
||||
)
|
||||
llama_cpp.llama_sample_typical(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
p=llama_cpp.c_float(1.0),
|
||||
min_keep=llama_cpp.c_size_t(1),
|
||||
)
|
||||
llama_cpp.llama_sample_top_p(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
p=top_p,
|
||||
min_keep=llama_cpp.c_size_t(1),
|
||||
)
|
||||
llama_cpp.llama_sample_temperature(
|
||||
ctx=self.ctx,
|
||||
@@ -390,18 +392,28 @@ class Llama:
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
|
||||
if (
|
||||
reset
|
||||
and len(self.eval_tokens) > 0
|
||||
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
|
||||
):
|
||||
if self.verbose:
|
||||
print("Llama.generate: cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
tokens = tokens[len(self.eval_tokens) :]
|
||||
if reset and len(self.eval_tokens) > 0:
|
||||
longest_prefix = 0
|
||||
for a, b in zip(self.eval_tokens, tokens[:-1]):
|
||||
if a == b:
|
||||
longest_prefix += 1
|
||||
else:
|
||||
break
|
||||
if longest_prefix > 0:
|
||||
if self.verbose:
|
||||
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
||||
reset = False
|
||||
tokens = tokens[longest_prefix:]
|
||||
for _ in range(len(self.eval_tokens) - longest_prefix):
|
||||
self.eval_tokens.pop()
|
||||
try:
|
||||
self.eval_logits.pop()
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
|
||||
while True:
|
||||
self.eval(tokens)
|
||||
token = self.sample(
|
||||
@@ -639,7 +651,10 @@ class Llama:
|
||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
for token in all_tokens
|
||||
]
|
||||
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
|
||||
all_logprobs = [
|
||||
Llama.logits_to_logprobs(list(map(float, row)))
|
||||
for row in self.eval_logits
|
||||
]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
):
|
||||
@@ -958,7 +973,10 @@ class Llama:
|
||||
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
|
||||
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
|
||||
if self.verbose:
|
||||
print(f"Llama.save_state: saving {n_bytes} bytes of llama state", file=sys.stderr)
|
||||
print(
|
||||
f"Llama.save_state: saving {n_bytes} bytes of llama state",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return LlamaState(
|
||||
eval_tokens=self.eval_tokens.copy(),
|
||||
eval_logits=self.eval_logits.copy(),
|
||||
@@ -985,7 +1003,7 @@ class Llama:
|
||||
return llama_cpp.llama_token_bos()
|
||||
|
||||
@staticmethod
|
||||
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
|
||||
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
||||
exps = [math.exp(float(x)) for x in logits]
|
||||
sum_exps = sum(exps)
|
||||
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
|
||||
return [math.log(x / sum_exps) for x in exps]
|
||||
|
||||
@@ -8,6 +8,7 @@ from ctypes import (
|
||||
c_void_p,
|
||||
c_bool,
|
||||
POINTER,
|
||||
_Pointer, # type: ignore
|
||||
Structure,
|
||||
Array,
|
||||
c_uint8,
|
||||
@@ -17,7 +18,7 @@ import pathlib
|
||||
|
||||
|
||||
# Load the library
|
||||
def _load_shared_library(lib_base_name):
|
||||
def _load_shared_library(lib_base_name: str):
|
||||
# Determine the file extension based on the platform
|
||||
if sys.platform.startswith("linux"):
|
||||
lib_ext = ".so"
|
||||
@@ -67,11 +68,11 @@ _lib_base_name = "llama"
|
||||
_lib = _load_shared_library(_lib_base_name)
|
||||
|
||||
# C types
|
||||
LLAMA_FILE_VERSION = ctypes.c_int(1)
|
||||
LLAMA_FILE_VERSION = c_int(1)
|
||||
LLAMA_FILE_MAGIC = b"ggjt"
|
||||
LLAMA_FILE_MAGIC_UNVERSIONED = b"ggml"
|
||||
LLAMA_SESSION_MAGIC = b"ggsn"
|
||||
LLAMA_SESSION_VERSION = ctypes.c_int(1)
|
||||
LLAMA_SESSION_VERSION = c_int(1)
|
||||
|
||||
llama_context_p = c_void_p
|
||||
|
||||
@@ -127,18 +128,23 @@ class llama_context_params(Structure):
|
||||
|
||||
llama_context_params_p = POINTER(llama_context_params)
|
||||
|
||||
LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
|
||||
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(
|
||||
LLAMA_FTYPE_ALL_F32 = c_int(0)
|
||||
LLAMA_FTYPE_MOSTLY_F16 = c_int(1) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(
|
||||
4
|
||||
) # tok_embeddings.weight and output.weight are F16
|
||||
LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes.c_int(5) # except 1d tensors
|
||||
# LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = ctypes.c_int(7) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_2 = c_int(5) # except 1d tensors
|
||||
# LLAMA_FTYPE_MOSTYL_Q4_3 = c_int(6) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9) # except 1d tensors
|
||||
|
||||
# Misc
|
||||
c_float_p = POINTER(c_float)
|
||||
c_uint8_p = POINTER(c_uint8)
|
||||
c_size_t_p = POINTER(c_size_t)
|
||||
|
||||
# Functions
|
||||
|
||||
@@ -210,8 +216,8 @@ _lib.llama_model_quantize.restype = c_int
|
||||
# Returns 0 on success
|
||||
def llama_apply_lora_from_file(
|
||||
ctx: llama_context_p,
|
||||
path_lora: ctypes.c_char_p,
|
||||
path_base_model: ctypes.c_char_p,
|
||||
path_lora: c_char_p,
|
||||
path_base_model: c_char_p,
|
||||
n_threads: c_int,
|
||||
) -> c_int:
|
||||
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
|
||||
@@ -252,21 +258,25 @@ _lib.llama_get_state_size.restype = c_size_t
|
||||
# Copies the state to the specified destination address.
|
||||
# Destination needs to have allocated enough memory.
|
||||
# Returns the number of bytes copied
|
||||
def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
|
||||
def llama_copy_state_data(
|
||||
ctx: llama_context_p, dest # type: Array[c_uint8]
|
||||
) -> c_size_t:
|
||||
return _lib.llama_copy_state_data(ctx, dest)
|
||||
|
||||
|
||||
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
|
||||
_lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p]
|
||||
_lib.llama_copy_state_data.restype = c_size_t
|
||||
|
||||
|
||||
# Set the state reading from the specified address
|
||||
# Returns the number of bytes read
|
||||
def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
|
||||
def llama_set_state_data(
|
||||
ctx: llama_context_p, src # type: Array[c_uint8]
|
||||
) -> c_size_t:
|
||||
return _lib.llama_set_state_data(ctx, src)
|
||||
|
||||
|
||||
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
|
||||
_lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p]
|
||||
_lib.llama_set_state_data.restype = c_size_t
|
||||
|
||||
|
||||
@@ -274,9 +284,9 @@ _lib.llama_set_state_data.restype = c_size_t
|
||||
def llama_load_session_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens_out,
|
||||
tokens_out, # type: Array[llama_token]
|
||||
n_token_capacity: c_size_t,
|
||||
n_token_count_out,
|
||||
n_token_count_out, # type: _Pointer[c_size_t]
|
||||
) -> c_size_t:
|
||||
return _lib.llama_load_session_file(
|
||||
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
|
||||
@@ -288,13 +298,16 @@ _lib.llama_load_session_file.argtypes = [
|
||||
c_char_p,
|
||||
llama_token_p,
|
||||
c_size_t,
|
||||
POINTER(c_size_t),
|
||||
c_size_t_p,
|
||||
]
|
||||
_lib.llama_load_session_file.restype = c_size_t
|
||||
|
||||
|
||||
def llama_save_session_file(
|
||||
ctx: llama_context_p, path_session: bytes, tokens, n_token_count: c_size_t
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens, # type: Array[llama_token]
|
||||
n_token_count: c_size_t,
|
||||
) -> c_size_t:
|
||||
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
|
||||
|
||||
@@ -374,22 +387,22 @@ _lib.llama_n_embd.restype = c_int
|
||||
# Can be mutated in order to change the probabilities of the next token
|
||||
# Rows: n_tokens
|
||||
# Cols: n_vocab
|
||||
def llama_get_logits(ctx: llama_context_p):
|
||||
def llama_get_logits(ctx: llama_context_p): # type: (...) -> Array[float] # type: ignore
|
||||
return _lib.llama_get_logits(ctx)
|
||||
|
||||
|
||||
_lib.llama_get_logits.argtypes = [llama_context_p]
|
||||
_lib.llama_get_logits.restype = POINTER(c_float)
|
||||
_lib.llama_get_logits.restype = c_float_p
|
||||
|
||||
|
||||
# Get the embeddings for the input
|
||||
# shape: [n_embd] (1-dimensional)
|
||||
def llama_get_embeddings(ctx: llama_context_p):
|
||||
def llama_get_embeddings(ctx: llama_context_p): # type: (...) -> Array[float] # type: ignore
|
||||
return _lib.llama_get_embeddings(ctx)
|
||||
|
||||
|
||||
_lib.llama_get_embeddings.argtypes = [llama_context_p]
|
||||
_lib.llama_get_embeddings.restype = POINTER(c_float)
|
||||
_lib.llama_get_embeddings.restype = c_float_p
|
||||
|
||||
|
||||
# Token Id -> String. Uses the vocabulary in the provided context
|
||||
@@ -433,8 +446,8 @@ _lib.llama_token_nl.restype = llama_token
|
||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
def llama_sample_repetition_penalty(
|
||||
ctx: llama_context_p,
|
||||
candidates,
|
||||
last_tokens_data,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
last_tokens_data, # type: Array[llama_token]
|
||||
last_tokens_size: c_int,
|
||||
penalty: c_float,
|
||||
):
|
||||
@@ -456,8 +469,8 @@ _lib.llama_sample_repetition_penalty.restype = None
|
||||
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
def llama_sample_frequency_and_presence_penalties(
|
||||
ctx: llama_context_p,
|
||||
candidates,
|
||||
last_tokens_data,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
last_tokens_data, # type: Array[llama_token]
|
||||
last_tokens_size: c_int,
|
||||
alpha_frequency: c_float,
|
||||
alpha_presence: c_float,
|
||||
@@ -484,7 +497,9 @@ _lib.llama_sample_frequency_and_presence_penalties.restype = None
|
||||
|
||||
|
||||
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
def llama_sample_softmax(ctx: llama_context_p, candidates):
|
||||
def llama_sample_softmax(
|
||||
ctx: llama_context_p, candidates # type: _Pointer[llama_token_data]
|
||||
):
|
||||
return _lib.llama_sample_softmax(ctx, candidates)
|
||||
|
||||
|
||||
@@ -497,7 +512,10 @@ _lib.llama_sample_softmax.restype = None
|
||||
|
||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_k(
|
||||
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
k: c_int,
|
||||
min_keep: c_size_t,
|
||||
):
|
||||
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
|
||||
|
||||
@@ -513,7 +531,10 @@ _lib.llama_sample_top_k.restype = None
|
||||
|
||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_p(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
p: c_float,
|
||||
min_keep: c_size_t,
|
||||
):
|
||||
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
|
||||
|
||||
@@ -529,7 +550,10 @@ _lib.llama_sample_top_p.restype = None
|
||||
|
||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
def llama_sample_tail_free(
|
||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
z: c_float,
|
||||
min_keep: c_size_t,
|
||||
):
|
||||
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
|
||||
|
||||
@@ -545,7 +569,10 @@ _lib.llama_sample_tail_free.restype = None
|
||||
|
||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
def llama_sample_typical(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
p: c_float,
|
||||
min_keep: c_size_t,
|
||||
):
|
||||
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
|
||||
|
||||
@@ -559,7 +586,11 @@ _lib.llama_sample_typical.argtypes = [
|
||||
_lib.llama_sample_typical.restype = None
|
||||
|
||||
|
||||
def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
|
||||
def llama_sample_temperature(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
temp: c_float,
|
||||
):
|
||||
return _lib.llama_sample_temperature(ctx, candidates, temp)
|
||||
|
||||
|
||||
@@ -578,7 +609,12 @@ _lib.llama_sample_temperature.restype = None
|
||||
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
def llama_sample_token_mirostat(
|
||||
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
tau: c_float,
|
||||
eta: c_float,
|
||||
m: c_int,
|
||||
mu, # type: _Pointer[c_float]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
|
||||
|
||||
@@ -589,7 +625,7 @@ _lib.llama_sample_token_mirostat.argtypes = [
|
||||
c_float,
|
||||
c_float,
|
||||
c_int,
|
||||
POINTER(c_float),
|
||||
c_float_p,
|
||||
]
|
||||
_lib.llama_sample_token_mirostat.restype = llama_token
|
||||
|
||||
@@ -600,7 +636,11 @@ _lib.llama_sample_token_mirostat.restype = llama_token
|
||||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
def llama_sample_token_mirostat_v2(
|
||||
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
tau: c_float,
|
||||
eta: c_float,
|
||||
mu, # type: _Pointer[c_float]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
|
||||
|
||||
@@ -610,13 +650,16 @@ _lib.llama_sample_token_mirostat_v2.argtypes = [
|
||||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_float,
|
||||
POINTER(c_float),
|
||||
c_float_p,
|
||||
]
|
||||
_lib.llama_sample_token_mirostat_v2.restype = llama_token
|
||||
|
||||
|
||||
# @details Selects the token with the highest probability.
|
||||
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
|
||||
def llama_sample_token_greedy(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token_greedy(ctx, candidates)
|
||||
|
||||
|
||||
@@ -628,7 +671,10 @@ _lib.llama_sample_token_greedy.restype = llama_token
|
||||
|
||||
|
||||
# @details Randomly selects a token from the candidates based on their probabilities.
|
||||
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
|
||||
def llama_sample_token(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token(ctx, candidates)
|
||||
|
||||
|
||||
|
||||
@@ -22,12 +22,26 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llama_cpp.server.app import create_app
|
||||
from llama_cpp.server.app import create_app, Settings
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
parser = argparse.ArgumentParser()
|
||||
for name, field in Settings.__fields__.items():
|
||||
parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
settings = Settings(**vars(args))
|
||||
app = create_app(settings=settings)
|
||||
|
||||
uvicorn.run(
|
||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||
|
||||
Reference in New Issue
Block a user