diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 41e6fd8..7b53112 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -261,7 +261,7 @@ class Llama: ] self.eval_logits.extend(logits) - def _sample_top_p_top_k( + def _sample( self, last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] last_n_tokens_size: llama_cpp.c_int, @@ -269,6 +269,8 @@ class Llama: top_p: llama_cpp.c_float, temp: llama_cpp.c_float, repeat_penalty: llama_cpp.c_float, + frequency_penalty: llama_cpp.c_float, + presence_penalty: llama_cpp.c_float, ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -298,6 +300,14 @@ class Llama: candidates=llama_cpp.ctypes.byref(candidates), # type: ignore penalty=repeat_penalty, ) + llama_cpp.llama_sample_frequency_and_presence_penalties( + ctx=self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), # type: ignore + last_tokens_data=last_n_tokens_data, + last_tokens_size=last_n_tokens_size, + alpha_frequency=frequency_penalty, + alpha_presence=presence_penalty, + ) if float(temp.value) == 0.0: return llama_cpp.llama_sample_token_greedy( ctx=self.ctx, @@ -344,6 +354,8 @@ class Llama: top_p: float, temp: float, repeat_penalty: float, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, ): """Sample a token from the model. @@ -360,7 +372,7 @@ class Llama: last_n_tokens_data = [llama_cpp.llama_token(0)] * max( 0, self.last_n_tokens_size - len(self.eval_tokens) ) + list(self.eval_tokens)[-self.last_n_tokens_size :] - return self._sample_top_p_top_k( + return self._sample( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data ), @@ -369,6 +381,8 @@ class Llama: top_p=llama_cpp.c_float(top_p), temp=llama_cpp.c_float(temp), repeat_penalty=llama_cpp.c_float(repeat_penalty), + frequency_penalty=llama_cpp.c_float(frequency_penalty), + presence_penalty=llama_cpp.c_float(presence_penalty), ) def generate( @@ -378,6 +392,8 @@ class Llama: top_p: float, temp: float, repeat_penalty: float, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, reset: bool = True, ) -> Generator[ llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None @@ -431,6 +447,8 @@ class Llama: top_k=top_k, top_p=top_p, temp=temp, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, ) tokens_or_none = yield token @@ -505,6 +523,8 @@ class Llama: logprobs: Optional[int] = None, echo: bool = False, stop: Optional[List[str]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, @@ -563,6 +583,8 @@ class Llama: top_k=top_k, top_p=top_p, temp=temperature, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, ): if token == llama_cpp.llama_token_eos(): @@ -737,6 +759,8 @@ class Llama: logprobs: Optional[int] = None, echo: bool = False, stop: Optional[List[str]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, @@ -772,6 +796,8 @@ class Llama: logprobs=logprobs, echo=echo, stop=stop, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, top_k=top_k, stream=stream, @@ -792,6 +818,8 @@ class Llama: logprobs: Optional[int] = None, echo: bool = False, stop: Optional[List[str]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, @@ -827,6 +855,8 @@ class Llama: logprobs=logprobs, echo=echo, stop=stop, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, top_k=top_k, stream=stream, @@ -899,6 +929,8 @@ class Llama: stream: bool = False, stop: Optional[List[str]] = [], max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -932,6 +964,8 @@ class Llama: stream=stream, max_tokens=max_tokens, repeat_penalty=repeat_penalty, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index b46914e..c9f2aef 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -214,8 +214,6 @@ def create_completion( exclude={ "model", "n", - "frequency_penalty", - "presence_penalty", "best_of", "logit_bias", "user", @@ -315,8 +313,6 @@ def create_chat_completion( exclude={ "model", "n", - "presence_penalty", - "frequency_penalty", "logit_bias", "user", }