Merge branch 'Maximilian-Winter/main' into main

This commit is contained in:
Andrei Betlen
2023-05-26 02:56:11 -04:00
7 changed files with 288 additions and 1 deletions

View File

@@ -4,7 +4,17 @@ import uuid
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from typing import (
List,
Optional,
Union,
Generator,
Sequence,
Iterator,
Deque,
Tuple,
Callable,
)
from collections import deque, OrderedDict
from . import llama_cpp
@@ -72,6 +82,24 @@ class LlamaState:
self.llama_state_size = llama_state_size
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
class LogitsProcessorList(List[LogitsProcessor]):
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
for processor in self:
scores = processor(input_ids, scores)
return scores
StoppingCriteria = Callable[[List[int], List[float]], bool]
class StoppingCriteriaList(List[StoppingCriteria]):
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
class Llama:
"""High-level Python wrapper for a llama.cpp model."""
@@ -314,6 +342,7 @@ class Llama:
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
@@ -326,6 +355,10 @@ class Llama:
else last_n_tokens_size
)
logits = self.eval_logits[-1]
if logits_processor is not None:
logits = logits_processor(list(self.eval_tokens), logits)
nl_logit = logits[self._token_nl]
candidates = self._candidates
for i, logit in enumerate(logits):
@@ -434,6 +467,7 @@ class Llama:
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
):
"""Sample a token from the model.
@@ -466,6 +500,7 @@ class Llama:
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl,
logits_processor=logits_processor,
)
def generate(
@@ -482,6 +517,8 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
@@ -539,7 +576,12 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processor=logits_processor,
)
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
):
return
tokens_or_none = yield token
tokens = [token]
if tokens_or_none is not None:
@@ -637,6 +679,7 @@ class Llama:
model: Optional[str] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
completion_tokens: List[int] = []
@@ -1334,6 +1377,11 @@ class Llama:
assert self.ctx is not None
return llama_cpp.llama_n_vocab(self.ctx)
def tokenizer(self) -> "LlamaTokenizer":
"""Return the tokenizer for this model."""
assert self.ctx is not None
return LlamaTokenizer(self)
@staticmethod
def token_eos() -> int:
"""Return the end-of-sequence token."""
@@ -1364,3 +1412,18 @@ class Llama:
else:
break
return longest_prefix
class LlamaTokenizer:
def __init__(self, llama: Llama):
self.llama = llama
def encode(self, text: str) -> List[int]:
return self.llama.tokenize(text.encode("utf-8", errors="ignore"))
def decode(self, tokens: List[int]) -> str:
return self.llama.detokenize(tokens).decode("utf-8", errors="ignore")
@classmethod
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
return cls(Llama(model_path=path, vocab_only=True))