Merge branch 'main' into server-embedding

This commit is contained in:
Andrei Betlen
2023-05-21 21:21:38 -04:00
5 changed files with 239 additions and 68 deletions

View File

@@ -127,7 +127,6 @@ class Llama:
self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx
self.params.n_parts = n_parts
self.params.n_gpu_layers = n_gpu_layers
self.params.seed = seed
self.params.f16_kv = f16_kv
@@ -149,6 +148,10 @@ class Llama:
self.lora_base = lora_base
self.lora_path = lora_path
### DEPRECATED ###
self.n_parts = n_parts
### DEPRECATED ###
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
@@ -173,6 +176,30 @@ class Llama:
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
n_vocab = self.n_vocab()
n_ctx = self.n_ctx()
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=llama_cpp.c_float(0.0),
p=llama_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
self._candidates = candidates
self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos()
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.
@@ -293,8 +320,8 @@ class Llama:
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
n_vocab = self.n_vocab()
n_ctx = self.n_ctx()
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
last_n_tokens_size = (
llama_cpp.c_int(n_ctx)
@@ -302,24 +329,14 @@ class Llama:
else last_n_tokens_size
)
logits = self.eval_logits[-1]
nl_logit = logits[int(Llama.token_nl())]
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=logits[i],
p=llama_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
nl_logit = logits[self._token_nl]
candidates = self._candidates
for i, logit in enumerate(logits):
candidates.data[i].id = llama_cpp.llama_token(i)
candidates.data[i].logit = llama_cpp.c_float(logit)
candidates.data[i].p = llama_cpp.c_float(0.0)
candidates.sorted = llama_cpp.c_bool(False)
candidates.size = llama_cpp.c_size_t(n_vocab)
llama_cpp.llama_sample_repetition_penalty(
ctx=self.ctx,
last_tokens_data=last_n_tokens_data,
@@ -336,7 +353,7 @@ class Llama:
alpha_presence=presence_penalty,
)
if not penalize_nl:
candidates.data[int(Llama.token_nl())].logit = nl_logit
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
@@ -685,7 +702,7 @@ class Llama:
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
):
if token == Llama.token_eos():
if token == self._token_eos:
text = self.detokenize(completion_tokens)
finish_reason = "stop"
break
@@ -1237,7 +1254,6 @@ class Llama:
verbose=self.verbose,
model_path=self.model_path,
n_ctx=self.params.n_ctx,
n_parts=self.params.n_parts,
n_gpu_layers=self.params.n_gpu_layers,
seed=self.params.seed,
f16_kv=self.params.f16_kv,
@@ -1251,6 +1267,9 @@ class Llama:
n_threads=self.n_threads,
lora_base=self.lora_base,
lora_path=self.lora_path,
### DEPRECATED ###
n_parts=self.n_parts,
### DEPRECATED ###
)
def __setstate__(self, state):
@@ -1303,6 +1322,21 @@ class Llama:
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
def n_ctx(self) -> int:
"""Return the context window size."""
assert self.ctx is not None
return llama_cpp.llama_n_ctx(self.ctx)
def n_embd(self) -> int:
"""Return the embedding size."""
assert self.ctx is not None
return llama_cpp.llama_n_embd(self.ctx)
def n_vocab(self) -> int:
"""Return the vocabulary size."""
assert self.ctx is not None
return llama_cpp.llama_n_vocab(self.ctx)
@staticmethod
def token_eos() -> int:
"""Return the end-of-sequence token."""