diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4b4fb01..7be51e1 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -268,14 +268,13 @@ class Llama: top_k: llama_cpp.c_int, top_p: llama_cpp.c_float, temp: llama_cpp.c_float, - mirostat_mode: llama_cpp.c_int, - mirostat_tau: llama_cpp.c_float, - mirostat_eta: llama_cpp.c_float, - mirostat_mu: llama_cpp.c_float, - mirostat_m: llama_cpp.c_int, + tfs_z: llama_cpp.c_float, repeat_penalty: llama_cpp.c_float, frequency_penalty: llama_cpp.c_float, presence_penalty: llama_cpp.c_float, + mirostat_mode: llama_cpp.c_int, + mirostat_tau: llama_cpp.c_float, + mirostat_eta: llama_cpp.c_float, ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -305,33 +304,6 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore penalty=repeat_penalty, ) - if mirostat_mode.value == 1: - llama_cpp.llama_sample_temperature( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - temp=temp, - ) - llama_cpp.llama_sample_token_mirostat( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - tau=mirostat_tau, - eta=mirostat_eta, - mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore - m=mirostat_m - ) - elif mirostat_mode.value == 2: - llama_cpp.llama_sample_temperature( - ctx=self.ctx, - candidates=llama_cpp.ctypes.pointer(candidates), - temp=temp, - ) - llama_cpp.llama_sample_token_mirostat_v2( - ctx=self.ctx, - candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - tau=mirostat_tau, - eta=mirostat_eta, - mu=llama_cpp.ctypes.byref(mirostat_mu) # type: ignore - ) llama_cpp.llama_sample_frequency_and_presence_penalties( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore @@ -340,11 +312,41 @@ class Llama: alpha_frequency=frequency_penalty, alpha_presence=presence_penalty, ) - if float(temp.value) == 0.0: + if temp.value == 0.0: return llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) + elif mirostat_mode.value == 1: + mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value) + mirostat_m = llama_cpp.c_int(100) + llama_cpp.llama_sample_temperature( + ctx=self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + temp=temp, + ) + return llama_cpp.llama_sample_token_mirostat( + ctx=self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + tau=mirostat_tau, + eta=mirostat_eta, + mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore + m=mirostat_m, + ) + elif mirostat_mode.value == 2: + mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value) + llama_cpp.llama_sample_temperature( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + temp=temp, + ) + return llama_cpp.llama_sample_token_mirostat_v2( + ctx=self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + tau=mirostat_tau, + eta=mirostat_eta, + mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore + ) else: llama_cpp.llama_sample_top_k( ctx=self.ctx, @@ -355,7 +357,7 @@ class Llama: llama_cpp.llama_sample_tail_free( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - z=llama_cpp.c_float(1.0), + z=tfs_z, min_keep=llama_cpp.c_size_t(1), ) llama_cpp.llama_sample_typical( @@ -382,17 +384,16 @@ class Llama: def sample( self, - top_k: int, - top_p: float, - temp: float, - mirostat_mode: int, - mirostat_tau: float, - mirostat_eta: float, - mirostat_mu: float, - mirostat_m: int, - repeat_penalty: float, + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_eta: float = 0.1, + mirostat_tau: float = 5.0, ): """Sample a token from the model. @@ -417,14 +418,13 @@ class Llama: top_k=llama_cpp.c_int(top_k), top_p=llama_cpp.c_float(top_p), temp=llama_cpp.c_float(temp), - mirostat_mode=llama_cpp.c_int(mirostat_mode), - mirostat_mu=llama_cpp.c_float(mirostat_mu), - mirostat_tau=llama_cpp.c_float(mirostat_tau), - mirostat_eta=llama_cpp.c_float(mirostat_eta), - mirostat_m=llama_cpp.c_int(mirostat_m), + tfs_z=llama_cpp.c_float(tfs_z), repeat_penalty=llama_cpp.c_float(repeat_penalty), frequency_penalty=llama_cpp.c_float(frequency_penalty), presence_penalty=llama_cpp.c_float(presence_penalty), + mirostat_mode=llama_cpp.c_int(mirostat_mode), + mirostat_tau=llama_cpp.c_float(mirostat_tau), + mirostat_eta=llama_cpp.c_float(mirostat_eta), ) def generate( @@ -433,15 +433,13 @@ class Llama: top_k: int, top_p: float, temp: float, - mirostat_mode: int, - mirostat_tau: float, - mirostat_eta: float, - mirostat_mu: float, - mirostat_m: int, repeat_penalty: float, + reset: bool = True, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, - reset: bool = True, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, ) -> Generator[ llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None ]: @@ -494,14 +492,12 @@ class Llama: top_k=top_k, top_p=top_p, temp=temp, + repeat_penalty=repeat_penalty, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, - mirostat_mu=mirostat_mu, - mirostat_m=mirostat_m, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - repeat_penalty=repeat_penalty, ) tokens_or_none = yield token tokens = [token] @@ -571,11 +567,6 @@ class Llama: suffix: Optional[str] = None, max_tokens: int = 16, temperature: float = 0.8, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - mirostat_mu: float = 10, - mirostat_m: int = 100, top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, @@ -585,6 +576,9 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -643,8 +637,6 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, - mirostat_mu=mirostat_mu, - mirostat_m=mirostat_m, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, @@ -817,11 +809,6 @@ class Llama: suffix: Optional[str] = None, max_tokens: int = 128, temperature: float = 0.8, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - mirostat_mu: float = 10, - mirostat_m: int = 100, top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, @@ -831,6 +818,9 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -859,11 +849,6 @@ class Llama: suffix=suffix, max_tokens=max_tokens, temperature=temperature, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - mirostat_mu=mirostat_mu, - mirostat_m=mirostat_m, top_p=top_p, logprobs=logprobs, echo=echo, @@ -873,6 +858,9 @@ class Llama: repeat_penalty=repeat_penalty, top_k=top_k, stream=stream, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -886,11 +874,6 @@ class Llama: suffix: Optional[str] = None, max_tokens: int = 128, temperature: float = 0.8, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - mirostat_mu: float = 10, - mirostat_m: int = 100, top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, @@ -900,6 +883,9 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -928,11 +914,6 @@ class Llama: suffix=suffix, max_tokens=max_tokens, temperature=temperature, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - mirostat_mu=mirostat_mu, - mirostat_m=mirostat_m, top_p=top_p, logprobs=logprobs, echo=echo, @@ -942,6 +923,9 @@ class Llama: repeat_penalty=repeat_penalty, top_k=top_k, stream=stream, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, ) def _convert_text_completion_to_chat( @@ -1014,6 +998,9 @@ class Llama: presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -1048,6 +1035,9 @@ class Llama: repeat_penalty=repeat_penalty, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore