mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Add model_alias option to override model_path in completions. Closes #39
This commit is contained in:
@@ -522,7 +522,7 @@ class Llama:
|
|||||||
if tokens_or_none is not None:
|
if tokens_or_none is not None:
|
||||||
tokens.extend(tokens_or_none)
|
tokens.extend(tokens_or_none)
|
||||||
|
|
||||||
def create_embedding(self, input: str) -> Embedding:
|
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
|
||||||
"""Embed a string.
|
"""Embed a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -532,6 +532,7 @@ class Llama:
|
|||||||
An embedding object.
|
An embedding object.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
_model: str = model if model is not None else self.model_path
|
||||||
|
|
||||||
if self.params.embedding == False:
|
if self.params.embedding == False:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -561,7 +562,7 @@ class Llama:
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"model": self.model_path,
|
"model": _model,
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": n_tokens,
|
"prompt_tokens": n_tokens,
|
||||||
"total_tokens": n_tokens,
|
"total_tokens": n_tokens,
|
||||||
@@ -598,6 +599,7 @@ class Llama:
|
|||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
@@ -610,6 +612,7 @@ class Llama:
|
|||||||
text: bytes = b""
|
text: bytes = b""
|
||||||
returned_characters: int = 0
|
returned_characters: int = 0
|
||||||
stop = stop if stop is not None else []
|
stop = stop if stop is not None else []
|
||||||
|
_model: str = model if model is not None else self.model_path
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_reset_timings(self.ctx)
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
@@ -708,7 +711,7 @@ class Llama:
|
|||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": _model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[start:].decode("utf-8", errors="ignore"),
|
"text": text[start:].decode("utf-8", errors="ignore"),
|
||||||
@@ -737,7 +740,7 @@ class Llama:
|
|||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": _model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[returned_characters:].decode(
|
"text": text[returned_characters:].decode(
|
||||||
@@ -807,7 +810,7 @@ class Llama:
|
|||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": _model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text_str,
|
"text": text_str,
|
||||||
@@ -842,6 +845,7 @@ class Llama:
|
|||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
@@ -883,6 +887,7 @@ class Llama:
|
|||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||||
@@ -909,6 +914,7 @@ class Llama:
|
|||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
@@ -950,6 +956,7 @@ class Llama:
|
|||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_text_completion_to_chat(
|
def _convert_text_completion_to_chat(
|
||||||
@@ -1026,6 +1033,7 @@ class Llama:
|
|||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||||
"""Generate a chat completion from a list of messages.
|
"""Generate a chat completion from a list of messages.
|
||||||
|
|
||||||
@@ -1064,6 +1072,7 @@ class Llama:
|
|||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ class Settings(BaseSettings):
|
|||||||
model: str = Field(
|
model: str = Field(
|
||||||
description="The path to the model to use for generating completions."
|
description="The path to the model to use for generating completions."
|
||||||
)
|
)
|
||||||
|
model_alias: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The alias of the model to use for generating completions.",
|
||||||
|
)
|
||||||
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
||||||
n_gpu_layers: int = Field(
|
n_gpu_layers: int = Field(
|
||||||
default=0,
|
default=0,
|
||||||
@@ -64,6 +68,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
settings: Optional[Settings] = None
|
||||||
llama: Optional[llama_cpp.Llama] = None
|
llama: Optional[llama_cpp.Llama] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None):
|
|||||||
if settings.cache:
|
if settings.cache:
|
||||||
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
|
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
|
||||||
llama.set_cache(cache)
|
llama.set_cache(cache)
|
||||||
|
|
||||||
|
def set_settings(_settings: Settings):
|
||||||
|
global settings
|
||||||
|
settings = _settings
|
||||||
|
|
||||||
|
set_settings(settings)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@@ -112,6 +123,10 @@ def get_llama():
|
|||||||
yield llama
|
yield llama
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings():
|
||||||
|
yield settings
|
||||||
|
|
||||||
|
|
||||||
model_field = Field(description="The model to use for generating completions.")
|
model_field = Field(description="The model to use for generating completions.")
|
||||||
|
|
||||||
max_tokens_field = Field(
|
max_tokens_field = Field(
|
||||||
@@ -236,7 +251,6 @@ def create_completion(
|
|||||||
completion_or_chunks = llama(
|
completion_or_chunks = llama(
|
||||||
**request.dict(
|
**request.dict(
|
||||||
exclude={
|
exclude={
|
||||||
"model",
|
|
||||||
"n",
|
"n",
|
||||||
"best_of",
|
"best_of",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
@@ -274,7 +288,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
|||||||
def create_embedding(
|
def create_embedding(
|
||||||
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||||
):
|
):
|
||||||
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
return llama.create_embedding(**request.dict(exclude={"user"}))
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestMessage(BaseModel):
|
class ChatCompletionRequestMessage(BaseModel):
|
||||||
@@ -335,7 +349,6 @@ def create_chat_completion(
|
|||||||
completion_or_chunks = llama.create_chat_completion(
|
completion_or_chunks = llama.create_chat_completion(
|
||||||
**request.dict(
|
**request.dict(
|
||||||
exclude={
|
exclude={
|
||||||
"model",
|
|
||||||
"n",
|
"n",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
"user",
|
"user",
|
||||||
@@ -378,13 +391,16 @@ GetModelResponse = create_model_from_typeddict(ModelList)
|
|||||||
|
|
||||||
@router.get("/v1/models", response_model=GetModelResponse)
|
@router.get("/v1/models", response_model=GetModelResponse)
|
||||||
def get_models(
|
def get_models(
|
||||||
|
settings: Settings = Depends(get_settings),
|
||||||
llama: llama_cpp.Llama = Depends(get_llama),
|
llama: llama_cpp.Llama = Depends(get_llama),
|
||||||
) -> ModelList:
|
) -> ModelList:
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"id": llama.model_path,
|
"id": settings.model_alias
|
||||||
|
if settings.model_alias is not None
|
||||||
|
else llama.model_path,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": "me",
|
"owned_by": "me",
|
||||||
"permissions": [],
|
"permissions": [],
|
||||||
|
|||||||
Reference in New Issue
Block a user