Add model_alias option to override model_path in completions. Closes #39

This commit is contained in:
Andrei Betlen
2023-05-16 17:22:00 -04:00
parent 214589e462
commit a3352923c7
2 changed files with 34 additions and 9 deletions

View File

@@ -522,7 +522,7 @@ class Llama:
if tokens_or_none is not None: if tokens_or_none is not None:
tokens.extend(tokens_or_none) tokens.extend(tokens_or_none)
def create_embedding(self, input: str) -> Embedding: def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
"""Embed a string. """Embed a string.
Args: Args:
@@ -532,6 +532,7 @@ class Llama:
An embedding object. An embedding object.
""" """
assert self.ctx is not None assert self.ctx is not None
_model: str = model if model is not None else self.model_path
if self.params.embedding == False: if self.params.embedding == False:
raise RuntimeError( raise RuntimeError(
@@ -561,7 +562,7 @@ class Llama:
"index": 0, "index": 0,
} }
], ],
"model": self.model_path, "model": _model,
"usage": { "usage": {
"prompt_tokens": n_tokens, "prompt_tokens": n_tokens,
"total_tokens": n_tokens, "total_tokens": n_tokens,
@@ -598,6 +599,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@@ -610,6 +612,7 @@ class Llama:
text: bytes = b"" text: bytes = b""
returned_characters: int = 0 returned_characters: int = 0
stop = stop if stop is not None else [] stop = stop if stop is not None else []
_model: str = model if model is not None else self.model_path
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
@@ -708,7 +711,7 @@ class Llama:
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": self.model_path, "model": _model,
"choices": [ "choices": [
{ {
"text": text[start:].decode("utf-8", errors="ignore"), "text": text[start:].decode("utf-8", errors="ignore"),
@@ -737,7 +740,7 @@ class Llama:
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": self.model_path, "model": _model,
"choices": [ "choices": [
{ {
"text": text[returned_characters:].decode( "text": text[returned_characters:].decode(
@@ -807,7 +810,7 @@ class Llama:
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": self.model_path, "model": _model,
"choices": [ "choices": [
{ {
"text": text_str, "text": text_str,
@@ -842,6 +845,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@@ -883,6 +887,7 @@ class Llama:
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -909,6 +914,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@@ -950,6 +956,7 @@ class Llama:
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model,
) )
def _convert_text_completion_to_chat( def _convert_text_completion_to_chat(
@@ -1026,6 +1033,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@@ -1064,6 +1072,7 @@ class Llama:
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

View File

@@ -16,6 +16,10 @@ class Settings(BaseSettings):
model: str = Field( model: str = Field(
description="The path to the model to use for generating completions." description="The path to the model to use for generating completions."
) )
model_alias: Optional[str] = Field(
default=None,
description="The alias of the model to use for generating completions.",
)
n_ctx: int = Field(default=2048, ge=1, description="The context size.") n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_gpu_layers: int = Field( n_gpu_layers: int = Field(
default=0, default=0,
@@ -64,6 +68,7 @@ class Settings(BaseSettings):
router = APIRouter() router = APIRouter()
settings: Optional[Settings] = None
llama: Optional[llama_cpp.Llama] = None llama: Optional[llama_cpp.Llama] = None
@@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None):
if settings.cache: if settings.cache:
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size) cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
llama.set_cache(cache) llama.set_cache(cache)
def set_settings(_settings: Settings):
global settings
settings = _settings
set_settings(settings)
return app return app
@@ -112,6 +123,10 @@ def get_llama():
yield llama yield llama
def get_settings():
yield settings
model_field = Field(description="The model to use for generating completions.") model_field = Field(description="The model to use for generating completions.")
max_tokens_field = Field( max_tokens_field = Field(
@@ -236,7 +251,6 @@ def create_completion(
completion_or_chunks = llama( completion_or_chunks = llama(
**request.dict( **request.dict(
exclude={ exclude={
"model",
"n", "n",
"best_of", "best_of",
"logit_bias", "logit_bias",
@@ -274,7 +288,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
def create_embedding( def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
): ):
return llama.create_embedding(**request.dict(exclude={"model", "user"})) return llama.create_embedding(**request.dict(exclude={"user"}))
class ChatCompletionRequestMessage(BaseModel): class ChatCompletionRequestMessage(BaseModel):
@@ -335,7 +349,6 @@ def create_chat_completion(
completion_or_chunks = llama.create_chat_completion( completion_or_chunks = llama.create_chat_completion(
**request.dict( **request.dict(
exclude={ exclude={
"model",
"n", "n",
"logit_bias", "logit_bias",
"user", "user",
@@ -378,13 +391,16 @@ GetModelResponse = create_model_from_typeddict(ModelList)
@router.get("/v1/models", response_model=GetModelResponse) @router.get("/v1/models", response_model=GetModelResponse)
def get_models( def get_models(
settings: Settings = Depends(get_settings),
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
) -> ModelList: ) -> ModelList:
return { return {
"object": "list", "object": "list",
"data": [ "data": [
{ {
"id": llama.model_path, "id": settings.model_alias
if settings.model_alias is not None
else llama.model_path,
"object": "model", "object": "model",
"owned_by": "me", "owned_by": "me",
"permissions": [], "permissions": [],