mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Add model_alias option to override model_path in completions. Closes #39
This commit is contained in:
@@ -16,6 +16,10 @@ class Settings(BaseSettings):
|
||||
model: str = Field(
|
||||
description="The path to the model to use for generating completions."
|
||||
)
|
||||
model_alias: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The alias of the model to use for generating completions.",
|
||||
)
|
||||
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
||||
n_gpu_layers: int = Field(
|
||||
default=0,
|
||||
@@ -64,6 +68,7 @@ class Settings(BaseSettings):
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
settings: Optional[Settings] = None
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
|
||||
|
||||
@@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None):
|
||||
if settings.cache:
|
||||
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
|
||||
llama.set_cache(cache)
|
||||
|
||||
def set_settings(_settings: Settings):
|
||||
global settings
|
||||
settings = _settings
|
||||
|
||||
set_settings(settings)
|
||||
return app
|
||||
|
||||
|
||||
@@ -112,6 +123,10 @@ def get_llama():
|
||||
yield llama
|
||||
|
||||
|
||||
def get_settings():
|
||||
yield settings
|
||||
|
||||
|
||||
model_field = Field(description="The model to use for generating completions.")
|
||||
|
||||
max_tokens_field = Field(
|
||||
@@ -236,7 +251,6 @@ def create_completion(
|
||||
completion_or_chunks = llama(
|
||||
**request.dict(
|
||||
exclude={
|
||||
"model",
|
||||
"n",
|
||||
"best_of",
|
||||
"logit_bias",
|
||||
@@ -274,7 +288,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||
def create_embedding(
|
||||
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||
):
|
||||
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
||||
return llama.create_embedding(**request.dict(exclude={"user"}))
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
@@ -335,7 +349,6 @@ def create_chat_completion(
|
||||
completion_or_chunks = llama.create_chat_completion(
|
||||
**request.dict(
|
||||
exclude={
|
||||
"model",
|
||||
"n",
|
||||
"logit_bias",
|
||||
"user",
|
||||
@@ -378,13 +391,16 @@ GetModelResponse = create_model_from_typeddict(ModelList)
|
||||
|
||||
@router.get("/v1/models", response_model=GetModelResponse)
|
||||
def get_models(
|
||||
settings: Settings = Depends(get_settings),
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
) -> ModelList:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": llama.model_path,
|
||||
"id": settings.model_alias
|
||||
if settings.model_alias is not None
|
||||
else llama.model_path,
|
||||
"object": "model",
|
||||
"owned_by": "me",
|
||||
"permissions": [],
|
||||
|
||||
Reference in New Issue
Block a user