diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6c4e153..48fde53 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -522,7 +522,7 @@ class Llama: if tokens_or_none is not None: tokens.extend(tokens_or_none) - def create_embedding(self, input: str) -> Embedding: + def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding: """Embed a string. Args: @@ -532,6 +532,7 @@ class Llama: An embedding object. """ assert self.ctx is not None + _model: str = model if model is not None else self.model_path if self.params.embedding == False: raise RuntimeError( @@ -561,7 +562,7 @@ class Llama: "index": 0, } ], - "model": self.model_path, + "model": _model, "usage": { "prompt_tokens": n_tokens, "total_tokens": n_tokens, @@ -598,6 +599,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + model: Optional[str] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -610,6 +612,7 @@ class Llama: text: bytes = b"" returned_characters: int = 0 stop = stop if stop is not None else [] + _model: str = model if model is not None else self.model_path if self.verbose: llama_cpp.llama_reset_timings(self.ctx) @@ -708,7 +711,7 @@ class Llama: "id": completion_id, "object": "text_completion", "created": created, - "model": self.model_path, + "model": _model, "choices": [ { "text": text[start:].decode("utf-8", errors="ignore"), @@ -737,7 +740,7 @@ class Llama: "id": completion_id, "object": "text_completion", "created": created, - "model": self.model_path, + "model": _model, "choices": [ { "text": text[returned_characters:].decode( @@ -807,7 +810,7 @@ class Llama: "id": completion_id, "object": "text_completion", "created": created, - "model": self.model_path, + "model": _model, "choices": [ { "text": text_str, @@ -842,6 +845,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + model: Optional[str] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -883,6 +887,7 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + model=model, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -909,6 +914,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + model: Optional[str] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -950,6 +956,7 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + model=model, ) def _convert_text_completion_to_chat( @@ -1026,6 +1033,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + model: Optional[str] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -1064,6 +1072,7 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + model=model, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 8a83674..e8f62e8 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -16,6 +16,10 @@ class Settings(BaseSettings): model: str = Field( description="The path to the model to use for generating completions." ) + model_alias: Optional[str] = Field( + default=None, + description="The alias of the model to use for generating completions.", + ) n_ctx: int = Field(default=2048, ge=1, description="The context size.") n_gpu_layers: int = Field( default=0, @@ -64,6 +68,7 @@ class Settings(BaseSettings): router = APIRouter() +settings: Optional[Settings] = None llama: Optional[llama_cpp.Llama] = None @@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None): if settings.cache: cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size) llama.set_cache(cache) + + def set_settings(_settings: Settings): + global settings + settings = _settings + + set_settings(settings) return app @@ -112,6 +123,10 @@ def get_llama(): yield llama +def get_settings(): + yield settings + + model_field = Field(description="The model to use for generating completions.") max_tokens_field = Field( @@ -236,7 +251,6 @@ def create_completion( completion_or_chunks = llama( **request.dict( exclude={ - "model", "n", "best_of", "logit_bias", @@ -274,7 +288,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) def create_embedding( request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) ): - return llama.create_embedding(**request.dict(exclude={"model", "user"})) + return llama.create_embedding(**request.dict(exclude={"user"})) class ChatCompletionRequestMessage(BaseModel): @@ -335,7 +349,6 @@ def create_chat_completion( completion_or_chunks = llama.create_chat_completion( **request.dict( exclude={ - "model", "n", "logit_bias", "user", @@ -378,13 +391,16 @@ GetModelResponse = create_model_from_typeddict(ModelList) @router.get("/v1/models", response_model=GetModelResponse) def get_models( + settings: Settings = Depends(get_settings), llama: llama_cpp.Llama = Depends(get_llama), ) -> ModelList: return { "object": "list", "data": [ { - "id": llama.model_path, + "id": settings.model_alias + if settings.model_alias is not None + else llama.model_path, "object": "model", "owned_by": "me", "permissions": [],